Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ testpaths = [
"tests"
]
asyncio_mode = "auto"
markers = [
"integration: marks tests as integration tests (require AWS credentials, deselect with '-m \"not integration\"')",
]

[tool.coverage.run]
branch = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,20 @@ def events_to_messages(events: list[dict[str, Any]]) -> list[SessionMessage]:
elif "blob" in payload_item:
try:
blob_data = json.loads(payload_item["blob"])
# New format: dict with _type marker
if isinstance(blob_data, dict):
if blob_data.get("_type") == "agent_state":
continue # Skip agent state payloads
if blob_data.get("_type") == "message":
data = blob_data.get("data", [])
if isinstance(data, (tuple, list)) and len(data) == 2:
session_msg = SessionMessage.from_dict(json.loads(data[0]))
filtered = AgentCoreMemoryConverter._filter_empty_text(session_msg.message)
session_msg.message = filtered
if session_msg.message.get("content"):
messages.append(session_msg)
continue
# Legacy format: tuple [json_string, role]
if isinstance(blob_data, (tuple, list)) and len(blob_data) == 2:
try:
session_msg = SessionMessage.from_dict(json.loads(blob_data[0]))
Expand All @@ -91,8 +105,8 @@ def events_to_messages(events: list[dict[str, Any]]) -> list[SessionMessage]:
messages.append(session_msg)
except (json.JSONDecodeError, ValueError):
logger.error("This is not a SessionMessage but just a blob message. Ignoring")
except (json.JSONDecodeError, ValueError):
logger.error("Failed to parse blob content: %s", payload_item)
except Exception as e:
logger.error("Failed to parse blob content: %s, error: %s", payload_item, e)
return list(reversed(messages))

@staticmethod
Expand Down
228 changes: 198 additions & 30 deletions src/bedrock_agentcore/memory/integrations/strands/session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@

import boto3
from botocore.config import Config as BotocoreConfig
from strands.hooks import MessageAddedEvent
from strands.experimental.hooks.multiagent import (
AfterMultiAgentInvocationEvent,
AfterNodeCallEvent,
MultiAgentInitializedEvent,
)
from strands.hooks import AfterInvocationEvent, AgentInitializedEvent, MessageAddedEvent
from strands.hooks.registry import HookRegistry
from strands.session.repository_session_manager import RepositorySessionManager
from strands.session.session_repository import SessionRepository
Expand All @@ -33,6 +38,10 @@
MESSAGE_PREFIX = "message_"
MAX_FETCH_ALL_RESULTS = 10000

# Payload type markers for batched storage
PAYLOAD_TYPE_MESSAGE = "message"
PAYLOAD_TYPE_AGENT_STATE = "agent_state"


class AgentCoreMemorySessionManager(RepositorySessionManager, SessionRepository):
"""AgentCore Memory-based session manager for Bedrock AgentCore Memory integration.
Expand Down Expand Up @@ -103,6 +112,7 @@ def __init__(
self.memory_client = MemoryClient(region_name=region_name)
session = boto_session or boto3.Session(region_name=region_name)
self.has_existing_agent = False
self._last_synced_state_hash: Optional[int] = None

# Override the clients if custom boto session or config is provided
# Add strands-agents to the request user agent
Expand Down Expand Up @@ -157,6 +167,122 @@ def _get_full_agent_id(self, agent_id: str) -> str:
)
return full_agent_id

# region Internal Storage Methods

def _build_agent_state_payload(self, agent: "Agent") -> dict:
"""Create agent state payload for unified storage.

Creates a SessionAgent-compatible payload that can be reconstructed
via SessionAgent.from_dict().

Args:
agent (Agent): The agent whose state to capture.

Returns:
dict: Agent state payload with type markers.
"""
session_agent = SessionAgent.from_agent(agent)
return {
"_type": PAYLOAD_TYPE_AGENT_STATE,
"_agent_id": agent.agent_id,
**session_agent.to_dict(),
}

def _compute_state_hash(self, agent: "Agent") -> int:
"""Compute hash of agent state for change detection.

