From 2c5542e912873e112a460d9347e60a9a12d6c457 Mon Sep 17 00:00:00 2001 From: Kihyeon Myung Date: Fri, 23 Jan 2026 08:28:21 -0700 Subject: [PATCH 1/4] feat(memory): Add storage_version parameter for API call optimization Add opt-in storage_version="v2" parameter to AgentCoreMemorySessionManager that enables batched API calls by using unified actorId. Changes: - Add storage_version parameter ("v1" default, "v2" opt-in) - v2 batches message + agent state in single API call --- .../integrations/strands/bedrock_converter.py | 17 +- .../integrations/strands/session_manager.py | 218 +++++++++++++++--- 2 files changed, 204 insertions(+), 31 deletions(-) diff --git a/src/bedrock_agentcore/memory/integrations/strands/bedrock_converter.py b/src/bedrock_agentcore/memory/integrations/strands/bedrock_converter.py index 1f0905c..35cfaa2 100644 --- a/src/bedrock_agentcore/memory/integrations/strands/bedrock_converter.py +++ b/src/bedrock_agentcore/memory/integrations/strands/bedrock_converter.py @@ -83,6 +83,19 @@ def events_to_messages(events: list[dict[str, Any]]) -> list[SessionMessage]: elif "blob" in payload_item: try: blob_data = json.loads(payload_item["blob"]) + # V2 format: dict with _type marker + if isinstance(blob_data, dict): + if blob_data.get("_type") == "agent_state": + continue # Skip agent state payloads + if blob_data.get("_type") == "message": + data = blob_data.get("data", []) + if isinstance(data, (tuple, list)) and len(data) == 2: + session_msg = SessionMessage.from_dict(json.loads(data[0])) + session_msg.message = AgentCoreMemoryConverter._filter_empty_text(session_msg.message) + if session_msg.message.get("content"): + messages.append(session_msg) + continue + # Legacy format: tuple [json_string, role] if isinstance(blob_data, (tuple, list)) and len(blob_data) == 2: try: session_msg = SessionMessage.from_dict(json.loads(blob_data[0])) @@ -91,8 +104,8 @@ def events_to_messages(events: list[dict[str, Any]]) -> list[SessionMessage]: messages.append(session_msg) except (json.JSONDecodeError, ValueError): logger.error("This is not a SessionMessage but just a blob message. Ignoring") - except (json.JSONDecodeError, ValueError): - logger.error("Failed to parse blob content: %s", payload_item) + except Exception as e: + logger.error("Failed to parse blob content: %s, error: %s", payload_item, e) return list(reversed(messages)) @staticmethod diff --git a/src/bedrock_agentcore/memory/integrations/strands/session_manager.py b/src/bedrock_agentcore/memory/integrations/strands/session_manager.py index c3912a7..540dbcd 100644 --- a/src/bedrock_agentcore/memory/integrations/strands/session_manager.py +++ b/src/bedrock_agentcore/memory/integrations/strands/session_manager.py @@ -5,11 +5,11 @@ import threading from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime, timedelta, timezone -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Literal, Optional import boto3 from botocore.config import Config as BotocoreConfig -from strands.hooks import MessageAddedEvent +from strands.hooks import AfterInvocationEvent, AgentInitializedEvent, MessageAddedEvent from strands.hooks.registry import HookRegistry from strands.session.repository_session_manager import RepositorySessionManager from strands.session.session_repository import SessionRepository @@ -33,6 +33,10 @@ MESSAGE_PREFIX = "message_" MAX_FETCH_ALL_RESULTS = 10000 +# Payload type markers for v2 storage +PAYLOAD_TYPE_MESSAGE = "message" +PAYLOAD_TYPE_AGENT_STATE = "agent_state" + class AgentCoreMemorySessionManager(RepositorySessionManager, SessionRepository): """AgentCore Memory-based session manager for Bedrock AgentCore Memory integration. @@ -87,6 +91,7 @@ def __init__( region_name: Optional[str] = None, boto_session: Optional[boto3.Session] = None, boto_client_config: Optional[BotocoreConfig] = None, + storage_version: Literal["v1", "v2"] = "v1", **kwargs: Any, ): """Initialize AgentCoreMemorySessionManager with Bedrock AgentCore Memory. @@ -97,12 +102,17 @@ def __init__( boto_session (Optional[boto3.Session], optional): Optional boto3 session. Defaults to None. boto_client_config (Optional[BotocoreConfig], optional): Optional boto3 client configuration. Defaults to None. + storage_version (Literal["v1", "v2"], optional): Storage version for API optimization. + - "v1" (default): Original behavior where agent state uses separate actorId. + - "v2": Unified actorId enabling batched API calls to reduce redundant requests. + Defaults to "v1" for backward compatibility. **kwargs (Any): Additional keyword arguments. """ self.config = agentcore_memory_config self.memory_client = MemoryClient(region_name=region_name) session = boto_session or boto3.Session(region_name=region_name) self.has_existing_agent = False + self.storage_version = storage_version # Override the clients if custom boto session or config is provided # Add strands-agents to the request user agent @@ -157,6 +167,105 @@ def _get_full_agent_id(self, agent_id: str) -> str: ) return full_agent_id + # region Optimized Storage Methods (storage_version="v2") + + def _build_agent_state_payload(self, agent: "Agent") -> dict: + """Create agent state payload for unified storage. + + Creates a SessionAgent-compatible payload that can be reconstructed + via SessionAgent.from_dict(). + + Args: + agent (Agent): The agent whose state to capture. + + Returns: + dict: Agent state payload with type markers. + """ + session_agent = SessionAgent.from_agent(agent) + return { + "_type": PAYLOAD_TYPE_AGENT_STATE, + "_agent_id": agent.agent_id, + **session_agent.to_dict(), + } + + def save_message_with_state(self, message: Message, agent: "Agent") -> None: + """Save message and agent state in a single batched API call (v2 only). + + Combines message and agent state into one API call instead of two separate calls. + + Args: + message (Message): The message to save. + agent (Agent): The agent whose state to sync. + """ + if self.storage_version != "v2": + raise RuntimeError("save_message_with_state is only available in storage_version='v2'") + + session_message = SessionMessage.from_message(message, 0) + messages = AgentCoreMemoryConverter.message_to_payload(session_message) + if not messages: + return + + # Convert message to v2 payload format with type marker + message_tuple = messages[0] + message_payload = {"_type": PAYLOAD_TYPE_MESSAGE, "data": list(message_tuple)} + + # Prepare agent state payload with type marker + agent_state_payload = self._build_agent_state_payload(agent) + + # Parse the original timestamp and use it as desired timestamp + original_timestamp = datetime.fromisoformat(session_message.created_at.replace("Z", "+00:00")) + monotonic_timestamp = self._get_monotonic_timestamp(original_timestamp) + + try: + # Single batched API call with both payloads + event = self.memory_client.gmdp_client.create_event( + memoryId=self.config.memory_id, + actorId=self.config.actor_id, + sessionId=self.session_id, + payload=[ + {"blob": json.dumps(message_payload)}, + {"blob": json.dumps(agent_state_payload)}, + ], + eventTimestamp=monotonic_timestamp, + ) + logger.debug( + "Saved message and agent state in single call: event=%s, agent=%s", + event.get("event", {}).get("eventId"), + agent.agent_id, + ) + + # Update latest message tracking + session_message = SessionMessage.from_message(message, event.get("event", {}).get("eventId")) + self._latest_agent_message[agent.agent_id] = session_message + + except Exception as e: + logger.error("Failed to save message with state: %s", e) + raise SessionException(f"Failed to save message with state: {e}") from e + + def _sync_agent_state(self, agent: "Agent") -> None: + """Sync agent state to AgentCore Memory using unified actorId. + + Args: + agent (Agent): The agent to sync. + """ + agent_state_payload = self._build_agent_state_payload(agent) + + try: + event = self.memory_client.gmdp_client.create_event( + memoryId=self.config.memory_id, + actorId=self.config.actor_id, + sessionId=self.session_id, + payload=[{"blob": json.dumps(agent_state_payload)}], + eventTimestamp=self._get_monotonic_timestamp(), + ) + logger.debug( + "Synced agent state: event=%s, agent=%s", event.get("event", {}).get("eventId"), agent.agent_id + ) + except Exception as e: + logger.error("Failed to sync agent state: %s", e) + + # endregion Optimized Storage Methods + # region SessionRepository interface implementation def create_session(self, session: Session, **kwargs: Any) -> Session: """Create a new session in AgentCore Memory. @@ -237,6 +346,8 @@ def create_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: A The agent's existence is inferred from the presence of events/messages in the memory system, but we validate the session_id matches our config. + For storage_version="v2", uses unified actorId with type markers. + Args: session_id (str): The session ID to create the agent in. session_agent (SessionAgent): The agent to create. @@ -248,15 +359,30 @@ def create_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: A if session_id != self.config.session_id: raise SessionException(f"Session ID mismatch: expected {self.config.session_id}, got {session_id}") - event = self.memory_client.gmdp_client.create_event( - memoryId=self.config.memory_id, - actorId=self._get_full_agent_id(session_agent.agent_id), - sessionId=self.session_id, - payload=[ - {"blob": json.dumps(session_agent.to_dict())}, - ], - eventTimestamp=self._get_monotonic_timestamp(), - ) + if self.storage_version == "v2": + # V2: Use unified actorId with type marker + agent_state_payload = { + "_type": PAYLOAD_TYPE_AGENT_STATE, + "_agent_id": session_agent.agent_id, + **session_agent.to_dict(), + } + event = self.memory_client.gmdp_client.create_event( + memoryId=self.config.memory_id, + actorId=self.config.actor_id, + sessionId=self.session_id, + payload=[{"blob": json.dumps(agent_state_payload)}], + eventTimestamp=self._get_monotonic_timestamp(), + ) + else: + # V1: Use separate actorId for agent + event = self.memory_client.gmdp_client.create_event( + memoryId=self.config.memory_id, + actorId=self._get_full_agent_id(session_agent.agent_id), + sessionId=self.session_id, + payload=[{"blob": json.dumps(session_agent.to_dict())}], + eventTimestamp=self._get_monotonic_timestamp(), + ) + logger.info( "Created agent: %s in session: %s with event %s", session_agent.agent_id, @@ -267,7 +393,9 @@ def create_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: A def read_agent(self, session_id: str, agent_id: str, **kwargs: Any) -> Optional[SessionAgent]: """Read agent data from AgentCore Memory events. - We reconstruct the agent state from the conversation history. + Uses the storage version specified at initialization: + - v1: Reads from separate actorId (agent_{id}) + - v2: Reads from unified actorId with type markers Args: session_id (str): The session ID to read from. @@ -279,22 +407,42 @@ def read_agent(self, session_id: str, agent_id: str, **kwargs: Any) -> Optional[ """ if session_id != self.config.session_id: return None - try: - events = self.memory_client.list_events( - memory_id=self.config.memory_id, - actor_id=self._get_full_agent_id(agent_id), - session_id=session_id, - max_results=1, - ) - - if not events: - return None - agent_data = json.loads(events[0].get("payload", {})[0].get("blob")) - return SessionAgent.from_dict(agent_data) + try: + if self.storage_version == "v2": + # V2: Read from unified actorId, filter by type marker + events = self.memory_client.list_events( + memory_id=self.config.memory_id, + actor_id=self.config.actor_id, + session_id=session_id, + max_results=MAX_FETCH_ALL_RESULTS, + ) + for event in events: + for payload_item in event.get("payload", []): + blob = payload_item.get("blob") + if blob: + try: + data = json.loads(blob) + if data.get("_type") == PAYLOAD_TYPE_AGENT_STATE and data.get("_agent_id") == agent_id: + agent_data = {k: v for k, v in data.items() if k not in ("_type", "_agent_id")} + return SessionAgent.from_dict(agent_data) + except json.JSONDecodeError: + continue + else: + # V1: Read from separate actorId + events = self.memory_client.list_events( + memory_id=self.config.memory_id, + actor_id=self._get_full_agent_id(agent_id), + session_id=session_id, + max_results=1, + ) + if events: + agent_data = json.loads(events[0].get("payload", {})[0].get("blob")) + return SessionAgent.from_dict(agent_data) except Exception as e: - logger.error("Failed to read agent %s", e) - return None + logger.error("Failed to read agent: %s", e) + + return None def update_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: Any) -> None: """Update agent data. @@ -566,14 +714,26 @@ def retrieve_for_namespace(namespace: str, retrieval_config: RetrievalConfig): @override def register_hooks(self, registry: HookRegistry, **kwargs) -> None: - """Register additional hooks. + """Register hooks for session management. + + For storage_version="v1" (default): + Uses parent class behavior with separate API calls for message and agent state. + + For storage_version="v2": + Uses batched API calls to reduce redundant API calls. Args: registry (HookRegistry): The hook registry to register callbacks with. **kwargs: Additional keyword arguments. """ - RepositorySessionManager.register_hooks(self, registry, **kwargs) - registry.add_callback(MessageAddedEvent, lambda event: self.retrieve_customer_context(event)) + if self.storage_version == "v2": + registry.add_callback(AgentInitializedEvent, lambda event: self.initialize(event.agent)) + registry.add_callback(MessageAddedEvent, lambda event: self.save_message_with_state(event.message, event.agent)) + registry.add_callback(AfterInvocationEvent, lambda event: self._sync_agent_state(event.agent)) + registry.add_callback(MessageAddedEvent, lambda event: self.retrieve_customer_context(event)) + else: + RepositorySessionManager.register_hooks(self, registry, **kwargs) + registry.add_callback(MessageAddedEvent, lambda event: self.retrieve_customer_context(event)) @override def initialize(self, agent: "Agent", **kwargs: Any) -> None: From 46ce4698b6f65a9531101e555e8e4cf5394dfb78 Mon Sep 17 00:00:00 2001 From: Kihyeon Myung Date: Fri, 23 Jan 2026 09:42:09 -0700 Subject: [PATCH 2/4] test(memory): Add unit and integration tests for storage_version v2 --- pyproject.toml | 3 + .../integrations/strands/bedrock_converter.py | 3 +- .../integrations/strands/session_manager.py | 4 +- .../test_agentcore_memory_session_manager.py | 253 ++++++++++++ .../strands/test_bedrock_converter.py | 98 +++++ .../test_storage_version_integration.py | 373 ++++++++++++++++++ 6 files changed, 732 insertions(+), 2 deletions(-) create mode 100644 tests/bedrock_agentcore/memory/integrations/strands/test_storage_version_integration.py diff --git a/pyproject.toml b/pyproject.toml index a061fc2..d646ae4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -94,6 +94,9 @@ testpaths = [ "tests" ] asyncio_mode = "auto" +markers = [ + "integration: marks tests as integration tests (require AWS credentials, deselect with '-m \"not integration\"')", +] [tool.coverage.run] branch = true diff --git a/src/bedrock_agentcore/memory/integrations/strands/bedrock_converter.py b/src/bedrock_agentcore/memory/integrations/strands/bedrock_converter.py index 35cfaa2..eafa973 100644 --- a/src/bedrock_agentcore/memory/integrations/strands/bedrock_converter.py +++ b/src/bedrock_agentcore/memory/integrations/strands/bedrock_converter.py @@ -91,7 +91,8 @@ def events_to_messages(events: list[dict[str, Any]]) -> list[SessionMessage]: data = blob_data.get("data", []) if isinstance(data, (tuple, list)) and len(data) == 2: session_msg = SessionMessage.from_dict(json.loads(data[0])) - session_msg.message = AgentCoreMemoryConverter._filter_empty_text(session_msg.message) + filtered = AgentCoreMemoryConverter._filter_empty_text(session_msg.message) + session_msg.message = filtered if session_msg.message.get("content"): messages.append(session_msg) continue diff --git a/src/bedrock_agentcore/memory/integrations/strands/session_manager.py b/src/bedrock_agentcore/memory/integrations/strands/session_manager.py index 540dbcd..832872d 100644 --- a/src/bedrock_agentcore/memory/integrations/strands/session_manager.py +++ b/src/bedrock_agentcore/memory/integrations/strands/session_manager.py @@ -728,7 +728,9 @@ def register_hooks(self, registry: HookRegistry, **kwargs) -> None: """ if self.storage_version == "v2": registry.add_callback(AgentInitializedEvent, lambda event: self.initialize(event.agent)) - registry.add_callback(MessageAddedEvent, lambda event: self.save_message_with_state(event.message, event.agent)) + registry.add_callback( + MessageAddedEvent, lambda event: self.save_message_with_state(event.message, event.agent) + ) registry.add_callback(AfterInvocationEvent, lambda event: self._sync_agent_state(event.agent)) registry.add_callback(MessageAddedEvent, lambda event: self.retrieve_customer_context(event)) else: diff --git a/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py b/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py index a01973c..576c97e 100644 --- a/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py +++ b/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py @@ -1116,3 +1116,256 @@ def test_list_messages_with_limit_calculates_max_results(self, session_manager, mock_memory_client.list_events.assert_called_once() call_kwargs = mock_memory_client.list_events.call_args[1] assert call_kwargs["max_results"] == 550 # limit + offset + + +class TestStorageVersionV2: + """Tests for storage_version='v2' functionality.""" + + @pytest.fixture + def agentcore_config(self): + """Create a test AgentCore Memory configuration.""" + return AgentCoreMemoryConfig( + memory_id="test-memory-123", + session_id="test-session-456", + actor_id="test-actor-789" + ) + + @pytest.fixture + def mock_memory_client(self): + """Create a mock MemoryClient.""" + client = Mock() + client.create_event.return_value = {"eventId": "event_123456"} + client.list_events.return_value = [] + client.retrieve_memories.return_value = [] + client.gmcp_client = Mock() + client.gmdp_client = Mock() + client.gmdp_client.create_event.return_value = {"event": {"eventId": "event_123456"}} + return client + + @pytest.fixture + def session_manager_v2(self, agentcore_config, mock_memory_client): + """Create AgentCoreMemorySessionManager with storage_version='v2'.""" + with patch( + "bedrock_agentcore.memory.integrations.strands.session_manager.MemoryClient", + return_value=mock_memory_client + ): + with patch("boto3.Session") as mock_boto_session: + mock_session = Mock() + mock_session.region_name = "us-west-2" + mock_session.client.return_value = Mock() + mock_boto_session.return_value = mock_session + + with patch( + "strands.session.repository_session_manager.RepositorySessionManager.__init__", + return_value=None + ): + manager = AgentCoreMemorySessionManager( + agentcore_config, + storage_version="v2" + ) + manager.session_id = agentcore_config.session_id + manager.session = Session( + session_id=agentcore_config.session_id, + session_type=SessionType.AGENT + ) + manager.memory_client = mock_memory_client + manager._latest_agent_message = {} + return manager + + def test_init_storage_version_default(self, agentcore_config): + """Test default storage_version is 'v1'.""" + with patch("bedrock_agentcore.memory.integrations.strands.session_manager.MemoryClient"): + with patch("boto3.Session") as mock_boto_session: + mock_session = Mock() + mock_session.region_name = "us-west-2" + mock_boto_session.return_value = mock_session + + with patch( + "strands.session.repository_session_manager.RepositorySessionManager.__init__", + return_value=None + ): + manager = AgentCoreMemorySessionManager(agentcore_config) + assert manager.storage_version == "v1" + + def test_init_storage_version_v2(self, agentcore_config): + """Test storage_version can be set to 'v2'.""" + with patch("bedrock_agentcore.memory.integrations.strands.session_manager.MemoryClient"): + with patch("boto3.Session") as mock_boto_session: + mock_session = Mock() + mock_session.region_name = "us-west-2" + mock_boto_session.return_value = mock_session + + with patch( + "strands.session.repository_session_manager.RepositorySessionManager.__init__", + return_value=None + ): + manager = AgentCoreMemorySessionManager( + agentcore_config, + storage_version="v2" + ) + assert manager.storage_version == "v2" + + def test_register_hooks_v2_registers_agent_initialized(self, session_manager_v2): + """Test v2 registers AgentInitializedEvent hook.""" + from strands.hooks import AgentInitializedEvent + from strands.hooks.registry import HookRegistry + + registry = HookRegistry() + session_manager_v2.register_hooks(registry) + + assert AgentInitializedEvent in registry._registered_callbacks + assert len(registry._registered_callbacks[AgentInitializedEvent]) == 1 + + def test_register_hooks_v2_message_callbacks(self, session_manager_v2): + """Test v2 registers correct MessageAddedEvent callbacks.""" + from strands.hooks import MessageAddedEvent + from strands.hooks.registry import HookRegistry + + registry = HookRegistry() + session_manager_v2.register_hooks(registry) + + # v2: save_message_with_state + retrieve_customer_context + assert MessageAddedEvent in registry._registered_callbacks + assert len(registry._registered_callbacks[MessageAddedEvent]) == 2 + + def test_save_message_with_state_batches_payloads(self, session_manager_v2, mock_memory_client): + """Test save_message_with_state creates single API call with batched payloads.""" + import json + + mock_agent = Mock() + mock_agent.agent_id = "test-agent" + + mock_session_agent = SessionAgent( + agent_id="test-agent", + state={"key": "value"}, + conversation_manager_state={"cm_key": "cm_value"} + ) + + message = {"role": "user", "content": [{"text": "Hello"}]} + + with patch.object(SessionAgent, "from_agent", return_value=mock_session_agent): + session_manager_v2.save_message_with_state(message, mock_agent) + + mock_memory_client.gmdp_client.create_event.assert_called_once() + call_kwargs = mock_memory_client.gmdp_client.create_event.call_args[1] + + # Should have 2 payloads: message + agent_state + assert len(call_kwargs["payload"]) == 2 + + # Verify message payload format + msg_payload = json.loads(call_kwargs["payload"][0]["blob"]) + assert msg_payload["_type"] == "message" + assert "data" in msg_payload + + # Verify agent_state payload format + agent_payload = json.loads(call_kwargs["payload"][1]["blob"]) + assert agent_payload["_type"] == "agent_state" + assert agent_payload["_agent_id"] == "test-agent" + + def test_read_agent_v2_format(self, session_manager_v2, mock_memory_client): + """Test read_agent correctly parses v2 format with _type marker.""" + import json + + agent_state_payload = { + "_type": "agent_state", + "_agent_id": "test-agent", + "agent_id": "test-agent", + "state": {"key": "value"}, + "conversation_manager_state": {"cm_key": "cm_value"} + } + + mock_memory_client.list_events.return_value = [ + { + "eventId": "event-1", + "payload": [{"blob": json.dumps(agent_state_payload)}] + } + ] + + result = session_manager_v2.read_agent("test-session-456", "test-agent") + + assert result is not None + assert result.agent_id == "test-agent" + assert result.state == {"key": "value"} + + def test_read_agent_v2_skips_message_payloads(self, session_manager_v2, mock_memory_client): + """Test read_agent skips message payloads when looking for agent state.""" + import json + + from strands.types.session import SessionMessage + + msg = SessionMessage( + message_id=1, + message={"role": "user", "content": [{"text": "Hello"}]}, + created_at="2023-01-01T00:00:00Z" + ) + message_payload = { + "_type": "message", + "data": [json.dumps(msg.to_dict()), "user"] + } + agent_state_payload = { + "_type": "agent_state", + "_agent_id": "test-agent", + "agent_id": "test-agent", + "state": {}, + "conversation_manager_state": {} + } + + mock_memory_client.list_events.return_value = [ + { + "eventId": "event-1", + "payload": [ + {"blob": json.dumps(message_payload)}, + {"blob": json.dumps(agent_state_payload)} + ] + } + ] + + result = session_manager_v2.read_agent("test-session-456", "test-agent") + + assert result is not None + assert result.agent_id == "test-agent" + + def test_create_agent_v2_format(self, session_manager_v2, mock_memory_client): + """Test create_agent saves in v2 format with _type marker.""" + import json + + session_agent = SessionAgent( + agent_id="test-agent", + state={"key": "value"}, + conversation_manager_state={} + ) + + session_manager_v2.create_agent("test-session-456", session_agent) + + mock_memory_client.gmdp_client.create_event.assert_called_once() + call_kwargs = mock_memory_client.gmdp_client.create_event.call_args[1] + + payload_data = json.loads(call_kwargs["payload"][0]["blob"]) + assert payload_data["_type"] == "agent_state" + assert payload_data["_agent_id"] == "test-agent" + assert payload_data["agent_id"] == "test-agent" + + def test_list_messages_parses_v2_format(self, session_manager_v2, mock_memory_client): + """Test list_messages correctly parses v2 format messages.""" + import json + + from strands.types.session import SessionMessage + + msg = SessionMessage( + message_id=1, + message={"role": "user", "content": [{"text": "Hello from v2"}]}, + created_at="2023-01-01T00:00:00Z" + ) + message_payload = { + "_type": "message", + "data": [json.dumps(msg.to_dict()), "user"] + } + + mock_memory_client.list_events.return_value = [ + {"eventId": "event-1", "payload": [{"blob": json.dumps(message_payload)}]} + ] + + result = session_manager_v2.list_messages("test-session-456", "test-agent") + + assert len(result) == 1 + assert result[0].message["content"][0]["text"] == "Hello from v2" diff --git a/tests/bedrock_agentcore/memory/integrations/strands/test_bedrock_converter.py b/tests/bedrock_agentcore/memory/integrations/strands/test_bedrock_converter.py index e15a457..f845ba9 100644 --- a/tests/bedrock_agentcore/memory/integrations/strands/test_bedrock_converter.py +++ b/tests/bedrock_agentcore/memory/integrations/strands/test_bedrock_converter.py @@ -221,3 +221,101 @@ def test_message_to_payload_with_bytes_encodes_before_filtering(self): assert isinstance(encoded_bytes, dict) assert encoded_bytes.get("__bytes_encoded__") is True assert "data" in encoded_bytes + + def test_events_to_messages_v2_format(self): + """Test parsing v2 format with _type marker.""" + session_message = SessionMessage( + message_id=1, message={"role": "user", "content": [{"text": "Hello"}]}, created_at="2023-01-01T00:00:00Z" + ) + + v2_payload = { + "_type": "message", + "data": [json.dumps(session_message.to_dict()), "user"] + } + events = [{"payload": [{"blob": json.dumps(v2_payload)}]}] + + result = AgentCoreMemoryConverter.events_to_messages(events) + + assert len(result) == 1 + assert result[0].message["role"] == "user" + assert result[0].message["content"][0]["text"] == "Hello" + + def test_events_to_messages_v2_skips_agent_state(self): + """Test v2 format skips agent_state payloads.""" + session_message = SessionMessage( + message_id=1, message={"role": "user", "content": [{"text": "Hello"}]}, created_at="2023-01-01T00:00:00Z" + ) + + message_payload = { + "_type": "message", + "data": [json.dumps(session_message.to_dict()), "user"] + } + agent_state_payload = { + "_type": "agent_state", + "_agent_id": "test-agent", + "agent_id": "test-agent", + "state": {}, + "conversation_manager_state": {} + } + events = [ + {"payload": [ + {"blob": json.dumps(message_payload)}, + {"blob": json.dumps(agent_state_payload)} + ]} + ] + + result = AgentCoreMemoryConverter.events_to_messages(events) + + assert len(result) == 1 + assert result[0].message["content"][0]["text"] == "Hello" + + def test_events_to_messages_v2_multiple_messages_reversed(self): + """Test v2 format returns messages in chronological order (reversed from API).""" + msg1 = SessionMessage( + message_id=1, + message={"role": "user", "content": [{"text": "First"}]}, + created_at="2023-01-01T00:00:00Z", + ) + msg2 = SessionMessage( + message_id=2, + message={"role": "assistant", "content": [{"text": "Second"}]}, + created_at="2023-01-01T00:00:01Z", + ) + + # API returns newest first + v2_msg2 = {"_type": "message", "data": [json.dumps(msg2.to_dict()), "assistant"]} + v2_msg1 = {"_type": "message", "data": [json.dumps(msg1.to_dict()), "user"]} + events = [ + {"payload": [{"blob": json.dumps(v2_msg2)}]}, + {"payload": [{"blob": json.dumps(v2_msg1)}]}, + ] + + result = AgentCoreMemoryConverter.events_to_messages(events) + + assert len(result) == 2 + assert result[0].message["content"][0]["text"] == "First" + assert result[1].message["content"][0]["text"] == "Second" + + def test_events_to_messages_mixed_v2_and_legacy(self): + """Test handling mixed v2 and legacy formats.""" + msg_v2 = SessionMessage( + message_id=1, + message={"role": "user", "content": [{"text": "V2 message"}]}, + created_at="2023-01-01T00:00:00Z", + ) + msg_legacy = SessionMessage( + message_id=2, + message={"role": "assistant", "content": [{"text": "Legacy message"}]}, + created_at="2023-01-01T00:00:01Z", + ) + + v2_payload = {"_type": "message", "data": [json.dumps(msg_v2.to_dict()), "user"]} + legacy_payload = [json.dumps(msg_legacy.to_dict()), "assistant"] + events = [ + {"payload": [{"blob": json.dumps(v2_payload)}]}, + {"payload": [{"blob": json.dumps(legacy_payload)}]}, + ] + + result = AgentCoreMemoryConverter.events_to_messages(events) + + assert len(result) == 2 diff --git a/tests/bedrock_agentcore/memory/integrations/strands/test_storage_version_integration.py b/tests/bedrock_agentcore/memory/integrations/strands/test_storage_version_integration.py new file mode 100644 index 0000000..7f99a1d --- /dev/null +++ b/tests/bedrock_agentcore/memory/integrations/strands/test_storage_version_integration.py @@ -0,0 +1,373 @@ +"""Integration tests for storage_version parameter. + +These tests require AWS credentials and make actual API calls. +Run with: pytest -m integration tests/.../test_storage_version_integration.py -v +Skip with: pytest -m "not integration" + +To skip these tests in normal runs, they are marked with @pytest.mark.integration. +""" + +import json +import os +import uuid + +import pytest + +# Skip all tests in this module if --run-integration is not provided +pytestmark = pytest.mark.integration + + +def get_memory_id(): + """Get Memory ID from environment.""" + return os.environ.get("MEMORY_ID") + + +def get_region(): + """Get AWS region from environment.""" + return os.environ.get("AWS_REGION", "us-west-2") + + +@pytest.fixture(scope="module") +def memory_id(): + """Get memory ID, skip if not available.""" + mid = get_memory_id() + if not mid: + pytest.skip("MEMORY_ID environment variable not set") + return mid + + +@pytest.fixture(scope="module") +def region(): + """Get AWS region.""" + return get_region() + + +@pytest.fixture +def unique_session_id(): + """Generate unique session ID for each test.""" + return f"test-{uuid.uuid4().hex[:12]}" + + +class APICallCounter: + """Count API calls to AgentCore Memory.""" + + def __init__(self): + self.call_count = 0 + self.call_details = [] + + def reset(self): + self.call_count = 0 + self.call_details = [] + + def increment(self, method_name: str): + self.call_count += 1 + self.call_details.append(method_name) + + +class TestStorageVersionIntegration: + """Integration tests for storage_version parameter.""" + + def test_v2_reduces_api_calls_simple_message(self, memory_id, region, unique_session_id): + """Test that v2 reduces API calls for simple message.""" + from strands import Agent + from strands.models import BedrockModel + + from bedrock_agentcore.memory.integrations.strands.config import AgentCoreMemoryConfig + from bedrock_agentcore.memory.integrations.strands.session_manager import AgentCoreMemorySessionManager + + actor_id = "test-user" + + # Test v1 + config_v1 = AgentCoreMemoryConfig( + memory_id=memory_id, + session_id=f"v1-{unique_session_id}", + actor_id=actor_id, + ) + session_manager_v1 = AgentCoreMemorySessionManager( + agentcore_memory_config=config_v1, + region_name=region, + storage_version="v1", + ) + + model = BedrockModel( + model_id="us.anthropic.claude-haiku-4-5-20251001-v1:0", + region_name=region + ) + + agent_v1 = Agent( + model=model, + session_manager=session_manager_v1, + system_prompt="You are a helpful assistant. Keep responses very brief." + ) + + # Count API calls for v1 + counter_v1 = APICallCounter() + original_create_event_v1 = session_manager_v1.memory_client.gmdp_client.create_event + + def counted_v1(*args, **kwargs): + counter_v1.increment("create_event") + return original_create_event_v1(*args, **kwargs) + + session_manager_v1.memory_client.gmdp_client.create_event = counted_v1 + agent_v1("Say hello briefly.") + v1_calls = counter_v1.call_count + + # Test v2 + config_v2 = AgentCoreMemoryConfig( + memory_id=memory_id, + session_id=f"v2-{unique_session_id}", + actor_id=actor_id, + ) + session_manager_v2 = AgentCoreMemorySessionManager( + agentcore_memory_config=config_v2, + region_name=region, + storage_version="v2", + ) + + agent_v2 = Agent( + model=model, + session_manager=session_manager_v2, + system_prompt="You are a helpful assistant. Keep responses very brief." + ) + + # Count API calls for v2 + counter_v2 = APICallCounter() + original_create_event_v2 = session_manager_v2.memory_client.gmdp_client.create_event + + def counted_v2(*args, **kwargs): + counter_v2.increment("create_event") + return original_create_event_v2(*args, **kwargs) + + session_manager_v2.memory_client.gmdp_client.create_event = counted_v2 + agent_v2("Say hello briefly.") + v2_calls = counter_v2.call_count + + # v2 should have fewer API calls than v1 + assert v2_calls < v1_calls, f"v2 ({v2_calls}) should have fewer calls than v1 ({v1_calls})" + # Expected: v1=5, v2=3 + assert v1_calls == 5, f"Expected v1=5 calls, got {v1_calls}" + assert v2_calls == 3, f"Expected v2=3 calls, got {v2_calls}" + + def test_v2_multi_turn_preserves_context(self, memory_id, region, unique_session_id): + """Test that v2 correctly preserves context across turns.""" + from strands import Agent + from strands.models import BedrockModel + + from bedrock_agentcore.memory.integrations.strands.config import AgentCoreMemoryConfig + from bedrock_agentcore.memory.integrations.strands.session_manager import AgentCoreMemorySessionManager + + actor_id = "test-user" + session_id = f"v2-multi-{unique_session_id}" + + config = AgentCoreMemoryConfig( + memory_id=memory_id, + session_id=session_id, + actor_id=actor_id, + ) + + model = BedrockModel( + model_id="us.anthropic.claude-haiku-4-5-20251001-v1:0", + region_name=region + ) + + # Turn 1: Tell the agent something to remember + session_manager_1 = AgentCoreMemorySessionManager( + agentcore_memory_config=config, + region_name=region, + storage_version="v2", + ) + + agent_1 = Agent( + model=model, + session_manager=session_manager_1, + system_prompt="You are a helpful assistant. Remember what the user tells you." + ) + + agent_1("My favorite color is blue. Remember this.") + + # Turn 2: Create NEW agent instance with SAME session ID + session_manager_2 = AgentCoreMemorySessionManager( + agentcore_memory_config=config, + region_name=region, + storage_version="v2", + ) + + agent_2 = Agent( + model=model, + session_manager=session_manager_2, + system_prompt="You are a helpful assistant. Remember what the user tells you." + ) + + # Verify messages were loaded + assert len(agent_2.messages) >= 2, f"Expected at least 2 messages loaded, got {len(agent_2.messages)}" + + # Ask about the remembered information + response = agent_2("What is my favorite color?") + + # Extract response text + response_text = "" + if response.message and response.message.get("content"): + for block in response.message["content"]: + if block.get("text"): + response_text = block["text"].lower() + break + + assert "blue" in response_text, f"Expected 'blue' in response, got: {response_text}" + + def test_v2_parses_messages_correctly(self, memory_id, region, unique_session_id): + """Test that v2 format messages are correctly saved and loaded.""" + from strands import Agent + from strands.models import BedrockModel + + from bedrock_agentcore.memory.client import MemoryClient + from bedrock_agentcore.memory.integrations.strands.config import AgentCoreMemoryConfig + from bedrock_agentcore.memory.integrations.strands.session_manager import AgentCoreMemorySessionManager + + actor_id = "test-user" + session_id = f"v2-parse-{unique_session_id}" + + config = AgentCoreMemoryConfig( + memory_id=memory_id, + session_id=session_id, + actor_id=actor_id, + ) + + model = BedrockModel( + model_id="us.anthropic.claude-haiku-4-5-20251001-v1:0", + region_name=region + ) + + session_manager = AgentCoreMemorySessionManager( + agentcore_memory_config=config, + region_name=region, + storage_version="v2", + ) + + agent = Agent( + model=model, + session_manager=session_manager, + system_prompt="You are a helpful assistant." + ) + + agent("Test message for v2 format.") + + # Verify stored format + memory_client = MemoryClient(region_name=region) + events = memory_client.list_events( + memory_id=memory_id, + actor_id=actor_id, + session_id=session_id, + max_results=10, + ) + + assert len(events) > 0, "Expected events to be saved" + + # Check that at least one event has v2 format + found_v2_message = False + found_v2_agent_state = False + + for event in events: + for payload_item in event.get("payload", []): + if "blob" in payload_item: + blob_data = json.loads(payload_item["blob"]) + if isinstance(blob_data, dict): + if blob_data.get("_type") == "message": + found_v2_message = True + if blob_data.get("_type") == "agent_state": + found_v2_agent_state = True + + assert found_v2_message, "Expected v2 format message with _type marker" + assert found_v2_agent_state, "Expected v2 format agent_state with _type marker" + + def test_v2_reduces_api_calls_with_tools(self, memory_id, region, unique_session_id): + """Test that v2 reduces API calls when using tools.""" + from datetime import datetime + + from strands import Agent, tool + from strands.models import BedrockModel + + from bedrock_agentcore.memory.integrations.strands.config import AgentCoreMemoryConfig + from bedrock_agentcore.memory.integrations.strands.session_manager import AgentCoreMemorySessionManager + + @tool + def get_time() -> str: + """Get the current time.""" + return datetime.now().strftime("%H:%M:%S") + + @tool + def add_numbers(a: int, b: int) -> int: + """Add two numbers.""" + return a + b + + actor_id = "test-user" + + model = BedrockModel( + model_id="us.anthropic.claude-haiku-4-5-20251001-v1:0", + region_name=region + ) + + # Test v1 with tools + config_v1 = AgentCoreMemoryConfig( + memory_id=memory_id, + session_id=f"v1-tools-{unique_session_id}", + actor_id=actor_id, + ) + session_manager_v1 = AgentCoreMemorySessionManager( + agentcore_memory_config=config_v1, + region_name=region, + storage_version="v1", + ) + + agent_v1 = Agent( + model=model, + session_manager=session_manager_v1, + tools=[get_time, add_numbers], + system_prompt="You are a helpful assistant. Use tools when asked. Keep responses brief." + ) + + counter_v1 = APICallCounter() + original_create_event_v1 = session_manager_v1.memory_client.gmdp_client.create_event + + def counted_v1(*args, **kwargs): + counter_v1.increment("create_event") + return original_create_event_v1(*args, **kwargs) + + session_manager_v1.memory_client.gmdp_client.create_event = counted_v1 + agent_v1("What time is it? And what is 10 + 20?") + v1_calls = counter_v1.call_count + + # Test v2 with tools + config_v2 = AgentCoreMemoryConfig( + memory_id=memory_id, + session_id=f"v2-tools-{unique_session_id}", + actor_id=actor_id, + ) + session_manager_v2 = AgentCoreMemorySessionManager( + agentcore_memory_config=config_v2, + region_name=region, + storage_version="v2", + ) + + agent_v2 = Agent( + model=model, + session_manager=session_manager_v2, + tools=[get_time, add_numbers], + system_prompt="You are a helpful assistant. Use tools when asked. Keep responses brief." + ) + + counter_v2 = APICallCounter() + original_create_event_v2 = session_manager_v2.memory_client.gmdp_client.create_event + + def counted_v2(*args, **kwargs): + counter_v2.increment("create_event") + return original_create_event_v2(*args, **kwargs) + + session_manager_v2.memory_client.gmdp_client.create_event = counted_v2 + agent_v2("What time is it? And what is 10 + 20?") + v2_calls = counter_v2.call_count + + # v2 should have fewer API calls than v1 + assert v2_calls < v1_calls, f"v2 ({v2_calls}) should have fewer calls than v1 ({v1_calls})" + # Expected: v1=9, v2=5 (with 2 tool calls) + assert v1_calls == 9, f"Expected v1=9 calls with tools, got {v1_calls}" + assert v2_calls == 5, f"Expected v2=5 calls with tools, got {v2_calls}" From 8c6984165969e925f104ffb233fdb813e7d0c67a Mon Sep 17 00:00:00 2001 From: Kihyeon Myung Date: Fri, 23 Jan 2026 14:06:15 -0700 Subject: [PATCH 3/4] Remove storage_version param, add state hash tracking --- .../integrations/strands/session_manager.py | 220 ++++++------ .../test_agentcore_memory_session_manager.py | 324 +++++++++++++----- ...py => test_session_manager_integration.py} | 185 +++------- 3 files changed, 404 insertions(+), 325 deletions(-) rename tests/bedrock_agentcore/memory/integrations/strands/{test_storage_version_integration.py => test_session_manager_integration.py} (54%) diff --git a/src/bedrock_agentcore/memory/integrations/strands/session_manager.py b/src/bedrock_agentcore/memory/integrations/strands/session_manager.py index 832872d..3089e62 100644 --- a/src/bedrock_agentcore/memory/integrations/strands/session_manager.py +++ b/src/bedrock_agentcore/memory/integrations/strands/session_manager.py @@ -5,7 +5,7 @@ import threading from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime, timedelta, timezone -from typing import TYPE_CHECKING, Any, Literal, Optional +from typing import TYPE_CHECKING, Any, Optional import boto3 from botocore.config import Config as BotocoreConfig @@ -33,7 +33,7 @@ MESSAGE_PREFIX = "message_" MAX_FETCH_ALL_RESULTS = 10000 -# Payload type markers for v2 storage +# Payload type markers for batched storage PAYLOAD_TYPE_MESSAGE = "message" PAYLOAD_TYPE_AGENT_STATE = "agent_state" @@ -91,7 +91,6 @@ def __init__( region_name: Optional[str] = None, boto_session: Optional[boto3.Session] = None, boto_client_config: Optional[BotocoreConfig] = None, - storage_version: Literal["v1", "v2"] = "v1", **kwargs: Any, ): """Initialize AgentCoreMemorySessionManager with Bedrock AgentCore Memory. @@ -102,17 +101,13 @@ def __init__( boto_session (Optional[boto3.Session], optional): Optional boto3 session. Defaults to None. boto_client_config (Optional[BotocoreConfig], optional): Optional boto3 client configuration. Defaults to None. - storage_version (Literal["v1", "v2"], optional): Storage version for API optimization. - - "v1" (default): Original behavior where agent state uses separate actorId. - - "v2": Unified actorId enabling batched API calls to reduce redundant requests. - Defaults to "v1" for backward compatibility. **kwargs (Any): Additional keyword arguments. """ self.config = agentcore_memory_config self.memory_client = MemoryClient(region_name=region_name) session = boto_session or boto3.Session(region_name=region_name) self.has_existing_agent = False - self.storage_version = storage_version + self._last_synced_state_hash: Optional[int] = None # Override the clients if custom boto session or config is provided # Add strands-agents to the request user agent @@ -167,7 +162,7 @@ def _get_full_agent_id(self, agent_id: str) -> str: ) return full_agent_id - # region Optimized Storage Methods (storage_version="v2") + # region Internal Storage Methods def _build_agent_state_payload(self, agent: "Agent") -> dict: """Create agent state payload for unified storage. @@ -188,8 +183,26 @@ def _build_agent_state_payload(self, agent: "Agent") -> dict: **session_agent.to_dict(), } - def save_message_with_state(self, message: Message, agent: "Agent") -> None: - """Save message and agent state in a single batched API call (v2 only). + def _compute_state_hash(self, agent: "Agent") -> int: + """Compute hash of agent state for change detection. + + Excludes timestamps (created_at, updated_at) as they change on every call. + + Args: + agent (Agent): The agent whose state to hash. + + Returns: + int: Hash of the agent state. + """ + session_agent = SessionAgent.from_agent(agent) + state_dict = session_agent.to_dict() + # Exclude timestamps that change on every call + state_dict.pop("created_at", None) + state_dict.pop("updated_at", None) + return hash(json.dumps(state_dict, sort_keys=True)) + + def _save_message_with_state(self, message: Message, agent: "Agent") -> None: + """Save message and agent state in a single batched API call. Combines message and agent state into one API call instead of two separate calls. @@ -197,27 +210,19 @@ def save_message_with_state(self, message: Message, agent: "Agent") -> None: message (Message): The message to save. agent (Agent): The agent whose state to sync. """ - if self.storage_version != "v2": - raise RuntimeError("save_message_with_state is only available in storage_version='v2'") - session_message = SessionMessage.from_message(message, 0) messages = AgentCoreMemoryConverter.message_to_payload(session_message) if not messages: return - # Convert message to v2 payload format with type marker message_tuple = messages[0] message_payload = {"_type": PAYLOAD_TYPE_MESSAGE, "data": list(message_tuple)} - - # Prepare agent state payload with type marker agent_state_payload = self._build_agent_state_payload(agent) - # Parse the original timestamp and use it as desired timestamp original_timestamp = datetime.fromisoformat(session_message.created_at.replace("Z", "+00:00")) monotonic_timestamp = self._get_monotonic_timestamp(original_timestamp) try: - # Single batched API call with both payloads event = self.memory_client.gmdp_client.create_event( memoryId=self.config.memory_id, actorId=self.config.actor_id, @@ -234,20 +239,25 @@ def save_message_with_state(self, message: Message, agent: "Agent") -> None: agent.agent_id, ) - # Update latest message tracking session_message = SessionMessage.from_message(message, event.get("event", {}).get("eventId")) self._latest_agent_message[agent.agent_id] = session_message + self._last_synced_state_hash = self._compute_state_hash(agent) except Exception as e: logger.error("Failed to save message with state: %s", e) raise SessionException(f"Failed to save message with state: {e}") from e - def _sync_agent_state(self, agent: "Agent") -> None: - """Sync agent state to AgentCore Memory using unified actorId. + def _sync_agent_state_if_changed(self, agent: "Agent") -> None: + """Sync agent state to AgentCore Memory only if state changed since last sync. Args: agent (Agent): The agent to sync. """ + current_hash = self._compute_state_hash(agent) + if current_hash == self._last_synced_state_hash: + logger.debug("Agent state unchanged, skipping sync for agent=%s", agent.agent_id) + return + agent_state_payload = self._build_agent_state_payload(agent) try: @@ -258,13 +268,15 @@ def _sync_agent_state(self, agent: "Agent") -> None: payload=[{"blob": json.dumps(agent_state_payload)}], eventTimestamp=self._get_monotonic_timestamp(), ) + self._last_synced_state_hash = current_hash logger.debug( "Synced agent state: event=%s, agent=%s", event.get("event", {}).get("eventId"), agent.agent_id ) except Exception as e: logger.error("Failed to sync agent state: %s", e) + raise SessionException(f"Failed to sync agent state: {e}") from e - # endregion Optimized Storage Methods + # endregion Internal Storage Methods # region SessionRepository interface implementation def create_session(self, session: Session, **kwargs: Any) -> Session: @@ -346,7 +358,7 @@ def create_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: A The agent's existence is inferred from the presence of events/messages in the memory system, but we validate the session_id matches our config. - For storage_version="v2", uses unified actorId with type markers. + Uses unified actorId with type markers for optimized storage. Args: session_id (str): The session ID to create the agent in. @@ -359,29 +371,18 @@ def create_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: A if session_id != self.config.session_id: raise SessionException(f"Session ID mismatch: expected {self.config.session_id}, got {session_id}") - if self.storage_version == "v2": - # V2: Use unified actorId with type marker - agent_state_payload = { - "_type": PAYLOAD_TYPE_AGENT_STATE, - "_agent_id": session_agent.agent_id, - **session_agent.to_dict(), - } - event = self.memory_client.gmdp_client.create_event( - memoryId=self.config.memory_id, - actorId=self.config.actor_id, - sessionId=self.session_id, - payload=[{"blob": json.dumps(agent_state_payload)}], - eventTimestamp=self._get_monotonic_timestamp(), - ) - else: - # V1: Use separate actorId for agent - event = self.memory_client.gmdp_client.create_event( - memoryId=self.config.memory_id, - actorId=self._get_full_agent_id(session_agent.agent_id), - sessionId=self.session_id, - payload=[{"blob": json.dumps(session_agent.to_dict())}], - eventTimestamp=self._get_monotonic_timestamp(), - ) + agent_state_payload = { + "_type": PAYLOAD_TYPE_AGENT_STATE, + "_agent_id": session_agent.agent_id, + **session_agent.to_dict(), + } + event = self.memory_client.gmdp_client.create_event( + memoryId=self.config.memory_id, + actorId=self.config.actor_id, + sessionId=self.session_id, + payload=[{"blob": json.dumps(agent_state_payload)}], + eventTimestamp=self._get_monotonic_timestamp(), + ) logger.info( "Created agent: %s in session: %s with event %s", @@ -393,9 +394,8 @@ def create_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: A def read_agent(self, session_id: str, agent_id: str, **kwargs: Any) -> Optional[SessionAgent]: """Read agent data from AgentCore Memory events. - Uses the storage version specified at initialization: - - v1: Reads from separate actorId (agent_{id}) - - v2: Reads from unified actorId with type markers + Uses dual-read approach: tries new format first (unified actorId with type markers), + falls back to legacy format (separate actorId) for backward compatibility. Args: session_id (str): The session ID to read from. @@ -409,36 +409,37 @@ def read_agent(self, session_id: str, agent_id: str, **kwargs: Any) -> Optional[ return None try: - if self.storage_version == "v2": - # V2: Read from unified actorId, filter by type marker - events = self.memory_client.list_events( - memory_id=self.config.memory_id, - actor_id=self.config.actor_id, - session_id=session_id, - max_results=MAX_FETCH_ALL_RESULTS, - ) - for event in events: - for payload_item in event.get("payload", []): - blob = payload_item.get("blob") - if blob: - try: - data = json.loads(blob) - if data.get("_type") == PAYLOAD_TYPE_AGENT_STATE and data.get("_agent_id") == agent_id: - agent_data = {k: v for k, v in data.items() if k not in ("_type", "_agent_id")} - return SessionAgent.from_dict(agent_data) - except json.JSONDecodeError: - continue - else: - # V1: Read from separate actorId - events = self.memory_client.list_events( - memory_id=self.config.memory_id, - actor_id=self._get_full_agent_id(agent_id), - session_id=session_id, - max_results=1, - ) - if events: - agent_data = json.loads(events[0].get("payload", {})[0].get("blob")) - return SessionAgent.from_dict(agent_data) + events = self.memory_client.list_events( + memory_id=self.config.memory_id, + actor_id=self.config.actor_id, + session_id=session_id, + max_results=MAX_FETCH_ALL_RESULTS, + ) + + # Events are returned oldest-first, so reverse to get latest state + for event in reversed(events): + for payload_item in event.get("payload", []): + blob = payload_item.get("blob") + if blob: + try: + data = json.loads(blob) + if data.get("_type") == PAYLOAD_TYPE_AGENT_STATE and data.get("_agent_id") == agent_id: + agent_data = {k: v for k, v in data.items() if k not in ("_type", "_agent_id")} + return SessionAgent.from_dict(agent_data) + except json.JSONDecodeError: + continue + + # Fallback to legacy format for backward compatibility + legacy_events = self.memory_client.list_events( + memory_id=self.config.memory_id, + actor_id=self._get_full_agent_id(agent_id), + session_id=session_id, + max_results=1, + ) + if legacy_events: + agent_data = json.loads(legacy_events[0].get("payload", {})[0].get("blob")) + return SessionAgent.from_dict(agent_data) + except Exception as e: logger.error("Failed to read agent: %s", e) @@ -464,7 +465,7 @@ def update_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: A # Create a new agent as AgentCore Memory is immutable. We always get the latest one in `read_agent` self.create_agent(session_id, session_agent) - def create_message( + def create_message( # type: ignore[override] self, session_id: str, agent_id: str, session_message: SessionMessage, **kwargs: Any ) -> Optional[dict[str, Any]]: """Create a new message in AgentCore Memory. @@ -501,9 +502,8 @@ def create_message( try: messages = AgentCoreMemoryConverter.message_to_payload(session_message) if not messages: - return + return None - # Parse the original timestamp and use it as desired timestamp original_timestamp = datetime.fromisoformat(session_message.created_at.replace("Z", "+00:00")) monotonic_timestamp = self._get_monotonic_timestamp(original_timestamp) @@ -625,16 +625,28 @@ def list_messages( # region RepositorySessionManager overrides @override def append_message(self, message: Message, agent: "Agent", **kwargs: Any) -> None: - """Append a message to the agent's session using AgentCore's eventId as message_id. + """Append a message to the agent's session with batched agent state sync. + + Saves message and agent state in a single API call for optimization. Args: message: Message to add to the agent in the session agent: Agent to append the message to **kwargs: Additional keyword arguments for future extensibility. """ - created_message = self.create_message(self.session_id, agent.agent_id, SessionMessage.from_message(message, 0)) - session_message = SessionMessage.from_message(message, created_message.get("eventId")) - self._latest_agent_message[agent.agent_id] = session_message + self._save_message_with_state(message, agent) + + @override + def sync_agent(self, agent: "Agent", **kwargs: Any) -> None: + """Sync agent state only if it changed since last sync. + + Skips sync if agent state is unchanged, avoiding redundant API calls. + + Args: + agent: Agent to sync to the session. + **kwargs: Additional keyword arguments for future extensibility. + """ + self._sync_agent_state_if_changed(agent) def retrieve_customer_context(self, event: MessageAddedEvent) -> None: """Retrieve customer LTM context before processing support query. @@ -643,7 +655,8 @@ def retrieve_customer_context(self, event: MessageAddedEvent) -> None: event (MessageAddedEvent): The message added event containing the agent and message data. """ messages = event.agent.messages - if not messages or messages[-1].get("role") != "user" or "toolResult" in messages[-1].get("content")[0]: + content = messages[-1].get("content") if messages else None + if not messages or messages[-1].get("role") != "user" or not content or "toolResult" in content[0]: return None if not self.config.retrieval_config: # Only retrieve LTM @@ -651,7 +664,7 @@ def retrieve_customer_context(self, event: MessageAddedEvent) -> None: user_query = messages[-1]["content"][0]["text"] - def retrieve_for_namespace(namespace: str, retrieval_config: RetrievalConfig): + def retrieve_for_namespace(namespace: str, retrieval_config: RetrievalConfig) -> list[str]: """Helper function to retrieve memories for a single namespace.""" resolved_namespace = namespace.format( actorId=self.config.actor_id, @@ -713,29 +726,12 @@ def retrieve_for_namespace(namespace: str, retrieval_config: RetrievalConfig): logger.error("Failed to retrieve customer context: %s", e) @override - def register_hooks(self, registry: HookRegistry, **kwargs) -> None: - """Register hooks for session management. - - For storage_version="v1" (default): - Uses parent class behavior with separate API calls for message and agent state. - - For storage_version="v2": - Uses batched API calls to reduce redundant API calls. - - Args: - registry (HookRegistry): The hook registry to register callbacks with. - **kwargs: Additional keyword arguments. - """ - if self.storage_version == "v2": - registry.add_callback(AgentInitializedEvent, lambda event: self.initialize(event.agent)) - registry.add_callback( - MessageAddedEvent, lambda event: self.save_message_with_state(event.message, event.agent) - ) - registry.add_callback(AfterInvocationEvent, lambda event: self._sync_agent_state(event.agent)) - registry.add_callback(MessageAddedEvent, lambda event: self.retrieve_customer_context(event)) - else: - RepositorySessionManager.register_hooks(self, registry, **kwargs) - registry.add_callback(MessageAddedEvent, lambda event: self.retrieve_customer_context(event)) + def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: + """Register hooks for session management with optimized storage.""" + registry.add_callback(AgentInitializedEvent, lambda event: self.initialize(event.agent)) + registry.add_callback(MessageAddedEvent, lambda e: self._save_message_with_state(e.message, e.agent)) + registry.add_callback(AfterInvocationEvent, lambda event: self._sync_agent_state_if_changed(event.agent)) + registry.add_callback(MessageAddedEvent, lambda event: self.retrieve_customer_context(event)) @override def initialize(self, agent: "Agent", **kwargs: Any) -> None: diff --git a/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py b/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py index 576c97e..1100e96 100644 --- a/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py +++ b/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py @@ -1118,8 +1118,8 @@ def test_list_messages_with_limit_calculates_max_results(self, session_manager, assert call_kwargs["max_results"] == 550 # limit + offset -class TestStorageVersionV2: - """Tests for storage_version='v2' functionality.""" +class TestOptimizedStorage: + """Tests for optimized storage functionality.""" @pytest.fixture def agentcore_config(self): @@ -1143,8 +1143,8 @@ def mock_memory_client(self): return client @pytest.fixture - def session_manager_v2(self, agentcore_config, mock_memory_client): - """Create AgentCoreMemorySessionManager with storage_version='v2'.""" + def session_manager(self, agentcore_config, mock_memory_client): + """Create AgentCoreMemorySessionManager.""" with patch( "bedrock_agentcore.memory.integrations.strands.session_manager.MemoryClient", return_value=mock_memory_client @@ -1159,10 +1159,7 @@ def session_manager_v2(self, agentcore_config, mock_memory_client): "strands.session.repository_session_manager.RepositorySessionManager.__init__", return_value=None ): - manager = AgentCoreMemorySessionManager( - agentcore_config, - storage_version="v2" - ) + manager = AgentCoreMemorySessionManager(agentcore_config) manager.session_id = agentcore_config.session_id manager.session = Session( session_id=agentcore_config.session_id, @@ -1172,64 +1169,8 @@ def session_manager_v2(self, agentcore_config, mock_memory_client): manager._latest_agent_message = {} return manager - def test_init_storage_version_default(self, agentcore_config): - """Test default storage_version is 'v1'.""" - with patch("bedrock_agentcore.memory.integrations.strands.session_manager.MemoryClient"): - with patch("boto3.Session") as mock_boto_session: - mock_session = Mock() - mock_session.region_name = "us-west-2" - mock_boto_session.return_value = mock_session - - with patch( - "strands.session.repository_session_manager.RepositorySessionManager.__init__", - return_value=None - ): - manager = AgentCoreMemorySessionManager(agentcore_config) - assert manager.storage_version == "v1" - - def test_init_storage_version_v2(self, agentcore_config): - """Test storage_version can be set to 'v2'.""" - with patch("bedrock_agentcore.memory.integrations.strands.session_manager.MemoryClient"): - with patch("boto3.Session") as mock_boto_session: - mock_session = Mock() - mock_session.region_name = "us-west-2" - mock_boto_session.return_value = mock_session - - with patch( - "strands.session.repository_session_manager.RepositorySessionManager.__init__", - return_value=None - ): - manager = AgentCoreMemorySessionManager( - agentcore_config, - storage_version="v2" - ) - assert manager.storage_version == "v2" - - def test_register_hooks_v2_registers_agent_initialized(self, session_manager_v2): - """Test v2 registers AgentInitializedEvent hook.""" - from strands.hooks import AgentInitializedEvent - from strands.hooks.registry import HookRegistry - - registry = HookRegistry() - session_manager_v2.register_hooks(registry) - - assert AgentInitializedEvent in registry._registered_callbacks - assert len(registry._registered_callbacks[AgentInitializedEvent]) == 1 - - def test_register_hooks_v2_message_callbacks(self, session_manager_v2): - """Test v2 registers correct MessageAddedEvent callbacks.""" - from strands.hooks import MessageAddedEvent - from strands.hooks.registry import HookRegistry - - registry = HookRegistry() - session_manager_v2.register_hooks(registry) - - # v2: save_message_with_state + retrieve_customer_context - assert MessageAddedEvent in registry._registered_callbacks - assert len(registry._registered_callbacks[MessageAddedEvent]) == 2 - - def test_save_message_with_state_batches_payloads(self, session_manager_v2, mock_memory_client): - """Test save_message_with_state creates single API call with batched payloads.""" + def test_append_message_batches_with_state(self, session_manager, mock_memory_client): + """Test append_message saves message and agent state in single API call.""" import json mock_agent = Mock() @@ -1244,7 +1185,7 @@ def test_save_message_with_state_batches_payloads(self, session_manager_v2, mock message = {"role": "user", "content": [{"text": "Hello"}]} with patch.object(SessionAgent, "from_agent", return_value=mock_session_agent): - session_manager_v2.save_message_with_state(message, mock_agent) + session_manager.append_message(message, mock_agent) mock_memory_client.gmdp_client.create_event.assert_called_once() call_kwargs = mock_memory_client.gmdp_client.create_event.call_args[1] @@ -1252,18 +1193,112 @@ def test_save_message_with_state_batches_payloads(self, session_manager_v2, mock # Should have 2 payloads: message + agent_state assert len(call_kwargs["payload"]) == 2 - # Verify message payload format + # Verify payload formats msg_payload = json.loads(call_kwargs["payload"][0]["blob"]) assert msg_payload["_type"] == "message" - assert "data" in msg_payload - # Verify agent_state payload format agent_payload = json.loads(call_kwargs["payload"][1]["blob"]) assert agent_payload["_type"] == "agent_state" - assert agent_payload["_agent_id"] == "test-agent" - def test_read_agent_v2_format(self, session_manager_v2, mock_memory_client): - """Test read_agent correctly parses v2 format with _type marker.""" + def test_sync_agent_skips_unchanged_state(self, session_manager, mock_memory_client): + """Test sync_agent skips API call when state is unchanged.""" + mock_agent = Mock() + mock_agent.agent_id = "test-agent" + + mock_session_agent = SessionAgent( + agent_id="test-agent", + state={"key": "value"}, + conversation_manager_state={} + ) + + with patch.object(SessionAgent, "from_agent", return_value=mock_session_agent): + # Set initial state hash + session_manager._last_synced_state_hash = session_manager._compute_state_hash(mock_agent) + + # Reset mock to track new calls + mock_memory_client.gmdp_client.create_event.reset_mock() + + # Sync should skip since state unchanged + session_manager.sync_agent(mock_agent) + + mock_memory_client.gmdp_client.create_event.assert_not_called() + + def test_compute_state_hash_excludes_timestamps(self, session_manager, mock_memory_client): + """Test that _compute_state_hash excludes timestamps from hash computation.""" + mock_agent = Mock() + mock_agent.agent_id = "test-agent" + + # Create two SessionAgents with different timestamps but same state + session_agent_1 = SessionAgent( + agent_id="test-agent", + state={"key": "value"}, + conversation_manager_state={"cm_key": "cm_value"}, + created_at="2023-01-01T00:00:00Z", + updated_at="2023-01-01T00:00:00Z" + ) + session_agent_2 = SessionAgent( + agent_id="test-agent", + state={"key": "value"}, + conversation_manager_state={"cm_key": "cm_value"}, + created_at="2023-06-15T12:00:00Z", # Different timestamp + updated_at="2023-12-31T23:59:59Z" # Different timestamp + ) + + with patch.object(SessionAgent, "from_agent", return_value=session_agent_1): + hash1 = session_manager._compute_state_hash(mock_agent) + + with patch.object(SessionAgent, "from_agent", return_value=session_agent_2): + hash2 = session_manager._compute_state_hash(mock_agent) + + # Hashes should be equal since timestamps are excluded + assert hash1 == hash2, "Hash should be same when only timestamps differ" + + def test_compute_state_hash_changes_with_state(self, session_manager, mock_memory_client): + """Test that _compute_state_hash changes when actual state changes.""" + mock_agent = Mock() + mock_agent.agent_id = "test-agent" + + session_agent_1 = SessionAgent( + agent_id="test-agent", + state={"key": "value1"}, + conversation_manager_state={"cm_key": "cm_value"} + ) + session_agent_2 = SessionAgent( + agent_id="test-agent", + state={"key": "value2"}, # Different state + conversation_manager_state={"cm_key": "cm_value"} + ) + + with patch.object(SessionAgent, "from_agent", return_value=session_agent_1): + hash1 = session_manager._compute_state_hash(mock_agent) + + with patch.object(SessionAgent, "from_agent", return_value=session_agent_2): + hash2 = session_manager._compute_state_hash(mock_agent) + + # Hashes should be different since state changed + assert hash1 != hash2, "Hash should differ when state changes" + + def test_sync_agent_saves_changed_state(self, session_manager, mock_memory_client): + """Test sync_agent saves when state has changed.""" + mock_agent = Mock() + mock_agent.agent_id = "test-agent" + + mock_session_agent = SessionAgent( + agent_id="test-agent", + state={"key": "new_value"}, + conversation_manager_state={} + ) + + # Set different initial hash + session_manager._last_synced_state_hash = hash("different") + + with patch.object(SessionAgent, "from_agent", return_value=mock_session_agent): + session_manager.sync_agent(mock_agent) + + mock_memory_client.gmdp_client.create_event.assert_called_once() + + def test_read_agent_new_format(self, session_manager, mock_memory_client): + """Test read_agent correctly parses new format with _type marker.""" import json agent_state_payload = { @@ -1281,13 +1316,13 @@ def test_read_agent_v2_format(self, session_manager_v2, mock_memory_client): } ] - result = session_manager_v2.read_agent("test-session-456", "test-agent") + result = session_manager.read_agent("test-session-456", "test-agent") assert result is not None assert result.agent_id == "test-agent" assert result.state == {"key": "value"} - def test_read_agent_v2_skips_message_payloads(self, session_manager_v2, mock_memory_client): + def test_read_agent_skips_message_payloads(self, session_manager, mock_memory_client): """Test read_agent skips message payloads when looking for agent state.""" import json @@ -1320,13 +1355,46 @@ def test_read_agent_v2_skips_message_payloads(self, session_manager_v2, mock_mem } ] - result = session_manager_v2.read_agent("test-session-456", "test-agent") + result = session_manager.read_agent("test-session-456", "test-agent") assert result is not None assert result.agent_id == "test-agent" - def test_create_agent_v2_format(self, session_manager_v2, mock_memory_client): - """Test create_agent saves in v2 format with _type marker.""" + def test_read_agent_falls_back_to_legacy(self, session_manager, mock_memory_client): + """Test read_agent falls back to legacy format when new format not found. + + This tests the v1→v2 migration scenario where existing data was stored + with actor_id='agent_{id}' (v1) but new code queries with unified actor_id. + """ + import json + + legacy_agent_data = { + "agent_id": "test-agent", + "state": {"legacy": "data"}, + "conversation_manager_state": {} + } + + # First call (new format) returns empty, second call (legacy) returns data + mock_memory_client.list_events.side_effect = [ + [], # New format query with unified actor_id + [{"eventId": "event-1", "payload": [{"blob": json.dumps(legacy_agent_data)}]}] # Legacy + ] + + result = session_manager.read_agent("test-session-456", "test-agent") + + assert result is not None + assert result.state == {"legacy": "data"} + assert mock_memory_client.list_events.call_count == 2 + + # Verify correct actor_ids were used in each call + calls = mock_memory_client.list_events.call_args_list + # First call: unified actor_id (config.actor_id) + assert calls[0][1]["actor_id"] == "test-actor-789" + # Second call: legacy actor_id (agent_{agent_id}) + assert calls[1][1]["actor_id"] == "agent_test-agent" + + def test_create_agent_uses_new_format(self, session_manager, mock_memory_client): + """Test create_agent saves in new format with _type marker.""" import json session_agent = SessionAgent( @@ -1335,7 +1403,7 @@ def test_create_agent_v2_format(self, session_manager_v2, mock_memory_client): conversation_manager_state={} ) - session_manager_v2.create_agent("test-session-456", session_agent) + session_manager.create_agent("test-session-456", session_agent) mock_memory_client.gmdp_client.create_event.assert_called_once() call_kwargs = mock_memory_client.gmdp_client.create_event.call_args[1] @@ -1343,17 +1411,16 @@ def test_create_agent_v2_format(self, session_manager_v2, mock_memory_client): payload_data = json.loads(call_kwargs["payload"][0]["blob"]) assert payload_data["_type"] == "agent_state" assert payload_data["_agent_id"] == "test-agent" - assert payload_data["agent_id"] == "test-agent" - def test_list_messages_parses_v2_format(self, session_manager_v2, mock_memory_client): - """Test list_messages correctly parses v2 format messages.""" + def test_list_messages_parses_new_format(self, session_manager, mock_memory_client): + """Test list_messages correctly parses new format messages.""" import json from strands.types.session import SessionMessage msg = SessionMessage( message_id=1, - message={"role": "user", "content": [{"text": "Hello from v2"}]}, + message={"role": "user", "content": [{"text": "Hello"}]}, created_at="2023-01-01T00:00:00Z" ) message_payload = { @@ -1365,7 +1432,94 @@ def test_list_messages_parses_v2_format(self, session_manager_v2, mock_memory_cl {"eventId": "event-1", "payload": [{"blob": json.dumps(message_payload)}]} ] - result = session_manager_v2.list_messages("test-session-456", "test-agent") + result = session_manager.list_messages("test-session-456", "test-agent") assert len(result) == 1 - assert result[0].message["content"][0]["text"] == "Hello from v2" + assert result[0].message["content"][0]["text"] == "Hello" + + def test_read_agent_returns_latest_state_multi_turn(self, session_manager, mock_memory_client): + """Test read_agent returns the latest agent state when multiple states exist (multi-turn).""" + import json + + # Simulate multiple agent states from different turns (oldest first, as API returns) + old_state = { + "_type": "agent_state", + "_agent_id": "test-agent", + "agent_id": "test-agent", + "state": {"turn": 1}, + "conversation_manager_state": {"removed_message_count": 0} + } + latest_state = { + "_type": "agent_state", + "_agent_id": "test-agent", + "agent_id": "test-agent", + "state": {"turn": 2}, + "conversation_manager_state": {"removed_message_count": 0} + } + + # Events returned oldest-first by API + mock_memory_client.list_events.return_value = [ + {"eventId": "event-1", "payload": [{"blob": json.dumps(old_state)}]}, + {"eventId": "event-2", "payload": [{"blob": json.dumps(latest_state)}]}, + ] + + result = session_manager.read_agent("test-session-456", "test-agent") + + assert result is not None + assert result.state == {"turn": 2}, "Should return latest state, not oldest" + + def test_read_agent_multi_agent_returns_latest_per_agent(self, session_manager, mock_memory_client): + """Test read_agent returns latest state for each agent in multi-agent multi-turn scenario.""" + import json + + # Interleaved states: agent-1 turn 1, agent-2 turn 1, agent-1 turn 2 (oldest first) + events = [ + {"eventId": "event-1", "payload": [{"blob": json.dumps({ + "_type": "agent_state", "_agent_id": "agent-1", "agent_id": "agent-1", + "state": {"turn": 1}, "conversation_manager_state": {} + })}]}, + {"eventId": "event-2", "payload": [{"blob": json.dumps({ + "_type": "agent_state", "_agent_id": "agent-2", "agent_id": "agent-2", + "state": {"turn": 1}, "conversation_manager_state": {} + })}]}, + {"eventId": "event-3", "payload": [{"blob": json.dumps({ + "_type": "agent_state", "_agent_id": "agent-1", "agent_id": "agent-1", + "state": {"turn": 2}, "conversation_manager_state": {} + })}]}, + ] + + mock_memory_client.list_events.return_value = events + + # agent-1 should return turn 2 (latest), not turn 1 + result = session_manager.read_agent("test-session-456", "agent-1") + assert result is not None + assert result.state == {"turn": 2}, "Should return agent-1's latest state" + + # agent-2 should return turn 1 (only state) + result = session_manager.read_agent("test-session-456", "agent-2") + assert result is not None + assert result.state == {"turn": 1} + + def test_read_agent_multi_agent_not_found(self, session_manager, mock_memory_client): + """Test read_agent returns None when agent_id not found in multi-agent scenario.""" + import json + + agent1_state = { + "_type": "agent_state", + "_agent_id": "agent-1", + "agent_id": "agent-1", + "state": {}, + "conversation_manager_state": {} + } + + mock_memory_client.list_events.return_value = [ + {"eventId": "event-1", "payload": [{"blob": json.dumps(agent1_state)}]}, + ] + + # agent-3 doesn't exist, and no legacy fallback + mock_memory_client.list_events.side_effect = [ + [{"eventId": "event-1", "payload": [{"blob": json.dumps(agent1_state)}]}], + [] # Legacy fallback returns empty + ] + result = session_manager.read_agent("test-session-456", "agent-3") + assert result is None diff --git a/tests/bedrock_agentcore/memory/integrations/strands/test_storage_version_integration.py b/tests/bedrock_agentcore/memory/integrations/strands/test_session_manager_integration.py similarity index 54% rename from tests/bedrock_agentcore/memory/integrations/strands/test_storage_version_integration.py rename to tests/bedrock_agentcore/memory/integrations/strands/test_session_manager_integration.py index 7f99a1d..96535d2 100644 --- a/tests/bedrock_agentcore/memory/integrations/strands/test_storage_version_integration.py +++ b/tests/bedrock_agentcore/memory/integrations/strands/test_session_manager_integration.py @@ -1,7 +1,7 @@ -"""Integration tests for storage_version parameter. +"""Integration tests for AgentCoreMemorySessionManager. These tests require AWS credentials and make actual API calls. -Run with: pytest -m integration tests/.../test_storage_version_integration.py -v +Run with: pytest -m integration tests/.../test_session_manager_integration.py -v Skip with: pytest -m "not integration" To skip these tests in normal runs, they are marked with @pytest.mark.integration. @@ -64,11 +64,11 @@ def increment(self, method_name: str): self.call_details.append(method_name) -class TestStorageVersionIntegration: - """Integration tests for storage_version parameter.""" +class TestOptimizedStorageIntegration: + """Integration tests for optimized storage.""" - def test_v2_reduces_api_calls_simple_message(self, memory_id, region, unique_session_id): - """Test that v2 reduces API calls for simple message.""" + def test_batched_api_calls_simple_message(self, memory_id, region, unique_session_id): + """Test that batched storage reduces API calls for simple message.""" from strands import Agent from strands.models import BedrockModel @@ -77,16 +77,14 @@ def test_v2_reduces_api_calls_simple_message(self, memory_id, region, unique_ses actor_id = "test-user" - # Test v1 - config_v1 = AgentCoreMemoryConfig( + config = AgentCoreMemoryConfig( memory_id=memory_id, - session_id=f"v1-{unique_session_id}", + session_id=unique_session_id, actor_id=actor_id, ) - session_manager_v1 = AgentCoreMemorySessionManager( - agentcore_memory_config=config_v1, + session_manager = AgentCoreMemorySessionManager( + agentcore_memory_config=config, region_name=region, - storage_version="v1", ) model = BedrockModel( @@ -94,62 +92,30 @@ def test_v2_reduces_api_calls_simple_message(self, memory_id, region, unique_ses region_name=region ) - agent_v1 = Agent( - model=model, - session_manager=session_manager_v1, - system_prompt="You are a helpful assistant. Keep responses very brief." - ) - - # Count API calls for v1 - counter_v1 = APICallCounter() - original_create_event_v1 = session_manager_v1.memory_client.gmdp_client.create_event - - def counted_v1(*args, **kwargs): - counter_v1.increment("create_event") - return original_create_event_v1(*args, **kwargs) - - session_manager_v1.memory_client.gmdp_client.create_event = counted_v1 - agent_v1("Say hello briefly.") - v1_calls = counter_v1.call_count - - # Test v2 - config_v2 = AgentCoreMemoryConfig( - memory_id=memory_id, - session_id=f"v2-{unique_session_id}", - actor_id=actor_id, - ) - session_manager_v2 = AgentCoreMemorySessionManager( - agentcore_memory_config=config_v2, - region_name=region, - storage_version="v2", - ) - - agent_v2 = Agent( + agent = Agent( model=model, - session_manager=session_manager_v2, + session_manager=session_manager, system_prompt="You are a helpful assistant. Keep responses very brief." ) - # Count API calls for v2 - counter_v2 = APICallCounter() - original_create_event_v2 = session_manager_v2.memory_client.gmdp_client.create_event + # Count API calls + counter = APICallCounter() + original_create_event = session_manager.memory_client.gmdp_client.create_event - def counted_v2(*args, **kwargs): - counter_v2.increment("create_event") - return original_create_event_v2(*args, **kwargs) + def counted(*args, **kwargs): + counter.increment("create_event") + return original_create_event(*args, **kwargs) - session_manager_v2.memory_client.gmdp_client.create_event = counted_v2 - agent_v2("Say hello briefly.") - v2_calls = counter_v2.call_count + session_manager.memory_client.gmdp_client.create_event = counted + agent("Say hello briefly.") + api_calls = counter.call_count - # v2 should have fewer API calls than v1 - assert v2_calls < v1_calls, f"v2 ({v2_calls}) should have fewer calls than v1 ({v1_calls})" - # Expected: v1=5, v2=3 - assert v1_calls == 5, f"Expected v1=5 calls, got {v1_calls}" - assert v2_calls == 3, f"Expected v2=3 calls, got {v2_calls}" + # Batched storage: 1 (session) + 1 (user msg+state) + 1 (assistant msg+state) = 3 + # AfterInvocation sync skips due to state hash tracking + assert api_calls <= 4, f"Expected at most 4 API calls, got {api_calls}" - def test_v2_multi_turn_preserves_context(self, memory_id, region, unique_session_id): - """Test that v2 correctly preserves context across turns.""" + def test_multi_turn_preserves_context(self, memory_id, region, unique_session_id): + """Test that context is preserved across turns.""" from strands import Agent from strands.models import BedrockModel @@ -157,7 +123,7 @@ def test_v2_multi_turn_preserves_context(self, memory_id, region, unique_session from bedrock_agentcore.memory.integrations.strands.session_manager import AgentCoreMemorySessionManager actor_id = "test-user" - session_id = f"v2-multi-{unique_session_id}" + session_id = f"multi-{unique_session_id}" config = AgentCoreMemoryConfig( memory_id=memory_id, @@ -174,7 +140,6 @@ def test_v2_multi_turn_preserves_context(self, memory_id, region, unique_session session_manager_1 = AgentCoreMemorySessionManager( agentcore_memory_config=config, region_name=region, - storage_version="v2", ) agent_1 = Agent( @@ -189,7 +154,6 @@ def test_v2_multi_turn_preserves_context(self, memory_id, region, unique_session session_manager_2 = AgentCoreMemorySessionManager( agentcore_memory_config=config, region_name=region, - storage_version="v2", ) agent_2 = Agent( @@ -214,8 +178,8 @@ def test_v2_multi_turn_preserves_context(self, memory_id, region, unique_session assert "blue" in response_text, f"Expected 'blue' in response, got: {response_text}" - def test_v2_parses_messages_correctly(self, memory_id, region, unique_session_id): - """Test that v2 format messages are correctly saved and loaded.""" + def test_messages_correctly_parsed(self, memory_id, region, unique_session_id): + """Test that messages are correctly saved and loaded.""" from strands import Agent from strands.models import BedrockModel @@ -224,7 +188,7 @@ def test_v2_parses_messages_correctly(self, memory_id, region, unique_session_id from bedrock_agentcore.memory.integrations.strands.session_manager import AgentCoreMemorySessionManager actor_id = "test-user" - session_id = f"v2-parse-{unique_session_id}" + session_id = f"parse-{unique_session_id}" config = AgentCoreMemoryConfig( memory_id=memory_id, @@ -240,7 +204,6 @@ def test_v2_parses_messages_correctly(self, memory_id, region, unique_session_id session_manager = AgentCoreMemorySessionManager( agentcore_memory_config=config, region_name=region, - storage_version="v2", ) agent = Agent( @@ -249,7 +212,7 @@ def test_v2_parses_messages_correctly(self, memory_id, region, unique_session_id system_prompt="You are a helpful assistant." ) - agent("Test message for v2 format.") + agent("Test message for format.") # Verify stored format memory_client = MemoryClient(region_name=region) @@ -262,9 +225,9 @@ def test_v2_parses_messages_correctly(self, memory_id, region, unique_session_id assert len(events) > 0, "Expected events to be saved" - # Check that at least one event has v2 format - found_v2_message = False - found_v2_agent_state = False + # Check that at least one event has type markers + found_message = False + found_agent_state = False for event in events: for payload_item in event.get("payload", []): @@ -272,15 +235,15 @@ def test_v2_parses_messages_correctly(self, memory_id, region, unique_session_id blob_data = json.loads(payload_item["blob"]) if isinstance(blob_data, dict): if blob_data.get("_type") == "message": - found_v2_message = True + found_message = True if blob_data.get("_type") == "agent_state": - found_v2_agent_state = True + found_agent_state = True - assert found_v2_message, "Expected v2 format message with _type marker" - assert found_v2_agent_state, "Expected v2 format agent_state with _type marker" + assert found_message, "Expected message with _type marker" + assert found_agent_state, "Expected agent_state with _type marker" - def test_v2_reduces_api_calls_with_tools(self, memory_id, region, unique_session_id): - """Test that v2 reduces API calls when using tools.""" + def test_batched_api_calls_with_tools(self, memory_id, region, unique_session_id): + """Test that batched storage works with tool calls.""" from datetime import datetime from strands import Agent, tool @@ -306,68 +269,34 @@ def add_numbers(a: int, b: int) -> int: region_name=region ) - # Test v1 with tools - config_v1 = AgentCoreMemoryConfig( - memory_id=memory_id, - session_id=f"v1-tools-{unique_session_id}", - actor_id=actor_id, - ) - session_manager_v1 = AgentCoreMemorySessionManager( - agentcore_memory_config=config_v1, - region_name=region, - storage_version="v1", - ) - - agent_v1 = Agent( - model=model, - session_manager=session_manager_v1, - tools=[get_time, add_numbers], - system_prompt="You are a helpful assistant. Use tools when asked. Keep responses brief." - ) - - counter_v1 = APICallCounter() - original_create_event_v1 = session_manager_v1.memory_client.gmdp_client.create_event - - def counted_v1(*args, **kwargs): - counter_v1.increment("create_event") - return original_create_event_v1(*args, **kwargs) - - session_manager_v1.memory_client.gmdp_client.create_event = counted_v1 - agent_v1("What time is it? And what is 10 + 20?") - v1_calls = counter_v1.call_count - - # Test v2 with tools - config_v2 = AgentCoreMemoryConfig( + config = AgentCoreMemoryConfig( memory_id=memory_id, - session_id=f"v2-tools-{unique_session_id}", + session_id=f"tools-{unique_session_id}", actor_id=actor_id, ) - session_manager_v2 = AgentCoreMemorySessionManager( - agentcore_memory_config=config_v2, + session_manager = AgentCoreMemorySessionManager( + agentcore_memory_config=config, region_name=region, - storage_version="v2", ) - agent_v2 = Agent( + agent = Agent( model=model, - session_manager=session_manager_v2, + session_manager=session_manager, tools=[get_time, add_numbers], system_prompt="You are a helpful assistant. Use tools when asked. Keep responses brief." ) - counter_v2 = APICallCounter() - original_create_event_v2 = session_manager_v2.memory_client.gmdp_client.create_event + counter = APICallCounter() + original_create_event = session_manager.memory_client.gmdp_client.create_event - def counted_v2(*args, **kwargs): - counter_v2.increment("create_event") - return original_create_event_v2(*args, **kwargs) + def counted(*args, **kwargs): + counter.increment("create_event") + return original_create_event(*args, **kwargs) - session_manager_v2.memory_client.gmdp_client.create_event = counted_v2 - agent_v2("What time is it? And what is 10 + 20?") - v2_calls = counter_v2.call_count + session_manager.memory_client.gmdp_client.create_event = counted + agent("What time is it? And what is 10 + 20?") + api_calls = counter.call_count - # v2 should have fewer API calls than v1 - assert v2_calls < v1_calls, f"v2 ({v2_calls}) should have fewer calls than v1 ({v1_calls})" - # Expected: v1=9, v2=5 (with 2 tool calls) - assert v1_calls == 9, f"Expected v1=9 calls with tools, got {v1_calls}" - assert v2_calls == 5, f"Expected v2=5 calls with tools, got {v2_calls}" + # With tools: more messages but still batched + # Should be less than 10 calls (unbatched would be ~9+ calls per message) + assert api_calls <= 7, f"Expected at most 7 API calls with tools, got {api_calls}" From 4ff0e93385339c8b4f5baa903d0378653825cbf3 Mon Sep 17 00:00:00 2001 From: Kihyeon Myung Date: Fri, 23 Jan 2026 14:12:52 -0700 Subject: [PATCH 4/4] Add multi-agent hooks --- .../integrations/strands/bedrock_converter.py | 2 +- .../integrations/strands/session_manager.py | 10 ++ .../test_agentcore_memory_session_manager.py | 152 +++++++++--------- .../strands/test_bedrock_converter.py | 19 +-- .../test_session_manager_integration.py | 34 ++-- 5 files changed, 98 insertions(+), 119 deletions(-) diff --git a/src/bedrock_agentcore/memory/integrations/strands/bedrock_converter.py b/src/bedrock_agentcore/memory/integrations/strands/bedrock_converter.py index eafa973..7923b19 100644 --- a/src/bedrock_agentcore/memory/integrations/strands/bedrock_converter.py +++ b/src/bedrock_agentcore/memory/integrations/strands/bedrock_converter.py @@ -83,7 +83,7 @@ def events_to_messages(events: list[dict[str, Any]]) -> list[SessionMessage]: elif "blob" in payload_item: try: blob_data = json.loads(payload_item["blob"]) - # V2 format: dict with _type marker + # New format: dict with _type marker if isinstance(blob_data, dict): if blob_data.get("_type") == "agent_state": continue # Skip agent state payloads diff --git a/src/bedrock_agentcore/memory/integrations/strands/session_manager.py b/src/bedrock_agentcore/memory/integrations/strands/session_manager.py index 3089e62..8d44c4d 100644 --- a/src/bedrock_agentcore/memory/integrations/strands/session_manager.py +++ b/src/bedrock_agentcore/memory/integrations/strands/session_manager.py @@ -9,6 +9,11 @@ import boto3 from botocore.config import Config as BotocoreConfig +from strands.experimental.hooks.multiagent import ( + AfterMultiAgentInvocationEvent, + AfterNodeCallEvent, + MultiAgentInitializedEvent, +) from strands.hooks import AfterInvocationEvent, AgentInitializedEvent, MessageAddedEvent from strands.hooks.registry import HookRegistry from strands.session.repository_session_manager import RepositorySessionManager @@ -733,6 +738,11 @@ def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: registry.add_callback(AfterInvocationEvent, lambda event: self._sync_agent_state_if_changed(event.agent)) registry.add_callback(MessageAddedEvent, lambda event: self.retrieve_customer_context(event)) + # Multi-agent hooks + registry.add_callback(MultiAgentInitializedEvent, lambda event: self.initialize_multi_agent(event.source)) + registry.add_callback(AfterNodeCallEvent, lambda event: self.sync_multi_agent(event.source)) + registry.add_callback(AfterMultiAgentInvocationEvent, lambda event: self.sync_multi_agent(event.source)) + @override def initialize(self, agent: "Agent", **kwargs: Any) -> None: if self.has_existing_agent: diff --git a/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py b/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py index 1100e96..293a525 100644 --- a/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py +++ b/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py @@ -1125,9 +1125,7 @@ class TestOptimizedStorage: def agentcore_config(self): """Create a test AgentCore Memory configuration.""" return AgentCoreMemoryConfig( - memory_id="test-memory-123", - session_id="test-session-456", - actor_id="test-actor-789" + memory_id="test-memory-123", session_id="test-session-456", actor_id="test-actor-789" ) @pytest.fixture @@ -1147,7 +1145,7 @@ def session_manager(self, agentcore_config, mock_memory_client): """Create AgentCoreMemorySessionManager.""" with patch( "bedrock_agentcore.memory.integrations.strands.session_manager.MemoryClient", - return_value=mock_memory_client + return_value=mock_memory_client, ): with patch("boto3.Session") as mock_boto_session: mock_session = Mock() @@ -1156,15 +1154,11 @@ def session_manager(self, agentcore_config, mock_memory_client): mock_boto_session.return_value = mock_session with patch( - "strands.session.repository_session_manager.RepositorySessionManager.__init__", - return_value=None + "strands.session.repository_session_manager.RepositorySessionManager.__init__", return_value=None ): manager = AgentCoreMemorySessionManager(agentcore_config) manager.session_id = agentcore_config.session_id - manager.session = Session( - session_id=agentcore_config.session_id, - session_type=SessionType.AGENT - ) + manager.session = Session(session_id=agentcore_config.session_id, session_type=SessionType.AGENT) manager.memory_client = mock_memory_client manager._latest_agent_message = {} return manager @@ -1177,9 +1171,7 @@ def test_append_message_batches_with_state(self, session_manager, mock_memory_cl mock_agent.agent_id = "test-agent" mock_session_agent = SessionAgent( - agent_id="test-agent", - state={"key": "value"}, - conversation_manager_state={"cm_key": "cm_value"} + agent_id="test-agent", state={"key": "value"}, conversation_manager_state={"cm_key": "cm_value"} ) message = {"role": "user", "content": [{"text": "Hello"}]} @@ -1205,11 +1197,7 @@ def test_sync_agent_skips_unchanged_state(self, session_manager, mock_memory_cli mock_agent = Mock() mock_agent.agent_id = "test-agent" - mock_session_agent = SessionAgent( - agent_id="test-agent", - state={"key": "value"}, - conversation_manager_state={} - ) + mock_session_agent = SessionAgent(agent_id="test-agent", state={"key": "value"}, conversation_manager_state={}) with patch.object(SessionAgent, "from_agent", return_value=mock_session_agent): # Set initial state hash @@ -1234,14 +1222,14 @@ def test_compute_state_hash_excludes_timestamps(self, session_manager, mock_memo state={"key": "value"}, conversation_manager_state={"cm_key": "cm_value"}, created_at="2023-01-01T00:00:00Z", - updated_at="2023-01-01T00:00:00Z" + updated_at="2023-01-01T00:00:00Z", ) session_agent_2 = SessionAgent( agent_id="test-agent", state={"key": "value"}, conversation_manager_state={"cm_key": "cm_value"}, created_at="2023-06-15T12:00:00Z", # Different timestamp - updated_at="2023-12-31T23:59:59Z" # Different timestamp + updated_at="2023-12-31T23:59:59Z", # Different timestamp ) with patch.object(SessionAgent, "from_agent", return_value=session_agent_1): @@ -1259,14 +1247,12 @@ def test_compute_state_hash_changes_with_state(self, session_manager, mock_memor mock_agent.agent_id = "test-agent" session_agent_1 = SessionAgent( - agent_id="test-agent", - state={"key": "value1"}, - conversation_manager_state={"cm_key": "cm_value"} + agent_id="test-agent", state={"key": "value1"}, conversation_manager_state={"cm_key": "cm_value"} ) session_agent_2 = SessionAgent( agent_id="test-agent", state={"key": "value2"}, # Different state - conversation_manager_state={"cm_key": "cm_value"} + conversation_manager_state={"cm_key": "cm_value"}, ) with patch.object(SessionAgent, "from_agent", return_value=session_agent_1): @@ -1284,9 +1270,7 @@ def test_sync_agent_saves_changed_state(self, session_manager, mock_memory_clien mock_agent.agent_id = "test-agent" mock_session_agent = SessionAgent( - agent_id="test-agent", - state={"key": "new_value"}, - conversation_manager_state={} + agent_id="test-agent", state={"key": "new_value"}, conversation_manager_state={} ) # Set different initial hash @@ -1306,14 +1290,11 @@ def test_read_agent_new_format(self, session_manager, mock_memory_client): "_agent_id": "test-agent", "agent_id": "test-agent", "state": {"key": "value"}, - "conversation_manager_state": {"cm_key": "cm_value"} + "conversation_manager_state": {"cm_key": "cm_value"}, } mock_memory_client.list_events.return_value = [ - { - "eventId": "event-1", - "payload": [{"blob": json.dumps(agent_state_payload)}] - } + {"eventId": "event-1", "payload": [{"blob": json.dumps(agent_state_payload)}]} ] result = session_manager.read_agent("test-session-456", "test-agent") @@ -1329,29 +1310,21 @@ def test_read_agent_skips_message_payloads(self, session_manager, mock_memory_cl from strands.types.session import SessionMessage msg = SessionMessage( - message_id=1, - message={"role": "user", "content": [{"text": "Hello"}]}, - created_at="2023-01-01T00:00:00Z" + message_id=1, message={"role": "user", "content": [{"text": "Hello"}]}, created_at="2023-01-01T00:00:00Z" ) - message_payload = { - "_type": "message", - "data": [json.dumps(msg.to_dict()), "user"] - } + message_payload = {"_type": "message", "data": [json.dumps(msg.to_dict()), "user"]} agent_state_payload = { "_type": "agent_state", "_agent_id": "test-agent", "agent_id": "test-agent", "state": {}, - "conversation_manager_state": {} + "conversation_manager_state": {}, } mock_memory_client.list_events.return_value = [ { "eventId": "event-1", - "payload": [ - {"blob": json.dumps(message_payload)}, - {"blob": json.dumps(agent_state_payload)} - ] + "payload": [{"blob": json.dumps(message_payload)}, {"blob": json.dumps(agent_state_payload)}], } ] @@ -1368,16 +1341,12 @@ def test_read_agent_falls_back_to_legacy(self, session_manager, mock_memory_clie """ import json - legacy_agent_data = { - "agent_id": "test-agent", - "state": {"legacy": "data"}, - "conversation_manager_state": {} - } + legacy_agent_data = {"agent_id": "test-agent", "state": {"legacy": "data"}, "conversation_manager_state": {}} # First call (new format) returns empty, second call (legacy) returns data mock_memory_client.list_events.side_effect = [ [], # New format query with unified actor_id - [{"eventId": "event-1", "payload": [{"blob": json.dumps(legacy_agent_data)}]}] # Legacy + [{"eventId": "event-1", "payload": [{"blob": json.dumps(legacy_agent_data)}]}], # Legacy ] result = session_manager.read_agent("test-session-456", "test-agent") @@ -1397,11 +1366,7 @@ def test_create_agent_uses_new_format(self, session_manager, mock_memory_client) """Test create_agent saves in new format with _type marker.""" import json - session_agent = SessionAgent( - agent_id="test-agent", - state={"key": "value"}, - conversation_manager_state={} - ) + session_agent = SessionAgent(agent_id="test-agent", state={"key": "value"}, conversation_manager_state={}) session_manager.create_agent("test-session-456", session_agent) @@ -1419,14 +1384,9 @@ def test_list_messages_parses_new_format(self, session_manager, mock_memory_clie from strands.types.session import SessionMessage msg = SessionMessage( - message_id=1, - message={"role": "user", "content": [{"text": "Hello"}]}, - created_at="2023-01-01T00:00:00Z" + message_id=1, message={"role": "user", "content": [{"text": "Hello"}]}, created_at="2023-01-01T00:00:00Z" ) - message_payload = { - "_type": "message", - "data": [json.dumps(msg.to_dict()), "user"] - } + message_payload = {"_type": "message", "data": [json.dumps(msg.to_dict()), "user"]} mock_memory_client.list_events.return_value = [ {"eventId": "event-1", "payload": [{"blob": json.dumps(message_payload)}]} @@ -1447,14 +1407,14 @@ def test_read_agent_returns_latest_state_multi_turn(self, session_manager, mock_ "_agent_id": "test-agent", "agent_id": "test-agent", "state": {"turn": 1}, - "conversation_manager_state": {"removed_message_count": 0} + "conversation_manager_state": {"removed_message_count": 0}, } latest_state = { "_type": "agent_state", "_agent_id": "test-agent", "agent_id": "test-agent", "state": {"turn": 2}, - "conversation_manager_state": {"removed_message_count": 0} + "conversation_manager_state": {"removed_message_count": 0}, } # Events returned oldest-first by API @@ -1474,18 +1434,54 @@ def test_read_agent_multi_agent_returns_latest_per_agent(self, session_manager, # Interleaved states: agent-1 turn 1, agent-2 turn 1, agent-1 turn 2 (oldest first) events = [ - {"eventId": "event-1", "payload": [{"blob": json.dumps({ - "_type": "agent_state", "_agent_id": "agent-1", "agent_id": "agent-1", - "state": {"turn": 1}, "conversation_manager_state": {} - })}]}, - {"eventId": "event-2", "payload": [{"blob": json.dumps({ - "_type": "agent_state", "_agent_id": "agent-2", "agent_id": "agent-2", - "state": {"turn": 1}, "conversation_manager_state": {} - })}]}, - {"eventId": "event-3", "payload": [{"blob": json.dumps({ - "_type": "agent_state", "_agent_id": "agent-1", "agent_id": "agent-1", - "state": {"turn": 2}, "conversation_manager_state": {} - })}]}, + { + "eventId": "event-1", + "payload": [ + { + "blob": json.dumps( + { + "_type": "agent_state", + "_agent_id": "agent-1", + "agent_id": "agent-1", + "state": {"turn": 1}, + "conversation_manager_state": {}, + } + ) + } + ], + }, + { + "eventId": "event-2", + "payload": [ + { + "blob": json.dumps( + { + "_type": "agent_state", + "_agent_id": "agent-2", + "agent_id": "agent-2", + "state": {"turn": 1}, + "conversation_manager_state": {}, + } + ) + } + ], + }, + { + "eventId": "event-3", + "payload": [ + { + "blob": json.dumps( + { + "_type": "agent_state", + "_agent_id": "agent-1", + "agent_id": "agent-1", + "state": {"turn": 2}, + "conversation_manager_state": {}, + } + ) + } + ], + }, ] mock_memory_client.list_events.return_value = events @@ -1509,7 +1505,7 @@ def test_read_agent_multi_agent_not_found(self, session_manager, mock_memory_cli "_agent_id": "agent-1", "agent_id": "agent-1", "state": {}, - "conversation_manager_state": {} + "conversation_manager_state": {}, } mock_memory_client.list_events.return_value = [ @@ -1519,7 +1515,7 @@ def test_read_agent_multi_agent_not_found(self, session_manager, mock_memory_cli # agent-3 doesn't exist, and no legacy fallback mock_memory_client.list_events.side_effect = [ [{"eventId": "event-1", "payload": [{"blob": json.dumps(agent1_state)}]}], - [] # Legacy fallback returns empty + [], # Legacy fallback returns empty ] result = session_manager.read_agent("test-session-456", "agent-3") assert result is None diff --git a/tests/bedrock_agentcore/memory/integrations/strands/test_bedrock_converter.py b/tests/bedrock_agentcore/memory/integrations/strands/test_bedrock_converter.py index f845ba9..9c2041d 100644 --- a/tests/bedrock_agentcore/memory/integrations/strands/test_bedrock_converter.py +++ b/tests/bedrock_agentcore/memory/integrations/strands/test_bedrock_converter.py @@ -228,10 +228,7 @@ def test_events_to_messages_v2_format(self): message_id=1, message={"role": "user", "content": [{"text": "Hello"}]}, created_at="2023-01-01T00:00:00Z" ) - v2_payload = { - "_type": "message", - "data": [json.dumps(session_message.to_dict()), "user"] - } + v2_payload = {"_type": "message", "data": [json.dumps(session_message.to_dict()), "user"]} events = [{"payload": [{"blob": json.dumps(v2_payload)}]}] result = AgentCoreMemoryConverter.events_to_messages(events) @@ -246,23 +243,15 @@ def test_events_to_messages_v2_skips_agent_state(self): message_id=1, message={"role": "user", "content": [{"text": "Hello"}]}, created_at="2023-01-01T00:00:00Z" ) - message_payload = { - "_type": "message", - "data": [json.dumps(session_message.to_dict()), "user"] - } + message_payload = {"_type": "message", "data": [json.dumps(session_message.to_dict()), "user"]} agent_state_payload = { "_type": "agent_state", "_agent_id": "test-agent", "agent_id": "test-agent", "state": {}, - "conversation_manager_state": {} + "conversation_manager_state": {}, } - events = [ - {"payload": [ - {"blob": json.dumps(message_payload)}, - {"blob": json.dumps(agent_state_payload)} - ]} - ] + events = [{"payload": [{"blob": json.dumps(message_payload)}, {"blob": json.dumps(agent_state_payload)}]}] result = AgentCoreMemoryConverter.events_to_messages(events) diff --git a/tests/bedrock_agentcore/memory/integrations/strands/test_session_manager_integration.py b/tests/bedrock_agentcore/memory/integrations/strands/test_session_manager_integration.py index 96535d2..87d45fa 100644 --- a/tests/bedrock_agentcore/memory/integrations/strands/test_session_manager_integration.py +++ b/tests/bedrock_agentcore/memory/integrations/strands/test_session_manager_integration.py @@ -87,15 +87,12 @@ def test_batched_api_calls_simple_message(self, memory_id, region, unique_sessio region_name=region, ) - model = BedrockModel( - model_id="us.anthropic.claude-haiku-4-5-20251001-v1:0", - region_name=region - ) + model = BedrockModel(model_id="us.anthropic.claude-haiku-4-5-20251001-v1:0", region_name=region) agent = Agent( model=model, session_manager=session_manager, - system_prompt="You are a helpful assistant. Keep responses very brief." + system_prompt="You are a helpful assistant. Keep responses very brief.", ) # Count API calls @@ -131,10 +128,7 @@ def test_multi_turn_preserves_context(self, memory_id, region, unique_session_id actor_id=actor_id, ) - model = BedrockModel( - model_id="us.anthropic.claude-haiku-4-5-20251001-v1:0", - region_name=region - ) + model = BedrockModel(model_id="us.anthropic.claude-haiku-4-5-20251001-v1:0", region_name=region) # Turn 1: Tell the agent something to remember session_manager_1 = AgentCoreMemorySessionManager( @@ -145,7 +139,7 @@ def test_multi_turn_preserves_context(self, memory_id, region, unique_session_id agent_1 = Agent( model=model, session_manager=session_manager_1, - system_prompt="You are a helpful assistant. Remember what the user tells you." + system_prompt="You are a helpful assistant. Remember what the user tells you.", ) agent_1("My favorite color is blue. Remember this.") @@ -159,7 +153,7 @@ def test_multi_turn_preserves_context(self, memory_id, region, unique_session_id agent_2 = Agent( model=model, session_manager=session_manager_2, - system_prompt="You are a helpful assistant. Remember what the user tells you." + system_prompt="You are a helpful assistant. Remember what the user tells you.", ) # Verify messages were loaded @@ -196,21 +190,14 @@ def test_messages_correctly_parsed(self, memory_id, region, unique_session_id): actor_id=actor_id, ) - model = BedrockModel( - model_id="us.anthropic.claude-haiku-4-5-20251001-v1:0", - region_name=region - ) + model = BedrockModel(model_id="us.anthropic.claude-haiku-4-5-20251001-v1:0", region_name=region) session_manager = AgentCoreMemorySessionManager( agentcore_memory_config=config, region_name=region, ) - agent = Agent( - model=model, - session_manager=session_manager, - system_prompt="You are a helpful assistant." - ) + agent = Agent(model=model, session_manager=session_manager, system_prompt="You are a helpful assistant.") agent("Test message for format.") @@ -264,10 +251,7 @@ def add_numbers(a: int, b: int) -> int: actor_id = "test-user" - model = BedrockModel( - model_id="us.anthropic.claude-haiku-4-5-20251001-v1:0", - region_name=region - ) + model = BedrockModel(model_id="us.anthropic.claude-haiku-4-5-20251001-v1:0", region_name=region) config = AgentCoreMemoryConfig( memory_id=memory_id, @@ -283,7 +267,7 @@ def add_numbers(a: int, b: int) -> int: model=model, session_manager=session_manager, tools=[get_time, add_numbers], - system_prompt="You are a helpful assistant. Use tools when asked. Keep responses brief." + system_prompt="You are a helpful assistant. Use tools when asked. Keep responses brief.", ) counter = APICallCounter()