diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 4a7c81672..14bcc3cbb 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -8,7 +8,20 @@ import logging import os import warnings -from typing import Any, AsyncGenerator, Callable, Iterable, Literal, Optional, Type, TypeVar, Union, ValuesView, cast +from typing import ( + Any, + AsyncGenerator, + Callable, + Iterable, + Literal, + Mapping, + Optional, + Type, + TypeVar, + Union, + ValuesView, + cast, +) import boto3 from botocore.config import Config as BotocoreConfig @@ -493,23 +506,16 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_CitationsContentBlock.html if "citationsContent" in content: citations = content["citationsContent"] - result = {} + citations_result: dict[str, Any] = {} if "citations" in citations: - result["citations"] = [] + citations_result["citations"] = [] for citation in citations["citations"]: filtered_citation: dict[str, Any] = {} if "location" in citation: - location = citation["location"] - filtered_location = {} - # Filter location fields to only include Bedrock-supported ones - if "documentIndex" in location: - filtered_location["documentIndex"] = location["documentIndex"] - if "start" in location: - filtered_location["start"] = location["start"] - if "end" in location: - filtered_location["end"] = location["end"] - filtered_citation["location"] = filtered_location + filtered_location = self._format_citation_location(citation["location"]) + if filtered_location: + filtered_citation["location"] = filtered_location if "sourceContent" in citation: filtered_source_content: list[dict[str, Any]] = [] for source_content in citation["sourceContent"]: @@ -519,7 +525,7 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An filtered_citation["sourceContent"] = filtered_source_content if "title" in citation: filtered_citation["title"] = citation["title"] - result["citations"].append(filtered_citation) + citations_result["citations"].append(filtered_citation) if "content" in citations: filtered_content: list[dict[str, Any]] = [] @@ -527,12 +533,43 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An if "text" in generated_content: filtered_content.append({"text": generated_content["text"]}) if filtered_content: - result["content"] = filtered_content + citations_result["content"] = filtered_content - return {"citationsContent": result} + return {"citationsContent": citations_result} raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") + def _format_citation_location(self, location: Mapping[str, Any]) -> dict[str, Any]: + """Format a citation location preserving the tagged union structure. + + The Bedrock API requires CitationLocation to be a tagged union with exactly one + of the following keys: documentChar, documentPage, or documentChunk. + + Args: + location: Citation location to format. + + Returns: + Formatted location with tagged union structure preserved, or empty dict if invalid. + + See: + https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_CitationLocation.html + """ + # Allowed fields for each tagged union type + allowed_fields = { + "documentChar": ("documentIndex", "start", "end"), + "documentPage": ("documentIndex", "start", "end"), + "documentChunk": ("documentIndex", "start", "end"), + } + + for location_type, fields in allowed_fields.items(): + if location_type in location: + inner = location[location_type] + filtered = {k: v for k, v in inner.items() if k in fields} + return {location_type: filtered} if filtered else {} + + logger.debug("location_type= | unrecognized citation location type, skipping") + return {} + def _has_blocked_guardrail(self, guardrail_data: dict[str, Any]) -> bool: """Check if guardrail data contains any blocked policies. diff --git a/src/strands/types/citations.py b/src/strands/types/citations.py index b0e28f655..12c918e34 100644 --- a/src/strands/types/citations.py +++ b/src/strands/types/citations.py @@ -1,6 +1,7 @@ """Citation type definitions for the SDK. These types are modeled after the Bedrock API. +https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_CitationLocation.html """ from typing import List, Union @@ -18,11 +19,8 @@ class CitationsConfig(TypedDict): enabled: bool -class DocumentCharLocation(TypedDict, total=False): - """Specifies a character-level location within a document. - - Provides precise positioning information for cited content using - start and end character indices. +class DocumentCharLocationInner(TypedDict, total=False): + """Inner content for character-level location within a document. Attributes: documentIndex: The index of the document within the array of documents @@ -38,11 +36,8 @@ class DocumentCharLocation(TypedDict, total=False): end: int -class DocumentChunkLocation(TypedDict, total=False): - """Specifies a chunk-level location within a document. - - Provides positioning information for cited content using logical - document segments or chunks. +class DocumentChunkLocationInner(TypedDict, total=False): + """Inner content for chunk-level location within a document. Attributes: documentIndex: The index of the document within the array of documents @@ -58,10 +53,8 @@ class DocumentChunkLocation(TypedDict, total=False): end: int -class DocumentPageLocation(TypedDict, total=False): - """Specifies a page-level location within a document. - - Provides positioning information for cited content using page numbers. +class DocumentPageLocationInner(TypedDict, total=False): + """Inner content for page-level location within a document. Attributes: documentIndex: The index of the document within the array of documents @@ -77,7 +70,37 @@ class DocumentPageLocation(TypedDict, total=False): end: int -# Union type for citation locations +class DocumentCharLocation(TypedDict, total=False): + """Tagged union wrapper for character-level document location. + + Attributes: + documentChar: The character-level location data. + """ + + documentChar: DocumentCharLocationInner + + +class DocumentChunkLocation(TypedDict, total=False): + """Tagged union wrapper for chunk-level document location. + + Attributes: + documentChunk: The chunk-level location data. + """ + + documentChunk: DocumentChunkLocationInner + + +class DocumentPageLocation(TypedDict, total=False): + """Tagged union wrapper for page-level document location. + + Attributes: + documentPage: The page-level location data. + """ + + documentPage: DocumentPageLocationInner + + +# Union type for citation locations - tagged union where exactly one key is present CitationLocation = Union[DocumentCharLocation, DocumentChunkLocation, DocumentPageLocation] diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 2809e8a72..5ebf3273a 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -2070,3 +2070,106 @@ async def test_stream_backward_compatibility_system_prompt(bedrock_client, model "system": [{"text": system_prompt}], } bedrock_client.converse_stream.assert_called_once_with(**expected_request) + + +def test_format_request_message_content_document_char_citation(model): + """Test that documentChar citations preserve tagged union structure.""" + content = { + "citationsContent": { + "citations": [ + { + "title": "Doc Citation", + "location": {"documentChar": {"documentIndex": 0, "start": 100, "end": 200}}, + "sourceContent": [{"text": "Excerpt"}], + } + ], + "content": [{"text": "Generated text"}], + } + } + + result = model._format_request_message_content(content) + + assert result["citationsContent"]["citations"][0]["location"] == { + "documentChar": {"documentIndex": 0, "start": 100, "end": 200} + } + + +def test_format_request_message_content_document_page_citation(model): + """Test that documentPage citations preserve tagged union structure.""" + content = { + "citationsContent": { + "citations": [ + { + "title": "Page Citation", + "location": {"documentPage": {"documentIndex": 0, "start": 2, "end": 3}}, + "sourceContent": [{"text": "Page content"}], + } + ], + "content": [{"text": "Generated text"}], + } + } + + result = model._format_request_message_content(content) + + assert result["citationsContent"]["citations"][0]["location"] == { + "documentPage": {"documentIndex": 0, "start": 2, "end": 3} + } + + +def test_format_request_message_content_document_chunk_citation(model): + """Test that documentChunk citations preserve tagged union structure.""" + content = { + "citationsContent": { + "citations": [ + { + "title": "Chunk Citation", + "location": {"documentChunk": {"documentIndex": 1, "start": 5, "end": 10}}, + "sourceContent": [{"text": "Chunk content"}], + } + ], + "content": [{"text": "Generated text"}], + } + } + + result = model._format_request_message_content(content) + + assert result["citationsContent"]["citations"][0]["location"] == { + "documentChunk": {"documentIndex": 1, "start": 5, "end": 10} + } + + +def test_format_request_message_content_citation_filters_extra_fields(model): + """Test that extra fields in citation location inner content are filtered out.""" + content = { + "citationsContent": { + "citations": [ + { + "title": "Citation with extra fields", + "location": {"documentChar": {"documentIndex": 0, "start": 0, "end": 50, "extraField": "ignored"}}, + "sourceContent": [{"text": "Content"}], + } + ], + "content": [{"text": "Text"}], + } + } + + result = model._format_request_message_content(content) + + # extraField should be filtered out + assert result["citationsContent"]["citations"][0]["location"] == { + "documentChar": {"documentIndex": 0, "start": 0, "end": 50} + } + + +def test_format_request_message_content_citation_unknown_location_type(model): + """Test that citations with unknown location types exclude the location field.""" + content = { + "citationsContent": { + "citations": [{"title": "Unknown location", "location": {"unknownType": {"field": "value"}}}], + "content": [{"text": "Text"}], + } + } + + result = model._format_request_message_content(content) + + assert "location" not in result["citationsContent"]["citations"][0] diff --git a/tests/strands/tools/mcp/test_mcp_client.py b/tests/strands/tools/mcp/test_mcp_client.py index e72aebd92..f3fdfc87a 100644 --- a/tests/strands/tools/mcp/test_mcp_client.py +++ b/tests/strands/tools/mcp/test_mcp_client.py @@ -533,7 +533,7 @@ def test_stop_closes_event_loop(): mock_thread.join = MagicMock() mock_event_loop = MagicMock() mock_event_loop.close = MagicMock() - + client._background_thread = mock_thread client._background_thread_event_loop = mock_event_loop @@ -542,7 +542,7 @@ def test_stop_closes_event_loop(): # Verify thread was joined mock_thread.join.assert_called_once() - + # Verify event loop was closed mock_event_loop.close.assert_called_once()