diff --git a/pyproject.toml b/pyproject.toml index a061fc2..d646ae4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -94,6 +94,9 @@ testpaths = [ "tests" ] asyncio_mode = "auto" +markers = [ + "integration: marks tests as integration tests (require AWS credentials, deselect with '-m \"not integration\"')", +] [tool.coverage.run] branch = true diff --git a/src/bedrock_agentcore/memory/integrations/strands/bedrock_converter.py b/src/bedrock_agentcore/memory/integrations/strands/bedrock_converter.py index 1f0905c..7923b19 100644 --- a/src/bedrock_agentcore/memory/integrations/strands/bedrock_converter.py +++ b/src/bedrock_agentcore/memory/integrations/strands/bedrock_converter.py @@ -83,6 +83,20 @@ def events_to_messages(events: list[dict[str, Any]]) -> list[SessionMessage]: elif "blob" in payload_item: try: blob_data = json.loads(payload_item["blob"]) + # New format: dict with _type marker + if isinstance(blob_data, dict): + if blob_data.get("_type") == "agent_state": + continue # Skip agent state payloads + if blob_data.get("_type") == "message": + data = blob_data.get("data", []) + if isinstance(data, (tuple, list)) and len(data) == 2: + session_msg = SessionMessage.from_dict(json.loads(data[0])) + filtered = AgentCoreMemoryConverter._filter_empty_text(session_msg.message) + session_msg.message = filtered + if session_msg.message.get("content"): + messages.append(session_msg) + continue + # Legacy format: tuple [json_string, role] if isinstance(blob_data, (tuple, list)) and len(blob_data) == 2: try: session_msg = SessionMessage.from_dict(json.loads(blob_data[0])) @@ -91,8 +105,8 @@ def events_to_messages(events: list[dict[str, Any]]) -> list[SessionMessage]: messages.append(session_msg) except (json.JSONDecodeError, ValueError): logger.error("This is not a SessionMessage but just a blob message. Ignoring") - except (json.JSONDecodeError, ValueError): - logger.error("Failed to parse blob content: %s", payload_item) + except Exception as e: + logger.error("Failed to parse blob content: %s, error: %s", payload_item, e) return list(reversed(messages)) @staticmethod diff --git a/src/bedrock_agentcore/memory/integrations/strands/session_manager.py b/src/bedrock_agentcore/memory/integrations/strands/session_manager.py index c3912a7..8d44c4d 100644 --- a/src/bedrock_agentcore/memory/integrations/strands/session_manager.py +++ b/src/bedrock_agentcore/memory/integrations/strands/session_manager.py @@ -9,7 +9,12 @@ import boto3 from botocore.config import Config as BotocoreConfig -from strands.hooks import MessageAddedEvent +from strands.experimental.hooks.multiagent import ( + AfterMultiAgentInvocationEvent, + AfterNodeCallEvent, + MultiAgentInitializedEvent, +) +from strands.hooks import AfterInvocationEvent, AgentInitializedEvent, MessageAddedEvent from strands.hooks.registry import HookRegistry from strands.session.repository_session_manager import RepositorySessionManager from strands.session.session_repository import SessionRepository @@ -33,6 +38,10 @@ MESSAGE_PREFIX = "message_" MAX_FETCH_ALL_RESULTS = 10000 +# Payload type markers for batched storage +PAYLOAD_TYPE_MESSAGE = "message" +PAYLOAD_TYPE_AGENT_STATE = "agent_state" + class AgentCoreMemorySessionManager(RepositorySessionManager, SessionRepository): """AgentCore Memory-based session manager for Bedrock AgentCore Memory integration. @@ -103,6 +112,7 @@ def __init__( self.memory_client = MemoryClient(region_name=region_name) session = boto_session or boto3.Session(region_name=region_name) self.has_existing_agent = False + self._last_synced_state_hash: Optional[int] = None # Override the clients if custom boto session or config is provided # Add strands-agents to the request user agent @@ -157,6 +167,122 @@ def _get_full_agent_id(self, agent_id: str) -> str: ) return full_agent_id + # region Internal Storage Methods + + def _build_agent_state_payload(self, agent: "Agent") -> dict: + """Create agent state payload for unified storage. + + Creates a SessionAgent-compatible payload that can be reconstructed + via SessionAgent.from_dict(). + + Args: + agent (Agent): The agent whose state to capture. + + Returns: + dict: Agent state payload with type markers. + """ + session_agent = SessionAgent.from_agent(agent) + return { + "_type": PAYLOAD_TYPE_AGENT_STATE, + "_agent_id": agent.agent_id, + **session_agent.to_dict(), + } + + def _compute_state_hash(self, agent: "Agent") -> int: + """Compute hash of agent state for change detection. + + Excludes timestamps (created_at, updated_at) as they change on every call. + + Args: + agent (Agent): The agent whose state to hash. + + Returns: + int: Hash of the agent state. + """ + session_agent = SessionAgent.from_agent(agent) + state_dict = session_agent.to_dict() + # Exclude timestamps that change on every call + state_dict.pop("created_at", None) + state_dict.pop("updated_at", None) + return hash(json.dumps(state_dict, sort_keys=True)) + + def _save_message_with_state(self, message: Message, agent: "Agent") -> None: + """Save message and agent state in a single batched API call. + + Combines message and agent state into one API call instead of two separate calls. + + Args: + message (Message): The message to save. + agent (Agent): The agent whose state to sync. + """ + session_message = SessionMessage.from_message(message, 0) + messages = AgentCoreMemoryConverter.message_to_payload(session_message) + if not messages: + return + + message_tuple = messages[0] + message_payload = {"_type": PAYLOAD_TYPE_MESSAGE, "data": list(message_tuple)} + agent_state_payload = self._build_agent_state_payload(agent) + + original_timestamp = datetime.fromisoformat(session_message.created_at.replace("Z", "+00:00")) + monotonic_timestamp = self._get_monotonic_timestamp(original_timestamp) + + try: + event = self.memory_client.gmdp_client.create_event( + memoryId=self.config.memory_id, + actorId=self.config.actor_id, + sessionId=self.session_id, + payload=[ + {"blob": json.dumps(message_payload)}, + {"blob": json.dumps(agent_state_payload)}, + ], + eventTimestamp=monotonic_timestamp, + ) + logger.debug( + "Saved message and agent state in single call: event=%s, agent=%s", + event.get("event", {}).get("eventId"), + agent.agent_id, + ) + + session_message = SessionMessage.from_message(message, event.get("event", {}).get("eventId")) + self._latest_agent_message[agent.agent_id] = session_message + self._last_synced_state_hash = self._compute_state_hash(agent) + + except Exception as e: + logger.error("Failed to save message with state: %s", e) + raise SessionException(f"Failed to save message with state: {e}") from e + + def _sync_agent_state_if_changed(self, agent: "Agent") -> None: + """Sync agent state to AgentCore Memory only if state changed since last sync. + + Args: + agent (Agent): The agent to sync. + """ + current_hash = self._compute_state_hash(agent) + if current_hash == self._last_synced_state_hash: + logger.debug("Agent state unchanged, skipping sync for agent=%s", agent.agent_id) + return + + agent_state_payload = self._build_agent_state_payload(agent) + + try: + event = self.memory_client.gmdp_client.create_event( + memoryId=self.config.memory_id, + actorId=self.config.actor_id, + sessionId=self.session_id, + payload=[{"blob": json.dumps(agent_state_payload)}], + eventTimestamp=self._get_monotonic_timestamp(), + ) + self._last_synced_state_hash = current_hash + logger.debug( + "Synced agent state: event=%s, agent=%s", event.get("event", {}).get("eventId"), agent.agent_id + ) + except Exception as e: + logger.error("Failed to sync agent state: %s", e) + raise SessionException(f"Failed to sync agent state: {e}") from e + + # endregion Internal Storage Methods + # region SessionRepository interface implementation def create_session(self, session: Session, **kwargs: Any) -> Session: """Create a new session in AgentCore Memory. @@ -237,6 +363,8 @@ def create_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: A The agent's existence is inferred from the presence of events/messages in the memory system, but we validate the session_id matches our config. + Uses unified actorId with type markers for optimized storage. + Args: session_id (str): The session ID to create the agent in. session_agent (SessionAgent): The agent to create. @@ -248,15 +376,19 @@ def create_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: A if session_id != self.config.session_id: raise SessionException(f"Session ID mismatch: expected {self.config.session_id}, got {session_id}") + agent_state_payload = { + "_type": PAYLOAD_TYPE_AGENT_STATE, + "_agent_id": session_agent.agent_id, + **session_agent.to_dict(), + } event = self.memory_client.gmdp_client.create_event( memoryId=self.config.memory_id, - actorId=self._get_full_agent_id(session_agent.agent_id), + actorId=self.config.actor_id, sessionId=self.session_id, - payload=[ - {"blob": json.dumps(session_agent.to_dict())}, - ], + payload=[{"blob": json.dumps(agent_state_payload)}], eventTimestamp=self._get_monotonic_timestamp(), ) + logger.info( "Created agent: %s in session: %s with event %s", session_agent.agent_id, @@ -267,7 +399,8 @@ def create_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: A def read_agent(self, session_id: str, agent_id: str, **kwargs: Any) -> Optional[SessionAgent]: """Read agent data from AgentCore Memory events. - We reconstruct the agent state from the conversation history. + Uses dual-read approach: tries new format first (unified actorId with type markers), + falls back to legacy format (separate actorId) for backward compatibility. Args: session_id (str): The session ID to read from. @@ -279,22 +412,43 @@ def read_agent(self, session_id: str, agent_id: str, **kwargs: Any) -> Optional[ """ if session_id != self.config.session_id: return None + try: events = self.memory_client.list_events( + memory_id=self.config.memory_id, + actor_id=self.config.actor_id, + session_id=session_id, + max_results=MAX_FETCH_ALL_RESULTS, + ) + + # Events are returned oldest-first, so reverse to get latest state + for event in reversed(events): + for payload_item in event.get("payload", []): + blob = payload_item.get("blob") + if blob: + try: + data = json.loads(blob) + if data.get("_type") == PAYLOAD_TYPE_AGENT_STATE and data.get("_agent_id") == agent_id: + agent_data = {k: v for k, v in data.items() if k not in ("_type", "_agent_id")} + return SessionAgent.from_dict(agent_data) + except json.JSONDecodeError: + continue + + # Fallback to legacy format for backward compatibility + legacy_events = self.memory_client.list_events( memory_id=self.config.memory_id, actor_id=self._get_full_agent_id(agent_id), session_id=session_id, max_results=1, ) + if legacy_events: + agent_data = json.loads(legacy_events[0].get("payload", {})[0].get("blob")) + return SessionAgent.from_dict(agent_data) - if not events: - return None - - agent_data = json.loads(events[0].get("payload", {})[0].get("blob")) - return SessionAgent.from_dict(agent_data) except Exception as e: - logger.error("Failed to read agent %s", e) - return None + logger.error("Failed to read agent: %s", e) + + return None def update_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: Any) -> None: """Update agent data. @@ -316,7 +470,7 @@ def update_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: A # Create a new agent as AgentCore Memory is immutable. We always get the latest one in `read_agent` self.create_agent(session_id, session_agent) - def create_message( + def create_message( # type: ignore[override] self, session_id: str, agent_id: str, session_message: SessionMessage, **kwargs: Any ) -> Optional[dict[str, Any]]: """Create a new message in AgentCore Memory. @@ -353,9 +507,8 @@ def create_message( try: messages = AgentCoreMemoryConverter.message_to_payload(session_message) if not messages: - return + return None - # Parse the original timestamp and use it as desired timestamp original_timestamp = datetime.fromisoformat(session_message.created_at.replace("Z", "+00:00")) monotonic_timestamp = self._get_monotonic_timestamp(original_timestamp) @@ -477,16 +630,28 @@ def list_messages( # region RepositorySessionManager overrides @override def append_message(self, message: Message, agent: "Agent", **kwargs: Any) -> None: - """Append a message to the agent's session using AgentCore's eventId as message_id. + """Append a message to the agent's session with batched agent state sync. + + Saves message and agent state in a single API call for optimization. Args: message: Message to add to the agent in the session agent: Agent to append the message to **kwargs: Additional keyword arguments for future extensibility. """ - created_message = self.create_message(self.session_id, agent.agent_id, SessionMessage.from_message(message, 0)) - session_message = SessionMessage.from_message(message, created_message.get("eventId")) - self._latest_agent_message[agent.agent_id] = session_message + self._save_message_with_state(message, agent) + + @override + def sync_agent(self, agent: "Agent", **kwargs: Any) -> None: + """Sync agent state only if it changed since last sync. + + Skips sync if agent state is unchanged, avoiding redundant API calls. + + Args: + agent: Agent to sync to the session. + **kwargs: Additional keyword arguments for future extensibility. + """ + self._sync_agent_state_if_changed(agent) def retrieve_customer_context(self, event: MessageAddedEvent) -> None: """Retrieve customer LTM context before processing support query. @@ -495,7 +660,8 @@ def retrieve_customer_context(self, event: MessageAddedEvent) -> None: event (MessageAddedEvent): The message added event containing the agent and message data. """ messages = event.agent.messages - if not messages or messages[-1].get("role") != "user" or "toolResult" in messages[-1].get("content")[0]: + content = messages[-1].get("content") if messages else None + if not messages or messages[-1].get("role") != "user" or not content or "toolResult" in content[0]: return None if not self.config.retrieval_config: # Only retrieve LTM @@ -503,7 +669,7 @@ def retrieve_customer_context(self, event: MessageAddedEvent) -> None: user_query = messages[-1]["content"][0]["text"] - def retrieve_for_namespace(namespace: str, retrieval_config: RetrievalConfig): + def retrieve_for_namespace(namespace: str, retrieval_config: RetrievalConfig) -> list[str]: """Helper function to retrieve memories for a single namespace.""" resolved_namespace = namespace.format( actorId=self.config.actor_id, @@ -565,16 +731,18 @@ def retrieve_for_namespace(namespace: str, retrieval_config: RetrievalConfig): logger.error("Failed to retrieve customer context: %s", e) @override - def register_hooks(self, registry: HookRegistry, **kwargs) -> None: - """Register additional hooks. - - Args: - registry (HookRegistry): The hook registry to register callbacks with. - **kwargs: Additional keyword arguments. - """ - RepositorySessionManager.register_hooks(self, registry, **kwargs) + def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: + """Register hooks for session management with optimized storage.""" + registry.add_callback(AgentInitializedEvent, lambda event: self.initialize(event.agent)) + registry.add_callback(MessageAddedEvent, lambda e: self._save_message_with_state(e.message, e.agent)) + registry.add_callback(AfterInvocationEvent, lambda event: self._sync_agent_state_if_changed(event.agent)) registry.add_callback(MessageAddedEvent, lambda event: self.retrieve_customer_context(event)) + # Multi-agent hooks + registry.add_callback(MultiAgentInitializedEvent, lambda event: self.initialize_multi_agent(event.source)) + registry.add_callback(AfterNodeCallEvent, lambda event: self.sync_multi_agent(event.source)) + registry.add_callback(AfterMultiAgentInvocationEvent, lambda event: self.sync_multi_agent(event.source)) + @override def initialize(self, agent: "Agent", **kwargs: Any) -> None: if self.has_existing_agent: diff --git a/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py b/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py index a01973c..293a525 100644 --- a/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py +++ b/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py @@ -1116,3 +1116,406 @@ def test_list_messages_with_limit_calculates_max_results(self, session_manager, mock_memory_client.list_events.assert_called_once() call_kwargs = mock_memory_client.list_events.call_args[1] assert call_kwargs["max_results"] == 550 # limit + offset + + +class TestOptimizedStorage: + """Tests for optimized storage functionality.""" + + @pytest.fixture + def agentcore_config(self): + """Create a test AgentCore Memory configuration.""" + return AgentCoreMemoryConfig( + memory_id="test-memory-123", session_id="test-session-456", actor_id="test-actor-789" + ) + + @pytest.fixture + def mock_memory_client(self): + """Create a mock MemoryClient.""" + client = Mock() + client.create_event.return_value = {"eventId": "event_123456"} + client.list_events.return_value = [] + client.retrieve_memories.return_value = [] + client.gmcp_client = Mock() + client.gmdp_client = Mock() + client.gmdp_client.create_event.return_value = {"event": {"eventId": "event_123456"}} + return client + + @pytest.fixture + def session_manager(self, agentcore_config, mock_memory_client): + """Create AgentCoreMemorySessionManager.""" + with patch( + "bedrock_agentcore.memory.integrations.strands.session_manager.MemoryClient", + return_value=mock_memory_client, + ): + with patch("boto3.Session") as mock_boto_session: + mock_session = Mock() + mock_session.region_name = "us-west-2" + mock_session.client.return_value = Mock() + mock_boto_session.return_value = mock_session + + with patch( + "strands.session.repository_session_manager.RepositorySessionManager.__init__", return_value=None + ): + manager = AgentCoreMemorySessionManager(agentcore_config) + manager.session_id = agentcore_config.session_id + manager.session = Session(session_id=agentcore_config.session_id, session_type=SessionType.AGENT) + manager.memory_client = mock_memory_client + manager._latest_agent_message = {} + return manager + + def test_append_message_batches_with_state(self, session_manager, mock_memory_client): + """Test append_message saves message and agent state in single API call.""" + import json + + mock_agent = Mock() + mock_agent.agent_id = "test-agent" + + mock_session_agent = SessionAgent( + agent_id="test-agent", state={"key": "value"}, conversation_manager_state={"cm_key": "cm_value"} + ) + + message = {"role": "user", "content": [{"text": "Hello"}]} + + with patch.object(SessionAgent, "from_agent", return_value=mock_session_agent): + session_manager.append_message(message, mock_agent) + + mock_memory_client.gmdp_client.create_event.assert_called_once() + call_kwargs = mock_memory_client.gmdp_client.create_event.call_args[1] + + # Should have 2 payloads: message + agent_state + assert len(call_kwargs["payload"]) == 2 + + # Verify payload formats + msg_payload = json.loads(call_kwargs["payload"][0]["blob"]) + assert msg_payload["_type"] == "message" + + agent_payload = json.loads(call_kwargs["payload"][1]["blob"]) + assert agent_payload["_type"] == "agent_state" + + def test_sync_agent_skips_unchanged_state(self, session_manager, mock_memory_client): + """Test sync_agent skips API call when state is unchanged.""" + mock_agent = Mock() + mock_agent.agent_id = "test-agent" + + mock_session_agent = SessionAgent(agent_id="test-agent", state={"key": "value"}, conversation_manager_state={}) + + with patch.object(SessionAgent, "from_agent", return_value=mock_session_agent): + # Set initial state hash + session_manager._last_synced_state_hash = session_manager._compute_state_hash(mock_agent) + + # Reset mock to track new calls + mock_memory_client.gmdp_client.create_event.reset_mock() + + # Sync should skip since state unchanged + session_manager.sync_agent(mock_agent) + + mock_memory_client.gmdp_client.create_event.assert_not_called() + + def test_compute_state_hash_excludes_timestamps(self, session_manager, mock_memory_client): + """Test that _compute_state_hash excludes timestamps from hash computation.""" + mock_agent = Mock() + mock_agent.agent_id = "test-agent" + + # Create two SessionAgents with different timestamps but same state + session_agent_1 = SessionAgent( + agent_id="test-agent", + state={"key": "value"}, + conversation_manager_state={"cm_key": "cm_value"}, + created_at="2023-01-01T00:00:00Z", + updated_at="2023-01-01T00:00:00Z", + ) + session_agent_2 = SessionAgent( + agent_id="test-agent", + state={"key": "value"}, + conversation_manager_state={"cm_key": "cm_value"}, + created_at="2023-06-15T12:00:00Z", # Different timestamp + updated_at="2023-12-31T23:59:59Z", # Different timestamp + ) + + with patch.object(SessionAgent, "from_agent", return_value=session_agent_1): + hash1 = session_manager._compute_state_hash(mock_agent) + + with patch.object(SessionAgent, "from_agent", return_value=session_agent_2): + hash2 = session_manager._compute_state_hash(mock_agent) + + # Hashes should be equal since timestamps are excluded + assert hash1 == hash2, "Hash should be same when only timestamps differ" + + def test_compute_state_hash_changes_with_state(self, session_manager, mock_memory_client): + """Test that _compute_state_hash changes when actual state changes.""" + mock_agent = Mock() + mock_agent.agent_id = "test-agent" + + session_agent_1 = SessionAgent( + agent_id="test-agent", state={"key": "value1"}, conversation_manager_state={"cm_key": "cm_value"} + ) + session_agent_2 = SessionAgent( + agent_id="test-agent", + state={"key": "value2"}, # Different state + conversation_manager_state={"cm_key": "cm_value"}, + ) + + with patch.object(SessionAgent, "from_agent", return_value=session_agent_1): + hash1 = session_manager._compute_state_hash(mock_agent) + + with patch.object(SessionAgent, "from_agent", return_value=session_agent_2): + hash2 = session_manager._compute_state_hash(mock_agent) + + # Hashes should be different since state changed + assert hash1 != hash2, "Hash should differ when state changes" + + def test_sync_agent_saves_changed_state(self, session_manager, mock_memory_client): + """Test sync_agent saves when state has changed.""" + mock_agent = Mock() + mock_agent.agent_id = "test-agent" + + mock_session_agent = SessionAgent( + agent_id="test-agent", state={"key": "new_value"}, conversation_manager_state={} + ) + + # Set different initial hash + session_manager._last_synced_state_hash = hash("different") + + with patch.object(SessionAgent, "from_agent", return_value=mock_session_agent): + session_manager.sync_agent(mock_agent) + + mock_memory_client.gmdp_client.create_event.assert_called_once() + + def test_read_agent_new_format(self, session_manager, mock_memory_client): + """Test read_agent correctly parses new format with _type marker.""" + import json + + agent_state_payload = { + "_type": "agent_state", + "_agent_id": "test-agent", + "agent_id": "test-agent", + "state": {"key": "value"}, + "conversation_manager_state": {"cm_key": "cm_value"}, + } + + mock_memory_client.list_events.return_value = [ + {"eventId": "event-1", "payload": [{"blob": json.dumps(agent_state_payload)}]} + ] + + result = session_manager.read_agent("test-session-456", "test-agent") + + assert result is not None + assert result.agent_id == "test-agent" + assert result.state == {"key": "value"} + + def test_read_agent_skips_message_payloads(self, session_manager, mock_memory_client): + """Test read_agent skips message payloads when looking for agent state.""" + import json + + from strands.types.session import SessionMessage + + msg = SessionMessage( + message_id=1, message={"role": "user", "content": [{"text": "Hello"}]}, created_at="2023-01-01T00:00:00Z" + ) + message_payload = {"_type": "message", "data": [json.dumps(msg.to_dict()), "user"]} + agent_state_payload = { + "_type": "agent_state", + "_agent_id": "test-agent", + "agent_id": "test-agent", + "state": {}, + "conversation_manager_state": {}, + } + + mock_memory_client.list_events.return_value = [ + { + "eventId": "event-1", + "payload": [{"blob": json.dumps(message_payload)}, {"blob": json.dumps(agent_state_payload)}], + } + ] + + result = session_manager.read_agent("test-session-456", "test-agent") + + assert result is not None + assert result.agent_id == "test-agent" + + def test_read_agent_falls_back_to_legacy(self, session_manager, mock_memory_client): + """Test read_agent falls back to legacy format when new format not found. + + This tests the v1→v2 migration scenario where existing data was stored + with actor_id='agent_{id}' (v1) but new code queries with unified actor_id. + """ + import json + + legacy_agent_data = {"agent_id": "test-agent", "state": {"legacy": "data"}, "conversation_manager_state": {}} + + # First call (new format) returns empty, second call (legacy) returns data + mock_memory_client.list_events.side_effect = [ + [], # New format query with unified actor_id + [{"eventId": "event-1", "payload": [{"blob": json.dumps(legacy_agent_data)}]}], # Legacy + ] + + result = session_manager.read_agent("test-session-456", "test-agent") + + assert result is not None + assert result.state == {"legacy": "data"} + assert mock_memory_client.list_events.call_count == 2 + + # Verify correct actor_ids were used in each call + calls = mock_memory_client.list_events.call_args_list + # First call: unified actor_id (config.actor_id) + assert calls[0][1]["actor_id"] == "test-actor-789" + # Second call: legacy actor_id (agent_{agent_id}) + assert calls[1][1]["actor_id"] == "agent_test-agent" + + def test_create_agent_uses_new_format(self, session_manager, mock_memory_client): + """Test create_agent saves in new format with _type marker.""" + import json + + session_agent = SessionAgent(agent_id="test-agent", state={"key": "value"}, conversation_manager_state={}) + + session_manager.create_agent("test-session-456", session_agent) + + mock_memory_client.gmdp_client.create_event.assert_called_once() + call_kwargs = mock_memory_client.gmdp_client.create_event.call_args[1] + + payload_data = json.loads(call_kwargs["payload"][0]["blob"]) + assert payload_data["_type"] == "agent_state" + assert payload_data["_agent_id"] == "test-agent" + + def test_list_messages_parses_new_format(self, session_manager, mock_memory_client): + """Test list_messages correctly parses new format messages.""" + import json + + from strands.types.session import SessionMessage + + msg = SessionMessage( + message_id=1, message={"role": "user", "content": [{"text": "Hello"}]}, created_at="2023-01-01T00:00:00Z" + ) + message_payload = {"_type": "message", "data": [json.dumps(msg.to_dict()), "user"]} + + mock_memory_client.list_events.return_value = [ + {"eventId": "event-1", "payload": [{"blob": json.dumps(message_payload)}]} + ] + + result = session_manager.list_messages("test-session-456", "test-agent") + + assert len(result) == 1 + assert result[0].message["content"][0]["text"] == "Hello" + + def test_read_agent_returns_latest_state_multi_turn(self, session_manager, mock_memory_client): + """Test read_agent returns the latest agent state when multiple states exist (multi-turn).""" + import json + + # Simulate multiple agent states from different turns (oldest first, as API returns) + old_state = { + "_type": "agent_state", + "_agent_id": "test-agent", + "agent_id": "test-agent", + "state": {"turn": 1}, + "conversation_manager_state": {"removed_message_count": 0}, + } + latest_state = { + "_type": "agent_state", + "_agent_id": "test-agent", + "agent_id": "test-agent", + "state": {"turn": 2}, + "conversation_manager_state": {"removed_message_count": 0}, + } + + # Events returned oldest-first by API + mock_memory_client.list_events.return_value = [ + {"eventId": "event-1", "payload": [{"blob": json.dumps(old_state)}]}, + {"eventId": "event-2", "payload": [{"blob": json.dumps(latest_state)}]}, + ] + + result = session_manager.read_agent("test-session-456", "test-agent") + + assert result is not None + assert result.state == {"turn": 2}, "Should return latest state, not oldest" + + def test_read_agent_multi_agent_returns_latest_per_agent(self, session_manager, mock_memory_client): + """Test read_agent returns latest state for each agent in multi-agent multi-turn scenario.""" + import json + + # Interleaved states: agent-1 turn 1, agent-2 turn 1, agent-1 turn 2 (oldest first) + events = [ + { + "eventId": "event-1", + "payload": [ + { + "blob": json.dumps( + { + "_type": "agent_state", + "_agent_id": "agent-1", + "agent_id": "agent-1", + "state": {"turn": 1}, + "conversation_manager_state": {}, + } + ) + } + ], + }, + { + "eventId": "event-2", + "payload": [ + { + "blob": json.dumps( + { + "_type": "agent_state", + "_agent_id": "agent-2", + "agent_id": "agent-2", + "state": {"turn": 1}, + "conversation_manager_state": {}, + } + ) + } + ], + }, + { + "eventId": "event-3", + "payload": [ + { + "blob": json.dumps( + { + "_type": "agent_state", + "_agent_id": "agent-1", + "agent_id": "agent-1", + "state": {"turn": 2}, + "conversation_manager_state": {}, + } + ) + } + ], + }, + ] + + mock_memory_client.list_events.return_value = events + + # agent-1 should return turn 2 (latest), not turn 1 + result = session_manager.read_agent("test-session-456", "agent-1") + assert result is not None + assert result.state == {"turn": 2}, "Should return agent-1's latest state" + + # agent-2 should return turn 1 (only state) + result = session_manager.read_agent("test-session-456", "agent-2") + assert result is not None + assert result.state == {"turn": 1} + + def test_read_agent_multi_agent_not_found(self, session_manager, mock_memory_client): + """Test read_agent returns None when agent_id not found in multi-agent scenario.""" + import json + + agent1_state = { + "_type": "agent_state", + "_agent_id": "agent-1", + "agent_id": "agent-1", + "state": {}, + "conversation_manager_state": {}, + } + + mock_memory_client.list_events.return_value = [ + {"eventId": "event-1", "payload": [{"blob": json.dumps(agent1_state)}]}, + ] + + # agent-3 doesn't exist, and no legacy fallback + mock_memory_client.list_events.side_effect = [ + [{"eventId": "event-1", "payload": [{"blob": json.dumps(agent1_state)}]}], + [], # Legacy fallback returns empty + ] + result = session_manager.read_agent("test-session-456", "agent-3") + assert result is None diff --git a/tests/bedrock_agentcore/memory/integrations/strands/test_bedrock_converter.py b/tests/bedrock_agentcore/memory/integrations/strands/test_bedrock_converter.py index e15a457..9c2041d 100644 --- a/tests/bedrock_agentcore/memory/integrations/strands/test_bedrock_converter.py +++ b/tests/bedrock_agentcore/memory/integrations/strands/test_bedrock_converter.py @@ -221,3 +221,90 @@ def test_message_to_payload_with_bytes_encodes_before_filtering(self): assert isinstance(encoded_bytes, dict) assert encoded_bytes.get("__bytes_encoded__") is True assert "data" in encoded_bytes + + def test_events_to_messages_v2_format(self): + """Test parsing v2 format with _type marker.""" + session_message = SessionMessage( + message_id=1, message={"role": "user", "content": [{"text": "Hello"}]}, created_at="2023-01-01T00:00:00Z" + ) + + v2_payload = {"_type": "message", "data": [json.dumps(session_message.to_dict()), "user"]} + events = [{"payload": [{"blob": json.dumps(v2_payload)}]}] + + result = AgentCoreMemoryConverter.events_to_messages(events) + + assert len(result) == 1 + assert result[0].message["role"] == "user" + assert result[0].message["content"][0]["text"] == "Hello" + + def test_events_to_messages_v2_skips_agent_state(self): + """Test v2 format skips agent_state payloads.""" + session_message = SessionMessage( + message_id=1, message={"role": "user", "content": [{"text": "Hello"}]}, created_at="2023-01-01T00:00:00Z" + ) + + message_payload = {"_type": "message", "data": [json.dumps(session_message.to_dict()), "user"]} + agent_state_payload = { + "_type": "agent_state", + "_agent_id": "test-agent", + "agent_id": "test-agent", + "state": {}, + "conversation_manager_state": {}, + } + events = [{"payload": [{"blob": json.dumps(message_payload)}, {"blob": json.dumps(agent_state_payload)}]}] + + result = AgentCoreMemoryConverter.events_to_messages(events) + + assert len(result) == 1 + assert result[0].message["content"][0]["text"] == "Hello" + + def test_events_to_messages_v2_multiple_messages_reversed(self): + """Test v2 format returns messages in chronological order (reversed from API).""" + msg1 = SessionMessage( + message_id=1, + message={"role": "user", "content": [{"text": "First"}]}, + created_at="2023-01-01T00:00:00Z", + ) + msg2 = SessionMessage( + message_id=2, + message={"role": "assistant", "content": [{"text": "Second"}]}, + created_at="2023-01-01T00:00:01Z", + ) + + # API returns newest first + v2_msg2 = {"_type": "message", "data": [json.dumps(msg2.to_dict()), "assistant"]} + v2_msg1 = {"_type": "message", "data": [json.dumps(msg1.to_dict()), "user"]} + events = [ + {"payload": [{"blob": json.dumps(v2_msg2)}]}, + {"payload": [{"blob": json.dumps(v2_msg1)}]}, + ] + + result = AgentCoreMemoryConverter.events_to_messages(events) + + assert len(result) == 2 + assert result[0].message["content"][0]["text"] == "First" + assert result[1].message["content"][0]["text"] == "Second" + + def test_events_to_messages_mixed_v2_and_legacy(self): + """Test handling mixed v2 and legacy formats.""" + msg_v2 = SessionMessage( + message_id=1, + message={"role": "user", "content": [{"text": "V2 message"}]}, + created_at="2023-01-01T00:00:00Z", + ) + msg_legacy = SessionMessage( + message_id=2, + message={"role": "assistant", "content": [{"text": "Legacy message"}]}, + created_at="2023-01-01T00:00:01Z", + ) + + v2_payload = {"_type": "message", "data": [json.dumps(msg_v2.to_dict()), "user"]} + legacy_payload = [json.dumps(msg_legacy.to_dict()), "assistant"] + events = [ + {"payload": [{"blob": json.dumps(v2_payload)}]}, + {"payload": [{"blob": json.dumps(legacy_payload)}]}, + ] + + result = AgentCoreMemoryConverter.events_to_messages(events) + + assert len(result) == 2 diff --git a/tests/bedrock_agentcore/memory/integrations/strands/test_session_manager_integration.py b/tests/bedrock_agentcore/memory/integrations/strands/test_session_manager_integration.py new file mode 100644 index 0000000..87d45fa --- /dev/null +++ b/tests/bedrock_agentcore/memory/integrations/strands/test_session_manager_integration.py @@ -0,0 +1,286 @@ +"""Integration tests for AgentCoreMemorySessionManager. + +These tests require AWS credentials and make actual API calls. +Run with: pytest -m integration tests/.../test_session_manager_integration.py -v +Skip with: pytest -m "not integration" + +To skip these tests in normal runs, they are marked with @pytest.mark.integration. +""" + +import json +import os +import uuid + +import pytest + +# Skip all tests in this module if --run-integration is not provided +pytestmark = pytest.mark.integration + + +def get_memory_id(): + """Get Memory ID from environment.""" + return os.environ.get("MEMORY_ID") + + +def get_region(): + """Get AWS region from environment.""" + return os.environ.get("AWS_REGION", "us-west-2") + + +@pytest.fixture(scope="module") +def memory_id(): + """Get memory ID, skip if not available.""" + mid = get_memory_id() + if not mid: + pytest.skip("MEMORY_ID environment variable not set") + return mid + + +@pytest.fixture(scope="module") +def region(): + """Get AWS region.""" + return get_region() + + +@pytest.fixture +def unique_session_id(): + """Generate unique session ID for each test.""" + return f"test-{uuid.uuid4().hex[:12]}" + + +class APICallCounter: + """Count API calls to AgentCore Memory.""" + + def __init__(self): + self.call_count = 0 + self.call_details = [] + + def reset(self): + self.call_count = 0 + self.call_details = [] + + def increment(self, method_name: str): + self.call_count += 1 + self.call_details.append(method_name) + + +class TestOptimizedStorageIntegration: + """Integration tests for optimized storage.""" + + def test_batched_api_calls_simple_message(self, memory_id, region, unique_session_id): + """Test that batched storage reduces API calls for simple message.""" + from strands import Agent + from strands.models import BedrockModel + + from bedrock_agentcore.memory.integrations.strands.config import AgentCoreMemoryConfig + from bedrock_agentcore.memory.integrations.strands.session_manager import AgentCoreMemorySessionManager + + actor_id = "test-user" + + config = AgentCoreMemoryConfig( + memory_id=memory_id, + session_id=unique_session_id, + actor_id=actor_id, + ) + session_manager = AgentCoreMemorySessionManager( + agentcore_memory_config=config, + region_name=region, + ) + + model = BedrockModel(model_id="us.anthropic.claude-haiku-4-5-20251001-v1:0", region_name=region) + + agent = Agent( + model=model, + session_manager=session_manager, + system_prompt="You are a helpful assistant. Keep responses very brief.", + ) + + # Count API calls + counter = APICallCounter() + original_create_event = session_manager.memory_client.gmdp_client.create_event + + def counted(*args, **kwargs): + counter.increment("create_event") + return original_create_event(*args, **kwargs) + + session_manager.memory_client.gmdp_client.create_event = counted + agent("Say hello briefly.") + api_calls = counter.call_count + + # Batched storage: 1 (session) + 1 (user msg+state) + 1 (assistant msg+state) = 3 + # AfterInvocation sync skips due to state hash tracking + assert api_calls <= 4, f"Expected at most 4 API calls, got {api_calls}" + + def test_multi_turn_preserves_context(self, memory_id, region, unique_session_id): + """Test that context is preserved across turns.""" + from strands import Agent + from strands.models import BedrockModel + + from bedrock_agentcore.memory.integrations.strands.config import AgentCoreMemoryConfig + from bedrock_agentcore.memory.integrations.strands.session_manager import AgentCoreMemorySessionManager + + actor_id = "test-user" + session_id = f"multi-{unique_session_id}" + + config = AgentCoreMemoryConfig( + memory_id=memory_id, + session_id=session_id, + actor_id=actor_id, + ) + + model = BedrockModel(model_id="us.anthropic.claude-haiku-4-5-20251001-v1:0", region_name=region) + + # Turn 1: Tell the agent something to remember + session_manager_1 = AgentCoreMemorySessionManager( + agentcore_memory_config=config, + region_name=region, + ) + + agent_1 = Agent( + model=model, + session_manager=session_manager_1, + system_prompt="You are a helpful assistant. Remember what the user tells you.", + ) + + agent_1("My favorite color is blue. Remember this.") + + # Turn 2: Create NEW agent instance with SAME session ID + session_manager_2 = AgentCoreMemorySessionManager( + agentcore_memory_config=config, + region_name=region, + ) + + agent_2 = Agent( + model=model, + session_manager=session_manager_2, + system_prompt="You are a helpful assistant. Remember what the user tells you.", + ) + + # Verify messages were loaded + assert len(agent_2.messages) >= 2, f"Expected at least 2 messages loaded, got {len(agent_2.messages)}" + + # Ask about the remembered information + response = agent_2("What is my favorite color?") + + # Extract response text + response_text = "" + if response.message and response.message.get("content"): + for block in response.message["content"]: + if block.get("text"): + response_text = block["text"].lower() + break + + assert "blue" in response_text, f"Expected 'blue' in response, got: {response_text}" + + def test_messages_correctly_parsed(self, memory_id, region, unique_session_id): + """Test that messages are correctly saved and loaded.""" + from strands import Agent + from strands.models import BedrockModel + + from bedrock_agentcore.memory.client import MemoryClient + from bedrock_agentcore.memory.integrations.strands.config import AgentCoreMemoryConfig + from bedrock_agentcore.memory.integrations.strands.session_manager import AgentCoreMemorySessionManager + + actor_id = "test-user" + session_id = f"parse-{unique_session_id}" + + config = AgentCoreMemoryConfig( + memory_id=memory_id, + session_id=session_id, + actor_id=actor_id, + ) + + model = BedrockModel(model_id="us.anthropic.claude-haiku-4-5-20251001-v1:0", region_name=region) + + session_manager = AgentCoreMemorySessionManager( + agentcore_memory_config=config, + region_name=region, + ) + + agent = Agent(model=model, session_manager=session_manager, system_prompt="You are a helpful assistant.") + + agent("Test message for format.") + + # Verify stored format + memory_client = MemoryClient(region_name=region) + events = memory_client.list_events( + memory_id=memory_id, + actor_id=actor_id, + session_id=session_id, + max_results=10, + ) + + assert len(events) > 0, "Expected events to be saved" + + # Check that at least one event has type markers + found_message = False + found_agent_state = False + + for event in events: + for payload_item in event.get("payload", []): + if "blob" in payload_item: + blob_data = json.loads(payload_item["blob"]) + if isinstance(blob_data, dict): + if blob_data.get("_type") == "message": + found_message = True + if blob_data.get("_type") == "agent_state": + found_agent_state = True + + assert found_message, "Expected message with _type marker" + assert found_agent_state, "Expected agent_state with _type marker" + + def test_batched_api_calls_with_tools(self, memory_id, region, unique_session_id): + """Test that batched storage works with tool calls.""" + from datetime import datetime + + from strands import Agent, tool + from strands.models import BedrockModel + + from bedrock_agentcore.memory.integrations.strands.config import AgentCoreMemoryConfig + from bedrock_agentcore.memory.integrations.strands.session_manager import AgentCoreMemorySessionManager + + @tool + def get_time() -> str: + """Get the current time.""" + return datetime.now().strftime("%H:%M:%S") + + @tool + def add_numbers(a: int, b: int) -> int: + """Add two numbers.""" + return a + b + + actor_id = "test-user" + + model = BedrockModel(model_id="us.anthropic.claude-haiku-4-5-20251001-v1:0", region_name=region) + + config = AgentCoreMemoryConfig( + memory_id=memory_id, + session_id=f"tools-{unique_session_id}", + actor_id=actor_id, + ) + session_manager = AgentCoreMemorySessionManager( + agentcore_memory_config=config, + region_name=region, + ) + + agent = Agent( + model=model, + session_manager=session_manager, + tools=[get_time, add_numbers], + system_prompt="You are a helpful assistant. Use tools when asked. Keep responses brief.", + ) + + counter = APICallCounter() + original_create_event = session_manager.memory_client.gmdp_client.create_event + + def counted(*args, **kwargs): + counter.increment("create_event") + return original_create_event(*args, **kwargs) + + session_manager.memory_client.gmdp_client.create_event = counted + agent("What time is it? And what is 10 + 20?") + api_calls = counter.call_count + + # With tools: more messages but still batched + # Should be less than 10 calls (unbatched would be ~9+ calls per message) + assert api_calls <= 7, f"Expected at most 7 API calls with tools, got {api_calls}"