diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 6156d332c..870b6568c 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -34,11 +34,13 @@ MultiAgentInitializedEvent, ) from ..hooks import HookProvider, HookRegistry +from ..interrupt import Interrupt, _InterruptState from ..session import SessionManager from ..telemetry import get_tracer from ..types._events import ( MultiAgentHandoffEvent, MultiAgentNodeCancelEvent, + MultiAgentNodeInterruptEvent, MultiAgentNodeStartEvent, MultiAgentNodeStopEvent, MultiAgentNodeStreamEvent, @@ -162,6 +164,7 @@ class GraphNode: execution_status: Status = Status.PENDING result: NodeResult | None = None execution_time: int = 0 + graph: Optional["Graph"] = None _initial_messages: Messages = field(default_factory=list, init=False) _initial_state: AgentState = field(default_factory=AgentState, init=False) @@ -177,9 +180,18 @@ def __post_init__(self) -> None: def reset_executor_state(self) -> None: """Reset GraphNode executor state to initial state when graph was created. - This is useful when nodes are executed multiple times and need to start - fresh on each execution, providing stateless behavior. + If Graph is resuming from an interrupt, we reset the executor state from the interrupt context. """ + if self.graph and self.graph._interrupt_state.activated and self.node_id in self.graph._interrupt_state.context: + context = self.graph._interrupt_state.context[self.node_id] + if hasattr(self.executor, "messages"): + self.executor.messages = context["messages"] + if hasattr(self.executor, "state"): + self.executor.state = AgentState(context["state"]) + if hasattr(self.executor, "_interrupt_state"): + self.executor._interrupt_state = _InterruptState.from_dict(context["interrupt_state"]) + return + if hasattr(self.executor, "messages"): self.executor.messages = copy.deepcopy(self._initial_messages) @@ -440,11 +452,16 @@ def __init__( self.nodes = nodes self.edges = edges self.entry_points = entry_points + + # Set graph reference on all nodes for interrupt state restoration + for node in self.nodes.values(): + node.graph = self self.max_node_executions = max_node_executions self.execution_timeout = execution_timeout self.node_timeout = node_timeout self.reset_on_revisit = reset_on_revisit self.state = GraphState() + self._interrupt_state = _InterruptState() self.tracer = get_tracer() self.trace_attributes: dict[str, AttributeValue] = self._parse_trace_attributes(trace_attributes) self.session_manager = session_manager @@ -519,6 +536,8 @@ async def stream_async( - multi_agent_node_stop: When a node stops execution - result: Final graph result """ + self._interrupt_state.resume(task) + if invocation_state is None: invocation_state = {} @@ -544,6 +563,8 @@ async def stream_async( span = self.tracer.start_multiagent_span(task, "graph", custom_trace_attributes=self.trace_attributes) with trace_api.use_span(span, end_on_exit=True): + interrupts = [] + try: logger.debug( "max_node_executions=<%s>, execution_timeout=<%s>s, node_timeout=<%s>s | graph execution config", @@ -553,6 +574,9 @@ async def stream_async( ) async for event in self._execute_graph(invocation_state): + if isinstance(event, MultiAgentNodeInterruptEvent): + interrupts = event.interrupts + yield event.as_dict() # Set final status based on execution results @@ -564,7 +588,7 @@ async def stream_async( logger.debug("status=<%s> | graph execution completed", self.state.status) # Yield final result (consistent with Agent's AgentResultEvent format) - result = self._build_result() + result = self._build_result(interrupts) # Use the same event format as Agent for consistency yield MultiAgentResultEvent(result=result).as_dict() @@ -576,8 +600,11 @@ async def stream_async( finally: self.state.execution_time = round((time.time() - start_time) * 1000) await self.hooks.invoke_callbacks_async(AfterMultiAgentInvocationEvent(self)) - self._resume_from_session = False - self._resume_next_nodes.clear() + # Don't clear resume flags here - they should only be cleared when consumed + # at the start of _execute_graph, or when status is COMPLETED/FAILED + if self.state.status in (Status.COMPLETED, Status.FAILED): + self._resume_from_session = False + self._resume_next_nodes.clear() def _validate_graph(self, nodes: dict[str, GraphNode]) -> None: """Validate graph nodes for duplicate instances.""" @@ -593,7 +620,22 @@ def _validate_graph(self, nodes: dict[str, GraphNode]) -> None: async def _execute_graph(self, invocation_state: dict[str, Any]) -> AsyncIterator[Any]: """Execute graph and yield TypedEvent objects.""" - ready_nodes = self._resume_next_nodes if self._resume_from_session else list(self.entry_points) + # Make a copy to avoid clearing the list we're about to use + ready_nodes = self._resume_next_nodes.copy() if self._resume_from_session else list(self.entry_points) + + logger.debug( + "resume_from_session=<%s>, resume_next_nodes=<%s>, entry_points=<%s>, ready_nodes=<%s> | " + "starting execution", + self._resume_from_session, + [n.node_id for n in self._resume_next_nodes] if self._resume_next_nodes else [], + [n.node_id for n in self.entry_points], + [n.node_id for n in ready_nodes], + ) + + # Clear resume flags after consuming them once + if self._resume_from_session: + self._resume_from_session = False + self._resume_next_nodes.clear() while ready_nodes: # Check execution limits before continuing @@ -609,15 +651,31 @@ async def _execute_graph(self, invocation_state: dict[str, Any]) -> AsyncIterato current_batch = ready_nodes.copy() ready_nodes.clear() + # Track if interrupt occurred + interrupt_detected = False + # Execute current batch async for event in self._execute_nodes_parallel(current_batch, invocation_state): + # Check for interrupt event + if isinstance(event, MultiAgentNodeInterruptEvent): + interrupt_detected = True yield event + # Stop execution if interrupted + if interrupt_detected: + break + # Find newly ready nodes after batch execution # We add all nodes in current batch as completed batch, # because a failure would throw exception and code would not make it here newly_ready = self._find_newly_ready_nodes(current_batch) + logger.debug( + "completed_batch=<%s>, newly_ready=<%s> | finding next nodes", + [n.node_id for n in current_batch], + [n.node_id for n in newly_ready], + ) + # Emit handoff event for batch transition if there are nodes to transition to if newly_ready: handoff_event = MultiAgentHandoffEvent( @@ -797,10 +855,14 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) ) yield start_event - before_event, _ = await self.hooks.invoke_callbacks_async( + before_event, interrupts = await self.hooks.invoke_callbacks_async( BeforeNodeCallEvent(self, node.node_id, invocation_state) ) + if interrupts: + yield self._activate_interrupt(node, interrupts) + return + start_time = time.time() try: if before_event.cancel_node: @@ -811,8 +873,22 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) yield MultiAgentNodeCancelEvent(node.node_id, cancel_message) raise RuntimeError(cancel_message) - # Build node input from satisfied dependencies - node_input = self._build_node_input(node) + # Check if resuming from interrupt with response + # For hook interrupts: activated=False at agent level → use original task input + # For agent interrupts: activated=True at agent level → use interrupt responses + if ( + self._interrupt_state.activated + and node.node_id in self._interrupt_state.context + and self._interrupt_state.context[node.node_id].get("activated") + ): + # Agent was interrupted - use interrupt response as input + node_input = self._interrupt_state.context.get("responses", []) + + # Restore node state from interrupt context (this restores the agent's interrupt state) + node.reset_executor_state() + else: + # Normal execution or hook interrupt - build normal input + node_input = self._build_node_input(node) # Execute and stream events (timeout handled at task level) if isinstance(node.executor, MultiAgentBase): @@ -830,13 +906,15 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) if multi_agent_result is None: raise ValueError(f"Node '{node.node_id}' did not produce a result event") + # Check the actual status from multi_agent_result node_result = NodeResult( result=multi_agent_result, execution_time=multi_agent_result.execution_time, - status=Status.COMPLETED, + status=multi_agent_result.status, accumulated_usage=multi_agent_result.accumulated_usage, accumulated_metrics=multi_agent_result.accumulated_metrics, execution_count=multi_agent_result.execution_count, + interrupts=multi_agent_result.interrupts or [], ) elif isinstance(node.executor, Agent): @@ -854,13 +932,6 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) if agent_response is None: raise ValueError(f"Node '{node.node_id}' did not produce a result event") - # Check for interrupt (from main branch) - if agent_response.stop_reason == "interrupt": - node.executor.messages.pop() # remove interrupted tool use message - node.executor._interrupt_state.deactivate() - - raise RuntimeError("user raised interrupt from agent | interrupts are not yet supported in graphs") - # Extract metrics with defaults response_metrics = getattr(agent_response, "metrics", None) usage = getattr( @@ -868,17 +939,30 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) ) metrics = getattr(response_metrics, "accumulated_metrics", Metrics(latencyMs=0)) + # Check for interrupt and set appropriate status + execution_time = round((time.time() - start_time) * 1000) + status = Status.INTERRUPTED if agent_response.stop_reason == "interrupt" else Status.COMPLETED + node_result = NodeResult( result=agent_response, - execution_time=round((time.time() - start_time) * 1000), - status=Status.COMPLETED, + execution_time=execution_time, + status=status, accumulated_usage=usage, accumulated_metrics=metrics, execution_count=1, + interrupts=agent_response.interrupts or [], ) else: raise ValueError(f"Node '{node.node_id}' of type '{type(node.executor)}' is not supported") + # Check if node was interrupted + if node_result.status == Status.INTERRUPTED: + yield self._activate_interrupt(node, node_result.interrupts) + return + + # Deactivate interrupt state after successful execution + self._interrupt_state.deactivate() + # Mark as completed node.execution_status = Status.COMPLETED node.result = node_result @@ -979,7 +1063,13 @@ def _build_node_input(self, node: GraphNode) -> list[ContentBlock]: if isinstance(self.state.task, str): return [ContentBlock(text=self.state.task)] else: - return cast(list[ContentBlock], self.state.task) + # Filter out interruptResponse dicts from task + task_list = cast(list[ContentBlock], self.state.task) + filtered_task = [] + for item in task_list: + if not (isinstance(item, dict) and "interruptResponse" in item): + filtered_task.append(item) + return filtered_task # Combine task with dependency outputs node_input = [] @@ -988,9 +1078,12 @@ def _build_node_input(self, node: GraphNode) -> list[ContentBlock]: if isinstance(self.state.task, str): node_input.append(ContentBlock(text=f"Original Task: {self.state.task}")) else: - # Add task content blocks with a prefix + # Add task content blocks with a prefix (filter out interruptResponse dicts) node_input.append(ContentBlock(text="Original Task:")) - node_input.extend(cast(list[ContentBlock], self.state.task)) + task_list = cast(list[ContentBlock], self.state.task) + for item in task_list: + if not (isinstance(item, dict) and "interruptResponse" in item): + node_input.append(item) # Add dependency outputs node_input.append(ContentBlock(text="\nInputs from previous nodes:")) @@ -1006,8 +1099,49 @@ def _build_node_input(self, node: GraphNode) -> list[ContentBlock]: return node_input - def _build_result(self) -> GraphResult: + def _activate_interrupt(self, node: GraphNode, interrupts: list[Interrupt]) -> MultiAgentNodeInterruptEvent: + """Activate the interrupt state. + + A Graph may be interrupted either from a BeforeNodeCallEvent hook or from within an agent node. + In either case, we must manage the interrupt state of both the Graph and the individual agent nodes. + + Args: + node: The interrupted node. + interrupts: The interrupts raised by the user. + + Returns: + MultiAgentNodeInterruptEvent + """ + logger.debug("node=<%s> | node interrupted", node.node_id) + self.state.status = Status.INTERRUPTED + + # Save node and graph state for resumption + node_executor = node.executor + self._interrupt_state.context[node.node_id] = { + "activated": node_executor._interrupt_state.activated + if hasattr(node_executor, "_interrupt_state") + else False, + "interrupt_state": node_executor._interrupt_state.to_dict() + if hasattr(node_executor, "_interrupt_state") + else {}, + "state": node_executor.state.get() if hasattr(node_executor, "state") else {}, + "messages": node_executor.messages if hasattr(node_executor, "messages") else [], + } + + self._interrupt_state.interrupts.update({interrupt.id: interrupt for interrupt in interrupts}) + self._interrupt_state.activate() + + # Set resume state so graph continues from this node instead of entry points + self._resume_next_nodes = [node] + self._resume_from_session = True + + return MultiAgentNodeInterruptEvent(node.node_id, interrupts) + + def _build_result(self, interrupts: list[Interrupt] | None = None) -> GraphResult: """Build graph result from current state.""" + if interrupts is None: + interrupts = [] + return GraphResult( status=self.state.status, results=self.state.results, @@ -1021,6 +1155,7 @@ def _build_result(self) -> GraphResult: execution_order=self.state.execution_order, edges=self.state.edges, entry_points=self.state.entry_points, + interrupts=interrupts, ) def serialize_state(self) -> dict[str, Any]: @@ -1037,6 +1172,9 @@ def serialize_state(self) -> dict[str, Any]: "next_nodes_to_execute": next_nodes, "current_task": self.state.task, "execution_order": [n.node_id for n in self.state.execution_order], + "_internal_state": { + "interrupt_state": self._interrupt_state.to_dict(), + }, } def deserialize_state(self, payload: dict[str, Any]) -> None: @@ -1052,6 +1190,11 @@ def deserialize_state(self, payload: dict[str, Any]) -> None: payload: Dictionary containing persisted state data including status, completed nodes, results, and next nodes to execute. """ + # Restore interrupt state if present + if "_internal_state" in payload: + internal_state = payload["_internal_state"] + self._interrupt_state = _InterruptState.from_dict(internal_state.get("interrupt_state", {})) + if not payload.get("next_nodes_to_execute"): # Reset all nodes for node in self.nodes.values(): diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py index 4875d1bec..a125ae234 100644 --- a/tests/strands/multiagent/test_graph.py +++ b/tests/strands/multiagent/test_graph.py @@ -1,6 +1,6 @@ import asyncio import time -from unittest.mock import AsyncMock, MagicMock, Mock, call, patch +from unittest.mock import ANY, AsyncMock, MagicMock, Mock, call, patch import pytest @@ -9,6 +9,7 @@ from strands.experimental.hooks.multiagent import BeforeNodeCallEvent from strands.hooks import AgentInitializedEvent from strands.hooks.registry import HookProvider, HookRegistry +from strands.interrupt import Interrupt, _InterruptState from strands.multiagent.base import MultiAgentBase, MultiAgentResult, NodeResult from strands.multiagent.graph import Graph, GraphBuilder, GraphEdge, GraphNode, GraphResult, GraphState, Status from strands.session.file_session_manager import FileSessionManager @@ -2068,3 +2069,130 @@ def cancel_callback(event): tru_status = graph.state.status exp_status = Status.FAILED assert tru_status == exp_status + + +def test_graph_interrupt_on_before_node_call_event(interrupt_hook): + agent = create_mock_agent("test_agent", "Task completed") + agent.state = AgentState() + agent.messages = [] + agent._interrupt_state = _InterruptState() + + builder = GraphBuilder() + builder.add_node(agent, "test_agent") + builder.set_entry_point("test_agent") + builder.set_hook_providers([interrupt_hook]) + + graph = builder.build() + + multiagent_result = graph("Test task") + + tru_status = multiagent_result.status + exp_status = Status.INTERRUPTED + assert tru_status == exp_status + + tru_interrupts = multiagent_result.interrupts + exp_interrupts = [ + Interrupt( + id=ANY, + name="test_name", + reason="test_reason", + ), + ] + assert tru_interrupts == exp_interrupts + + interrupt = multiagent_result.interrupts[0] + responses = [ + { + "interruptResponse": { + "interruptId": interrupt.id, + "response": "test_response", + }, + }, + ] + multiagent_result = graph(responses) + + tru_status = multiagent_result.status + exp_status = Status.COMPLETED + assert tru_status == exp_status + + assert len(multiagent_result.results) == 1 + agent_result = multiagent_result.results["test_agent"] + + tru_message = agent_result.result.message["content"][0]["text"] + exp_message = "Task completed" + assert tru_message == exp_message + + +def test_graph_interrupt_on_agent(agenerator): + exp_interrupts = [ + Interrupt( + id="test_id", + name="test_name", + reason="test_reason", + ), + ] + + agent = create_mock_agent("test_agent", "Task completed") + agent.state = AgentState() + agent.messages = [] + agent._interrupt_state = _InterruptState() + + builder = GraphBuilder() + builder.add_node(agent, "test_agent") + builder.set_entry_point("test_agent") + graph = builder.build() + + agent.stream_async = Mock() + agent.stream_async.return_value = agenerator( + [ + { + "result": AgentResult( + message={}, + stop_reason="interrupt", + state={}, + metrics=None, + interrupts=exp_interrupts, + ), + }, + ], + ) + multiagent_result = graph("Test task") + + tru_status = multiagent_result.status + exp_status = Status.INTERRUPTED + assert tru_status == exp_status + + tru_interrupts = multiagent_result.interrupts + assert tru_interrupts == exp_interrupts + + agent.stream_async = Mock() + agent.stream_async.return_value = agenerator( + [ + { + "result": AgentResult( + message={}, + stop_reason="end_turn", + state={}, + metrics=None, + ), + }, + ], + ) + graph._interrupt_state.context["test_agent"]["activated"] = True + + interrupt = multiagent_result.interrupts[0] + responses = [ + { + "interruptResponse": { + "interruptId": interrupt.id, + "response": "test_response", + }, + }, + ] + multiagent_result = graph(responses) + + tru_status = multiagent_result.status + exp_status = Status.COMPLETED + assert tru_status == exp_status + + agent.stream_async.assert_called_once_with(responses, invocation_state={}) diff --git a/tests/strands/tools/mcp/test_mcp_client.py b/tests/strands/tools/mcp/test_mcp_client.py index e72aebd92..f3fdfc87a 100644 --- a/tests/strands/tools/mcp/test_mcp_client.py +++ b/tests/strands/tools/mcp/test_mcp_client.py @@ -533,7 +533,7 @@ def test_stop_closes_event_loop(): mock_thread.join = MagicMock() mock_event_loop = MagicMock() mock_event_loop.close = MagicMock() - + client._background_thread = mock_thread client._background_thread_event_loop = mock_event_loop @@ -542,7 +542,7 @@ def test_stop_closes_event_loop(): # Verify thread was joined mock_thread.join.assert_called_once() - + # Verify event loop was closed mock_event_loop.close.assert_called_once() diff --git a/tests_integ/interrupts/multiagent/test_hook.py b/tests_integ/interrupts/multiagent/test_hook.py index be7682082..374e64a80 100644 --- a/tests_integ/interrupts/multiagent/test_hook.py +++ b/tests_integ/interrupts/multiagent/test_hook.py @@ -131,3 +131,112 @@ async def test_swarm_interrupt_reject(swarm): tru_node_id = multiagent_result.node_history[0].node_id exp_node_id = "info" assert tru_node_id == exp_node_id + + +# Graph tests + + +@pytest.fixture +def graph(interrupt_hook, weather_tool): + from strands.multiagent import GraphBuilder + + info_agent = Agent(name="info") + weather_agent = Agent(name="weather", tools=[weather_tool]) + + builder = GraphBuilder() + builder.add_node(info_agent, "info") + builder.add_node(weather_agent, "weather") + builder.add_edge("info", "weather") + builder.set_entry_point("info") + builder.set_hook_providers([interrupt_hook]) + + return builder.build() + + +def test_graph_interrupt(graph): + multiagent_result = graph("What is the weather?") + + tru_status = multiagent_result.status + exp_status = Status.INTERRUPTED + assert tru_status == exp_status + + tru_interrupts = multiagent_result.interrupts + exp_interrupts = [ + Interrupt( + id=ANY, + name="test_interrupt", + reason="need approval", + ), + ] + assert tru_interrupts == exp_interrupts + + interrupt = multiagent_result.interrupts[0] + + responses = [ + { + "interruptResponse": { + "interruptId": interrupt.id, + "response": "APPROVE", + }, + }, + ] + multiagent_result = graph(responses) + + tru_status = multiagent_result.status + exp_status = Status.COMPLETED + assert tru_status == exp_status + + assert len(multiagent_result.results) == 2 + weather_result = multiagent_result.results["weather"] + + weather_message = json.dumps(weather_result.result.message).lower() + assert "sunny" in weather_message + + +@pytest.mark.asyncio +async def test_graph_interrupt_reject(graph): + multiagent_result = graph("What is the weather?") + + tru_status = multiagent_result.status + exp_status = Status.INTERRUPTED + assert tru_status == exp_status + + tru_interrupts = multiagent_result.interrupts + exp_interrupts = [ + Interrupt( + id=ANY, + name="test_interrupt", + reason="need approval", + ), + ] + assert tru_interrupts == exp_interrupts + + interrupt = multiagent_result.interrupts[0] + + responses = [ + { + "interruptResponse": { + "interruptId": interrupt.id, + "response": "REJECT", + }, + }, + ] + tru_cancel_id = None + + # Graph raises RuntimeError for cancel_node + with pytest.raises(RuntimeError, match="node rejected"): + async for event in graph.stream_async(responses): + if event.get("type") == "multiagent_node_cancel": + tru_cancel_id = event["node_id"] + + exp_cancel_id = "weather" + assert tru_cancel_id == exp_cancel_id + + tru_status = graph.state.status + exp_status = Status.FAILED + assert tru_status == exp_status + + assert len(graph.state.execution_order) == 1 + tru_node_id = graph.state.execution_order[0].node_id + exp_node_id = "info" + assert tru_node_id == exp_node_id