diff --git a/docs/truncation-strategies.md b/docs/truncation-strategies.md new file mode 100644 index 00000000..99076387 --- /dev/null +++ b/docs/truncation-strategies.md @@ -0,0 +1,218 @@ +# Truncation Strategies + +This guide explains how conversation history management works in AskUI Vision Agent and how to optimize token usage for long-running agents. + +## Table of Contents + +- [Overview](#overview) +- [Available Strategies](#available-strategies) + - [SimpleTruncationStrategy (Default)](#simpletruncationstrategy-default) + - [LatestImageOnlyTruncationStrategy (Experimental)](#latestimageonlytruncationstrategy-experimental) +- [When to Use Each Strategy](#when-to-use-each-strategy) +- [Configuration](#configuration) +- [Cost Implications](#cost-implications) +- [Troubleshooting](#troubleshooting) + +## Overview + +As vision agents execute tasks, they maintain a conversation history including: +- User instructions +- Assistant responses with reasoning +- Tool calls and results +- Screenshots (consuming 1,000-3,000 tokens each) + +Without proper management, long conversations can: +- Exceed the 100,000 token API limit +- Increase API costs significantly +- Slow down response times + +Truncation strategies automatically manage conversation history by removing less important messages while preserving critical context. + +## Available Strategies + +### SimpleTruncationStrategy (Default) + +The default strategy that provides stable, reliable truncation. + +**How it works:** +- Monitors token count and message count +- When thresholds are exceeded (75% of limits), removes messages in priority order: + 1. Tool calling loops from closed conversations + 2. Entire closed conversation loops (except first and last) + 3. The first conversation loop if not the active one + 4. Tool calling turns from the open loop (except first and last) + +**Characteristics:** +- Conservative: Preserves all images until limits approached +- Stable: Well-tested for all use cases +- Higher token cost: All screenshots remain in context + +**Usage:** + +```python +from askui import VisionAgent + +# Default - no configuration needed +with VisionAgent() as agent: + agent.act("Complete the task") +``` + +### LatestImageOnlyTruncationStrategy (Experimental) + +An experimental strategy that aggressively reduces token usage by keeping only the most recent screenshot. + +**How it works:** +- Applies same message truncation as SimpleTruncationStrategy +- **Additionally**: Keeps only the most recent screenshot +- Replaces older images with text: `"[Image removed to save tokens]"` +- Preserves all text, tool calls, and non-image tool results + +**Characteristics:** +- Aggressive savings: Up to 90% token reduction in multi-step scenarios +- Experimental: May affect performance when historical visual context is needed +- Cost-effective: Dramatically reduces API costs + +**Token savings example:** + +| Scenario | SimpleTruncationStrategy | LatestImageOnlyTruncationStrategy | Savings | +|----------|-------------------------|-----------------------------------|---------| +| 10-step task with screenshots | ~25,000 tokens | ~7,100 tokens | 72% | +| 20-step task with screenshots | ~50,000 tokens | ~12,000 tokens | 76% | + +**Usage:** + +```python +from askui import VisionAgent, ActSettings, TruncationStrategySettings + +with VisionAgent() as agent: + # Configure to use latest image only strategy + settings = ActSettings( + truncation=TruncationStrategySettings(strategy="latest_image_only") + ) + agent.act("Complete the task", settings=settings) +``` + +**Warning:** On first use, you'll see a warning log: +``` +WARNING: Using experimental LatestImageOnlyTruncationStrategy. This strategy removes old images from conversation history to save tokens. Only the most recent image is kept. This may affect model performance in scenarios requiring historical visual context. +``` + +## When to Use Each Strategy + +### Use SimpleTruncationStrategy (Default) when: +- Starting a new project +- Requiring maximum stability +- Visual history is important +- Debugging complex issues +- Running critical production workloads + +### Use LatestImageOnlyTruncationStrategy when: +- ✅ Tasks have many sequential actions with screenshots +- ✅ Token costs are a primary concern +- ✅ Only current screen state matters (e.g., form filling, navigation) +- ✅ Long-running agents would exceed token limits + +### Avoid LatestImageOnlyTruncationStrategy when: +- ❌ Tasks require comparison between multiple screens +- ❌ Debugging scenarios need full visual history +- ❌ Complex workflows depend on historical visual context +- ❌ Stability is more important than cost + +## Configuration + +You can easily configure truncation strategies through the `act()` method using settings: + +```python +from askui import VisionAgent, ActSettings, TruncationStrategySettings + +with VisionAgent() as agent: + # Use default (simple) strategy + agent.act("Complete the task") + + # Use experimental latest image only strategy + settings = ActSettings( + truncation=TruncationStrategySettings(strategy="latest_image_only") + ) + agent.act("Complete another task", settings=settings) + + # Mix and match - use simple for some tasks, latest_image_only for others + simple_settings = ActSettings( + truncation=TruncationStrategySettings(strategy="simple") + ) + agent.act("Task requiring full visual history", settings=simple_settings) +``` + +**Available strategies:** +- `"simple"` (default): Conservative strategy preserving all images +- `"latest_image_only"` (experimental): Aggressive strategy keeping only latest image + +**Note:** Advanced configuration options (token limits, thresholds) use default values and cannot be customized through settings. If you need custom thresholds, you can still set the truncation strategy factory directly on the agent during initialization. + +## Cost Implications + +**Anthropic Claude API Pricing (approximate):** +- Input tokens: ~$3 per million tokens +- Output tokens: ~$15 per million tokens + +**Cost comparison for a 20-step task:** + +SimpleTruncationStrategy: +``` +20 screenshots × 2,000 tokens = 40,000 tokens +Additional context = 10,000 tokens +Total = 50,000 tokens +Cost = $0.15 per execution +``` + +LatestImageOnlyTruncationStrategy: +``` +1 screenshot × 2,000 tokens = 2,000 tokens +Additional context = 10,000 tokens +Total = 12,000 tokens +Cost = $0.036 per execution +Savings = $0.114 per execution (76% reduction) +``` + +For applications running thousands of executions, these savings are substantial. + +## Troubleshooting + +### Token limit errors + +**Symptom:** `MaxTokensExceededError` or API context length errors + +**Solutions:** +1. Lower `input_token_truncation_threshold` for earlier truncation +2. Reduce `max_input_tokens` for more aggressive truncation +3. Try `LatestImageOnlyTruncationStrategy` for dramatic reduction + +### Agent loses context during long conversations + +**Symptom:** Agent forgets earlier actions or repeats steps + +**Solutions:** +1. Increase `input_token_truncation_threshold` to preserve more history +2. Include critical information in system prompts, not just conversation +3. Break long tasks into smaller subtasks + +### Performance degrades with LatestImageOnlyTruncationStrategy + +**Symptom:** Agent makes mistakes or fails tasks + +**Solutions:** +1. Revert to `SimpleTruncationStrategy` for those tasks +2. Restructure tasks to require less historical visual context +3. Add explicit text descriptions of previous screens in tool results + +## Best Practices + +1. **Start with the default:** Use `SimpleTruncationStrategy` for new projects +2. **Monitor costs:** If token usage is high, test `LatestImageOnlyTruncationStrategy` +3. **Test thoroughly:** Verify agent success rates with the experimental strategy +4. **Use selectively:** Apply `LatestImageOnlyTruncationStrategy` only where appropriate + +## See Also + +- [Using Models](using-models.md) - Learn about AI models and their capabilities +- [Caching](caching.md) - Speed up repeated tasks with action caching +- [Observability](observability.md) - Monitor and debug agent behavior diff --git a/src/askui/__init__.py b/src/askui/__init__.py index 4470405f..d2963685 100644 --- a/src/askui/__init__.py +++ b/src/askui/__init__.py @@ -37,7 +37,11 @@ ToolUseBlockParam, UrlImageSourceParam, ) -from .models.shared.settings import ActSettings, MessageSettings +from .models.shared.settings import ( + ActSettings, + MessageSettings, + TruncationStrategySettings, +) from .models.shared.tools import Tool from .models.types.response_schemas import ResponseSchema, ResponseSchemaBase from .retry import ConfigurableRetry, Retry @@ -100,6 +104,7 @@ "Tool", "ToolResultBlockParam", "ToolUseBlockParam", + "TruncationStrategySettings", "UrlImageSourceParam", "VisionAgent", ] diff --git a/src/askui/models/shared/agent.py b/src/askui/models/shared/agent.py index fdde6c32..011893c4 100644 --- a/src/askui/models/shared/agent.py +++ b/src/askui/models/shared/agent.py @@ -14,6 +14,7 @@ from askui.models.shared.settings import ActSettings from askui.models.shared.tools import ToolCollection from askui.models.shared.truncation_strategies import ( + LatestImageOnlyTruncationStrategyFactory, SimpleTruncationStrategyFactory, TruncationStrategy, TruncationStrategyFactory, @@ -140,13 +141,20 @@ def act( ) -> None: _settings = settings or ActSettings() _tool_collection = tools or ToolCollection() - truncation_strategy = ( - self._truncation_strategy_factory.create_truncation_strategy( - tools=_tool_collection.to_params(), - system=_settings.messages.system or None, - messages=messages, - model=model, - ) + + # Create truncation strategy factory based on settings + truncation_strategy_factory: TruncationStrategyFactory + if _settings.truncation.strategy == "latest_image_only": + truncation_strategy_factory = LatestImageOnlyTruncationStrategyFactory() + else: + # Use default factory from initialization if "simple" or use default + truncation_strategy_factory = self._truncation_strategy_factory + + truncation_strategy = truncation_strategy_factory.create_truncation_strategy( + tools=_tool_collection.to_params(), + system=_settings.messages.system or None, + messages=messages, + model=model, ) self._step( model=model, diff --git a/src/askui/models/shared/settings.py b/src/askui/models/shared/settings.py index 547d97b6..fe937167 100644 --- a/src/askui/models/shared/settings.py +++ b/src/askui/models/shared/settings.py @@ -11,6 +11,7 @@ COMPUTER_USE_20250124_BETA_FLAG = "computer-use-2025-01-24" COMPUTER_USE_20251124_BETA_FLAG = "computer-use-2025-11-24" +TRUNCATION_STRATEGY = Literal["simple", "latest_image_only"] CACHING_STRATEGY = Literal["read", "write", "both", "no"] @@ -25,10 +26,31 @@ class MessageSettings(BaseModel): temperature: float | Omit = Field(default=omit, ge=0.0, le=1.0) +class TruncationStrategySettings(BaseModel): + """Settings for conversation truncation strategy. + + Controls how conversation history is managed to stay within token limits. + + Attributes: + strategy: The truncation strategy to use: + - "simple" (default): Conservative strategy that preserves all images + until limits are approached. Provides maximum stability. + - "latest_image_only" (experimental): Aggressive strategy that keeps only + the most recent screenshot, dramatically reducing token usage by up + to 90%. May affect model performance when historical visual context + is needed. + """ + + strategy: TRUNCATION_STRATEGY = "simple" + + class ActSettings(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) messages: MessageSettings = Field(default_factory=MessageSettings) + truncation: TruncationStrategySettings = Field( + default_factory=TruncationStrategySettings + ) class CachedExecutionToolSettings(BaseModel): diff --git a/src/askui/models/shared/truncation_strategies.py b/src/askui/models/shared/truncation_strategies.py index 935ebcec..1fbe05da 100644 --- a/src/askui/models/shared/truncation_strategies.py +++ b/src/askui/models/shared/truncation_strategies.py @@ -1,3 +1,4 @@ +import logging from dataclasses import dataclass from typing import Annotated @@ -7,11 +8,15 @@ from askui.models.shared.agent_message_param import ( CacheControlEphemeralParam, + ContentBlockParam, + ImageBlockParam, MessageParam, TextBlockParam, ) from askui.models.shared.token_counter import SimpleTokenCounter, TokenCounter +logger = logging.getLogger(__name__) + # needs to be below limits imposed by endpoint MAX_INPUT_TOKENS = 100_000 @@ -323,6 +328,126 @@ def _cluster_into_tool_calling_loops(self) -> list[list[MessageContainer]]: return loops +class LatestImageOnlyTruncationStrategy(SimpleTruncationStrategy): + """Truncation strategy that keeps only the latest image to save tokens. + + Extends SimpleTruncationStrategy by keeping only the most recent image + in the conversation history. All older images are replaced with text + placeholders to significantly reduce token consumption. + + This strategy maintains the same truncation logic as SimpleTruncationStrategy + but adds automatic removal of old images before returning messages. + + WARNING: This is an experimental feature. Keeping only the latest image may + affect model performance in scenarios where historical visual context is important. + + Args: + Same as SimpleTruncationStrategy. + """ + + def __init__( + self, + tools: list[BetaToolUnionParam] | None, + system: str | list[BetaTextBlockParam] | None, + messages: list[MessageParam], + model: str, + max_input_tokens: int = MAX_INPUT_TOKENS, + input_token_truncation_threshold: Annotated[ + float, Field(gt=0.0, lt=1.0) + ] = 0.75, + max_messages: int = MAX_MESSAGES, + message_truncation_threshold: Annotated[float, Field(gt=0.0, lt=1.0)] = 0.75, + token_counter: TokenCounter | None = None, + ) -> None: + super().__init__( + tools=tools, + system=system, + messages=messages, + model=model, + max_input_tokens=max_input_tokens, + input_token_truncation_threshold=input_token_truncation_threshold, + max_messages=max_messages, + message_truncation_threshold=message_truncation_threshold, + token_counter=token_counter, + ) + # Log warning on first use + logger.warning( + "Using experimental LatestImageOnlyTruncationStrategy. " + "This strategy removes old images from conversation history to save " + "tokens. Only the most recent image is kept. This may affect model " + "performance in scenarios requiring historical visual context." + ) + + @property + @override + def messages(self) -> list[MessageParam]: + self._move_cache_control_to_last_non_tool_result_user_message() + self._remove_old_images_from_messages() + return self._messages + + def _remove_old_images_from_messages(self) -> None: + """Remove all images except those in the last message containing images.""" + # Find the index of the last message that contains images + last_image_message_index = -1 + for i in reversed(range(len(self._messages))): + if self._message_has_image(self._messages[i]): + last_image_message_index = i + break + + # If no images found, nothing to do + if last_image_message_index == -1: + return + + # Remove images from all messages before the last one with images + for i in range(last_image_message_index): + self._remove_images_from_message(self._messages[i]) + + def _message_has_image(self, message: MessageParam) -> bool: + """Check if a message contains any image blocks.""" + if not isinstance(message.content, list): + return False + + for block in message.content: + if block.type == "image": + return True + # Check inside tool_result blocks + if block.type == "tool_result" and isinstance(block.content, list): + for inner_block in block.content: + if inner_block.type == "image": + return True + return False + + def _remove_images_from_message(self, message: MessageParam) -> None: + """Replace all image blocks in a message with text placeholders.""" + if not isinstance(message.content, list): + return + + new_content: list[ContentBlockParam] = [] + for block in message.content: + if block.type == "image": + # Replace image with text placeholder + new_content.append( + TextBlockParam(text="[Image removed to save tokens]") + ) + elif block.type == "tool_result": + # Handle images inside tool_result blocks + if isinstance(block.content, list): + new_tool_result_content: list[TextBlockParam | ImageBlockParam] = [] + for inner_block in block.content: + if inner_block.type == "image": + new_tool_result_content.append( + TextBlockParam(text="[Image removed to save tokens]") + ) + else: + new_tool_result_content.append(inner_block) + block.content = new_tool_result_content + new_content.append(block) + else: + new_content.append(block) + + message.content = new_content + + class TruncationStrategyFactory: def create_truncation_strategy( self, @@ -374,3 +499,60 @@ def create_truncation_strategy( message_truncation_threshold=self._message_truncation_threshold, token_counter=self._token_counter, ) + + +class LatestImageOnlyTruncationStrategyFactory(TruncationStrategyFactory): + """Factory for creating LatestImageOnlyTruncationStrategy instances. + + Creates truncation strategies that keep only the latest image from conversation + history to save tokens. + + WARNING: This is an experimental feature. A warning will be displayed when + the strategy is first created. + + Args: + max_input_tokens (int, optional): Maximum input tokens allowed. + Defaults to 100,000. + input_token_truncation_threshold (float, optional): Fraction of max tokens + to truncate at. Defaults to 0.75. + max_messages (int, optional): Maximum messages allowed. Defaults to 100,000. + message_truncation_threshold (float, optional): Fraction of max messages + to truncate at. Defaults to 0.75. + token_counter (TokenCounter | None, optional): Token counter instance. + Defaults to SimpleTokenCounter. + """ + + def __init__( + self, + max_input_tokens: int = MAX_INPUT_TOKENS, + input_token_truncation_threshold: Annotated[ + float, Field(gt=0.0, lt=1.0) + ] = 0.75, + max_messages: int = MAX_MESSAGES, + message_truncation_threshold: Annotated[float, Field(gt=0.0, lt=1.0)] = 0.75, + token_counter: TokenCounter | None = None, + ) -> None: + self._max_input_tokens = max_input_tokens + self._input_token_truncation_threshold = input_token_truncation_threshold + self._max_messages = max_messages + self._message_truncation_threshold = message_truncation_threshold + self._token_counter = token_counter or SimpleTokenCounter() + + def create_truncation_strategy( + self, + tools: list[BetaToolUnionParam] | None, + system: str | list[BetaTextBlockParam] | None, + messages: list[MessageParam], + model: str, + ) -> TruncationStrategy: + return LatestImageOnlyTruncationStrategy( + tools=tools, + system=system, + messages=messages, + model=model, + max_input_tokens=self._max_input_tokens, + input_token_truncation_threshold=self._input_token_truncation_threshold, + max_messages=self._max_messages, + message_truncation_threshold=self._message_truncation_threshold, + token_counter=self._token_counter, + ) diff --git a/tests/unit/models/shared/__init__.py b/tests/unit/models/shared/__init__.py new file mode 100644 index 00000000..5d21e4b1 --- /dev/null +++ b/tests/unit/models/shared/__init__.py @@ -0,0 +1 @@ +# Tests for shared models diff --git a/tests/unit/models/shared/test_agent_truncation_integration.py b/tests/unit/models/shared/test_agent_truncation_integration.py new file mode 100644 index 00000000..95c6ec48 --- /dev/null +++ b/tests/unit/models/shared/test_agent_truncation_integration.py @@ -0,0 +1,262 @@ +"""Integration tests for Agent.act() with different truncation strategies.""" + +import logging +from unittest.mock import Mock + +import pytest +from typing_extensions import Literal + +from askui.models.shared.agent import Agent +from askui.models.shared.agent_message_param import ( + Base64ImageSourceParam, + ImageBlockParam, + MessageParam, +) +from askui.models.shared.messages_api import MessagesApi +from askui.models.shared.settings import ActSettings, TruncationStrategySettings + + +def _create_text_message(role: Literal["user", "assistant"], text: str) -> MessageParam: + """Helper to create a simple text message.""" + return MessageParam(role=role, content=text) + + +def _create_image_message( + role: Literal["user", "assistant"], image_data: str = "image_data" +) -> MessageParam: + """Helper to create a message with an image.""" + return MessageParam( + role=role, + content=[ + ImageBlockParam( + type="image", + source=Base64ImageSourceParam( + type="base64", media_type="image/png", data=image_data + ), + ) + ], + ) + + +class TestAgentTruncationStrategyIntegration: + """Integration tests for Agent with truncation strategies.""" + + def _create_mock_messages_api(self) -> MessagesApi: + """Create a mock MessagesApi that returns a simple response.""" + mock_api = Mock(spec=MessagesApi) + # Mock the create_message to return a simple assistant response + mock_api.create_message.return_value = MessageParam( + role="assistant", + content="Response", + stop_reason="end_turn", + ) + return mock_api + + def test_agent_uses_default_simple_truncation_strategy( + self, caplog: pytest.LogCaptureFixture + ) -> None: + """Test that Agent uses SimpleTruncationStrategy by default.""" + mock_api = self._create_mock_messages_api() + agent = Agent(messages_api=mock_api) + + messages = [_create_text_message("user", "Hello")] + + with caplog.at_level(logging.WARNING): + agent.act( + messages=messages, + model="claude-3-5-sonnet-20241022", + ) + + # Verify no warning about experimental strategy + assert not any( + "experimental LatestImageOnlyTruncationStrategy" in record.message + for record in caplog.records + ) + + # Verify API was called + mock_api.create_message.assert_called_once() # type: ignore[attr-defined] + + def test_agent_uses_simple_strategy_when_explicitly_set( + self, caplog: pytest.LogCaptureFixture + ) -> None: + """Test that Agent uses SimpleTruncationStrategy when explicitly set.""" + mock_api = self._create_mock_messages_api() + agent = Agent(messages_api=mock_api) + + messages = [_create_text_message("user", "Hello")] + settings = ActSettings(truncation=TruncationStrategySettings(strategy="simple")) + + with caplog.at_level(logging.WARNING): + agent.act( + messages=messages, + model="claude-3-5-sonnet-20241022", + settings=settings, + ) + + # Verify no warning about experimental strategy + assert not any( + "experimental LatestImageOnlyTruncationStrategy" in record.message + for record in caplog.records + ) + + # Verify API was called + mock_api.create_message.assert_called_once() # type: ignore[attr-defined] + + def test_agent_uses_latest_image_only_strategy_when_configured( + self, caplog: pytest.LogCaptureFixture + ) -> None: + """Test that Agent uses LatestImageOnlyTruncationStrategy when configured.""" + mock_api = self._create_mock_messages_api() + agent = Agent(messages_api=mock_api) + + messages = [_create_text_message("user", "Hello")] + settings = ActSettings( + truncation=TruncationStrategySettings(strategy="latest_image_only") + ) + + with caplog.at_level(logging.WARNING): + agent.act( + messages=messages, + model="claude-3-5-sonnet-20241022", + settings=settings, + ) + + # Verify warning about experimental strategy was logged + assert any( + "experimental LatestImageOnlyTruncationStrategy" in record.message + for record in caplog.records + ) + + # Verify API was called + mock_api.create_message.assert_called_once() # type: ignore[attr-defined] + + def test_agent_latest_image_only_strategy_removes_old_images(self) -> None: + """Test that latest_image_only strategy actually removes old images.""" + mock_api = self._create_mock_messages_api() + agent = Agent(messages_api=mock_api) + + messages = [ + _create_image_message("user", "old_image"), + _create_text_message("assistant", "Response"), + _create_image_message("user", "new_image"), + ] + + settings = ActSettings( + truncation=TruncationStrategySettings(strategy="latest_image_only") + ) + + agent.act( + messages=messages, + model="claude-3-5-sonnet-20241022", + settings=settings, + ) + + # Get the messages that were sent to the API + call_args = mock_api.create_message.call_args # type: ignore[attr-defined] + sent_messages = call_args.kwargs["messages"] + + # First message should have image replaced + assert isinstance(sent_messages[0].content, list) + assert not any(block.type == "image" for block in sent_messages[0].content) + assert any( + block.type == "text" and block.text == "[Image removed to save tokens]" + for block in sent_messages[0].content + ) + + # Last message should keep the image + assert isinstance(sent_messages[2].content, list) + assert any(block.type == "image" for block in sent_messages[2].content) + + def test_agent_simple_strategy_preserves_all_images(self) -> None: + """Test that simple strategy preserves all images.""" + mock_api = self._create_mock_messages_api() + agent = Agent(messages_api=mock_api) + + messages = [ + _create_image_message("user", "image1"), + _create_text_message("assistant", "Response"), + _create_image_message("user", "image2"), + ] + + settings = ActSettings(truncation=TruncationStrategySettings(strategy="simple")) + + agent.act( + messages=messages, + model="claude-3-5-sonnet-20241022", + settings=settings, + ) + + # Get the messages that were sent to the API + call_args = mock_api.create_message.call_args # type: ignore[attr-defined] + sent_messages = call_args.kwargs["messages"] + + # Both messages should still have images + assert isinstance(sent_messages[0].content, list) + assert any(block.type == "image" for block in sent_messages[0].content) + + assert isinstance(sent_messages[2].content, list) + assert any(block.type == "image" for block in sent_messages[2].content) + + def test_agent_can_switch_strategies_between_act_calls( + self, caplog: pytest.LogCaptureFixture + ) -> None: + """Test that different strategies can be used for different act() calls.""" + mock_api = self._create_mock_messages_api() + agent = Agent(messages_api=mock_api) + + messages = [_create_text_message("user", "Hello")] + + # First call with simple strategy + settings_simple = ActSettings( + truncation=TruncationStrategySettings(strategy="simple") + ) + with caplog.at_level(logging.WARNING): + caplog.clear() + agent.act( + messages=messages.copy(), + model="claude-3-5-sonnet-20241022", + settings=settings_simple, + ) + + # No warning for simple strategy + assert not any( + "experimental LatestImageOnlyTruncationStrategy" in record.message + for record in caplog.records + ) + + # Second call with latest_image_only strategy + settings_latest = ActSettings( + truncation=TruncationStrategySettings(strategy="latest_image_only") + ) + with caplog.at_level(logging.WARNING): + caplog.clear() + agent.act( + messages=messages.copy(), + model="claude-3-5-sonnet-20241022", + settings=settings_latest, + ) + + # Warning for experimental strategy + assert any( + "experimental LatestImageOnlyTruncationStrategy" in record.message + for record in caplog.records + ) + + # Verify both calls were made + assert mock_api.create_message.call_count == 2 # type: ignore[attr-defined] + + def test_backwards_compatibility_with_no_settings(self) -> None: + """Test that Agent works without any settings (backwards compatibility).""" + mock_api = self._create_mock_messages_api() + agent = Agent(messages_api=mock_api) + + messages = [_create_text_message("user", "Hello")] + + # Should work without settings parameter + agent.act( + messages=messages, + model="claude-3-5-sonnet-20241022", + ) + + # Verify API was called + mock_api.create_message.assert_called_once() # type: ignore[attr-defined] diff --git a/tests/unit/models/shared/test_settings.py b/tests/unit/models/shared/test_settings.py new file mode 100644 index 00000000..f688fd7b --- /dev/null +++ b/tests/unit/models/shared/test_settings.py @@ -0,0 +1,75 @@ +import pytest +from pydantic import ValidationError + +from askui.models.shared.settings import ActSettings, TruncationStrategySettings + + +class TestTruncationStrategySettings: + """Tests for TruncationStrategySettings.""" + + def test_default_strategy_is_simple(self) -> None: + """Test that the default truncation strategy is 'simple'.""" + settings = TruncationStrategySettings() + assert settings.strategy == "simple" + + def test_can_set_simple_strategy(self) -> None: + """Test that 'simple' strategy can be explicitly set.""" + settings = TruncationStrategySettings(strategy="simple") + assert settings.strategy == "simple" + + def test_can_set_latest_image_only_strategy(self) -> None: + """Test that 'latest_image_only' strategy can be set.""" + settings = TruncationStrategySettings(strategy="latest_image_only") + assert settings.strategy == "latest_image_only" + + def test_rejects_invalid_strategy(self) -> None: + """Test that invalid strategy values are rejected.""" + with pytest.raises(ValidationError): + TruncationStrategySettings(strategy="invalid_strategy") # pyright: ignore[reportArgumentType] + + def test_serialization(self) -> None: + """Test that settings can be serialized.""" + settings = TruncationStrategySettings(strategy="latest_image_only") + serialized = settings.model_dump() + assert serialized == {"strategy": "latest_image_only"} + + def test_deserialization(self) -> None: + """Test that settings can be deserialized.""" + data = {"strategy": "simple"} + settings = TruncationStrategySettings(**data) + assert settings.strategy == "simple" + + +class TestActSettings: + """Tests for ActSettings integration with TruncationStrategySettings.""" + + def test_default_act_settings_has_default_truncation(self) -> None: + """Test that ActSettings has default truncation settings.""" + settings = ActSettings() + assert settings.truncation.strategy == "simple" + + def test_can_set_truncation_strategy_in_act_settings(self) -> None: + """Test that truncation strategy can be set in ActSettings.""" + settings = ActSettings( + truncation=TruncationStrategySettings(strategy="latest_image_only") + ) + assert settings.truncation.strategy == "latest_image_only" + + def test_act_settings_serialization_includes_truncation(self) -> None: + """Test that ActSettings serialization includes truncation settings.""" + settings = ActSettings( + truncation=TruncationStrategySettings(strategy="latest_image_only") + ) + serialized = settings.model_dump() + assert "truncation" in serialized + assert serialized["truncation"]["strategy"] == "latest_image_only" + + def test_act_settings_deserialization_with_truncation(self) -> None: + """Test that ActSettings can be deserialized with truncation settings.""" + data = { + "messages": {"max_tokens": 4096}, + "truncation": {"strategy": "latest_image_only"}, + } + settings = ActSettings(**data) + assert settings.truncation.strategy == "latest_image_only" + assert settings.messages.max_tokens == 4096 diff --git a/tests/unit/models/shared/test_truncation_strategies.py b/tests/unit/models/shared/test_truncation_strategies.py new file mode 100644 index 00000000..7a8c4e0f --- /dev/null +++ b/tests/unit/models/shared/test_truncation_strategies.py @@ -0,0 +1,463 @@ +import logging + +import pytest +from typing_extensions import Literal + +from askui.models.shared.agent_message_param import ( + Base64ImageSourceParam, + ImageBlockParam, + MessageParam, + TextBlockParam, + ToolResultBlockParam, +) +from askui.models.shared.truncation_strategies import ( + LatestImageOnlyTruncationStrategy, + LatestImageOnlyTruncationStrategyFactory, + SimpleTruncationStrategy, + SimpleTruncationStrategyFactory, +) + + +def _create_text_message(role: Literal["user", "assistant"], text: str) -> MessageParam: + """Helper to create a simple text message.""" + return MessageParam(role=role, content=text) + + +def _create_image_message( + role: Literal["user", "assistant"], image_data: str = "image_data" +) -> MessageParam: + """Helper to create a message with an image.""" + return MessageParam( + role=role, + content=[ + ImageBlockParam( + type="image", + source=Base64ImageSourceParam( + type="base64", media_type="image/png", data=image_data + ), + ) + ], + ) + + +def _create_mixed_message( + role: Literal["user", "assistant"], text: str, image_data: str = "image_data" +) -> MessageParam: + """Helper to create a message with both text and image.""" + return MessageParam( + role=role, + content=[ + TextBlockParam(type="text", text=text), + ImageBlockParam( + type="image", + source=Base64ImageSourceParam( + type="base64", media_type="image/png", data=image_data + ), + ), + ], + ) + + +def _create_tool_result_with_image(image_data: str = "image_data") -> MessageParam: + """Helper to create a message with a tool_result containing an image.""" + return MessageParam( + role="user", + content=[ + ToolResultBlockParam( + type="tool_result", + tool_use_id="tool_123", + content=[ + ImageBlockParam( + type="image", + source=Base64ImageSourceParam( + type="base64", media_type="image/png", data=image_data + ), + ) + ], + ) + ], + ) + + +def _has_image_in_message(message: MessageParam) -> bool: + """Helper to check if a message contains images.""" + if not isinstance(message.content, list): + return False + + for block in message.content: + if block.type == "image": + return True + if block.type == "tool_result" and isinstance(block.content, list): + for inner_block in block.content: + if inner_block.type == "image": + return True + return False + + +def _has_placeholder_in_message(message: MessageParam) -> bool: + """Helper to check if a message contains image removal placeholders.""" + if not isinstance(message.content, list): + return False + + for block in message.content: + if block.type == "text" and block.text == "[Image removed to save tokens]": + return True + if block.type == "tool_result" and isinstance(block.content, list): + for inner_block in block.content: + if ( + inner_block.type == "text" + and inner_block.text == "[Image removed to save tokens]" + ): + return True + return False + + +class TestLatestImageOnlyTruncationStrategy: + """Tests for LatestImageOnlyTruncationStrategy.""" + + def test_keeps_only_latest_image_in_conversation( + self, caplog: pytest.LogCaptureFixture + ) -> None: + """Test that only the latest image is kept and older ones are replaced.""" + messages = [ + _create_text_message("user", "First message"), + _create_image_message("user", "image1"), + _create_text_message("assistant", "Response 1"), + _create_image_message("user", "image2"), + _create_text_message("assistant", "Response 2"), + _create_image_message("user", "image3"), # This should be kept + ] + + with caplog.at_level(logging.WARNING): + strategy = LatestImageOnlyTruncationStrategy( + tools=None, + system=None, + messages=messages.copy(), + model="claude-3-5-sonnet-20241022", + ) + + # Check warning was logged + assert any( + "experimental LatestImageOnlyTruncationStrategy" in record.message + for record in caplog.records + ) + + result_messages = strategy.messages + + # First image should be replaced with placeholder + assert not _has_image_in_message(result_messages[1]) + assert _has_placeholder_in_message(result_messages[1]) + + # Second image should be replaced with placeholder + assert not _has_image_in_message(result_messages[3]) + assert _has_placeholder_in_message(result_messages[3]) + + # Third image (latest) should be kept + assert _has_image_in_message(result_messages[5]) + assert not _has_placeholder_in_message(result_messages[5]) + + def test_keeps_images_in_latest_message_with_multiple_images(self) -> None: + """Test that all images in the latest message with images are kept.""" + messages = [ + _create_image_message("user", "old_image"), + _create_text_message("assistant", "Response"), + MessageParam( + role="user", + content=[ + ImageBlockParam( + type="image", + source=Base64ImageSourceParam( + type="base64", media_type="image/png", data="new_image1" + ), + ), + ImageBlockParam( + type="image", + source=Base64ImageSourceParam( + type="base64", media_type="image/png", data="new_image2" + ), + ), + ], + ), + ] + + strategy = LatestImageOnlyTruncationStrategy( + tools=None, + system=None, + messages=messages.copy(), + model="claude-3-5-sonnet-20241022", + ) + + result_messages = strategy.messages + + # First message image should be replaced + assert _has_placeholder_in_message(result_messages[0]) + assert not _has_image_in_message(result_messages[0]) + + # Last message should keep both images + assert _has_image_in_message(result_messages[2]) + assert isinstance(result_messages[2].content, list) + image_count = sum( + 1 for block in result_messages[2].content if block.type == "image" + ) + assert image_count == 2 + + def test_handles_images_in_tool_result_blocks(self) -> None: + """Test that images inside tool_result blocks are handled correctly.""" + messages = [ + _create_tool_result_with_image("old_tool_image"), + _create_text_message("assistant", "Response"), + _create_tool_result_with_image("new_tool_image"), # This should be kept + ] + + strategy = LatestImageOnlyTruncationStrategy( + tools=None, + system=None, + messages=messages.copy(), + model="claude-3-5-sonnet-20241022", + ) + + result_messages = strategy.messages + + # First tool_result image should be replaced + assert _has_placeholder_in_message(result_messages[0]) + assert not _has_image_in_message(result_messages[0]) + + # Last tool_result image should be kept + assert _has_image_in_message(result_messages[2]) + assert not _has_placeholder_in_message(result_messages[2]) + + def test_handles_mixed_content_messages(self) -> None: + """Test messages with both text and images.""" + messages = [ + _create_mixed_message("user", "Look at this", "image1"), + _create_text_message("assistant", "I see it"), + _create_mixed_message("user", "Now look at this", "image2"), + ] + + strategy = LatestImageOnlyTruncationStrategy( + tools=None, + system=None, + messages=messages.copy(), + model="claude-3-5-sonnet-20241022", + ) + + result_messages = strategy.messages + + # First message: text should remain, image should be replaced + assert isinstance(result_messages[0].content, list) + text_blocks = [b for b in result_messages[0].content if b.type == "text"] + assert len(text_blocks) == 2 # Original text + placeholder + assert any(b.text == "Look at this" for b in text_blocks) + assert any(b.text == "[Image removed to save tokens]" for b in text_blocks) + assert not _has_image_in_message(result_messages[0]) + + # Last message: both text and image should be kept + assert isinstance(result_messages[2].content, list) + assert any( + b.type == "text" and b.text == "Now look at this" + for b in result_messages[2].content + ) + assert _has_image_in_message(result_messages[2]) + + def test_handles_conversation_with_no_images(self) -> None: + """Test that conversations with no images are not affected.""" + messages = [ + _create_text_message("user", "Hello"), + _create_text_message("assistant", "Hi there"), + _create_text_message("user", "How are you?"), + ] + + strategy = LatestImageOnlyTruncationStrategy( + tools=None, + system=None, + messages=messages.copy(), + model="claude-3-5-sonnet-20241022", + ) + + result_messages = strategy.messages + + # All messages should remain unchanged + assert len(result_messages) == 3 + assert not any(_has_image_in_message(msg) for msg in result_messages) + assert not any(_has_placeholder_in_message(msg) for msg in result_messages) + + def test_handles_single_image_message(self) -> None: + """Test conversation with only one image (should be kept).""" + messages = [ + _create_text_message("user", "Hello"), + _create_image_message("user", "only_image"), + _create_text_message("assistant", "I see the image"), + ] + + strategy = LatestImageOnlyTruncationStrategy( + tools=None, + system=None, + messages=messages.copy(), + model="claude-3-5-sonnet-20241022", + ) + + result_messages = strategy.messages + + # The single image should be kept + assert _has_image_in_message(result_messages[1]) + assert not _has_placeholder_in_message(result_messages[1]) + + def test_preserves_non_image_content_blocks(self) -> None: + """Test that non-image content blocks are preserved correctly.""" + messages = [ + MessageParam( + role="user", + content=[ + TextBlockParam(type="text", text="First text"), + ImageBlockParam( + type="image", + source=Base64ImageSourceParam( + type="base64", media_type="image/png", data="image1" + ), + ), + TextBlockParam(type="text", text="Second text"), + ], + ), + _create_image_message("user", "image2"), + ] + + strategy = LatestImageOnlyTruncationStrategy( + tools=None, + system=None, + messages=messages.copy(), + model="claude-3-5-sonnet-20241022", + ) + + result_messages = strategy.messages + + # First message: text blocks should be preserved, image replaced + assert isinstance(result_messages[0].content, list) + text_blocks = [b for b in result_messages[0].content if b.type == "text"] + assert len(text_blocks) == 3 # Two original texts + placeholder + assert any(b.text == "First text" for b in text_blocks) + assert any(b.text == "Second text" for b in text_blocks) + assert any(b.text == "[Image removed to save tokens]" for b in text_blocks) + + def test_inherits_simple_truncation_behavior(self) -> None: + """Test that the strategy still inherits SimpleTruncationStrategy behavior.""" + # Create a conversation that would trigger truncation + # Use small limits to trigger truncation + messages = [_create_text_message("user", "Message")] * 100 + + strategy = LatestImageOnlyTruncationStrategy( + tools=None, + system=None, + messages=messages.copy(), + model="claude-3-5-sonnet-20241022", + max_messages=50, + message_truncation_threshold=0.5, + ) + + # Strategy should be created successfully and inherit truncation logic + assert isinstance(strategy, SimpleTruncationStrategy) + assert isinstance(strategy, LatestImageOnlyTruncationStrategy) + + +class TestLatestImageOnlyTruncationStrategyFactory: + """Tests for LatestImageOnlyTruncationStrategyFactory.""" + + def test_creates_latest_image_only_strategy(self) -> None: + """Test that factory creates LatestImageOnlyTruncationStrategy instance.""" + factory = LatestImageOnlyTruncationStrategyFactory() + messages = [_create_text_message("user", "Hello")] + + strategy = factory.create_truncation_strategy( + tools=None, + system=None, + messages=messages, + model="claude-3-5-sonnet-20241022", + ) + + assert isinstance(strategy, LatestImageOnlyTruncationStrategy) + assert isinstance(strategy, SimpleTruncationStrategy) + + def test_factory_can_be_instantiated_with_custom_parameters(self) -> None: + """Test that factory accepts custom parameters.""" + custom_max_tokens = 50_000 + custom_threshold = 0.6 + + factory = LatestImageOnlyTruncationStrategyFactory( + max_input_tokens=custom_max_tokens, + input_token_truncation_threshold=custom_threshold, + ) + + messages = [_create_text_message("user", "Hello")] + + # Verify the strategy can be created with custom parameters + strategy = factory.create_truncation_strategy( + tools=None, + system=None, + messages=messages, + model="claude-3-5-sonnet-20241022", + ) + + # Verify it's the correct type + assert isinstance(strategy, LatestImageOnlyTruncationStrategy) + + def test_factory_creates_functional_strategy(self) -> None: + """Test that factory creates a working strategy instance.""" + factory = LatestImageOnlyTruncationStrategyFactory() + messages = [ + _create_image_message("user", "image1"), + _create_image_message("user", "image2"), + ] + + strategy = factory.create_truncation_strategy( + tools=None, + system=None, + messages=messages.copy(), + model="claude-3-5-sonnet-20241022", + ) + + # Test that the strategy actually performs image removal + result_messages = strategy.messages + assert not _has_image_in_message(result_messages[0]) + assert _has_image_in_message(result_messages[1]) + + +class TestSimpleTruncationStrategyFactory: + """Tests for SimpleTruncationStrategyFactory to ensure backwards compatibility.""" + + def test_creates_simple_strategy(self) -> None: + """Test that factory creates SimpleTruncationStrategy instance.""" + factory = SimpleTruncationStrategyFactory() + messages = [_create_text_message("user", "Hello")] + + strategy = factory.create_truncation_strategy( + tools=None, + system=None, + messages=messages, + model="claude-3-5-sonnet-20241022", + ) + + assert isinstance(strategy, SimpleTruncationStrategy) + assert not isinstance(strategy, LatestImageOnlyTruncationStrategy) + + def test_simple_strategy_preserves_all_images(self) -> None: + """Test that SimpleTruncationStrategy does NOT remove images.""" + factory = SimpleTruncationStrategyFactory() + messages = [ + _create_image_message("user", "image1"), + _create_text_message("assistant", "Response"), + _create_image_message("user", "image2"), + ] + + strategy = factory.create_truncation_strategy( + tools=None, + system=None, + messages=messages.copy(), + model="claude-3-5-sonnet-20241022", + ) + + result_messages = strategy.messages + + # All images should be preserved + assert _has_image_in_message(result_messages[0]) + assert _has_image_in_message(result_messages[2]) + # No placeholders should be present + assert not any(_has_placeholder_in_message(msg) for msg in result_messages)