Excludes timestamps (created_at, updated_at) as they change on every call.

Args:
agent (Agent): The agent whose state to hash.

Returns:
int: Hash of the agent state.
"""
session_agent = SessionAgent.from_agent(agent)
state_dict = session_agent.to_dict()
# Exclude timestamps that change on every call
state_dict.pop("created_at", None)
state_dict.pop("updated_at", None)
return hash(json.dumps(state_dict, sort_keys=True))

def _save_message_with_state(self, message: Message, agent: "Agent") -> None:
"""Save message and agent state in a single batched API call.

Combines message and agent state into one API call instead of two separate calls.

Args:
message (Message): The message to save.
agent (Agent): The agent whose state to sync.
"""
session_message = SessionMessage.from_message(message, 0)
messages = AgentCoreMemoryConverter.message_to_payload(session_message)
if not messages:
return

message_tuple = messages[0]
message_payload = {"_type": PAYLOAD_TYPE_MESSAGE, "data": list(message_tuple)}
agent_state_payload = self._build_agent_state_payload(agent)

original_timestamp = datetime.fromisoformat(session_message.created_at.replace("Z", "+00:00"))
monotonic_timestamp = self._get_monotonic_timestamp(original_timestamp)

try:
event = self.memory_client.gmdp_client.create_event(
memoryId=self.config.memory_id,
actorId=self.config.actor_id,
sessionId=self.session_id,
payload=[
{"blob": json.dumps(message_payload)},
{"blob": json.dumps(agent_state_payload)},
],
eventTimestamp=monotonic_timestamp,
)
logger.debug(
"Saved message and agent state in single call: event=%s, agent=%s",
event.get("event", {}).get("eventId"),
agent.agent_id,
)

session_message = SessionMessage.from_message(message, event.get("event", {}).get("eventId"))
self._latest_agent_message[agent.agent_id] = session_message
self._last_synced_state_hash = self._compute_state_hash(agent)

except Exception as e:
logger.error("Failed to save message with state: %s", e)
raise SessionException(f"Failed to save message with state: {e}") from e

def _sync_agent_state_if_changed(self, agent: "Agent") -> None:
"""Sync agent state to AgentCore Memory only if state changed since last sync.

Args:
agent (Agent): The agent to sync.
"""
current_hash = self._compute_state_hash(agent)
if current_hash == self._last_synced_state_hash:
logger.debug("Agent state unchanged, skipping sync for agent=%s", agent.agent_id)
return

agent_state_payload = self._build_agent_state_payload(agent)

try:
event = self.memory_client.gmdp_client.create_event(
memoryId=self.config.memory_id,
actorId=self.config.actor_id,
sessionId=self.session_id,
payload=[{"blob": json.dumps(agent_state_payload)}],
eventTimestamp=self._get_monotonic_timestamp(),
)
self._last_synced_state_hash = current_hash
logger.debug(
"Synced agent state: event=%s, agent=%s", event.get("event", {}).get("eventId"), agent.agent_id
)
except Exception as e:
logger.error("Failed to sync agent state: %s", e)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

small: This silently swallows the exception, but save_message_with_state (line 243) raises SessionException on failure. Should this also raise to maintain consistent error handling? Silent failures here could lead to undetected data loss.

Suggested fix:

except Exception as e:
   logger.error("Failed to sync agent state: %s", e)
   raise SessionException(f"Failed to sync agent state: {e}") from e

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed - raises SessionException on failure

raise SessionException(f"Failed to sync agent state: {e}") from e

# endregion Internal Storage Methods

# region SessionRepository interface implementation
def create_session(self, session: Session, **kwargs: Any) -> Session:
"""Create a new session in AgentCore Memory.
Expand Down Expand Up @@ -237,6 +363,8 @@ def create_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: A
The agent's existence is inferred from the presence of events/messages in the memory system,
but we validate the session_id matches our config.

