Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 53 additions & 16 deletions src/strands/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,20 @@
import logging
import os
import warnings
from typing import Any, AsyncGenerator, Callable, Iterable, Literal, Optional, Type, TypeVar, Union, ValuesView, cast
from typing import (
Any,
AsyncGenerator,
Callable,
Iterable,
Literal,
Mapping,
Optional,
Type,
TypeVar,
Union,
ValuesView,
cast,
)

import boto3
from botocore.config import Config as BotocoreConfig
Expand Down Expand Up @@ -493,23 +506,16 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An
# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_CitationsContentBlock.html
if "citationsContent" in content:
citations = content["citationsContent"]
result = {}
citations_result: dict[str, Any] = {}

if "citations" in citations:
result["citations"] = []
citations_result["citations"] = []
for citation in citations["citations"]:
filtered_citation: dict[str, Any] = {}
if "location" in citation:
location = citation["location"]
filtered_location = {}
# Filter location fields to only include Bedrock-supported ones
if "documentIndex" in location:
filtered_location["documentIndex"] = location["documentIndex"]
if "start" in location:
filtered_location["start"] = location["start"]
if "end" in location:
filtered_location["end"] = location["end"]
filtered_citation["location"] = filtered_location
filtered_location = self._format_citation_location(citation["location"])
if filtered_location:
filtered_citation["location"] = filtered_location
if "sourceContent" in citation:
filtered_source_content: list[dict[str, Any]] = []
for source_content in citation["sourceContent"]:
Expand All @@ -519,20 +525,51 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An
filtered_citation["sourceContent"] = filtered_source_content
if "title" in citation:
filtered_citation["title"] = citation["title"]
result["citations"].append(filtered_citation)
citations_result["citations"].append(filtered_citation)

if "content" in citations:
filtered_content: list[dict[str, Any]] = []
for generated_content in citations["content"]:
if "text" in generated_content:
filtered_content.append({"text": generated_content["text"]})
if filtered_content:
result["content"] = filtered_content
citations_result["content"] = filtered_content

return {"citationsContent": result}
return {"citationsContent": citations_result}

raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type")

def _format_citation_location(self, location: Mapping[str, Any]) -> dict[str, Any]:
"""Format a citation location preserving the tagged union structure.

The Bedrock API requires CitationLocation to be a tagged union with exactly one
of the following keys: documentChar, documentPage, or documentChunk.

Args:
location: Citation location to format.

Returns:
Formatted location with tagged union structure preserved, or empty dict if invalid.

See:
https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_CitationLocation.html
"""
# Allowed fields for each tagged union type
allowed_fields = {
"documentChar": ("documentIndex", "start", "end"),
"documentPage": ("documentIndex", "start", "end"),
"documentChunk": ("documentIndex", "start", "end"),
}

for location_type, fields in allowed_fields.items():
if location_type in location:
inner = location[location_type]
filtered = {k: v for k, v in inner.items() if k in fields}
return {location_type: filtered} if filtered else {}

logger.debug("location_type=<unknown> | unrecognized citation location type, skipping")
return {}

def _has_blocked_guardrail(self, guardrail_data: dict[str, Any]) -> bool:
"""Check if guardrail data contains any blocked policies.

Expand Down
53 changes: 38 additions & 15 deletions src/strands/types/citations.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Citation type definitions for the SDK.

These types are modeled after the Bedrock API.
https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_CitationLocation.html
"""

from typing import List, Union
Expand All @@ -18,11 +19,8 @@ class CitationsConfig(TypedDict):
enabled: bool


class DocumentCharLocation(TypedDict, total=False):
"""Specifies a character-level location within a document.

Provides precise positioning information for cited content using
start and end character indices.
class DocumentCharLocationInner(TypedDict, total=False):
"""Inner content for character-level location within a document.

Attributes:
documentIndex: The index of the document within the array of documents
Expand All @@ -38,11 +36,8 @@ class DocumentCharLocation(TypedDict, total=False):
end: int


class DocumentChunkLocation(TypedDict, total=False):
"""Specifies a chunk-level location within a document.

Provides positioning information for cited content using logical
document segments or chunks.
class DocumentChunkLocationInner(TypedDict, total=False):
"""Inner content for chunk-level location within a document.

Attributes:
documentIndex: The index of the document within the array of documents
Expand All @@ -58,10 +53,8 @@ class DocumentChunkLocation(TypedDict, total=False):
end: int


class DocumentPageLocation(TypedDict, total=False):
"""Specifies a page-level location within a document.

Provides positioning information for cited content using page numbers.
class DocumentPageLocationInner(TypedDict, total=False):
"""Inner content for page-level location within a document.

