From 34840ef1fdb00c1b310154fb128a3852cad3049b Mon Sep 17 00:00:00 2001 From: Danilo Poccia Date: Thu, 11 Dec 2025 15:30:58 +0100 Subject: [PATCH] feat(citations): Add support for web-based citations in Bedrock Converse API Add support for web-based citations in addition to document-based citations: - Added WebLocation TypedDict to citations.py with url and domain fields - Updated CitationLocation union to include WebLocation - Updated bedrock.py to filter web citation fields (url, domain only) - Handle optional citation fields gracefully (title, location, sourceContent) - Added tests for web citations, document citations, and edge cases --- src/strands/models/bedrock.py | 42 ++++++++-------- src/strands/types/citations.py | 16 ++++++- tests/strands/models/test_bedrock.py | 72 ++++++++++++++++++++++++++++ 3 files changed, 110 insertions(+), 20 deletions(-) diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 4a7c81672..22e0168ff 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -493,7 +493,7 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_CitationsContentBlock.html if "citationsContent" in content: citations = content["citationsContent"] - result = {} + result: dict[str, Any] = {} if "citations" in citations: result["citations"] = [] @@ -501,15 +501,18 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An filtered_citation: dict[str, Any] = {} if "location" in citation: location = citation["location"] - filtered_location = {} - # Filter location fields to only include Bedrock-supported ones - if "documentIndex" in location: - filtered_location["documentIndex"] = location["documentIndex"] - if "start" in location: - filtered_location["start"] = location["start"] - if "end" in location: - filtered_location["end"] = location["end"] - filtered_citation["location"] = filtered_location + filtered_location: dict[str, Any] = {} + # Handle web-based citations + if "web" in location: + filtered_location["web"] = { + k: v for k, v in location["web"].items() if k in ("url", "domain") + } + # Handle document-based citations + for field in ("documentIndex", "start", "end"): + if field in location: + filtered_location[field] = location[field] + if filtered_location: + filtered_citation["location"] = filtered_location if "sourceContent" in citation: filtered_source_content: list[dict[str, Any]] = [] for source_content in citation["sourceContent"]: @@ -831,20 +834,21 @@ def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Itera # For non-streaming citations, emit text and metadata deltas in sequence # to match streaming behavior where they flow naturally if "content" in content["citationsContent"]: - text_content = "".join([content["text"] for content in content["citationsContent"]["content"]]) + text_content = "".join([c["text"] for c in content["citationsContent"]["content"]]) yield { "contentBlockDelta": {"delta": {"text": text_content}}, } for citation in content["citationsContent"]["citations"]: - # Then emit citation metadata (for structure) - - citation_metadata: CitationsDelta = { - "title": citation["title"], - "location": citation["location"], - "sourceContent": citation["sourceContent"], - } - yield {"contentBlockDelta": {"delta": {"citation": citation_metadata}}} + # Emit citation metadata with only present fields + citation_metadata: dict[str, Any] = {} + if "title" in citation: + citation_metadata["title"] = citation["title"] + if "location" in citation: + citation_metadata["location"] = citation["location"] + if "sourceContent" in citation: + citation_metadata["sourceContent"] = citation["sourceContent"] + yield {"contentBlockDelta": {"delta": {"citation": cast(CitationsDelta, citation_metadata)}}} # Yield contentBlockStop event yield {"contentBlockStop": {}} diff --git a/src/strands/types/citations.py b/src/strands/types/citations.py index b0e28f655..130ad2d9a 100644 --- a/src/strands/types/citations.py +++ b/src/strands/types/citations.py @@ -77,8 +77,22 @@ class DocumentPageLocation(TypedDict, total=False): end: int +class WebLocation(TypedDict, total=False): + """Specifies a web-based location for cited content. + + Provides location information for content cited from web sources. + + Attributes: + url: The URL of the web page containing the cited content. + domain: The domain of the web page containing the cited content. + """ + + url: str + domain: str + + # Union type for citation locations -CitationLocation = Union[DocumentCharLocation, DocumentChunkLocation, DocumentPageLocation] +CitationLocation = Union[DocumentCharLocation, DocumentChunkLocation, DocumentPageLocation, WebLocation] class CitationSourceContent(TypedDict, total=False): diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 2809e8a72..866ee3f03 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -2070,3 +2070,75 @@ async def test_stream_backward_compatibility_system_prompt(bedrock_client, model "system": [{"text": system_prompt}], } bedrock_client.converse_stream.assert_called_once_with(**expected_request) + + +def test_format_request_message_content_web_citation(model): + """Test that web citations are correctly filtered to include only url and domain.""" + content = { + "citationsContent": { + "citations": [ + { + "title": "Web Citation", + "location": {"web": {"url": "https://example.com", "domain": "example.com", "extra": "ignored"}}, + "sourceContent": [{"text": "Content"}], + } + ], + "content": [{"text": "Generated text"}], + } + } + + result = model._format_request_message_content(content) + + citation = result["citationsContent"]["citations"][0] + assert citation["location"]["web"] == {"url": "https://example.com", "domain": "example.com"} + + +def test_format_request_message_content_document_citation(model): + """Test that document citations preserve documentIndex, start, and end fields.""" + content = { + "citationsContent": { + "citations": [ + { + "title": "Doc Citation", + "location": {"documentIndex": 0, "start": 100, "end": 200}, + "sourceContent": [{"text": "Excerpt"}], + } + ], + "content": [{"text": "Generated text"}], + } + } + + result = model._format_request_message_content(content) + + assert result["citationsContent"]["citations"][0]["location"] == {"documentIndex": 0, "start": 100, "end": 200} + + +def test_format_request_message_content_citation_optional_fields(model): + """Test that citations with missing optional fields are handled correctly.""" + content = { + "citationsContent": { + "citations": [{"title": "Minimal", "location": {"web": {"url": "https://example.com"}}}], + "content": [{"text": "Text"}], + } + } + + result = model._format_request_message_content(content) + + citation = result["citationsContent"]["citations"][0] + assert citation["title"] == "Minimal" + assert citation["location"]["web"]["url"] == "https://example.com" + assert "sourceContent" not in citation + + +def test_format_request_message_content_citation_empty_location(model): + """Test that citations with invalid locations exclude the location field.""" + content = { + "citationsContent": { + "citations": [{"title": "No valid location", "location": {"unknown": "value"}}], + "content": [{"text": "Text"}], + } + } + + result = model._format_request_message_content(content) + + assert "location" not in result["citationsContent"]["citations"][0]