Uses unified actorId with type markers for optimized storage.

Args:
session_id (str): The session ID to create the agent in.
session_agent (SessionAgent): The agent to create.
Expand All @@ -248,15 +376,19 @@ def create_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: A
if session_id != self.config.session_id:
raise SessionException(f"Session ID mismatch: expected {self.config.session_id}, got {session_id}")

agent_state_payload = {
"_type": PAYLOAD_TYPE_AGENT_STATE,
"_agent_id": session_agent.agent_id,
**session_agent.to_dict(),
}
event = self.memory_client.gmdp_client.create_event(
memoryId=self.config.memory_id,
actorId=self._get_full_agent_id(session_agent.agent_id),
actorId=self.config.actor_id,
sessionId=self.session_id,
payload=[
{"blob": json.dumps(session_agent.to_dict())},
],
payload=[{"blob": json.dumps(agent_state_payload)}],
eventTimestamp=self._get_monotonic_timestamp(),
)

logger.info(
"Created agent: %s in session: %s with event %s",
session_agent.agent_id,
Expand All @@ -267,7 +399,8 @@ def create_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: A
def read_agent(self, session_id: str, agent_id: str, **kwargs: Any) -> Optional[SessionAgent]:
"""Read agent data from AgentCore Memory events.

We reconstruct the agent state from the conversation history.
Uses dual-read approach: tries new format first (unified actorId with type markers),
falls back to legacy format (separate actorId) for backward compatibility.

Args:
session_id (str): The session ID to read from.
Expand All @@ -279,22 +412,43 @@ def read_agent(self, session_id: str, agent_id: str, **kwargs: Any) -> Optional[
"""
if session_id != self.config.session_id:
return None

try:
events = self.memory_client.list_events(
memory_id=self.config.memory_id,
actor_id=self.config.actor_id,
session_id=session_id,
max_results=MAX_FETCH_ALL_RESULTS,
)

# Events are returned oldest-first, so reverse to get latest state
for event in reversed(events):
for payload_item in event.get("payload", []):
blob = payload_item.get("blob")
if blob:
try:
data = json.loads(blob)
if data.get("_type") == PAYLOAD_TYPE_AGENT_STATE and data.get("_agent_id") == agent_id:
agent_data = {k: v for k, v in data.items() if k not in ("_type", "_agent_id")}
return SessionAgent.from_dict(agent_data)
except json.JSONDecodeError:
continue

# Fallback to legacy format for backward compatibility
legacy_events = self.memory_client.list_events(
memory_id=self.config.memory_id,
actor_id=self._get_full_agent_id(agent_id),
session_id=session_id,
max_results=1,
)
if legacy_events:
agent_data = json.loads(legacy_events[0].get("payload", {})[0].get("blob"))
return SessionAgent.from_dict(agent_data)

if not events:
return None

agent_data = json.loads(events[0].get("payload", {})[0].get("blob"))
return SessionAgent.from_dict(agent_data)
except Exception as e:
logger.error("Failed to read agent %s", e)
return None
logger.error("Failed to read agent: %s", e)

return None

def update_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: Any) -> None:
"""Update agent data.
Expand All @@ -316,7 +470,7 @@ def update_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: A
# Create a new agent as AgentCore Memory is immutable. We always get the latest one in `read_agent`
self.create_agent(session_id, session_agent)

def create_message(
def create_message( # type: ignore[override]
self, session_id: str, agent_id: str, session_message: SessionMessage, **kwargs: Any
) -> Optional[dict[str, Any]]:
"""Create a new message in AgentCore Memory.
Expand Down Expand Up @@ -353,9 +507,8 @@ def create_message(
try:
messages = AgentCoreMemoryConverter.message_to_payload(session_message)
if not messages:
return
return None

# Parse the original timestamp and use it as desired timestamp
original_timestamp = datetime.fromisoformat(session_message.created_at.replace("Z", "+00:00"))
monotonic_timestamp = self._get_monotonic_timestamp(original_timestamp)

Expand Down Expand Up @@ -477,16 +630,28 @@ def list_messages(
# region RepositorySessionManager overrides
@override
def append_message(self, message: Message, agent: "Agent", **kwargs: Any) -> None:
"""Append a message to the agent's session using AgentCore's eventId as message_id.
"""Append a message to the agent's session with batched agent state sync.

Saves message and agent state in a single API call for optimization.

Args:
message: Message to add to the agent in the session
agent: Agent to append the message to
**kwargs: Additional keyword arguments for future extensibility.
"""
created_message = self.create_message(self.session_id, agent.agent_id, SessionMessage.from_message(message, 0))
session_message = SessionMessage.from_message(message, created_message.get("eventId"))
self._latest_agent_message[agent.agent_id] = session_message
self._save_message_with_state(message, agent)

@override
def sync_agent(self, agent: "Agent", **kwargs: Any) -> None:
"""Sync agent state only if it changed since last sync.

Skips sync if agent state is unchanged, avoiding redundant API calls.

Args:
agent: Agent to sync to the session.
**kwargs: Additional keyword arguments for future extensibility.
"""
self._sync_agent_state_if_changed(agent)

def retrieve_customer_context(self, event: MessageAddedEvent) -> None:
"""Retrieve customer LTM context before processing support query.
Expand All @@ -495,15 +660,16 @@ def retrieve_customer_context(self, event: MessageAddedEvent) -> None:
event (MessageAddedEvent): The message added event containing the agent and message data.
"""
messages = event.agent.messages
if not messages or messages[-1].get("role") != "user" or "toolResult" in messages[-1].get("content")[0]:
content = messages[-1].get("content") if messages else None
if not messages or messages[-1].get("role") != "user" or not content or "toolResult" in content[0]:
return None
if not self.config.retrieval_config:
# Only retrieve LTM
return None

user_query = messages[-1]["content"][0]["text"]

def retrieve_for_namespace(namespace: str, retrieval_config: RetrievalConfig):
def retrieve_for_namespace(namespace: str, retrieval_config: RetrievalConfig) -> list[str]:
"""Helper function to retrieve memories for a single namespace."""
resolved_namespace = namespace.format(
actorId=self.config.actor_id,
Expand Down Expand Up @@ -565,16 +731,18 @@ def retrieve_for_namespace(namespace: str, retrieval_config: RetrievalConfig):
logger.error("Failed to retrieve customer context: %s", e)

@override
def register_hooks(self, registry: HookRegistry, **kwargs) -> None:
"""Register additional hooks.

Args:
registry (HookRegistry): The hook registry to register callbacks with.
**kwargs: Additional keyword arguments.
"""
RepositorySessionManager.register_hooks(self, registry, **kwargs)
def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None:
"""Register hooks for session management with optimized storage."""
registry.add_callback(AgentInitializedEvent, lambda event: self.initialize(event.agent))
registry.add_callback(MessageAddedEvent, lambda e: self._save_message_with_state(e.message, e.agent))
registry.add_callback(AfterInvocationEvent, lambda event: self._sync_agent_state_if_changed(event.agent))
registry.add_callback(MessageAddedEvent, lambda event: self.retrieve_customer_context(event))

# Multi-agent hooks
registry.add_callback(MultiAgentInitializedEvent, lambda event: self.initialize_multi_agent(event.source))
registry.add_callback(AfterNodeCallEvent, lambda event: self.sync_multi_agent(event.source))
registry.add_callback(AfterMultiAgentInvocationEvent, lambda event: self.sync_multi_agent(event.source))

@override
def initialize(self, agent: "Agent", **kwargs: Any) -> None:
if self.has_existing_agent:
Expand Down
Loading
Loading