From 9413c0b17e881107fd78fa72596e8c61459596de Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Tue, 16 Dec 2025 13:54:15 +1100 Subject: [PATCH 1/6] Fixes CitationLocation UnionType, and when streaming is being used --- src/strands/event_loop/streaming.py | 7 +++-- src/strands/models/bedrock.py | 22 +++++++++------ src/strands/types/citations.py | 32 +++++++++++++++++++++- tests/strands/tools/mcp/test_mcp_client.py | 4 +-- tests_integ/models/test_model_bedrock.py | 3 ++ 5 files changed, 54 insertions(+), 14 deletions(-) diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py index 43836fe34..804f90a1d 100644 --- a/src/strands/event_loop/streaming.py +++ b/src/strands/event_loop/streaming.py @@ -289,12 +289,13 @@ def handle_content_block_stop(state: dict[str, Any]) -> dict[str, Any]: state["current_tool_use"] = {} elif text: - content.append({"text": text}) - state["text"] = "" if citations_content: - citations_block: CitationsContentBlock = {"citations": citations_content} + citations_block: CitationsContentBlock = {"citations": citations_content, "content": [{"text": text}]} content.append({"citationsContent": citations_block}) state["citationsContent"] = [] + else: + content.append({"text": text}) + state["text"] = "" elif reasoning_text: content_block: ContentBlock = { diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 4a7c81672..8c075cda9 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -501,15 +501,21 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An filtered_citation: dict[str, Any] = {} if "location" in citation: location = citation["location"] - filtered_location = {} # Filter location fields to only include Bedrock-supported ones - if "documentIndex" in location: - filtered_location["documentIndex"] = location["documentIndex"] - if "start" in location: - filtered_location["start"] = location["start"] - if "end" in location: - filtered_location["end"] = location["end"] - filtered_citation["location"] = filtered_location + allowed = { + "documentChar": {"documentIndex", "start", "end"}, + "documentPage": {"documentIndex", "start", "end"}, + "documentChunk": {"documentIndex", "start", "end"}, + "searchResultLocation": {"searchResultIndex", "start", "end"}, + "web": {"url", "domain"}, + } + for union_type, fields in allowed.items(): + if union_type in location: + inner = location[union_type] + filtered_citation["location"] = { + union_type: {k: v for k, v in inner.items() if k in fields} + } + break if "sourceContent" in citation: filtered_source_content: list[dict[str, Any]] = [] for source_content in citation["sourceContent"]: diff --git a/src/strands/types/citations.py b/src/strands/types/citations.py index b0e28f655..66a14e497 100644 --- a/src/strands/types/citations.py +++ b/src/strands/types/citations.py @@ -77,8 +77,38 @@ class DocumentPageLocation(TypedDict, total=False): end: int +class SearchResultLocation(TypedDict, total=False): + """Specifies a search result location within the content array. + + Provides positioning information for cited content using search result index and block positions. + + Attributes: + searchResultIndex: The index of the search result content block where the cited content is found. Minimum value of 0. + start: The starting position in the content array where the cited content begins. Minimum value of 0. + end: The ending position in the content array where the cited content ends. Minimum value of 0. + """ + + searchResultIndex: int + start: int + end: int + + +class WebLocation(TypedDict, total=False): + """Provides the URL and domain information for the website that was cited when performing a web search. + + Attributes: + domain: The domain that was cited when performing a web search. + url: The URL that was cited when performing a web search. + """ + + domain: str + url: str + + # Union type for citation locations -CitationLocation = Union[DocumentCharLocation, DocumentChunkLocation, DocumentPageLocation] +CitationLocation = Union[ + DocumentCharLocation, DocumentChunkLocation, DocumentPageLocation, SearchResultLocation, WebLocation +] class CitationSourceContent(TypedDict, total=False): diff --git a/tests/strands/tools/mcp/test_mcp_client.py b/tests/strands/tools/mcp/test_mcp_client.py index e72aebd92..f3fdfc87a 100644 --- a/tests/strands/tools/mcp/test_mcp_client.py +++ b/tests/strands/tools/mcp/test_mcp_client.py @@ -533,7 +533,7 @@ def test_stop_closes_event_loop(): mock_thread.join = MagicMock() mock_event_loop = MagicMock() mock_event_loop.close = MagicMock() - + client._background_thread = mock_thread client._background_thread_event_loop = mock_event_loop @@ -542,7 +542,7 @@ def test_stop_closes_event_loop(): # Verify thread was joined mock_thread.join.assert_called_once() - + # Verify event loop was closed mock_event_loop.close.assert_called_once() diff --git a/tests_integ/models/test_model_bedrock.py b/tests_integ/models/test_model_bedrock.py index 2c2e125ad..f9e536481 100644 --- a/tests_integ/models/test_model_bedrock.py +++ b/tests_integ/models/test_model_bedrock.py @@ -228,6 +228,9 @@ def test_document_citations_streaming(streaming_agent, letter_pdf): assert any("citationsContent" in content for content in streaming_agent.messages[-1]["content"]) + # Validate message structure is valid in multi-turn + streaming_agent("What is your favorite part?") + def test_structured_output_multi_modal_input(streaming_agent, yellow_img, yellow_color): content = [ From f5dfbe77c6689895f9d399502a17525d638a900a Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Tue, 16 Dec 2025 14:31:03 +1100 Subject: [PATCH 2/6] Modification to citations.py for linter --- src/strands/types/citations.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/strands/types/citations.py b/src/strands/types/citations.py index 66a14e497..33ccad84c 100644 --- a/src/strands/types/citations.py +++ b/src/strands/types/citations.py @@ -80,12 +80,16 @@ class DocumentPageLocation(TypedDict, total=False): class SearchResultLocation(TypedDict, total=False): """Specifies a search result location within the content array. - Provides positioning information for cited content using search result index and block positions. + Provides positioning information for cited content using search result index + and block positions. Attributes: - searchResultIndex: The index of the search result content block where the cited content is found. Minimum value of 0. - start: The starting position in the content array where the cited content begins. Minimum value of 0. - end: The ending position in the content array where the cited content ends. Minimum value of 0. + searchResultIndex: The index of the search result content block where + the cited content is found. Minimum value of 0. + start: The starting position in the content array where the cited content + begins. Minimum value of 0. + end: The ending position in the content array where the cited content ends. + Minimum value of 0. """ searchResultIndex: int From 88ac264562d4d307dc13a7f9101a6eab0774c46c Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Tue, 16 Dec 2025 14:52:45 +1100 Subject: [PATCH 3/6] Ignore dynamic dict construction in bedrock.py for mypy --- src/strands/models/bedrock.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 8c075cda9..aeea2d293 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -511,7 +511,7 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An } for union_type, fields in allowed.items(): if union_type in location: - inner = location[union_type] + inner = location[union_type] # type: ignore[literal-required] filtered_citation["location"] = { union_type: {k: v for k, v in inner.items() if k in fields} } From 937e8f4065e3ccc24affbeea64e8207fffc3397c Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Tue, 16 Dec 2025 15:40:56 -0500 Subject: [PATCH 4/6] feat: align citation content block pattern with ToolChoice --- src/strands/models/bedrock.py | 17 +----- src/strands/types/citations.py | 43 +++----------- tests/strands/models/test_bedrock.py | 73 ++++++++++++++++++++++++ tests_integ/models/test_model_bedrock.py | 3 + 4 files changed, 86 insertions(+), 50 deletions(-) diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index aeea2d293..08d8f400c 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -500,22 +500,7 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An for citation in citations["citations"]: filtered_citation: dict[str, Any] = {} if "location" in citation: - location = citation["location"] - # Filter location fields to only include Bedrock-supported ones - allowed = { - "documentChar": {"documentIndex", "start", "end"}, - "documentPage": {"documentIndex", "start", "end"}, - "documentChunk": {"documentIndex", "start", "end"}, - "searchResultLocation": {"searchResultIndex", "start", "end"}, - "web": {"url", "domain"}, - } - for union_type, fields in allowed.items(): - if union_type in location: - inner = location[union_type] # type: ignore[literal-required] - filtered_citation["location"] = { - union_type: {k: v for k, v in inner.items() if k in fields} - } - break + filtered_citation["location"] = citation["location"] if "sourceContent" in citation: filtered_source_content: list[dict[str, Any]] = [] for source_content in citation["sourceContent"]: diff --git a/src/strands/types/citations.py b/src/strands/types/citations.py index 33ccad84c..41f2fa4e0 100644 --- a/src/strands/types/citations.py +++ b/src/strands/types/citations.py @@ -3,7 +3,7 @@ These types are modeled after the Bedrock API. """ -from typing import List, Union +from typing import List, Literal, Union from typing_extensions import TypedDict @@ -77,41 +77,16 @@ class DocumentPageLocation(TypedDict, total=False): end: int -class SearchResultLocation(TypedDict, total=False): - """Specifies a search result location within the content array. +# Tagged union type aliases following the ToolChoice pattern +DocumentCharLocationDict = dict[Literal["documentChar"], DocumentCharLocation] +DocumentPageLocationDict = dict[Literal["documentPage"], DocumentPageLocation] +DocumentChunkLocationDict = dict[Literal["documentChunk"], DocumentChunkLocation] - Provides positioning information for cited content using search result index - and block positions. - - Attributes: - searchResultIndex: The index of the search result content block where - the cited content is found. Minimum value of 0. - start: The starting position in the content array where the cited content - begins. Minimum value of 0. - end: The ending position in the content array where the cited content ends. - Minimum value of 0. - """ - - searchResultIndex: int - start: int - end: int - - -class WebLocation(TypedDict, total=False): - """Provides the URL and domain information for the website that was cited when performing a web search. - - Attributes: - domain: The domain that was cited when performing a web search. - url: The URL that was cited when performing a web search. - """ - - domain: str - url: str - - -# Union type for citation locations +# Union type for citation locations - tagged union format matching AWS Bedrock API CitationLocation = Union[ - DocumentCharLocation, DocumentChunkLocation, DocumentPageLocation, SearchResultLocation, WebLocation + DocumentCharLocationDict, + DocumentPageLocationDict, + DocumentChunkLocationDict, ] diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 2809e8a72..5ec5a7072 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -2070,3 +2070,76 @@ async def test_stream_backward_compatibility_system_prompt(bedrock_client, model "system": [{"text": system_prompt}], } bedrock_client.converse_stream.assert_called_once_with(**expected_request) + + +@pytest.mark.asyncio +async def test_citations_content_preserves_tagged_union_structure(bedrock_client, model, alist): + """Test that citationsContent preserves AWS Bedrock's required tagged union structure for citation locations. + + This test verifies that when messages contain citationsContent with tagged union CitationLocation objects, + the structure is preserved when sent to AWS Bedrock API. AWS Bedrock expects CitationLocation to be a + tagged union with exactly one wrapper key (documentChar, documentPage, etc.) containing the location fields. + """ + # Mock the Bedrock response + bedrock_client.converse_stream.return_value = {"stream": []} + + # Messages with citationsContent using tagged union CitationLocation structure + messages = [ + {"role": "user", "content": [{"text": "Analyze this document"}]}, + { + "role": "assistant", + "content": [ + { + "citationsContent": { + "citations": [ + { + "location": {"documentChar": {"documentIndex": 0, "start": 150, "end": 300}}, + "sourceContent": [ + {"text": "Employee benefits include health insurance and retirement plans"} + ], + "title": "Benefits Section", + }, + { + "location": {"documentPage": {"documentIndex": 0, "start": 2, "end": 3}}, + "sourceContent": [{"text": "Vacation policy allows 15 days per year"}], + "title": "Vacation Policy", + }, + ], + "content": [{"text": "Based on the document, employees receive comprehensive benefits."}], + } + } + ], + }, + ] + + # Call the public stream method + await alist(model.stream(messages)) + + # Verify the request sent to Bedrock preserves the tagged union structure + bedrock_client.converse_stream.assert_called_once() + call_args = bedrock_client.converse_stream.call_args[1] + + # Extract the citationsContent from the formatted messages + formatted_messages = call_args["messages"] + citations_content = formatted_messages[1]["content"][0]["citationsContent"] + + # Verify the tagged union structure is preserved + expected_citations = [ + { + "location": {"documentChar": {"documentIndex": 0, "start": 150, "end": 300}}, + "sourceContent": [{"text": "Employee benefits include health insurance and retirement plans"}], + "title": "Benefits Section", + }, + { + "location": {"documentPage": {"documentIndex": 0, "start": 2, "end": 3}}, + "sourceContent": [{"text": "Vacation policy allows 15 days per year"}], + "title": "Vacation Policy", + }, + ] + + assert citations_content["citations"] == expected_citations, ( + "Citation location tagged union structure was not preserved. " + "AWS Bedrock requires CitationLocation to have exactly one wrapper key " + "(documentChar, documentPage, documentChunk, searchResultLocation, or web) " + "with the location fields nested inside." + ) diff --git a/tests_integ/models/test_model_bedrock.py b/tests_integ/models/test_model_bedrock.py index f9e536481..b31f23663 100644 --- a/tests_integ/models/test_model_bedrock.py +++ b/tests_integ/models/test_model_bedrock.py @@ -210,6 +210,9 @@ def test_document_citations(non_streaming_agent, letter_pdf): assert any("citationsContent" in content for content in non_streaming_agent.messages[-1]["content"]) + # Validate message structure is valid in multi-turn + non_streaming_agent("What is your favorite part?") + def test_document_citations_streaming(streaming_agent, letter_pdf): content: list[ContentBlock] = [ From 35e70dae4eda02f00565dae13c571a90a8c4974a Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Tue, 16 Dec 2025 16:08:08 -0500 Subject: [PATCH 5/6] fix(citations): remove unnecessary key from CitationStreamEvent --- src/strands/types/_events.py | 2 +- tests/strands/event_loop/test_streaming.py | 227 ++++++++++++++++++++- tests/strands/types/test__events.py | 5 +- 3 files changed, 228 insertions(+), 6 deletions(-) diff --git a/src/strands/types/_events.py b/src/strands/types/_events.py index c3890f428..d64357cf8 100644 --- a/src/strands/types/_events.py +++ b/src/strands/types/_events.py @@ -161,7 +161,7 @@ class CitationStreamEvent(ModelStreamEvent): def __init__(self, delta: ContentBlockDelta, citation: Citation) -> None: """Initialize with delta and citation content.""" - super().__init__({"callback": {"citation": citation, "delta": delta}}) + super().__init__({"citation": citation, "delta": delta}) class ReasoningTextStreamEvent(ModelStreamEvent): diff --git a/tests/strands/event_loop/test_streaming.py b/tests/strands/event_loop/test_streaming.py index 02be400b1..c6e44b78a 100644 --- a/tests/strands/event_loop/test_streaming.py +++ b/tests/strands/event_loop/test_streaming.py @@ -215,6 +215,59 @@ def test_handle_content_block_start(chunk: ContentBlockStartEvent, exp_tool_use) {}, {}, ), + # Citation - New + ( + { + "delta": { + "citation": { + "location": {"documentChar": {"documentIndex": 0, "start": 10, "end": 20}}, + "title": "Test Doc", + } + } + }, + {}, + {}, + { + "citationsContent": [ + {"location": {"documentChar": {"documentIndex": 0, "start": 10, "end": 20}}, "title": "Test Doc"} + ] + }, + { + "citation": { + "location": {"documentChar": {"documentIndex": 0, "start": 10, "end": 20}}, + "title": "Test Doc", + } + }, + ), + # Citation - Existing + ( + { + "delta": { + "citation": { + "location": {"documentPage": {"documentIndex": 1, "start": 5, "end": 6}}, + "title": "Another Doc", + } + } + }, + {}, + { + "citationsContent": [ + {"location": {"documentChar": {"documentIndex": 0, "start": 10, "end": 20}}, "title": "Test Doc"} + ] + }, + { + "citationsContent": [ + {"location": {"documentChar": {"documentIndex": 0, "start": 10, "end": 20}}, "title": "Test Doc"}, + {"location": {"documentPage": {"documentIndex": 1, "start": 5, "end": 6}}, "title": "Another Doc"}, + ] + }, + { + "citation": { + "location": {"documentPage": {"documentIndex": 1, "start": 5, "end": 6}}, + "title": "Another Doc", + } + }, + ), # Empty ( {"delta": {}}, @@ -294,14 +347,49 @@ def test_handle_content_block_delta(event: ContentBlockDeltaEvent, event_type, s "redactedContent": b"", }, ), - # Citations + # Text with Citations + ( + { + "content": [], + "current_tool_use": {}, + "text": "This is cited text", + "reasoningText": "", + "citationsContent": [ + {"location": {"documentChar": {"documentIndex": 0, "start": 10, "end": 20}}, "title": "Test Doc"} + ], + "redactedContent": b"", + }, + { + "content": [ + { + "citationsContent": { + "citations": [ + { + "location": {"documentChar": {"documentIndex": 0, "start": 10, "end": 20}}, + "title": "Test Doc", + } + ], + "content": [{"text": "This is cited text"}], + } + } + ], + "current_tool_use": {}, + "text": "", + "reasoningText": "", + "citationsContent": [], + "redactedContent": b"", + }, + ), + # Citations without text (should not create content block) ( { "content": [], "current_tool_use": {}, "text": "", "reasoningText": "", - "citationsContent": [{"citations": [{"text": "test", "source": "test"}]}], + "citationsContent": [ + {"location": {"documentChar": {"documentIndex": 0, "start": 10, "end": 20}}, "title": "Test Doc"} + ], "redactedContent": b"", }, { @@ -309,7 +397,9 @@ def test_handle_content_block_delta(event: ContentBlockDeltaEvent, event_type, s "current_tool_use": {}, "text": "", "reasoningText": "", - "citationsContent": [{"citations": [{"text": "test", "source": "test"}]}], + "citationsContent": [ + {"location": {"documentChar": {"documentIndex": 0, "start": 10, "end": 20}}, "title": "Test Doc"} + ], "redactedContent": b"", }, ), @@ -578,6 +668,137 @@ def test_extract_usage_metrics_empty_metadata(): }, ], ), + # Message with Citations + ( + [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockStart": {"start": {}}}, + {"contentBlockDelta": {"delta": {"text": "This is cited text"}}}, + { + "contentBlockDelta": { + "delta": { + "citation": { + "location": {"documentChar": {"documentIndex": 0, "start": 10, "end": 20}}, + "title": "Test Doc", + } + } + } + }, + { + "contentBlockDelta": { + "delta": { + "citation": { + "location": {"documentPage": {"documentIndex": 1, "start": 5, "end": 6}}, + "title": "Another Doc", + } + } + } + }, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn"}}, + { + "metadata": { + "usage": {"inputTokens": 5, "outputTokens": 10, "totalTokens": 15}, + "metrics": {"latencyMs": 100}, + } + }, + ], + [ + {"event": {"messageStart": {"role": "assistant"}}}, + {"event": {"contentBlockStart": {"start": {}}}}, + {"event": {"contentBlockDelta": {"delta": {"text": "This is cited text"}}}}, + {"data": "This is cited text", "delta": {"text": "This is cited text"}}, + { + "event": { + "contentBlockDelta": { + "delta": { + "citation": { + "location": {"documentChar": {"documentIndex": 0, "start": 10, "end": 20}}, + "title": "Test Doc", + } + } + } + } + }, + { + "citation": { + "location": {"documentChar": {"documentIndex": 0, "start": 10, "end": 20}}, + "title": "Test Doc", + }, + "delta": { + "citation": { + "location": {"documentChar": {"documentIndex": 0, "start": 10, "end": 20}}, + "title": "Test Doc", + } + }, + }, + { + "event": { + "contentBlockDelta": { + "delta": { + "citation": { + "location": {"documentPage": {"documentIndex": 1, "start": 5, "end": 6}}, + "title": "Another Doc", + } + } + } + } + }, + { + "citation": { + "location": {"documentPage": {"documentIndex": 1, "start": 5, "end": 6}}, + "title": "Another Doc", + }, + "delta": { + "citation": { + "location": {"documentPage": {"documentIndex": 1, "start": 5, "end": 6}}, + "title": "Another Doc", + } + }, + }, + {"event": {"contentBlockStop": {}}}, + {"event": {"messageStop": {"stopReason": "end_turn"}}}, + { + "event": { + "metadata": { + "usage": {"inputTokens": 5, "outputTokens": 10, "totalTokens": 15}, + "metrics": {"latencyMs": 100}, + } + } + }, + { + "stop": ( + "end_turn", + { + "role": "assistant", + "content": [ + { + "citationsContent": { + "citations": [ + { + "location": { + "documentChar": {"documentIndex": 0, "start": 10, "end": 20} + }, + "title": "Test Doc", + }, + { + "location": { + "documentPage": {"documentIndex": 1, "start": 5, "end": 6} + }, + "title": "Another Doc", + }, + ], + "content": [{"text": "This is cited text"}], + } + } + ], + }, + {"inputTokens": 5, "outputTokens": 10, "totalTokens": 15}, + {"latencyMs": 100}, + ) + }, + ], + ), # Empty Message ( [{}], diff --git a/tests/strands/types/test__events.py b/tests/strands/types/test__events.py index d64cabb83..6a14ba10c 100644 --- a/tests/strands/types/test__events.py +++ b/tests/strands/types/test__events.py @@ -195,8 +195,9 @@ def test_initialization(self): delta = Mock(spec=ContentBlockDelta) citation = Mock(spec=Citation) event = CitationStreamEvent(delta, citation) - assert event["callback"]["citation"] == citation - assert event["callback"]["delta"] == delta + print(event) + assert event["citation"] == citation + assert event["delta"] == delta class TestReasoningTextStreamEvent: From 93836221e8090d0b8cf3f00ae8d1ad3a8098356f Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Wed, 17 Dec 2025 12:18:57 -0500 Subject: [PATCH 6/6] Update test__events.py --- tests/strands/types/test__events.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/strands/types/test__events.py b/tests/strands/types/test__events.py index 6a14ba10c..6163faeb6 100644 --- a/tests/strands/types/test__events.py +++ b/tests/strands/types/test__events.py @@ -195,7 +195,6 @@ def test_initialization(self): delta = Mock(spec=ContentBlockDelta) citation = Mock(spec=Citation) event = CitationStreamEvent(delta, citation) - print(event) assert event["citation"] == citation assert event["delta"] == delta