diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py index 43836fe34..804f90a1d 100644 --- a/src/strands/event_loop/streaming.py +++ b/src/strands/event_loop/streaming.py @@ -289,12 +289,13 @@ def handle_content_block_stop(state: dict[str, Any]) -> dict[str, Any]: state["current_tool_use"] = {} elif text: - content.append({"text": text}) - state["text"] = "" if citations_content: - citations_block: CitationsContentBlock = {"citations": citations_content} + citations_block: CitationsContentBlock = {"citations": citations_content, "content": [{"text": text}]} content.append({"citationsContent": citations_block}) state["citationsContent"] = [] + else: + content.append({"text": text}) + state["text"] = "" elif reasoning_text: content_block: ContentBlock = { diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 4a7c81672..08d8f400c 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -500,16 +500,7 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An for citation in citations["citations"]: filtered_citation: dict[str, Any] = {} if "location" in citation: - location = citation["location"] - filtered_location = {} - # Filter location fields to only include Bedrock-supported ones - if "documentIndex" in location: - filtered_location["documentIndex"] = location["documentIndex"] - if "start" in location: - filtered_location["start"] = location["start"] - if "end" in location: - filtered_location["end"] = location["end"] - filtered_citation["location"] = filtered_location + filtered_citation["location"] = citation["location"] if "sourceContent" in citation: filtered_source_content: list[dict[str, Any]] = [] for source_content in citation["sourceContent"]: diff --git a/src/strands/types/_events.py b/src/strands/types/_events.py index c3890f428..d64357cf8 100644 --- a/src/strands/types/_events.py +++ b/src/strands/types/_events.py @@ -161,7 +161,7 @@ class CitationStreamEvent(ModelStreamEvent): def __init__(self, delta: ContentBlockDelta, citation: Citation) -> None: """Initialize with delta and citation content.""" - super().__init__({"callback": {"citation": citation, "delta": delta}}) + super().__init__({"citation": citation, "delta": delta}) class ReasoningTextStreamEvent(ModelStreamEvent): diff --git a/src/strands/types/citations.py b/src/strands/types/citations.py index b0e28f655..41f2fa4e0 100644 --- a/src/strands/types/citations.py +++ b/src/strands/types/citations.py @@ -3,7 +3,7 @@ These types are modeled after the Bedrock API. """ -from typing import List, Union +from typing import List, Literal, Union from typing_extensions import TypedDict @@ -77,8 +77,17 @@ class DocumentPageLocation(TypedDict, total=False): end: int -# Union type for citation locations -CitationLocation = Union[DocumentCharLocation, DocumentChunkLocation, DocumentPageLocation] +# Tagged union type aliases following the ToolChoice pattern +DocumentCharLocationDict = dict[Literal["documentChar"], DocumentCharLocation] +DocumentPageLocationDict = dict[Literal["documentPage"], DocumentPageLocation] +DocumentChunkLocationDict = dict[Literal["documentChunk"], DocumentChunkLocation] + +# Union type for citation locations - tagged union format matching AWS Bedrock API +CitationLocation = Union[ + DocumentCharLocationDict, + DocumentPageLocationDict, + DocumentChunkLocationDict, +] class CitationSourceContent(TypedDict, total=False): diff --git a/tests/strands/event_loop/test_streaming.py b/tests/strands/event_loop/test_streaming.py index 02be400b1..c6e44b78a 100644 --- a/tests/strands/event_loop/test_streaming.py +++ b/tests/strands/event_loop/test_streaming.py @@ -215,6 +215,59 @@ def test_handle_content_block_start(chunk: ContentBlockStartEvent, exp_tool_use) {}, {}, ), + # Citation - New + ( + { + "delta": { + "citation": { + "location": {"documentChar": {"documentIndex": 0, "start": 10, "end": 20}}, + "title": "Test Doc", + } + } + }, + {}, + {}, + { + "citationsContent": [ + {"location": {"documentChar": {"documentIndex": 0, "start": 10, "end": 20}}, "title": "Test Doc"} + ] + }, + { + "citation": { + "location": {"documentChar": {"documentIndex": 0, "start": 10, "end": 20}}, + "title": "Test Doc", + } + }, + ), + # Citation - Existing + ( + { + "delta": { + "citation": { + "location": {"documentPage": {"documentIndex": 1, "start": 5, "end": 6}}, + "title": "Another Doc", + } + } + }, + {}, + { + "citationsContent": [ + {"location": {"documentChar": {"documentIndex": 0, "start": 10, "end": 20}}, "title": "Test Doc"} + ] + }, + { + "citationsContent": [ + {"location": {"documentChar": {"documentIndex": 0, "start": 10, "end": 20}}, "title": "Test Doc"}, + {"location": {"documentPage": {"documentIndex": 1, "start": 5, "end": 6}}, "title": "Another Doc"}, + ] + }, + { + "citation": { + "location": {"documentPage": {"documentIndex": 1, "start": 5, "end": 6}}, + "title": "Another Doc", + } + }, + ), # Empty ( {"delta": {}}, @@ -294,14 +347,49 @@ def test_handle_content_block_delta(event: ContentBlockDeltaEvent, event_type, s "redactedContent": b"", }, ), - # Citations + # Text with Citations + ( + { + "content": [], + "current_tool_use": {}, + "text": "This is cited text", + "reasoningText": "", + "citationsContent": [ + {"location": {"documentChar": {"documentIndex": 0, "start": 10, "end": 20}}, "title": "Test Doc"} + ], + "redactedContent": b"", + }, + { + "content": [ + { + "citationsContent": { + "citations": [ + { + "location": {"documentChar": {"documentIndex": 0, "start": 10, "end": 20}}, + "title": "Test Doc", + } + ], + "content": [{"text": "This is cited text"}], + } + } + ], + "current_tool_use": {}, + "text": "", + "reasoningText": "", + "citationsContent": [], + "redactedContent": b"", + }, + ), + # Citations without text (should not create content block) ( { "content": [], "current_tool_use": {}, "text": "", "reasoningText": "", - "citationsContent": [{"citations": [{"text": "test", "source": "test"}]}], + "citationsContent": [ + {"location": {"documentChar": {"documentIndex": 0, "start": 10, "end": 20}}, "title": "Test Doc"} + ], "redactedContent": b"", }, { @@ -309,7 +397,9 @@ def test_handle_content_block_delta(event: ContentBlockDeltaEvent, event_type, s "current_tool_use": {}, "text": "", "reasoningText": "", - "citationsContent": [{"citations": [{"text": "test", "source": "test"}]}], + "citationsContent": [ + {"location": {"documentChar": {"documentIndex": 0, "start": 10, "end": 20}}, "title": "Test Doc"} + ], "redactedContent": b"", }, ), @@ -578,6 +668,137 @@ def test_extract_usage_metrics_empty_metadata(): }, ], ), + # Message with Citations + ( + [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockStart": {"start": {}}}, + {"contentBlockDelta": {"delta": {"text": "This is cited text"}}}, + { + "contentBlockDelta": { + "delta": { + "citation": { + "location": {"documentChar": {"documentIndex": 0, "start": 10, "end": 20}}, + "title": "Test Doc", + } + } + } + }, + { + "contentBlockDelta": { + "delta": { + "citation": { + "location": {"documentPage": {"documentIndex": 1, "start": 5, "end": 6}}, + "title": "Another Doc", + } + } + } + }, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn"}}, + { + "metadata": { + "usage": {"inputTokens": 5, "outputTokens": 10, "totalTokens": 15}, + "metrics": {"latencyMs": 100}, + } + }, + ], + [ + {"event": {"messageStart": {"role": "assistant"}}}, + {"event": {"contentBlockStart": {"start": {}}}}, + {"event": {"contentBlockDelta": {"delta": {"text": "This is cited text"}}}}, + {"data": "This is cited text", "delta": {"text": "This is cited text"}}, + { + "event": { + "contentBlockDelta": { + "delta": { + "citation": { + "location": {"documentChar": {"documentIndex": 0, "start": 10, "end": 20}}, + "title": "Test Doc", + } + } + } + } + }, + { + "citation": { + "location": {"documentChar": {"documentIndex": 0, "start": 10, "end": 20}}, + "title": "Test Doc", + }, + "delta": { + "citation": { + "location": {"documentChar": {"documentIndex": 0, "start": 10, "end": 20}}, + "title": "Test Doc", + } + }, + }, + { + "event": { + "contentBlockDelta": { + "delta": { + "citation": { + "location": {"documentPage": {"documentIndex": 1, "start": 5, "end": 6}}, + "title": "Another Doc", + } + } + } + } + }, + { + "citation": { + "location": {"documentPage": {"documentIndex": 1, "start": 5, "end": 6}}, + "title": "Another Doc", + }, + "delta": { + "citation": { + "location": {"documentPage": {"documentIndex": 1, "start": 5, "end": 6}}, + "title": "Another Doc", + } + }, + }, + {"event": {"contentBlockStop": {}}}, + {"event": {"messageStop": {"stopReason": "end_turn"}}}, + { + "event": { + "metadata": { + "usage": {"inputTokens": 5, "outputTokens": 10, "totalTokens": 15}, + "metrics": {"latencyMs": 100}, + } + } + }, + { + "stop": ( + "end_turn", + { + "role": "assistant", + "content": [ + { + "citationsContent": { + "citations": [ + { + "location": { + "documentChar": {"documentIndex": 0, "start": 10, "end": 20} + }, + "title": "Test Doc", + }, + { + "location": { + "documentPage": {"documentIndex": 1, "start": 5, "end": 6} + }, + "title": "Another Doc", + }, + ], + "content": [{"text": "This is cited text"}], + } + } + ], + }, + {"inputTokens": 5, "outputTokens": 10, "totalTokens": 15}, + {"latencyMs": 100}, + ) + }, + ], + ), # Empty Message ( [{}], diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 2809e8a72..5ec5a7072 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -2070,3 +2070,76 @@ async def test_stream_backward_compatibility_system_prompt(bedrock_client, model "system": [{"text": system_prompt}], } bedrock_client.converse_stream.assert_called_once_with(**expected_request) + + +@pytest.mark.asyncio +async def test_citations_content_preserves_tagged_union_structure(bedrock_client, model, alist): + """Test that citationsContent preserves AWS Bedrock's required tagged union structure for citation locations. + + This test verifies that when messages contain citationsContent with tagged union CitationLocation objects, + the structure is preserved when sent to AWS Bedrock API. AWS Bedrock expects CitationLocation to be a + tagged union with exactly one wrapper key (documentChar, documentPage, etc.) containing the location fields. + """ + # Mock the Bedrock response + bedrock_client.converse_stream.return_value = {"stream": []} + + # Messages with citationsContent using tagged union CitationLocation structure + messages = [ + {"role": "user", "content": [{"text": "Analyze this document"}]}, + { + "role": "assistant", + "content": [ + { + "citationsContent": { + "citations": [ + { + "location": {"documentChar": {"documentIndex": 0, "start": 150, "end": 300}}, + "sourceContent": [ + {"text": "Employee benefits include health insurance and retirement plans"} + ], + "title": "Benefits Section", + }, + { + "location": {"documentPage": {"documentIndex": 0, "start": 2, "end": 3}}, + "sourceContent": [{"text": "Vacation policy allows 15 days per year"}], + "title": "Vacation Policy", + }, + ], + "content": [{"text": "Based on the document, employees receive comprehensive benefits."}], + } + } + ], + }, + ] + + # Call the public stream method + await alist(model.stream(messages)) + + # Verify the request sent to Bedrock preserves the tagged union structure + bedrock_client.converse_stream.assert_called_once() + call_args = bedrock_client.converse_stream.call_args[1] + + # Extract the citationsContent from the formatted messages + formatted_messages = call_args["messages"] + citations_content = formatted_messages[1]["content"][0]["citationsContent"] + + # Verify the tagged union structure is preserved + expected_citations = [ + { + "location": {"documentChar": {"documentIndex": 0, "start": 150, "end": 300}}, + "sourceContent": [{"text": "Employee benefits include health insurance and retirement plans"}], + "title": "Benefits Section", + }, + { + "location": {"documentPage": {"documentIndex": 0, "start": 2, "end": 3}}, + "sourceContent": [{"text": "Vacation policy allows 15 days per year"}], + "title": "Vacation Policy", + }, + ] + + assert citations_content["citations"] == expected_citations, ( + "Citation location tagged union structure was not preserved. " + "AWS Bedrock requires CitationLocation to have exactly one wrapper key " + "(documentChar, documentPage, documentChunk, searchResultLocation, or web) " + "with the location fields nested inside." + ) diff --git a/tests/strands/tools/mcp/test_mcp_client.py b/tests/strands/tools/mcp/test_mcp_client.py index e72aebd92..f3fdfc87a 100644 --- a/tests/strands/tools/mcp/test_mcp_client.py +++ b/tests/strands/tools/mcp/test_mcp_client.py @@ -533,7 +533,7 @@ def test_stop_closes_event_loop(): mock_thread.join = MagicMock() mock_event_loop = MagicMock() mock_event_loop.close = MagicMock() - + client._background_thread = mock_thread client._background_thread_event_loop = mock_event_loop @@ -542,7 +542,7 @@ def test_stop_closes_event_loop(): # Verify thread was joined mock_thread.join.assert_called_once() - + # Verify event loop was closed mock_event_loop.close.assert_called_once() diff --git a/tests/strands/types/test__events.py b/tests/strands/types/test__events.py index d64cabb83..6163faeb6 100644 --- a/tests/strands/types/test__events.py +++ b/tests/strands/types/test__events.py @@ -195,8 +195,8 @@ def test_initialization(self): delta = Mock(spec=ContentBlockDelta) citation = Mock(spec=Citation) event = CitationStreamEvent(delta, citation) - assert event["callback"]["citation"] == citation - assert event["callback"]["delta"] == delta + assert event["citation"] == citation + assert event["delta"] == delta class TestReasoningTextStreamEvent: diff --git a/tests_integ/models/test_model_bedrock.py b/tests_integ/models/test_model_bedrock.py index 2c2e125ad..b31f23663 100644 --- a/tests_integ/models/test_model_bedrock.py +++ b/tests_integ/models/test_model_bedrock.py @@ -210,6 +210,9 @@ def test_document_citations(non_streaming_agent, letter_pdf): assert any("citationsContent" in content for content in non_streaming_agent.messages[-1]["content"]) + # Validate message structure is valid in multi-turn + non_streaming_agent("What is your favorite part?") + def test_document_citations_streaming(streaming_agent, letter_pdf): content: list[ContentBlock] = [ @@ -228,6 +231,9 @@ def test_document_citations_streaming(streaming_agent, letter_pdf): assert any("citationsContent" in content for content in streaming_agent.messages[-1]["content"]) + # Validate message structure is valid in multi-turn + streaming_agent("What is your favorite part?") + def test_structured_output_multi_modal_input(streaming_agent, yellow_img, yellow_color): content = [