diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py index 43836fe34..1670a4fd3 100644 --- a/src/strands/event_loop/streaming.py +++ b/src/strands/event_loop/streaming.py @@ -336,8 +336,96 @@ def handle_redact_content(event: RedactContentEvent, state: dict[str, Any]) -> N event: Redact Content Event. state: The current state of message processing. """ + # Store both messages for later decision based on trace + # AWS Bedrock sends both messages regardless of which guardrail was triggered + if event.get("redactUserContentMessage") is not None: + state["redactUserContentMessage"] = event["redactUserContentMessage"] if event.get("redactAssistantContentMessage") is not None: - state["message"]["content"] = [{"text": event["redactAssistantContentMessage"]}] + state["redactAssistantContentMessage"] = event["redactAssistantContentMessage"] + + +def _check_if_blocked(assessment: dict[str, Any]) -> bool: + """Check if any policy in the assessment has BLOCKED action. + + Args: + assessment: Guardrail assessment data + + Returns: + True if any policy has BLOCKED action + """ + # Check word policy + word_policy = assessment.get("wordPolicy", {}) + custom_words = word_policy.get("customWords", []) + for word in custom_words: + if word.get("action") == "BLOCKED" and word.get("detected"): + return True + + # Check content policy + content_policy = assessment.get("contentPolicy", {}) + filters = content_policy.get("filters", []) + for filter_item in filters: + if filter_item.get("action") == "BLOCKED": + return True + + # Check sensitive information policy + pii_entities = assessment.get("sensitiveInformationPolicy", {}).get("piiEntities", []) + for entity in pii_entities: + if entity.get("action") == "BLOCKED": + return True + + return False + + +def finalize_redact_message(event: MetadataEvent, state: dict[str, Any]) -> None: + """Finalize the redacted message based on trace information. + + AWS Bedrock sends both redactUserContentMessage and redactAssistantContentMessage + regardless of which guardrail was triggered. We need to check the trace to determine + which one to use. + + Args: + event: Metadata event containing trace information + state: The current state of message processing + """ + # Check if we have redact messages stored + if "redactUserContentMessage" not in state and "redactAssistantContentMessage" not in state: + return + + # Get trace information + trace = event.get("trace", {}) + guardrail = trace.get("guardrail", {}) + + # Check input assessment + input_blocked = False + input_assessment = guardrail.get("inputAssessment", {}) + for guardrail_id, assessment in input_assessment.items(): + if _check_if_blocked(assessment): + input_blocked = True + break + + # Check output assessments + output_blocked = False + output_assessments = guardrail.get("outputAssessments", {}) + # outputAssessments is a dict with guardrail IDs as keys + for guardrail_id, assessments in output_assessments.items(): + if isinstance(assessments, list): + for assessment in assessments: + if _check_if_blocked(assessment): + output_blocked = True + break + if output_blocked: + break + + # Select the appropriate message based on trace + if output_blocked and "redactAssistantContentMessage" in state: + state["message"]["content"] = [{"text": state["redactAssistantContentMessage"]}] + elif input_blocked and "redactUserContentMessage" in state: + state["message"]["content"] = [{"text": state["redactUserContentMessage"]}] + # Fallback: use input message if trace is unclear but we have redact messages + elif "redactUserContentMessage" in state: + state["message"]["content"] = [{"text": state["redactUserContentMessage"]}] + elif "redactAssistantContentMessage" in state: + state["message"]["content"] = [{"text": state["redactAssistantContentMessage"]}] def extract_usage_metrics(event: MetadataEvent, time_to_first_byte_ms: int | None = None) -> tuple[Usage, Metrics]: @@ -392,7 +480,10 @@ async def process_stream( # Track first byte time when we get first content if first_byte_time is None and ("contentBlockDelta" in chunk or "contentBlockStart" in chunk): first_byte_time = time.time() - yield ModelStreamChunkEvent(chunk=chunk) + # Don't yield redactContent chunks to stream - they will be processed by + # handle_redact_content and finalize_redact_message to select the correct message + if "redactContent" not in chunk: + yield ModelStreamChunkEvent(chunk=chunk) if "messageStart" in chunk: state["message"] = handle_message_start(chunk["messageStart"], state["message"]) @@ -410,6 +501,8 @@ async def process_stream( int(1000 * (first_byte_time - start_time)) if (start_time and first_byte_time) else None ) usage, metrics = extract_usage_metrics(chunk["metadata"], time_to_first_byte_ms) + # Finalize redacted message based on trace information + finalize_redact_message(chunk["metadata"], state) elif "redactContent" in chunk: handle_redact_content(chunk["redactContent"], state) diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 4a7c81672..b2eac6db7 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -555,39 +555,91 @@ def _has_blocked_guardrail(self, guardrail_data: dict[str, Any]) -> bool: return False - def _generate_redaction_events(self) -> list[StreamEvent]: - """Generate redaction events based on configuration. + def _generate_redaction_events(self, guardrail_data: dict[str, Any] | None = None) -> list[StreamEvent]: + """Generate redaction events based on configuration and which guardrail was triggered. + + Args: + guardrail_data: Guardrail trace data to determine which guardrail (input/output) was blocked. + If None, falls back to legacy behavior using config flags only. Returns: List of redaction events to yield. """ events: list[StreamEvent] = [] - if self.config.get("guardrail_redact_input", True): - logger.debug("Redacting user input due to guardrail.") - events.append( - { - "redactContent": { - "redactUserContentMessage": self.config.get( - "guardrail_redact_input_message", "[User input redacted.]" - ) - } - } + # Determine which guardrail was blocked from trace data + input_blocked = False + output_blocked = False + + if guardrail_data: + input_assessment = guardrail_data.get("inputAssessment", {}) + output_assessments = guardrail_data.get("outputAssessments", {}) + + # Check if input guardrail blocked + input_blocked = any( + self._find_detected_and_blocked_policy(assessment) + for assessment in input_assessment.values() ) - if self.config.get("guardrail_redact_output", False): - logger.debug("Redacting assistant output due to guardrail.") - events.append( - { - "redactContent": { - "redactAssistantContentMessage": self.config.get( - "guardrail_redact_output_message", - "[Assistant output redacted.]", - ) - } - } + # Check if output guardrail blocked + output_blocked = any( + self._find_detected_and_blocked_policy(assessment) + for assessment in output_assessments.values() ) + # Generate appropriate redaction event based on which guardrail was triggered + if guardrail_data: + # Use trace data to determine which message to send + if output_blocked and self.config.get("guardrail_redact_output", False): + logger.debug("Redacting assistant output due to output guardrail.") + events.append( + { + "redactContent": { + "redactAssistantContentMessage": self.config.get( + "guardrail_redact_output_message", + "[Assistant output redacted.]", + ) + } + } + ) + elif input_blocked and self.config.get("guardrail_redact_input", True): + logger.debug("Redacting user input due to input guardrail.") + events.append( + { + "redactContent": { + "redactUserContentMessage": self.config.get( + "guardrail_redact_input_message", "[User input redacted.]" + ) + } + } + ) + else: + # Legacy fallback: use config flags only (original behavior) + if self.config.get("guardrail_redact_input", True): + logger.debug("Redacting user input due to guardrail.") + events.append( + { + "redactContent": { + "redactUserContentMessage": self.config.get( + "guardrail_redact_input_message", "[User input redacted.]" + ) + } + } + ) + + if self.config.get("guardrail_redact_output", False): + logger.debug("Redacting assistant output due to guardrail.") + events.append( + { + "redactContent": { + "redactAssistantContentMessage": self.config.get( + "guardrail_redact_output_message", + "[Assistant output redacted.]", + ) + } + } + ) + return events @override @@ -691,7 +743,7 @@ def _stream( ): guardrail_data = chunk["metadata"]["trace"]["guardrail"] if self._has_blocked_guardrail(guardrail_data): - for event in self._generate_redaction_events(): + for event in self._generate_redaction_events(guardrail_data): callback(event) # Track if we see tool use events @@ -723,7 +775,8 @@ def _stream( and "guardrail" in response["trace"] and self._has_blocked_guardrail(response["trace"]["guardrail"]) ): - for event in self._generate_redaction_events(): + guardrail_data = response["trace"]["guardrail"] + for event in self._generate_redaction_events(guardrail_data): callback(event) except ClientError as e: diff --git a/tests/strands/event_loop/test_streaming.py b/tests/strands/event_loop/test_streaming.py index 02be400b1..803c5824c 100644 --- a/tests/strands/event_loop/test_streaming.py +++ b/tests/strands/event_loop/test_streaming.py @@ -615,7 +615,7 @@ async def test_process_stream(response, exp_events, agenerator, alist): @pytest.mark.parametrize( ("response", "exp_events"), [ - # Redacted Message + # Redacted Message - Both input and output messages present (input takes priority) ( [ {"messageStart": {"role": "assistant"}}, @@ -676,7 +676,139 @@ async def test_process_stream(response, exp_events, agenerator, alist): { "stop": ( "guardrail_intervened", - {"role": "assistant", "content": [{"text": "REDACTED."}]}, + {"role": "assistant", "content": [{"text": "REDACTED"}]}, + {"inputTokens": 1, "outputTokens": 1, "totalTokens": 1}, + {"latencyMs": 1}, + ) + }, + ], + ), + # Redacted Message - Input only (redactUserContentMessage) + ( + [ + {"messageStart": {"role": "assistant"}}, + { + "contentBlockStart": {"start": {}}, + }, + { + "contentBlockDelta": {"delta": {"text": "Hello!"}}, + }, + {"contentBlockStop": {}}, + { + "messageStop": {"stopReason": "guardrail_intervened"}, + }, + { + "redactContent": { + "redactUserContentMessage": "INPUT_BLOCKED", + } + }, + { + "metadata": { + "usage": { + "inputTokens": 1, + "outputTokens": 1, + "totalTokens": 1, + }, + "metrics": {"latencyMs": 1}, + } + }, + ], + [ + {"event": {"messageStart": {"role": "assistant"}}}, + {"event": {"contentBlockStart": {"start": {}}}}, + {"event": {"contentBlockDelta": {"delta": {"text": "Hello!"}}}}, + {"data": "Hello!", "delta": {"text": "Hello!"}}, + {"event": {"contentBlockStop": {}}}, + {"event": {"messageStop": {"stopReason": "guardrail_intervened"}}}, + { + "event": { + "redactContent": { + "redactUserContentMessage": "INPUT_BLOCKED", + } + } + }, + { + "event": { + "metadata": { + "usage": { + "inputTokens": 1, + "outputTokens": 1, + "totalTokens": 1, + }, + "metrics": {"latencyMs": 1}, + } + } + }, + { + "stop": ( + "guardrail_intervened", + {"role": "assistant", "content": [{"text": "INPUT_BLOCKED"}]}, + {"inputTokens": 1, "outputTokens": 1, "totalTokens": 1}, + {"latencyMs": 1}, + ) + }, + ], + ), + # Redacted Message - Output only (redactAssistantContentMessage) + ( + [ + {"messageStart": {"role": "assistant"}}, + { + "contentBlockStart": {"start": {}}, + }, + { + "contentBlockDelta": {"delta": {"text": "Hello!"}}, + }, + {"contentBlockStop": {}}, + { + "messageStop": {"stopReason": "guardrail_intervened"}, + }, + { + "redactContent": { + "redactAssistantContentMessage": "OUTPUT_BLOCKED", + } + }, + { + "metadata": { + "usage": { + "inputTokens": 1, + "outputTokens": 1, + "totalTokens": 1, + }, + "metrics": {"latencyMs": 1}, + } + }, + ], + [ + {"event": {"messageStart": {"role": "assistant"}}}, + {"event": {"contentBlockStart": {"start": {}}}}, + {"event": {"contentBlockDelta": {"delta": {"text": "Hello!"}}}}, + {"data": "Hello!", "delta": {"text": "Hello!"}}, + {"event": {"contentBlockStop": {}}}, + {"event": {"messageStop": {"stopReason": "guardrail_intervened"}}}, + { + "event": { + "redactContent": { + "redactAssistantContentMessage": "OUTPUT_BLOCKED", + } + } + }, + { + "event": { + "metadata": { + "usage": { + "inputTokens": 1, + "outputTokens": 1, + "totalTokens": 1, + }, + "metrics": {"latencyMs": 1}, + } + } + }, + { + "stop": ( + "guardrail_intervened", + {"role": "assistant", "content": [{"text": "OUTPUT_BLOCKED"}]}, {"inputTokens": 1, "outputTokens": 1, "totalTokens": 1}, {"latencyMs": 1}, )