Attributes:
documentIndex: The index of the document within the array of documents
Expand All @@ -77,7 +70,37 @@ class DocumentPageLocation(TypedDict, total=False):
end: int


# Union type for citation locations
class DocumentCharLocation(TypedDict, total=False):
"""Tagged union wrapper for character-level document location.

Attributes:
documentChar: The character-level location data.
"""

documentChar: DocumentCharLocationInner


class DocumentChunkLocation(TypedDict, total=False):
"""Tagged union wrapper for chunk-level document location.

Attributes:
documentChunk: The chunk-level location data.
"""

documentChunk: DocumentChunkLocationInner


class DocumentPageLocation(TypedDict, total=False):
"""Tagged union wrapper for page-level document location.

Attributes:
documentPage: The page-level location data.
"""

documentPage: DocumentPageLocationInner


# Union type for citation locations - tagged union where exactly one key is present
CitationLocation = Union[DocumentCharLocation, DocumentChunkLocation, DocumentPageLocation]


Expand Down
103 changes: 103 additions & 0 deletions tests/strands/models/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -2070,3 +2070,106 @@ async def test_stream_backward_compatibility_system_prompt(bedrock_client, model
"system": [{"text": system_prompt}],
}
bedrock_client.converse_stream.assert_called_once_with(**expected_request)


def test_format_request_message_content_document_char_citation(model):
"""Test that documentChar citations preserve tagged union structure."""
content = {
"citationsContent": {
"citations": [
{
"title": "Doc Citation",
"location": {"documentChar": {"documentIndex": 0, "start": 100, "end": 200}},
"sourceContent": [{"text": "Excerpt"}],
}
],
"content": [{"text": "Generated text"}],
}
}

result = model._format_request_message_content(content)

assert result["citationsContent"]["citations"][0]["location"] == {
"documentChar": {"documentIndex": 0, "start": 100, "end": 200}
}


def test_format_request_message_content_document_page_citation(model):
"""Test that documentPage citations preserve tagged union structure."""
content = {
"citationsContent": {
"citations": [
{
"title": "Page Citation",
"location": {"documentPage": {"documentIndex": 0, "start": 2, "end": 3}},
"sourceContent": [{"text": "Page content"}],
}
],
"content": [{"text": "Generated text"}],
}
}

result = model._format_request_message_content(content)

assert result["citationsContent"]["citations"][0]["location"] == {
"documentPage": {"documentIndex": 0, "start": 2, "end": 3}
}


def test_format_request_message_content_document_chunk_citation(model):
"""Test that documentChunk citations preserve tagged union structure."""
content = {
"citationsContent": {
"citations": [
{
"title": "Chunk Citation",
"location": {"documentChunk": {"documentIndex": 1, "start": 5, "end": 10}},
"sourceContent": [{"text": "Chunk content"}],
}
],
"content": [{"text": "Generated text"}],
}
}

result = model._format_request_message_content(content)

assert result["citationsContent"]["citations"][0]["location"] == {
"documentChunk": {"documentIndex": 1, "start": 5, "end": 10}
}


def test_format_request_message_content_citation_filters_extra_fields(model):
"""Test that extra fields in citation location inner content are filtered out."""
content = {
"citationsContent": {
"citations": [
{
"title": "Citation with extra fields",
"location": {"documentChar": {"documentIndex": 0, "start": 0, "end": 50, "extraField": "ignored"}},
"sourceContent": [{"text": "Content"}],
}
],
"content": [{"text": "Text"}],
}
}

result = model._format_request_message_content(content)

# extraField should be filtered out
assert result["citationsContent"]["citations"][0]["location"] == {
"documentChar": {"documentIndex": 0, "start": 0, "end": 50}
}


def test_format_request_message_content_citation_unknown_location_type(model):
"""Test that citations with unknown location types exclude the location field."""
content = {
"citationsContent": {
"citations": [{"title": "Unknown location", "location": {"unknownType": {"field": "value"}}}],
"content": [{"text": "Text"}],
}
}

result = model._format_request_message_content(content)

assert "location" not in result["citationsContent"]["citations"][0]
4 changes: 2 additions & 2 deletions tests/strands/tools/mcp/test_mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ def test_stop_closes_event_loop():
mock_thread.join = MagicMock()
mock_event_loop = MagicMock()
mock_event_loop.close = MagicMock()

client._background_thread = mock_thread
client._background_thread_event_loop = mock_event_loop

Expand All @@ -542,7 +542,7 @@ def test_stop_closes_event_loop():

# Verify thread was joined
mock_thread.join.assert_called_once()

# Verify event loop was closed
mock_event_loop.close.assert_called_once()

Expand Down