diff --git a/.gitignore b/.gitignore index 246d7df9..600c9ba0 100644 --- a/.gitignore +++ b/.gitignore @@ -166,4 +166,6 @@ reports/ /chat /askui_chat.db .cache/ +.askui_cache/ +playground/* diff --git a/docs/caching.md b/docs/caching.md index d4da680c..df7fae7b 100644 --- a/docs/caching.md +++ b/docs/caching.md @@ -1,6 +1,6 @@ # Caching (Experimental) -**CAUTION: The Caching feature is still in alpha state and subject to change! Use it at your own risk. In case you run into issues, you can disable caching by removing the caching_settings parameter or by explicitly setting the caching_strategy to `no`.** +**CAUTION: The Caching feature is still in alpha state and subject to change! Use it at your own risk. In case you run into issues, you can disable caching by removing the caching_settings parameter or by explicitly setting the strategy to `None`.** The caching mechanism allows you to record and replay agent action sequences (trajectories) for faster and more robust test execution. This feature is particularly useful for regression testing, where you want to replay known-good interaction sequences to verify that your application still behaves correctly. @@ -8,48 +8,93 @@ The caching mechanism allows you to record and replay agent action sequences (tr The caching system works by recording all tool use actions (mouse movements, clicks, typing, etc.) performed by the agent during an `act()` execution. These recorded sequences can then be replayed in subsequent executions, allowing the agent to skip the decision-making process and execute the actions directly. +**New in v0.1:** The caching system now includes advanced features like parameter support for dynamic values, smart handling of non-cacheable tools that require agent intervention, comprehensive message history tracking, and automatic failure detection with recovery capabilities. + +**New in v0.2:** Visual validation using perceptual hashing ensures cached trajectories execute only when the UI state matches expectations. The settings structure has been refactored for better clarity, separating writing settings from execution settings. + ## Caching Strategies +**Updated in v0.2:** Strategy names have been renamed for clarity. + The caching mechanism supports four strategies, configured via the `caching_settings` parameter in the `act()` method: -- **`"no"`** (default): No caching is used. The agent executes normally without recording or replaying actions. -- **`"write"`**: Records all agent actions to a cache file for future replay. -- **`"read"`**: Provides tools to the agent to list and execute previously cached trajectories. -- **`"both"`**: Combines read and write modes - the agent can use existing cached trajectories and will also record new ones. +- **`None`** (default): No caching is used. The agent executes normally without recording or replaying actions. +- **`"record"`**: Records all agent actions to a cache file for future replay. +- **`"execute"`**: Provides tools to the agent to list and execute previously cached trajectories. +- **`"both"`**: Combines execute and record modes - the agent can use existing cached trajectories and will also record new ones. ## Configuration +**Updated in v0.2:** Caching settings have been refactored for better clarity, separating writing-related settings from execution-related settings. + Caching is configured using the `CachingSettings` class: ```python -from askui.models.shared.settings import CachingSettings, CachedExecutionToolSettings +from askui.models.shared.settings import ( + CachingSettings, + CacheWritingSettings, + CacheExecutionSettings, +) caching_settings = CachingSettings( - strategy="write", # One of: "read", "write", "both", "no" + strategy="both", # One of: "execute", "record", "both", or None cache_dir=".cache", # Directory to store cache files - filename="my_test.json", # Filename for the cache file (optional for write mode) - execute_cached_trajectory_tool_settings=CachedExecutionToolSettings( - delay_time_between_action=0.5 # Delay in seconds between each cached action - ) + writing_settings=CacheWritingSettings( + filename="my_test.json", # Cache file name + parameter_identification_strategy="llm", # Auto-detect dynamic values + visual_verification_method="phash", # Visual validation method + visual_validation_region_size=100, # Size of validation region (pixels) + visual_validation_threshold=10, # Hamming distance threshold + ), + execution_settings=CacheExecutionSettings( + delay_time_between_action=0.5 # Delay in seconds between each action + ), ) ``` ### Parameters -- **`strategy`**: The caching strategy to use (`"read"`, `"write"`, `"both"`, or `"no"`). +- **`strategy`**: The caching strategy to use (`"execute"`, `"record"`, `"both"`, or `None`). **Updated in v0.2:** Renamed from "read"/"write"/"no" to "execute"/"record"/None for clarity. - **`cache_dir`**: Directory where cache files are stored. Defaults to `".cache"`. -- **`filename`**: Name of the cache file to write to or read from. If not specified in write mode, a timestamped filename will be generated automatically (format: `cached_trajectory_YYYYMMDDHHMMSSffffff.json`). -- **`execute_cached_trajectory_tool_settings`**: Configuration for the trajectory execution tool (optional). See [Execution Settings](#execution-settings) below. +- **`writing_settings`**: **New in v0.2!** Configuration for cache recording. See [Writing Settings](#writing-settings) below. Can be `None` if only executing caches. +- **`execution_settings`**: **New in v0.2!** Configuration for cache execution. See [Execution Settings](#execution-settings) below. Can be `None` if only recording caches. + +### Writing Settings + +**New in v0.2!** The `CacheWritingSettings` class configures how cache files are recorded: + +```python +from askui.models.shared.settings import CacheWritingSettings + +writing_settings = CacheWritingSettings( + filename="my_test.json", # Name of cache file to create + parameter_identification_strategy="llm", # "llm" or "preset" + visual_verification_method="phash", # "phash", "ahash", or "none" + visual_validation_region_size=100, # Size of region to validate (pixels) + visual_validation_threshold=10, # Hamming distance threshold (0-64) +) +``` + +#### Parameters + +- **`filename`**: Name of the cache file to write. Defaults to `""` (auto-generates timestamped filename: `cached_trajectory_YYYYMMDDHHMMSSffffff.json`). +- **`parameter_identification_strategy`**: When `"llm"` (default), uses AI to automatically identify and parameterize dynamic values like dates, usernames, and IDs. When `"preset"`, only manually specified parameters using `{{...}}` syntax are detected. See [Automatic Parameter Identification](#automatic-parameter-identification). +- **`visual_verification_method`**: **New in v0.2!** Visual validation method to use: + - `"phash"` (default): Perceptual hash using DCT - robust to minor changes like compression and lighting + - `"ahash"`: Average hash - simpler and faster, less robust to transformations + - `"none"`: Disable visual validation +- **`visual_validation_region_size`**: **New in v0.2!** Size of the square region (in pixels) to extract around interaction coordinates for visual validation. Defaults to `100` (100×100 pixel region). +- **`visual_validation_threshold`**: **New in v0.2!** Maximum Hamming distance (0-64) between stored and current visual hashes to consider a match. Lower values require closer matches. Defaults to `10`. ### Execution Settings -The `CachedExecutionToolSettings` class allows you to configure how cached trajectories are executed: +The `CacheExecutionSettings` class configures how cached trajectories are executed: ```python -from askui.models.shared.settings import CachedExecutionToolSettings +from askui.models.shared.settings import CacheExecutionSettings -execution_settings = CachedExecutionToolSettings( - delay_time_between_action=0.5 # Delay in seconds between each action (default: 0.5) +execution_settings = CacheExecutionSettings( + delay_time_between_action=0.5 # Delay in seconds between each action ) ``` @@ -63,28 +108,35 @@ You can adjust this value based on your application's responsiveness: ## Usage Examples -### Writing a Cache (Recording) +### Recording a Cache Record agent actions to a cache file for later replay: ```python from askui import VisionAgent -from askui.models.shared.settings import CachingSettings +from askui.models.shared.settings import CachingSettings, CacheWritingSettings with VisionAgent() as agent: agent.act( goal="Fill out the login form with username 'admin' and password 'secret123'", caching_settings=CachingSettings( - strategy="write", + strategy="record", cache_dir=".cache", - filename="login_test.json" + writing_settings=CacheWritingSettings( + filename="login_test.json", + visual_verification_method="phash", # Enable visual validation + ), ) ) ``` -After execution, a cache file will be created at `.cache/login_test.json` containing all the tool use actions performed by the agent. +After execution, a cache file will be created at `.cache/login_test.json` containing: +- All tool use actions performed by the agent +- Metadata about the execution +- **New in v0.2:** Visual validation hashes for click and type actions +- Automatically detected cache parameters (if any) -### Reading from Cache (Replaying) +### Executing from Cache Provide the agent with access to previously recorded trajectories: @@ -96,22 +148,103 @@ with VisionAgent() as agent: agent.act( goal="Fill out the login form", caching_settings=CachingSettings( - strategy="read", + strategy="execute", + cache_dir=".cache" + ) + ) +``` + +When using `strategy="execute"`, the agent receives two tools: + +1. **`RetrieveCachedTestExecutions`**: Lists all available cache files in the cache directory +2. **`ExecuteCachedTrajectory`**: Executes a cached trajectory. Can start from the beginning (default) or continue from a specific step index using the optional `start_from_step_index` parameter (useful after handling non-cacheable steps) + +The agent will automatically check if a relevant cached trajectory exists and use it if appropriate. During execution, the agent can see all screenshots and results in the message history. After executing a cached trajectory, the agent will verify the results and make corrections if needed. + +### Using Cache Parameters for Dynamic Values + +**New in v0.1:** Trajectories can contain cache_parameters for dynamic values that change between executions: + +```python +from askui import VisionAgent +from askui.models.shared.settings import CachingSettings + +# When recording, use dynamic values as normal +# The system automatically detects patterns like dates and user-specific data +with VisionAgent() as agent: + agent.act( + goal="Create a new task for today with the title 'Review PR'", + caching_settings=CachingSettings( + strategy="record", + cache_dir=".cache", + writing_settings=CacheWritingSettings( + filename="create_task.json" + ) + ) + ) + +# Later, when executing, the agent can provide parameter values +# If the cache file contains {{current_date}} or {{task_title}}, provide them: +with VisionAgent() as agent: + agent.act( + goal="Create a task using the cached flow", + caching_settings=CachingSettings( + strategy="execute", + cache_dir=".cache" + ) + ) + # The agent will automatically detect required cache_parameters and can provide them + # via the parameter_values parameter when calling ExecuteCachedTrajectory +``` + +Cache Parameters use the syntax `{{variable_name}}` and are automatically detected during cache file creation. When executing a trajectory with cache_parameters, the agent must provide values for all required cache_parameters. + +### Handling Non-Cacheable Steps + +**New in v0.1:** Some tools cannot be cached and require the agent to execute them live. Examples include debugging tools, contextual decisions, or tools that depend on runtime state. + +```python +from askui import VisionAgent +from askui.models.shared.settings import CachingSettings + +with VisionAgent() as agent: + agent.act( + goal="Debug the login form by checking element states", + caching_settings=CachingSettings( + strategy="execute", cache_dir=".cache" ) ) + # If the cached trajectory contains non-cacheable steps: + # 1. Execution pauses when reaching the non-cacheable step + # 2. Agent receives NEEDS_AGENT status with current step index + # 3. Agent executes the non-cacheable step manually + # 4. Agent uses ExecuteCachedTrajectory with start_from_step_index to resume ``` -When using `strategy="read"`, the agent receives two additional tools: +Tools can be marked as non-cacheable by setting `is_cacheable=False` in their definition. When trajectory execution reaches a non-cacheable tool, it pauses and returns control to the agent for manual execution. -1. **`retrieve_available_trajectories_tool`**: Lists all available cache files in the cache directory -2. **`execute_cached_executions_tool`**: Executes a specific cached trajectory +### Continuing from a Specific Step -The agent will automatically check if a relevant cached trajectory exists and use it if appropriate. After executing a cached trajectory, the agent will verify the results and make corrections if needed. +**New in v0.1:** After handling a non-cacheable step or recovering from a failure, the agent can continue execution from a specific step index using the `start_from_step_index` parameter: + +```python +# The agent uses ExecuteCachedTrajectory with start_from_step_index like this: +result = execute_cached_trajectory_tool( + trajectory_file=".cache/my_test.json", + start_from_step_index=5, # Continue from step 5 + parameter_values={"date": "2025-12-11"} # Provide any required cache_parameters +) +``` + +This is particularly useful for: +- Resuming after manual execution of non-cacheable steps +- Recovering from partial failures +- Skipping steps that are no longer needed ### Referencing Cache Files in Goal Prompts -When using `strategy="read"` or `strategy="both"`, you need to inform the agent about which cache files are available and when to use them. This is done by including cache file information directly in your goal prompt. +When using `strategy="execute"` or `strategy="both"`, you need to inform the agent about which cache files are available and when to use them. This is done by including cache file information directly in your goal prompt. #### Explicit Cache File References @@ -124,11 +257,11 @@ from askui.models.shared.settings import CachingSettings with VisionAgent() as agent: agent.act( goal="""Open the website in Google Chrome. - - If the cache file "open_website_in_chrome.json" is available, please use it + + If the cache file "open_website_in_chrome.json" is available, please use it for this execution. It will open a new window in Chrome and navigate to the website.""", caching_settings=CachingSettings( - strategy="read", + strategy="execute", cache_dir=".cache" ) ) @@ -147,11 +280,11 @@ test_id = "TEST_001" with VisionAgent() as agent: agent.act( goal=f"""Execute test {test_id} according to the test definition. - - Check if a cache file named "{test_id}.json" exists. If it does, use it to + + Check if a cache file named "{test_id}.json" exists. If it does, use it to replay the test actions, then verify the results.""", caching_settings=CachingSettings( - strategy="read", + strategy="execute", cache_dir="test_cache" ) ) @@ -168,12 +301,12 @@ from askui.models.shared.settings import CachingSettings with VisionAgent() as agent: agent.act( goal="""Fill out the user registration form. - - Look for cache files that match the pattern "user_registration_*.json". - Choose the most recent one if multiple are available, as it likely contains + + Look for cache files that match the pattern "user_registration_*.json". + Choose the most recent one if multiple are available, as it likely contains the most up-to-date interaction sequence.""", caching_settings=CachingSettings( - strategy="read", + strategy="execute", cache_dir=".cache" ) ) @@ -190,14 +323,14 @@ from askui.models.shared.settings import CachingSettings with VisionAgent() as agent: agent.act( goal="""Complete the full checkout process: - + 1. If "login.json" exists, use it to log in 2. If "add_to_cart.json" exists, use it to add items to cart 3. If "checkout.json" exists, use it to complete the checkout - + After each cached execution, verify the step completed successfully before proceeding.""", caching_settings=CachingSettings( - strategy="read", + strategy="execute", cache_dir=".cache" ) ) @@ -215,15 +348,15 @@ You can customize the delay between cached actions to match your application's r ```python from askui import VisionAgent -from askui.models.shared.settings import CachingSettings, CachedExecutionToolSettings +from askui.models.shared.settings import CachingSettings, CacheExecutionSettings with VisionAgent() as agent: agent.act( goal="Fill out the login form", caching_settings=CachingSettings( - strategy="read", + strategy="execute", cache_dir=".cache", - execute_cached_trajectory_tool_settings=CachedExecutionToolSettings( + execution_settings=CacheExecutionSettings( delay_time_between_action=1.0 # Wait 1 second between each action ) ) @@ -261,108 +394,1018 @@ In this mode: ## Cache File Format -Cache files are JSON files containing an array of tool use blocks. Each block represents a single tool invocation with the following structure: +**New in v0.1:** Cache files now use an enhanced format with metadata tracking, parameter support, and execution history. + +**New in v0.2:** Cache files include visual validation metadata and enhanced trajectory steps with visual hashes. + +### v0.2 Format (Current) + +Cache files are JSON objects with the following structure: ```json -[ +{ + "metadata": { + "version": "0.2", + "created_at": "2025-12-30T10:30:00Z", + "goal": "Greet user {{user_name}} and log them in", + "last_executed_at": "2025-12-30T15:45:00Z", + "token_usage": { + "input_tokens": 1250, + "output_tokens": 380, + "cache_creation_input_tokens": 0, + "cache_read_input_tokens": 0 + }, + "execution_attempts": 3, + "failures": [ + { + "timestamp": "2025-12-30T14:20:00Z", + "step_index": 5, + "error_message": "Visual validation failed: UI region changed", + "failure_count_at_step": 1 + } + ], + "is_valid": true, + "invalidation_reason": null, + "visual_verification_method": "phash", + "visual_validation_region_size": 100, + "visual_validation_threshold": 10 + }, + "trajectory": [ { - "type": "tool_use", - "id": "toolu_01AbCdEfGhIjKlMnOpQrStUv", - "name": "computer", - "input": { - "action": "mouse_move", - "coordinate": [150, 200] - } + "type": "tool_use", + "id": "toolu_01AbCdEfGhIjKlMnOpQrStUv", + "name": "computer", + "input": { + "action": "left_click", + "coordinate": [450, 320] + }, + "visual_representation": "80c0e3f3e3e7e381c7c78f1f3f3f7f7e" }, { - "type": "tool_use", - "id": "toolu_02AbCdEfGhIjKlMnOpQrStUv", - "name": "computer", - "input": { - "action": "left_click" - } + "type": "tool_use", + "id": "toolu_02XyZaBcDeFgHiJkLmNoPqRs", + "name": "computer", + "input": { + "action": "type", + "text": "Hello {{user_name}}!" + }, + "visual_representation": "91d1f4e4d4c6c282c6c79e2e4e4e6e6d" }, { - "type": "tool_use", - "id": "toolu_03AbCdEfGhIjKlMnOpQrStUv", - "name": "computer", - "input": { - "action": "type", - "text": "admin" - } + "type": "tool_use", + "id": "toolu_03StUvWxYzAbCdEfGhIjKlMn", + "name": "print_debug_info", + "input": {}, + "visual_representation": null } -] + ], + "cache_parameters": { + "user_name": "Name of the user to greet" + } +} ``` -Note: Screenshot actions are excluded from cached trajectories as they don't modify the UI state. +**Note:** In the example above, `print_debug_info` is marked as non-cacheable (`is_cacheable=False`), so its `input` field is blank (`{}`). This saves space and privacy since non-cacheable tools aren't executed from cache anyway. + +#### Metadata Fields + +- **`version`**: Cache file format version (currently "0.2") +- **`created_at`**: ISO 8601 timestamp when the cache was created +- **`goal`**: The original goal/instruction given to the agent when recording this trajectory. Cache Parameters are applied to the goal text just like in the trajectory, making it easy to understand what the cache was designed to accomplish. +- **`last_executed_at`**: ISO 8601 timestamp of the last execution (null if never executed) +- **`token_usage`**: **New in v0.1!** Token usage statistics from the recording execution +- **`execution_attempts`**: Number of times this trajectory has been executed +- **`failures`**: List of failures encountered during execution (see [Failure Tracking](#failure-tracking)) +- **`is_valid`**: Boolean indicating if the cache is still considered valid +- **`invalidation_reason`**: Optional string explaining why the cache was invalidated +- **`visual_verification_method`**: **New in v0.2!** Visual validation method used when recording (`"phash"`, `"ahash"`, or `null`) +- **`visual_validation_region_size`**: **New in v0.2!** Size of the validation region in pixels (e.g., `100` for 100×100 pixels) +- **`visual_validation_threshold`**: **New in v0.2!** Hamming distance threshold for visual validation (0-64) + +#### Cache Parameters + +The `cache_parameters` object maps parameter names to their descriptions. Cache Parameters in the trajectory use the syntax `{{parameter_name}}` and must be substituted with actual values during execution. + +#### Failure Tracking + +Each failure record contains: +- **`timestamp`**: When the failure occurred +- **`step_index`**: Which step failed (0-indexed) +- **`error_message`**: The error that occurred +- **`failure_count_at_step`**: How many times this specific step has failed + +This information helps with cache invalidation decisions and debugging. ## How It Works -### Write Mode +### Internal Architecture -In write mode, the `CacheWriter` class: +The caching system consists of several key components: + +- **`CacheWriter`**: Handles recording trajectories in record mode +- **`CacheExecutionManager`**: Manages cache execution state, flow control, and metadata updates during trajectory replay +- **`TrajectoryExecutor`**: Executes individual steps from cached trajectories +- **Agent**: Orchestrates the conversation flow and delegates cache execution to `CacheExecutionManager` + +When executing a cached trajectory, the `Agent` class delegates all cache-related logic to `CacheExecutionManager`, which handles: +- State management (execution mode, verification pending, etc.) +- Execution flow control (success, failure, needs agent, completed) +- Message history building and injection +- Metadata updates (execution attempts, failures, invalidation) + +This separation of concerns keeps the Agent focused on conversation orchestration while CacheExecutionManager handles all caching complexity. + +### Record Mode + +In record mode, the `CacheWriter` class: 1. Intercepts all assistant messages via a callback function 2. Extracts tool use blocks from the messages -3. Stores them in memory during execution -4. Writes them to a JSON file when the agent finishes (on `stop_reason="end_turn"`) -5. Automatically skips writing if a cached execution was used (to avoid recording replays) +3. **Enhances with visual validation** (New in v0.2): + - For click and type actions, captures screenshot before execution + - Extracts region around interaction coordinate + - Computes perceptual hash using selected method (pHash/aHash) + - Attaches hash and validation settings to tool block +4. Stores enhanced tool blocks in memory during execution +5. When agent finishes (on `stop_reason="end_turn"`): + - **Automatically identifies cache_parameters** using AI (if `parameter_identification_strategy=llm`) + - Analyzes trajectory to find dynamic values (dates, usernames, IDs, etc.) + - Generates descriptive parameter definitions + - Replaces identified values with `{{parameter_name}}` syntax in trajectory + - Applies same replacements to the goal text + - **Blanks non-cacheable tool inputs** by setting `input: {}` for tools with `is_cacheable=False` (saves space and privacy) + - **Writes to JSON file** with: + - v0.2 metadata (version, timestamps, goal, token usage, visual validation settings) + - Trajectory of tool use blocks (with cache_parameters, visual hashes, and blanked inputs) + - Parameter definitions with descriptions +6. Automatically skips writing if a cached execution was used (to avoid recording replays) + +### Execute Mode + +In execute mode: + +1. Two caching tools are added to the agent's toolbox: + - `RetrieveCachedTestExecutions`: Lists available trajectories + - `ExecuteCachedTrajectory`: Executes from the beginning or continues from a specific step using `start_from_step_index` +2. A special system prompt (`CACHE_USE_PROMPT`) instructs the agent on: + - How to use trajectories + - Parameter handling + - Non-cacheable step management + - Failure recovery strategies +3. The agent can list available cache files and choose appropriate ones +4. During execution via `TrajectoryExecutor`: + - **Visual validation** (New in v0.2): Before each validated step, captures current UI and compares hash to stored hash + - Each step is executed sequentially with configurable delays + - All tools in the trajectory are executed, including screenshots and retrieval tools + - Non-cacheable tools trigger a pause with `NEEDS_AGENT` status + - Cache Parameters are validated and substituted before execution + - Message history is built with assistant (tool use) and user (tool result) messages + - Agent sees all screenshots and results in the message history +5. Execution can pause for agent intervention: + - When visual validation fails (New in v0.2) + - When reaching non-cacheable tools + - When errors occur (with failure details) +6. Agent can resume execution: + - Using `ExecuteCachedTrajectory` with `start_from_step_index` from the pause point + - Skipping failed or irrelevant steps +7. Results are verified by the agent, with corrections made as needed + +### Message History + +**New in v0.1:** During cached trajectory execution, a complete message history is built and returned to the agent. This includes: + +- **Assistant messages**: Containing `ToolUseBlockParam` for each action +- **User messages**: Containing `ToolResultBlockParam` with: + - Text results from tool execution + - Screenshots (when available) + - Error messages (on failure) + +This visibility allows the agent to: +- See the current UI state via screenshots +- Understand what actions were taken +- Detect when execution has diverged from expectations +- Make informed decisions about corrections or retries + +### Non-Cacheable Tools + +Tools can be marked as non-cacheable by setting `is_cacheable=False` in their definition: + +```python +from askui.models.shared.tools import Tool + +class DebugPrintTool(Tool): + name = "print_debug" + description = "Print debug information about current state" + is_cacheable = False # This tool requires agent context + + def __call__(self, message: str) -> str: + # Tool implementation... + pass +``` + +During trajectory execution, when a non-cacheable tool is encountered: + +1. `TrajectoryExecutor` pauses execution +2. Returns `ExecutionResult` with status `NEEDS_AGENT` +3. Includes current step index and message history +4. Agent receives control to execute the step manually +5. Agent uses `ExecuteCachedTrajectory` with `start_from_step_index` to resume from next step + +This mechanism is essential for tools that: +- Require runtime context (debugging, inspection) +- Make decisions based on current state +- Have side effects that shouldn't be blindly replayed +- Depend on external systems that may have changed + +## Failure Handling + +**New in v0.1:** Enhanced failure handling provides the agent with detailed information about what went wrong and where. + +### When Execution Fails + +If a step fails during trajectory execution: + +1. Execution stops at the failed step +2. `ExecutionResult` includes: + - Status: `FAILED` + - `step_index`: Which step failed + - `error_message`: The specific error + - `message_history`: All actions and results up to the failure +3. Failure is recorded in cache metadata for tracking +4. Agent receives the failure information and can decide: + - **Retry**: Execute remaining steps manually + - **Resume**: Fix the issue and use `ExecuteCachedTrajectory` with `start_from_step_index` from next step + - **Abort**: Report that cache needs re-recording + +### Failure Tracking + +Cache metadata tracks all failures: +```json +"failures": [ + { + "timestamp": "2025-12-11T14:20:00Z", + "step_index": 5, + "error_message": "Element not found: login button", + "failure_count_at_step": 2 + } +] +``` + +This information enables: +- Smart cache invalidation (too many failures → invalid cache) +- Debugging (which steps are problematic) +- Metrics (cache reliability over time) +- Auto-recovery strategies (skip commonly failing steps) + +### Agent Recovery Options + +The agent has several recovery strategies: + +1. **Manual Execution**: Execute remaining steps without cache +2. **Partial Resume**: Fix the issue (e.g., wait for element) then continue from next step +3. **Skip and Continue**: Skip the failed step and continue from a later step +4. **Report Invalid**: Mark the cache as outdated and request re-recording + +Example agent decision flow: +``` +Trajectory fails at step 5: "Element not found: submit button" +↓ +Agent takes screenshot to assess current state +↓ +Agent sees submit button is present but has different text +↓ +Agent clicks the button manually +↓ +Agent calls ExecuteCachedTrajectory(start_from_step_index=6) +↓ +Execution continues successfully +``` + +## Cache Parameters + +**New in v0.1:** Cache Parameters enable dynamic value substitution in cached trajectories. + +### Parameter Syntax + +Cache Parameters use double curly braces: `{{parameter_name}}` + +Valid parameter names: +- Must start with a letter or underscore +- Can contain letters, numbers, and underscores +- Examples: `{{date}}`, `{{user_name}}`, `{{order_id_123}}` + +### Automatic Cache Parameter Identification + +**New in v0.1!** The caching system uses AI to automatically identify and parameterize dynamic values when recording trajectories. + +#### How It Works + +When `parameter_identification_strategy=llm` (the default), the system: + +1. **Records the trajectory** as normal during agent execution +2. **Analyzes the trajectory** using an LLM to identify dynamic values such as: + - Dates and timestamps (e.g., "2025-12-11", "10:30 AM") + - Usernames, emails, names (e.g., "john.doe", "test@example.com") + - Session IDs, tokens, UUIDs, API keys + - Dynamic text referencing current state or time + - File paths with user-specific or time-specific components + - Temporary or generated identifiers +3. **Generates parameter definitions** with descriptive names and documentation: + ```json + { + "name": "current_date", + "value": "2025-12-11", + "description": "Current date in YYYY-MM-DD format" + } + ``` +4. **Replaces values with cache_parameters** in both the trajectory AND the goal: + - Original: `"text": "Login as john.doe"` + - Result: `"text": "Login as {{username}}"` +5. **Saves the templated trajectory** to the cache file + +#### Benefits + +✅ **No manual work** - Automatically identifies dynamic values +✅ **Smart detection** - LLM understands semantic meaning (dates vs coordinates) +✅ **Descriptive** - Generates helpful descriptions for each parameter +✅ **Applies to goal** - Goal text also gets parameter replacement + +#### What Gets Detected + +The AI identifies values that are likely to change between executions: + +**Will be detected as cache_parameters:** +- Dates: "2025-12-11", "Dec 11, 2025", "12/11/2025" +- Times: "10:30 AM", "14:45:00", "2025-12-11T10:30:00Z" +- Usernames: "john.doe", "admin_user", "test_account" +- Emails: "user@example.com", "test@domain.org" +- IDs: "uuid-1234-5678", "session_abc123", "order_9876" +- Names: "John Smith", "Jane Doe" +- Dynamic text: "Today is 2025-12-11", "Logged in as john.doe" + +**Will NOT be detected as cache_parameters:** +- UI coordinates: `{"x": 100, "y": 200}` +- Fixed button labels: "Submit", "Cancel", "OK" +- Configuration values: `{"timeout": 30, "retries": 3}` +- Generic actions: "click", "type", "scroll" +- Boolean values: `true`, `false` + +#### Disabling Auto-Identification + +If you prefer manual parameter control: + +```python +caching_settings = CachingSettings( + strategy="record", + writing_settings=CacheWritingSettings( + parameter_identification_strategy="preset" # Only detect {{...}} syntax + ) +) +``` + +With `parameter_identification_strategy="preset"`, only manually specified cache_parameters using the `{{...}}` syntax will be detected. + +#### Logging + +To see what cache_parameters are being identified, enable INFO-level logging: + +```python +import logging +logging.basicConfig(level=logging.INFO) +``` + +You'll see output like: +``` +INFO: Using LLM to identify cache_parameters in trajectory +INFO: Identified 3 cache_parameters in trajectory +DEBUG: - current_date: 2025-12-11 (Current date in YYYY-MM-DD format) +DEBUG: - username: john.doe (Username for login) +DEBUG: - session_id: abc123 (Session identifier) +INFO: Replaced 3 parameter values in trajectory +INFO: Applied parameter replacement to goal: Login as john.doe -> Login as {{username}} +``` + +### Manual Cache Parameters + +You can also manually create cache_parameters when recording by using the syntax in your goal description. The system will preserve `{{...}}` patterns in tool inputs. + +### Providing Parameter Values + +When executing a trajectory with cache_parameters, the agent must provide values: + +```python +# Via ExecuteCachedTrajectory +result = execute_cached_trajectory_tool( + trajectory_file=".cache/my_test.json", + parameter_values={ + "current_date": "2025-12-11", + "user_email": "test@example.com" + } +) + +# Via ExecuteCachedTrajectory with start_from_step_index +result = execute_cached_trajectory_tool( + trajectory_file=".cache/my_test.json", + start_from_step_index=3, # Continue from step 3 + parameter_values={ + "current_date": "2025-12-11", + "user_email": "test@example.com" + } +) +``` + +### Parameter Validation + +Before execution, the system validates that: +- All required cache_parameters have values provided +- No required cache_parameters are missing -### Read Mode +If validation fails, execution is aborted with a clear error message listing missing cache_parameters. -In read mode: +### Use Cases -1. Two caching tools are added to the agent's toolbox -2. A special system prompt (`CACHE_USE_PROMPT`) is appended to instruct the agent on how to use trajectories -3. The agent can call `retrieve_available_trajectories_tool` to see available cache files -4. The agent can call `execute_cached_executions_tool` with a trajectory file path to replay it -5. During replay, each tool use block is executed sequentially with a configurable delay between actions (default: 0.5 seconds) -6. Screenshot and trajectory retrieval tools are skipped during replay -7. The agent is instructed to verify results after replay and make corrections if needed +Cache Parameters are particularly useful for: +- **Date-dependent workflows**: Testing with current/future dates +- **User-specific actions**: Different users, emails, names +- **Order/transaction IDs**: Testing with different identifiers +- **Environment-specific values**: API endpoints, credentials +- **Parameterized testing**: Running same flow with different data -The delay between actions can be customized using `CachedExecutionToolSettings` to accommodate different application response times. +Example: +```json +{ + "name": "computer", + "input": { + "action": "type", + "text": "Schedule meeting for {{meeting_date}} with {{attendee_email}}" + } +} +``` + +## Visual Validation + +**New in v0.2!** Visual validation ensures cached trajectories execute only when the UI state matches the recorded state, preventing actions from being executed on incorrect UI elements. + +### How It Works + +During cache recording (record mode), the system: +1. **Captures screenshots** before each interaction (clicks, typing, key presses) +2. **Extracts a region** (e.g., 100×100 pixels) around the interaction coordinate +3. **Computes a perceptual hash** of that region using the selected method +4. **Stores the hash** in the trajectory step along with validation settings + +During cache execution (execute mode), the system: +1. **Captures the current UI state** before each step +2. **Extracts the same region** around the interaction coordinate +3. **Computes the hash** of the current region +4. **Compares hashes** using Hamming distance +5. **Validates the match** against the threshold +6. **Executes the step** only if validation passes, otherwise returns control to the agent + +### Visual Validation Methods + +#### pHash (Perceptual Hash) + +Default method using Discrete Cosine Transform (DCT): + +```python +writing_settings=CacheWritingSettings( + visual_verification_method="phash", # Default + visual_validation_region_size=100, + visual_validation_threshold=10, +) +``` + +**Characteristics:** +- ✅ Robust to minor changes (compression, scaling, lighting adjustments) +- ✅ Sensitive to structural changes (moved buttons, different layouts) +- ✅ Best for most use cases +- ⚠️ Slightly slower than aHash + +**When to use:** +- Production environments where UI may have subtle variations +- Cross-platform testing (different rendering engines) +- Long-lived caches that may encounter minor UI updates + +#### aHash (Average Hash) + +Simpler method using mean pixel values: + +```python +writing_settings=CacheWritingSettings( + visual_verification_method="ahash", + visual_validation_region_size=100, + visual_validation_threshold=10, +) +``` + +**Characteristics:** +- ✅ Fast computation +- ✅ Simple and predictable +- ⚠️ Less robust to transformations +- ⚠️ More sensitive to color/brightness changes + +**When to use:** +- Development/testing environments with controlled conditions +- Performance-critical scenarios +- UI that rarely changes + +#### Disabled + +Disable visual validation entirely: + +```python +writing_settings=CacheWritingSettings( + visual_verification_method="none", +) +``` + +**When to use:** +- UI that never changes +- Testing the caching system itself +- Debugging trajectory execution + +### Configuration Options + +#### Region Size + +The `visual_validation_region_size` parameter controls the size of the square region extracted around each interaction coordinate: + +```python +writing_settings=CacheWritingSettings( + visual_validation_region_size=50, # 50×50 pixel region (smaller, faster) + # visual_validation_region_size=100, # 100×100 pixel region (default, balanced) + # visual_validation_region_size=200, # 200×200 pixel region (larger, more context) +) +``` + +**Smaller regions (50-75 pixels):** +- ✅ Faster processing +- ✅ More focused validation (just the element) +- ⚠️ May miss context changes -## Limitations +**Larger regions (150-200 pixels):** +- ✅ Captures more UI context +- ✅ Detects broader layout changes +- ⚠️ Slower processing +- ⚠️ More sensitive to unrelated UI changes -- **UI State Sensitivity**: Cached trajectories assume the UI is in the same state as when they were recorded. If the UI has changed, the replay may fail or produce incorrect results. -- **No on_message Callback**: When using `strategy="write"` or `strategy="both"`, you cannot provide a custom `on_message` callback, as the caching system uses this callback to record actions. +**Default (100 pixels):** +- Balanced between speed and context +- Suitable for most use cases + +#### Validation Threshold + +The `visual_validation_threshold` parameter controls how similar the UI must be (Hamming distance, 0-64): + +```python +writing_settings=CacheWritingSettings( + visual_validation_threshold=5, # Strict: requires very close match + # visual_validation_threshold=10, # Default: balanced + # visual_validation_threshold=20, # Lenient: allows more variation +) +``` + +**Lower thresholds (0-5):** +- Very strict matching +- Fails on minor UI changes +- Best for pixel-perfect UIs + +**Medium thresholds (8-15):** +- Balanced sensitivity +- Tolerates minor variations +- **Default: 10** + +**Higher thresholds (20-30):** +- Lenient matching +- May allow too much variation +- Risk of false positives + +### Validated Actions + +Visual validation is applied to actions that interact with specific UI coordinates: + +**Validated automatically:** +- `left_click` +- `right_click` +- `double_click` +- `middle_click` +- `type` (validates input field location) +- `key` (validates focus location) + +**NOT validated:** +- `mouse_move` (movement doesn't require validation) +- `screenshot` (no UI interaction) +- Non-computer tools +- Tools marked as `is_cacheable=False` + +### Handling Validation Failures + +When visual validation fails during cache execution: + +1. **Execution stops** at the failed step +2. **Agent receives notification** with details: + - Which step failed + - The validation error message + - Current message history and screenshots +3. **Agent can decide**: + - Take a screenshot to assess current UI state + - Execute the step manually if safe + - Skip the step and continue + - Invalidate the cache and request re-recording + +Example agent recovery flow: +``` +Step 5 validation fails: "Visual validation failed: UI region changed (distance: 15 > threshold: 10)" +↓ +Agent takes screenshot to see current state +↓ +Agent sees button is present but slightly moved +↓ +Agent clicks button manually at new location +↓ +Agent continues execution from step 6 +``` + +### Best Practices + +1. **Choose the right method:** + - Use `phash` (default) for most cases + - Use `ahash` only for controlled environments + - Never use `none` in production + +2. **Tune the threshold:** + - Start with default (10) + - Increase if getting too many false failures + - Decrease if allowing incorrect executions + +3. **Adjust region size:** + - Use default (100) initially + - Increase for complex layouts + - Decrease for simple, isolated elements + +4. **Monitor validation logs:** + - Enable INFO logging to see validation results + - Track failure patterns + - Adjust settings based on failure analysis + +5. **Re-record when needed:** + - After significant UI changes + - When validation consistently fails + - After threshold/region adjustments + +### Logging + +Enable INFO-level logging to see visual validation activity: + +```python +import logging +logging.basicConfig(level=logging.INFO) +``` + +During **recording**, you'll see: +``` +INFO: ✓ Visual validation added to computer action=left_click at coordinate (450, 320) (hash=80c0e3f3e3e7e381...) +INFO: ✓ Visual validation added to computer action=type at coordinate (450, 380) (hash=91d1f4e4d4c6c282...) +``` + +During **execution**, validation happens silently on success. On **failure**, you'll see: +``` +WARNING: Visual validation failed at step 5: Visual validation failed: UI region changed significantly (Hamming distance: 15 > threshold: 10) +WARNING: Handing execution back to agent. +``` + +## Limitations and Considerations + +### Current Limitations + +- **UI State Sensitivity**: Cached trajectories assume the UI is in the same state as when they were recorded. If the UI has changed significantly, replay may fail. +- **No on_message Callback**: When using `strategy="record"` or `strategy="both"`, you cannot provide a custom `on_message` callback, as the caching system uses this callback to record actions. - **Verification Required**: After executing a cached trajectory, the agent should verify that the results are correct, as UI changes may cause partial failures. +### Best Practices + +1. **Always Verify Results**: After cached execution, verify the outcome matches expectations +2. **Handle Failures Gracefully**: Provide clear recovery paths when trajectories fail +3. **Use Cache Parameters Wisely**: Identify dynamic values that should be parameterized +4. **Mark Non-Cacheable Tools**: Properly mark tools that require agent intervention +5. **Monitor Cache Validity**: Track execution attempts and failures to identify stale caches +6. **Test Cache Replay**: Periodically test that cached trajectories still work +7. **Version Your Caches**: Use descriptive filenames or directories for different app versions +8. **Adjust Delays**: Tune `delay_time_between_action` based on your app's responsiveness + +### When to Re-Record + +Consider re-recording a cached trajectory when: +- UI layout or element positions have changed significantly +- Workflow steps have been added, removed, or reordered +- Failures occur consistently at the same steps +- Execution takes significantly longer than expected +- The cache has been marked invalid due to failure patterns + +## Migration from v0.1 to v0.2 + +v0.2 introduces visual validation and refactored settings structure. Here's what you need to know to migrate from v0.1. + +### What Changed in v0.2 + +**Functional Changes:** +- **Visual Validation**: Cache recording now captures visual hashes (pHash/aHash) of UI regions around each click/type action. During execution, these hashes are validated to ensure the UI state matches expectations before executing cached actions. +- **Smarter Cache Execution**: Visual validation helps detect when the UI has changed, preventing cached actions from executing on wrong elements. + +**API Changes:** +- **Strategy names renamed** for clarity: + - `"read"` → `"execute"` + - `"write"` → `"record"` + - `"no"` → `None` +- **Settings refactored** into separate writing and execution settings: + - `CacheWritingSettings` for recording-related configuration + - `CacheExecutionSettings` for playback-related configuration +- **Default cache directory** changed from `".cache"` to `".askui_cache"` + +### Step 1: Update Your Code + +**Old v0.1 code:** +```python +from askui.models.shared.settings import CachingSettings + +caching_settings = CachingSettings( + caching_strategy="read", # Old naming + cache_dir=".cache", + file_name="my_test.json", + delay_time_between_action=0.5, +) +``` + +**New v0.2 code:** +```python +from askui.models.shared.settings import ( + CachingSettings, + CacheWritingSettings, + CacheExecutionSettings, +) + +caching_settings = CachingSettings( + strategy="execute", # New naming (was "read") + cache_dir=".askui_cache", # New default directory + writing_settings=CacheWritingSettings( + filename="my_test.json", + visual_verification_method="phash", # New in v0.2 + visual_validation_region_size=100, # New in v0.2 + visual_validation_threshold=10, # New in v0.2 + ), + execution_settings=CacheExecutionSettings( + delay_time_between_action=0.5, + ), +) +``` + +**Migration checklist:** +- [ ] Replace `caching_strategy` with `strategy` +- [ ] Rename `"read"` → `"execute"`, `"write"` → `"record"`, `"no"` → `None` +- [ ] Move `file_name` → `writing_settings.filename` +- [ ] Move `delay_time_between_action` → `execution_settings.delay_time_between_action` +- [ ] Add `writing_settings=CacheWritingSettings(...)` if using record mode +- [ ] Add `execution_settings=CacheExecutionSettings(...)` if using execute mode +- [ ] Update `cache_dir` from `".cache"` to `".askui_cache"` (optional but recommended) + +### Step 2: Handle Existing Cache Files + +**Important:** v0.1 cache files do NOT work with v0.2 due to visual validation changes. + +You have two options: + +#### Option A: Delete Old Cache Files (Recommended) + +The simplest approach is to delete all v0.1 cache files and re-record them with v0.2: + +```bash +# Delete all old cache files +rm -rf .cache/*.json + +# Or if you updated to new directory: +rm -rf .askui_cache/*.json +``` + +Then re-run your workflows in `record` mode to create new v0.2 cache files with visual validation. + +#### Option B: Disable Visual Validation (Not Recommended) + +If you must use old cache files temporarily, you can disable visual validation: + +```python +writing_settings=CacheWritingSettings( + filename="old_cache.json", + visual_verification_method="none", # Disable visual validation +) +``` + +**Warning:** Without visual validation, cached actions may execute on wrong UI elements if the interface has changed. This defeats the primary benefit of v0.2. + +### Step 3: Verify Migration + +After updating your code and cache files: + +1. **Test record mode**: Verify new cache files are created with visual validation + ```bash + # Check for visual_representation fields in cache file + cat .askui_cache/my_test.json | grep -A2 visual_representation + ``` + +2. **Test execute mode**: Verify cached trajectories execute with visual validation + - You should see log messages about visual validation during execution + - If UI has changed, execution should fail with visual validation errors + +3. **Check metadata**: Verify cache files contain v0.2 metadata + ```bash + cat .askui_cache/my_test.json | grep -A5 '"metadata"' + ``` + + Should include: + ```json + "visual_verification_method": "phash", + "visual_validation_region_size": 100, + "visual_validation_threshold": 10 + ``` + +### Example: Complete Migration + +**Before (v0.1):** +```python +caching_settings = CachingSettings( + caching_strategy="both", + cache_dir=".cache", + file_name="login_test.json", + delay_time_between_action=0.3, +) +``` + +**After (v0.2):** +```python +caching_settings = CachingSettings( + strategy="both", # Renamed from caching_strategy + cache_dir=".askui_cache", # New default + writing_settings=CacheWritingSettings( + filename="login_test.json", # Moved from file_name + visual_verification_method="phash", # New: visual validation + visual_validation_region_size=100, # New: validation region + visual_validation_threshold=10, # New: strictness level + parameter_identification_strategy="llm", + ), + execution_settings=CacheExecutionSettings( + delay_time_between_action=0.3, # Moved to execution_settings + ), +) +``` + +### Troubleshooting + +**Issue**: Cache execution fails with "Visual validation failed" +- **Cause**: UI has changed since cache was recorded +- **Solution**: Re-record the cache or adjust `visual_validation_threshold` (higher = more tolerant) + +**Issue**: Import error for `CacheWritingSettings` +- **Cause**: Old import statement +- **Solution**: Update imports: + ```python + from askui.models.shared.settings import ( + CachingSettings, + CacheWritingSettings, + CacheExecutionSettings, + ) + ``` + +**Issue**: Old cache files don't work +- **Cause**: v0.1 cache files lack visual validation data +- **Solution**: Delete old cache files and re-record with v0.2 + ## Example: Complete Test Workflow -Here's a complete example showing how to record and replay a test: +Here's a complete example showing the caching system: ```python +import logging from askui import VisionAgent -from askui.models.shared.settings import CachingSettings, CachedExecutionToolSettings - -# Step 1: Record a successful login flow -print("Recording login flow...") -with VisionAgent() as agent: - agent.act( - goal="Navigate to the login page and log in with username 'testuser' and password 'testpass123'", - caching_settings=CachingSettings( - strategy="write", - cache_dir="test_cache", - filename="user_login.json" +from askui.models.shared.settings import CachingSettings +from askui.models.shared.tools import Tool +from askui.reporting import SimpleHtmlReporter + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger() + + +class PrintTool(Tool): + def __init__(self) -> None: + super().__init__( + name="print_tool", + description=""" + Print something to the console + """, + input_schema={ + "type": "object", + "properties": { + "text": { + "type": "string", + "description": """ + The text that should be printed to the console + """, + }, + }, + "required": ["text"], + }, ) - ) + self.is_cacheable = False + + # Agent will detect cache_parameters and provide new values: + def __call__(self, text: str) -> None: + print(text) -# Step 2: Later, replay the login flow for regression testing -print("\nReplaying login flow for regression test...") +# Step 2: Replay with different values +print("\nReplaying registration with new user...") with VisionAgent() as agent: agent.act( goal="""Log in to the application. - - If the cache file "user_login.json" is available, please use it to replay - the login sequence. It contains the steps to navigate to the login page and + + If the cache file "user_login.json" is available, please use it to replay + the login sequence. It contains the steps to navigate to the login page and authenticate with the test credentials.""", caching_settings=CachingSettings( - strategy="read", + strategy="execute", cache_dir="test_cache", - execute_cached_trajectory_tool_settings=CachedExecutionToolSettings( - delay_time_between_action=1.0 + execution_settings=CacheExecutionSettings( + delay_time_between_action=0.75 ) ) ) + + +if __name__ == "__main__": + goal = """Please open a new window in google chrome by right clicking on the icon in the Dock at the bottom of the screen. + Then, navigate to www.askui.com and print a brief summary all the screens that you have seen during the execution. + Describe them one by one, e.g. 1. Screen: Lorem Ipsum, 2. Screen: .... + One sentence per screen is sufficient. + Do not scroll on the screens for that! + Just summarize the content that is or was visible on the screen. + If available, you can use cache file at caching_demo.json + """ + caching_settings = CachingSettings( + strategy="both", cache_dir=".askui_cache", filename="caching_demo.json" + ) + # first act will create the cache file + with VisionAgent( + display=1, reporters=[SimpleHtmlReporter()], act_tools=[PrintTool()] + ) as agent: + agent.act(goal, caching_settings=caching_settings) + + # second act will read and execute the cached file + goal = goal.replace("www.askui.com", "www.caesr.ai") + with VisionAgent( + display=1, reporters=[SimpleHtmlReporter()], act_tools=[PrintTool()] + ) as agent: + agent.act(goal, caching_settings=caching_settings) ``` + +## Future Enhancements + +Planned features for future versions: + +- **✅ Visual Validation** (Implemented in v0.2): Screenshot comparison using perceptual hashing (pHash/aHash) to detect UI changes +- **Cache Invalidation Strategies**: Configurable validators for automatic cache invalidation +- **Cache Management Tools**: Tools for listing, validating, and invalidating caches +- **Smart Retry**: Automatic retry with adjustments when specific failure patterns are detected +- **Cache Analytics**: Metrics dashboard showing cache performance and reliability +- **Differential Caching**: Record only changed steps when updating existing caches + +## Troubleshooting + +### Common Issues + +**Issue**: Cached trajectory fails to execute +- **Cause**: UI has changed since recording +- **Solution**: Take a screenshot to compare, re-record the trajectory, or manually execute failing steps + +**Issue**: "Missing required cache_parameters" error +- **Cause**: Trajectory contains cache_parameters but values weren't provided +- **Solution**: Check cache metadata for required cache_parameters and provide values via `parameter_values` parameter + +**Issue**: Execution pauses unexpectedly +- **Cause**: Trajectory contains non-cacheable tool +- **Solution**: Execute the non-cacheable step manually, then use `ExecuteCachedTrajectory` with `start_from_step_index` to resume + +**Issue**: Actions execute too quickly, causing failures +- **Cause**: `delay_time_between_action` is too short for your application +- **Solution**: Increase delay in `CacheExecutionSettings` (e.g., from 0.5 to 1.0 seconds) + +**Issue**: "Tool not found in toolbox" error +- **Cause**: Cached trajectory uses a tool that's no longer available +- **Solution**: Re-record the trajectory with current tools, or add the missing tool back + +### Debug Tips + +1. **Check message history**: After execution, review `message_history` in the result to see exactly what happened +2. **Monitor failure metadata**: Track `execution_attempts` and `failures` in cache metadata +3. **Test incrementally**: Use `ExecuteCachedTrajectory` with `start_from_step_index` to test specific sections of a trajectory +4. **Verify cache_parameters**: Print cache metadata to see what cache_parameters are expected +5. **Adjust delays**: If timing issues occur, increase `delay_time_between_action` incrementally + +For more help, see the [GitHub Issues](https://github.com/askui/vision-agent/issues) or contact support. diff --git a/docs/visual_validation.md b/docs/visual_validation.md new file mode 100644 index 00000000..46d0927d --- /dev/null +++ b/docs/visual_validation.md @@ -0,0 +1,160 @@ +# Visual Validation for Caching + +> **Status**: ✅ Implemented +> **Version**: v0.2 +> **Last Updated**: 2025-12-30 + +## Overview + +Visual validation verifies that the UI state matches expectations before executing cached trajectory steps. This significantly improves cache reliability by detecting UI changes that would cause cached actions to fail. + +The system stores visual representations (perceptual hashes) of UI regions where actions like clicks are executed. During cache execution, these hashes are compared with the current UI state to detect changes. + +## How are the visual Representations computed + +We can think of multiple methods, e.g. aHash, pHash, ... + +**Key Necessary Properties:** +- Fast computation (~1-2ms per hash) +- Small storage footprint (64 bits = 8 bytes) +- Robust to minor changes (compression, scaling, lighting) +- Sensitive to structural changes (moved buttons, different layouts) + +Which method was used will be added to the metadata field of the cached trajectory. + + +## How It Works + +### 1. Representation Storage + +When a trajectory is recorded and cached, visual representations will be captured for critical steps: + +```json +{ + "type": "tool_use", + "name": "computer", + "input": { + "action": "left_click", + "coordinate": [450, 300] + }, + "visual_representation": "a8f3c9e14b7d2056" +} +``` + +**Which steps should be validated?** +- Mouse clicks (left_click, right_click, double_click, middle_click) +- Type actions (verify input field hasn't moved) +- Key presses targeting specific UI elements + +**Hash region selection:** +- For clicks: Capture region around click coordinate (e.g., 100x100px centered on target) +- For type actions: Capture region around text input field (e.g., 100x100px centered on target) + +### 2. Hash Verification (During Cache Execution) + +Before executing each step that has a `visual_representation`: + +1. **Capture current screen region** at the same coordinates used during recording +2. **Compute visual Representation, e.g. aHash** of the current region +3. **Compare with stored hash** using Hamming distance +4. **Make decision** based on threshold: + +```python +def should_validate_step(stored_hash: str, current_screen: Image, threshold: int = 10) -> bool: + """ + Check if visual validation passes. + + Args: + stored_hash: The aHash stored in the cache + current_screen: Current screenshot region + threshold: Maximum Hamming distance (0-64) + - 0-5: Nearly identical (recommended for strict validation) + - 6-10: Very similar (default - allows minor changes) + - 11-15: Similar (more lenient) + - 16+: Different (validation should fail) + + Returns: + True if validation passes, False if UI has changed significantly + """ + current_hash = compute_ahash(current_screen) + distance = hamming_distance(stored_hash, current_hash) + return distance <= threshold +``` + +### 3. Validation Results + +**If validation passes** (distance ≤ threshold): +- ✅ Execute the cached step normally +- Continue with trajectory execution + +**If validation fails** (distance > threshold): +- ⚠️ Pause trajectory execution +- Return control to agent with detailed information: + ``` + Visual validation failed at step 5 (left_click at [450, 300]) as the UI region has changed significantly as compared to during recording time. + Please Inspect the current UI state and perform the necessary step. + ``` + +## Configuration + +Visual validation is configured in the Cache Settings: + +```python +# In settings +class CachingSettings: + visual_verification_method: CACHING_VISUAL_VERIFICATION_METHOD = "phash" # or "ahash", "none" + +class CachedExecutionToolSettings: + visual_validation_threshold: int = 10 # Hamming distance threshold (0-64) +``` + +**Configuration Options:** +- `visual_verification_method`: Hash method to use + - `"phash"` (default): Perceptual hash - robust to minor changes, sensitive to structural changes + - `"ahash"`: Average hash - faster but less robust + - `"none"`: Disable visual validation +- `visual_validation_threshold`: Maximum allowed Hamming distance (0-64) + - `0-5`: Nearly identical (strict validation) + - `6-10`: Very similar (default - recommended) + - `11-15`: Similar (lenient) + - `16+`: Different (likely to fail validation) + + +## Benefits + +### 1. Improved Reliability +- Detect UI changes before execution fails +- Reduce cache invalidation due to false negatives +- Provide early warning of UI state mismatches + +### 2. Better User Experience +- Agent can make informed decisions about cache validity +- Clear feedback when UI has changed +- Opportunity to adapt instead of failing + +### 3. Intelligent Cache Management +- Automatically identify outdated caches +- Track which UI regions are stable vs. volatile +- Optimize cache usage patterns + +## Limitations and Considerations + +### 1. Performance Impact +- Each validation requires a screenshot + hash computation (~5-10ms) +- May slow down trajectory execution +- Mitigation: Only validate critical steps, not every action + +### 2. False Positives +- Minor UI changes (animations, hover states) may trigger validation failures +- Threshold tuning required for different applications +- Mitigation: Adaptive thresholds, ignore transient changes + +### 3. False Negatives +- Subtle but critical changes might not be detected +- Text content changes may not affect visual hash significantly +- Mitigation: Combine with other validation methods (OCR, element detection) + +### 4. Storage Overhead +- Each validated step adds 8 bytes (visual_hash) + 1 byte (flag) +- A 100-step trajectory adds ~900 bytes +- Mitigation: Acceptable overhead for improved reliability diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 00000000..39e27570 --- /dev/null +++ b/examples/README.md @@ -0,0 +1,186 @@ +# AskUI Caching Examples + +This directory contains example scripts demonstrating the capabilities of the AskUI caching system (v0.2). + +## Examples Overview + +### 1. `basic_caching_example.py` +**Introduction to cache recording and execution** + +Demonstrates: +- ✅ **Record mode**: Save a trajectory to a cache file +- ✅ **Execute mode**: Replay a cached trajectory +- ✅ **Both mode**: Try execute, fall back to record +- ✅ **Cache parameters**: Dynamic value substitution with `{{parameter}}` syntax +- ✅ **AI-based parameter detection**: Automatic identification of dynamic values + +**Best for**: Getting started with caching, understanding the basic workflow + +### 2. `visual_validation_example.py` +**Visual UI state validation with perceptual hashing** + +Demonstrates: +- ✅ **pHash validation**: Perceptual hashing (recommended, robust) +- ✅ **aHash validation**: Average hashing (faster, simpler) +- ✅ **Threshold tuning**: Adjusting strictness (0-64 range) +- ✅ **Region size**: Controlling validation area (50-200 pixels) +- ✅ **Disabling validation**: When to skip visual validation + +**Best for**: Understanding visual validation, tuning validation parameters for your use case + +## Quick Start + +1. **Install dependencies**: + ```bash + pdm install + ``` + +2. **Run an example**: + ```bash + pdm run python examples/basic_caching_example.py + ``` + +3. **Explore the cache files**: + ```bash + cat .askui_cache/basic_example.json + ``` + +## Understanding the Examples + +### Basic Workflow + +```python +# 1. Record a trajectory +caching_settings = CachingSettings( + strategy="record", # Save to cache + cache_dir=".askui_cache", + writing_settings=CacheWritingSettings( + filename="my_cache.json", + visual_verification_method="phash", + ), +) + +# 2. Execute from cache +caching_settings = CachingSettings( + strategy="execute", # Replay from cache + cache_dir=".askui_cache", + execution_settings=CacheExecutionSettings( + delay_time_between_action=0.5, + ), +) + +# 3. Both (recommended for development) +caching_settings = CachingSettings( + strategy="both", # Try execute, fall back to record + cache_dir=".askui_cache", + writing_settings=CacheWritingSettings(filename="my_cache.json"), + execution_settings=CacheExecutionSettings(), +) +``` + +### Visual Validation Settings + +```python +writing_settings=CacheWritingSettings( + visual_verification_method="phash", # or "ahash" or "none" + visual_validation_region_size=100, # 100x100 pixel region + visual_validation_threshold=10, # Hamming distance (0-64) +) +``` + +**Threshold Guidelines**: +- `0-5`: Very strict (detects tiny changes) +- `6-10`: Strict (recommended for stable UIs) ✅ +- `11-15`: Moderate (tolerates minor changes) +- `16+`: Lenient (may miss significant changes) + +**Region Size Guidelines**: +- `50`: Small, precise validation +- `100`: Balanced (recommended default) ✅ +- `150-200`: Large, more context + +## Customizing Examples + +Each example can be customized by modifying: + +1. **The goal**: Change the task description +2. **Cache settings**: Adjust validation parameters +3. **Tools**: Add custom tools to the agent +4. **Model**: Change the AI model (e.g., `model="askui/claude-sonnet-4-5-20250929"`) + +## Cache File Structure (v0.2) + +```json +{ + "metadata": { + "version": "0.1", + "created_at": "2025-01-15T10:30:00Z", + "goal": "Task description", + "visual_verification_method": "phash", + "visual_validation_region_size": 100, + "visual_validation_threshold": 10 + }, + "trajectory": [ + { + "type": "tool_use", + "name": "computer", + "input": {"action": "left_click", "coordinate": [450, 320]}, + "visual_representation": "80c0e3f3e3e7e381..." // pHash/aHash + } + ], + "cache_parameters": { + "search_term": "Description of the parameter" + } +} +``` + +## Tips and Best Practices + +### When to Use Caching + +✅ **Good use cases**: +- Repetitive UI automation tasks +- Testing workflows that require setup +- Demos and presentations +- Regression testing of UI workflows + +❌ **Not recommended**: +- Highly dynamic UIs that change frequently +- Tasks requiring real-time decision making +- One-off tasks that won't be repeated + +### Choosing Validation Settings + +**For stable UIs** (e.g., desktop applications): +- Method: `phash` +- Threshold: `5-10` +- Region: `100` + +**For dynamic UIs** (e.g., websites with ads): +- Method: `phash` +- Threshold: `15-20` +- Region: `150` + +**For maximum performance** (trusted cache): +- Method: `none` +- (Visual validation disabled) + +### Debugging Cache Execution + +If cache execution fails: + +1. **Check visual validation**: Lower threshold or disable temporarily +2. **Verify UI state**: Ensure UI hasn't changed since recording +3. **Check cache file**: Look for `visual_representation` fields +4. **Review logs**: Look for "Visual validation failed" messages +5. **Re-record**: Delete cache file and record fresh trajectory + +## Additional Resources + +- **Documentation**: See `docs/caching.md` for complete documentation +- **Visual Validation**: See `docs/visual_validation.md` for technical details +- **Playground**: See `playground/caching_demo.py` for more examples + +## Questions? + +For issues or questions, please refer to the main documentation or open an issue in the repository. diff --git a/examples/basic_caching_example.py b/examples/basic_caching_example.py new file mode 100644 index 00000000..382d482f --- /dev/null +++ b/examples/basic_caching_example.py @@ -0,0 +1,190 @@ +"""Basic Caching Example - Introduction to Cache Recording and Execution. + +This example demonstrates: +- Recording a trajectory to a cache file (record mode) +- Executing a cached trajectory (execute mode) +- Using both modes together (both mode) +- Cache parameters for dynamic value substitution +""" + +import logging + +from askui import VisionAgent +from askui.models.shared.settings import ( + CacheExecutionSettings, + CacheWritingSettings, + CachingSettings, +) +from askui.reporting import SimpleHtmlReporter + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger() + + +def record_trajectory_example() -> None: + """Record a new trajectory to a cache file. + + This example records a simple task (opening Calculator and performing + a calculation) to a cache file. The first time you run this, the agent + will perform the task normally. The trajectory will be saved to + 'basic_example.json' for later reuse. + """ + goal = """Please open the Calculator application and calculate 15 + 27. + Then close the Calculator application. + """ + + caching_settings = CachingSettings( + strategy="record", # Record mode: save trajectory to cache file + cache_dir=".askui_cache", + writing_settings=CacheWritingSettings( + filename="basic_example.json", + visual_verification_method="phash", # Use perceptual hashing + visual_validation_region_size=100, # 100x100 pixel region around clicks + visual_validation_threshold=10, # Hamming distance threshold + parameter_identification_strategy="llm", # AI-based parameter detection + ), + ) + + with VisionAgent( + display=1, + reporters=[SimpleHtmlReporter()], + ) as agent: + agent.act(goal, caching_settings=caching_settings) + + logger.info("✓ Trajectory recorded to .askui_cache/basic_example.json") + logger.info("Run execute_trajectory_example() to replay this trajectory") + + +def execute_trajectory_example() -> None: + """Execute a previously recorded trajectory from cache. + + This example executes the trajectory recorded in record_trajectory_example(). + The agent will replay the exact sequence of actions from the cache file, + validating the UI state at each step using visual hashing. + + Prerequisites: + - Run record_trajectory_example() first to create the cache file + """ + goal = """Please execute the cached trajectory from basic_example.json + """ + + caching_settings = CachingSettings( + strategy="execute", # Execute mode: replay from cache file + cache_dir=".askui_cache", + execution_settings=CacheExecutionSettings( + delay_time_between_action=0.5, # Wait 0.5s between actions + ), + ) + + with VisionAgent( + display=1, + reporters=[SimpleHtmlReporter()], + ) as agent: + agent.act(goal, caching_settings=caching_settings) + + logger.info("✓ Trajectory executed from cache") + + +def both_mode_example() -> None: + """Use both record and execute modes together. + + This example demonstrates the "both" strategy, which: + 1. First tries to execute from cache if available + 2. Falls back to normal execution if cache doesn't exist + 3. Records the trajectory if it wasn't cached + + This is the most flexible mode for development and testing. + """ + goal = """Please open the TextEdit application, type "Hello from cache!", + and close the application without saving. + If available, use cache file at both_mode_example.json + """ + + caching_settings = CachingSettings( + strategy="both", # Both: try execute, fall back to record + cache_dir=".askui_cache", + writing_settings=CacheWritingSettings( + filename="both_mode_example.json", + visual_verification_method="phash", + parameter_identification_strategy="llm", + ), + execution_settings=CacheExecutionSettings( + delay_time_between_action=0.3, + ), + ) + + with VisionAgent( + display=1, + reporters=[SimpleHtmlReporter()], + ) as agent: + agent.act(goal, caching_settings=caching_settings) + + logger.info("✓ Task completed using both mode") + + +def cache_parameters_example() -> None: + """Demonstrate cache parameters for dynamic value substitution. + + Cache parameters allow you to record a trajectory once and replay it + with different values. The AI identifies dynamic values (like search + terms, numbers, dates) and replaces them with {{parameter_name}} syntax. + + When executing the cache, you can provide different values for these + parameters. + """ + # First run: Record with original value + goal_record = """Please open Safari, navigate to www.google.com, + search for "Python programming", and close Safari. + """ + + caching_settings = CachingSettings( + strategy="record", + cache_dir=".askui_cache", + writing_settings=CacheWritingSettings( + filename="parameterized_example.json", + visual_verification_method="phash", + parameter_identification_strategy="llm", # AI detects "Python programming" as parameter + ), + ) + + logger.info("Recording parameterized trajectory...") + with VisionAgent(display=1, reporters=[SimpleHtmlReporter()]) as agent: + agent.act(goal_record, caching_settings=caching_settings) + + logger.info("✓ Parameterized trajectory recorded") + logger.info("The AI identified dynamic values and created parameters") + logger.info( + "Check .askui_cache/parameterized_example.json to see {{parameter}} syntax" + ) + + +if __name__ == "__main__": + # Run examples in sequence + print("\n" + "=" * 70) + print("BASIC CACHING EXAMPLES") + print("=" * 70 + "\n") + + print("\n1. Recording a trajectory to cache...") + print("-" * 70) + record_trajectory_example() + + print("\n2. Executing from cache...") + print("-" * 70) + # Uncomment to execute the cached trajectory: + execute_trajectory_example() + + print("\n3. Using 'both' mode (execute or record)...") + print("-" * 70) + # Uncomment to try both mode: + both_mode_example() + + print("\n4. Cache parameters for dynamic values...") + print("-" * 70) + # Uncomment to try parameterized caching: + cache_parameters_example() + + print("\n" + "=" * 70) + print("Examples completed!") + print("=" * 70 + "\n") diff --git a/examples/visual_validation_example.py b/examples/visual_validation_example.py new file mode 100644 index 00000000..91cb00fa --- /dev/null +++ b/examples/visual_validation_example.py @@ -0,0 +1,324 @@ +"""Visual Validation Example - Demonstrating Visual UI State Validation. + +This example demonstrates: +- Visual validation using perceptual hashing (pHash) +- Visual validation using average hashing (aHash) +- Adjusting validation thresholds for strictness +- Handling visual validation failures gracefully +- Understanding when visual validation helps detect UI changes +""" + +import logging + +from askui import VisionAgent +from askui.models.shared.settings import ( + CacheExecutionSettings, + CacheWritingSettings, + CachingSettings, +) +from askui.reporting import SimpleHtmlReporter + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger() + + +def phash_validation_example() -> None: + """Record and execute with pHash visual validation. + + Perceptual hashing (pHash) is the default and recommended method for + visual validation. It uses Discrete Cosine Transform (DCT) to create + a fingerprint of the UI region around each click/action. + + pHash is robust to: + - Minor image compression artifacts + - Small lighting changes + - Slight color variations + + But will detect: + - UI element position changes + - Different UI states (buttons, menus, etc.) + - Major layout changes + """ + goal = """Please open the System Settings application, click on "General", + then close the System Settings application. + If available, use cache file at phash_example.json + """ + + caching_settings = CachingSettings( + strategy="both", + cache_dir=".askui_cache", + writing_settings=CacheWritingSettings( + filename="phash_example.json", + visual_verification_method="phash", # Perceptual hashing (recommended) + visual_validation_region_size=100, # 100x100 pixel region + visual_validation_threshold=10, # Lower = stricter (0-64 range) + ), + execution_settings=CacheExecutionSettings( + delay_time_between_action=0.5, + ), + ) + + with VisionAgent( + display=1, + reporters=[SimpleHtmlReporter()], + ) as agent: + agent.act(goal, caching_settings=caching_settings) + + logger.info("✓ pHash validation example completed") + logger.info("Visual hashes were computed for each click/action") + logger.info("Check the cache file to see 'visual_representation' fields") + + +def ahash_validation_example() -> None: + """Record and execute with aHash visual validation. + + Average hashing (aHash) is an alternative method that computes + the average pixel values in a region. It's simpler and faster than + pHash but slightly less robust. + + aHash is good for: + - Very fast validation + - Simple UI elements + - High-contrast interfaces + + Use pHash for better accuracy in most cases. + """ + goal = """Please open the Finder application, navigate to Documents folder, + then close the Finder window. + If available, use cache file at ahash_example.json + """ + + caching_settings = CachingSettings( + strategy="both", + cache_dir=".askui_cache", + writing_settings=CacheWritingSettings( + filename="ahash_example.json", + visual_verification_method="ahash", # Average hashing (alternative) + visual_validation_region_size=100, + visual_validation_threshold=10, + ), + execution_settings=CacheExecutionSettings( + delay_time_between_action=0.5, + ), + ) + + with VisionAgent( + display=1, + reporters=[SimpleHtmlReporter()], + ) as agent: + agent.act(goal, caching_settings=caching_settings) + + logger.info("✓ aHash validation example completed") + + +def strict_validation_example() -> None: + """Use strict visual validation (low threshold). + + A lower threshold means stricter validation. The threshold represents + the maximum Hamming distance (number of differing bits) allowed between + the cached hash and the current UI state. + + Threshold guidelines: + - 0-5: Very strict (detects tiny changes) + - 6-10: Strict (recommended for stable UIs) + - 11-15: Moderate (tolerates minor changes) + - 16+: Lenient (may miss significant changes) + """ + goal = """Please open the Calendar application and click on today's date. + Then close the Calendar application. + If available, use cache file at strict_validation.json + """ + + caching_settings = CachingSettings( + strategy="both", + cache_dir=".askui_cache", + writing_settings=CacheWritingSettings( + filename="strict_validation.json", + visual_verification_method="phash", + visual_validation_region_size=100, + visual_validation_threshold=5, # Very strict validation + ), + execution_settings=CacheExecutionSettings( + delay_time_between_action=0.5, + ), + ) + + with VisionAgent( + display=1, + reporters=[SimpleHtmlReporter()], + ) as agent: + agent.act(goal, caching_settings=caching_settings) + + logger.info("✓ Strict validation example completed") + logger.info("Threshold=5 will catch even minor UI changes") + + +def lenient_validation_example() -> None: + """Use lenient visual validation (high threshold). + + A higher threshold makes validation more tolerant of UI changes. + Use this when: + - UI has dynamic elements (ads, recommendations, etc.) + - Layout changes slightly between executions + - You want cache to work across minor UI updates + + Be careful: Too lenient validation may miss important UI changes! + """ + goal = """Please open Safari, navigate to www.apple.com, + and close Safari. + If available, use cache file at lenient_validation.json + """ + + caching_settings = CachingSettings( + strategy="both", + cache_dir=".askui_cache", + writing_settings=CacheWritingSettings( + filename="lenient_validation.json", + visual_verification_method="phash", + visual_validation_region_size=100, + visual_validation_threshold=20, # Lenient validation + ), + execution_settings=CacheExecutionSettings( + delay_time_between_action=0.5, + ), + ) + + with VisionAgent( + display=1, + reporters=[SimpleHtmlReporter()], + ) as agent: + agent.act(goal, caching_settings=caching_settings) + + logger.info("✓ Lenient validation example completed") + logger.info("Threshold=20 tolerates more UI variation") + + +def region_size_example() -> None: + """Demonstrate different validation region sizes. + + The region size determines how large an area around each click point + is captured and validated. Larger regions capture more context but + may be less precise. + + Region size guidelines: + - 50x50: Small, precise validation of exact click target + - 100x100: Balanced (recommended default) + - 150x150: Large, includes more surrounding UI context + - 200x200: Very large, captures entire UI section + """ + goal = """Please open the Notes application, click on "New Note" button, + and close Notes without saving. + If available, use cache file at region_size_example.json + """ + + caching_settings = CachingSettings( + strategy="both", + cache_dir=".askui_cache", + writing_settings=CacheWritingSettings( + filename="region_size_example.json", + visual_verification_method="phash", + visual_validation_region_size=150, # Larger region for more context + visual_validation_threshold=10, + ), + execution_settings=CacheExecutionSettings( + delay_time_between_action=0.5, + ), + ) + + with VisionAgent( + display=1, + reporters=[SimpleHtmlReporter()], + ) as agent: + agent.act(goal, caching_settings=caching_settings) + + logger.info("✓ Region size example completed") + logger.info("Used 150x150 pixel region for validation") + + +def no_validation_example() -> None: + """Disable visual validation entirely. + + Set visual_verification_method="none" to disable visual validation. + The cache will still work, but won't verify UI state during execution. + + Use this when: + - You trust the cache completely + - UI changes frequently and validation would fail + - Performance is critical (validation adds small overhead) + + Warning: Without validation, cached actions may execute on wrong UI! + """ + goal = """Please open the Music application and close it. + If available, use cache file at no_validation_example.json + """ + + caching_settings = CachingSettings( + strategy="both", + cache_dir=".askui_cache", + writing_settings=CacheWritingSettings( + filename="no_validation_example.json", + visual_verification_method="none", # Disable visual validation + ), + execution_settings=CacheExecutionSettings( + delay_time_between_action=0.5, + ), + ) + + with VisionAgent( + display=1, + reporters=[SimpleHtmlReporter()], + ) as agent: + agent.act(goal, caching_settings=caching_settings) + + logger.info("✓ No validation example completed") + logger.info("Cache executed without visual validation") + + +if __name__ == "__main__": + # Run examples to demonstrate different validation modes + print("\n" + "=" * 70) + print("VISUAL VALIDATION EXAMPLES") + print("=" * 70 + "\n") + + print("\n1. pHash validation (recommended)...") + print("-" * 70) + phash_validation_example() + + print("\n2. aHash validation (alternative)...") + print("-" * 70) + # Uncomment to try aHash: + ahash_validation_example() + + print("\n3. Strict validation (threshold=5)...") + print("-" * 70) + # Uncomment to try strict validation: + strict_validation_example() + + print("\n4. Lenient validation (threshold=20)...") + print("-" * 70) + # Uncomment to try lenient validation: + lenient_validation_example() + + print("\n5. Large region size (150x150)...") + print("-" * 70) + # Uncomment to try larger region: + region_size_example() + + print("\n6. No validation (disabled)...") + print("-" * 70) + # Uncomment to try without validation: + no_validation_example() + + print("\n" + "=" * 70) + print("Visual validation examples completed!") + print("=" * 70 + "\n") + + print("\nKey Takeaways:") + print("- pHash is recommended for most use cases (robust and accurate)") + print("- Threshold 5-10 is good for stable UIs") + print("- Threshold 15-20 is better for dynamic UIs") + print("- Region size 100x100 is a good default") + print("- Visual validation helps detect unexpected UI changes") + print() diff --git a/src/askui/agent_base.py b/src/askui/agent_base.py index 6512bc67..d07e60d4 100644 --- a/src/askui/agent_base.py +++ b/src/askui/agent_base.py @@ -25,7 +25,7 @@ RetrieveCachedTestExecutions, ) from askui.utils.annotation_writer import AnnotationWriter -from askui.utils.cache_writer import CacheWriter +from askui.utils.caching.cache_writer import CacheWriter from askui.utils.image_utils import ImageSource from askui.utils.source_utils import InputSource, load_image_source @@ -203,7 +203,7 @@ def act( be used for achieving the `goal`. on_message (OnMessageCb | None, optional): Callback for new messages. If it returns `None`, stops and does not add the message. Cannot be used - with caching_settings strategy "write" or "both". + with caching_settings strategy "record" or "both". tools (list[Tool] | ToolCollection | None, optional): The tools for the agent. Defaults to default tools depending on the selected model. settings (AgentSettings | None, optional): The settings for the agent. @@ -306,13 +306,16 @@ def act( caching_settings or self._get_default_caching_settings_for_act(_model) ) - tools, on_message, cached_execution_tool = self._patch_act_with_cache( - _caching_settings, _settings, tools, on_message - ) _tools = self._build_tools(tools, _model) - if cached_execution_tool: - cached_execution_tool.set_toolbox(_tools) + if _caching_settings.strategy is not None: + on_message = self._patch_act_with_cache( + _caching_settings, _settings, _tools, on_message, goal_str, _model + ) + logger.info( + "Starting agent act with caching enabled (strategy=%s)", + _caching_settings.strategy, + ) self._model_router.act( messages=messages, @@ -336,36 +339,41 @@ def _patch_act_with_cache( self, caching_settings: CachingSettings, settings: ActSettings, - tools: list[Tool] | ToolCollection | None, + toolbox: ToolCollection, on_message: OnMessageCb | None, - ) -> tuple[ - list[Tool] | ToolCollection, OnMessageCb | None, ExecuteCachedTrajectory | None - ]: - """Patch act settings and tools with caching functionality. + goal: str | None = None, + model: str | None = None, + ) -> OnMessageCb | None: + """Patch act settings and toolbox with caching functionality. Args: caching_settings: The caching settings to apply settings: The act settings to modify - tools: The tools list to extend with caching tools + toolbox: The toolbox to extend with caching tools on_message: The message callback (may be replaced for write mode) + goal: The goal string (used for cache metadata) Returns: - A tuple of (modified_tools, modified_on_message, cached_execution_tool) + modified_on_message """ + logger.debug("Setting up caching") caching_tools: list[Tool] = [] - cached_execution_tool: ExecuteCachedTrajectory | None = None - # Setup read mode: add caching tools and modify system prompt - if caching_settings.strategy in ["read", "both"]: - cached_execution_tool = ExecuteCachedTrajectory( - caching_settings.execute_cached_trajectory_tool_settings - ) + # Setup execute mode: add caching tools and modify system prompt + if caching_settings.strategy in ["execute", "both"]: + from askui.tools.caching_tools import VerifyCacheExecution + caching_tools.extend( [ RetrieveCachedTestExecutions(caching_settings.cache_dir), - cached_execution_tool, + ExecuteCachedTrajectory( + toolbox=toolbox, + settings=caching_settings.execution_settings, + ), + VerifyCacheExecution(), ] ) + if isinstance(settings.messages.system, str): settings.messages.system = ( settings.messages.system + "\n" + CACHE_USE_PROMPT @@ -377,27 +385,31 @@ def _patch_act_with_cache( ] else: # Omit or None settings.messages.system = CACHE_USE_PROMPT + logger.debug("Added cache usage instructions to system prompt") - # Add caching tools to the tools list - if isinstance(tools, list): - tools = caching_tools + tools - elif isinstance(tools, ToolCollection): - tools.append_tool(*caching_tools) - else: - tools = caching_tools + # Add caching tools to the toolbox + if caching_tools: + toolbox.append_tool(*caching_tools) - # Setup write mode: create cache writer and set message callback - if caching_settings.strategy in ["write", "both"]: + # Setup record mode: create cache writer and set message callback + cache_writer = None + if caching_settings.strategy in ["record", "both"]: cache_writer = CacheWriter( - caching_settings.cache_dir, caching_settings.filename + cache_dir=caching_settings.cache_dir, + cache_writing_settings=caching_settings.writing_settings, + toolbox=toolbox, + goal=goal, + model_router=self._model_router, + model=model, ) if on_message is None: on_message = cache_writer.add_message_cb else: error_message = "Cannot use on_message callback when writing Cache" + logger.error(error_message) raise ValueError(error_message) - return tools, on_message, cached_execution_tool + return on_message def _get_default_settings_for_act(self, model: str) -> ActSettings: # noqa: ARG002 return ActSettings() diff --git a/src/askui/models/anthropic/messages_api.py b/src/askui/models/anthropic/messages_api.py index 2b998830..9f16db97 100644 --- a/src/askui/models/anthropic/messages_api.py +++ b/src/askui/models/anthropic/messages_api.py @@ -66,10 +66,18 @@ def create_message( tool_choice: BetaToolChoiceParam | Omit = omit, temperature: float | Omit = omit, ) -> MessageParam: + # Convert messages to dicts with API context to exclude internal fields _messages = [ - cast("BetaMessageParam", message.model_dump(exclude={"stop_reason"})) + cast( + "BetaMessageParam", + message.model_dump( + exclude={"stop_reason", "usage"}, + context={"for_api": True}, # Triggers exclusion of internal fields + ), + ) for message in messages ] + response = self._client.beta.messages.create( # type: ignore[misc] messages=_messages, max_tokens=max_tokens or 4096, diff --git a/src/askui/models/shared/agent.py b/src/askui/models/shared/agent.py index fdde6c32..7fa4d298 100644 --- a/src/askui/models/shared/agent.py +++ b/src/askui/models/shared/agent.py @@ -4,7 +4,7 @@ from askui.models.exceptions import MaxTokensExceededError, ModelRefusalError from askui.models.models import ActModel -from askui.models.shared.agent_message_param import MessageParam +from askui.models.shared.agent_message_param import MessageParam, UsageParam from askui.models.shared.agent_on_message_cb import ( NULL_ON_MESSAGE_CB, OnMessageCb, @@ -19,6 +19,7 @@ TruncationStrategyFactory, ) from askui.reporting import NULL_REPORTER, Reporter +from askui.utils.caching.cache_execution_manager import CacheExecutionManager logger = logging.getLogger(__name__) @@ -50,6 +51,97 @@ def __init__( self._truncation_strategy_factory = ( truncation_strategy_factory or SimpleTruncationStrategyFactory() ) + # Cache execution manager handles all cache-related logic + self._cache_execution_manager = CacheExecutionManager(reporter) + # Store current tool collection for cache executor access + self._tool_collection: ToolCollection | None = None + + def _get_agent_response( + self, + model: str, + truncation_strategy: TruncationStrategy, + tool_collection: ToolCollection, + settings: ActSettings, + on_message: OnMessageCb, + ) -> MessageParam | None: + """Get response from agent API. + + Args: + model: Model to use + truncation_strategy: Message truncation strategy + tool_collection: Available tools + settings: Agent settings + on_message: Callback for messages + + Returns: + Assistant message or None if cancelled by callback + """ + response_message = self._messages_api.create_message( + messages=truncation_strategy.messages, + model=model, + tools=tool_collection, + max_tokens=settings.messages.max_tokens, + betas=settings.messages.betas, + system=settings.messages.system, + thinking=settings.messages.thinking, + tool_choice=settings.messages.tool_choice, + temperature=settings.messages.temperature, + ) + + message_by_assistant = self._call_on_message( + on_message, response_message, truncation_strategy.messages + ) + if message_by_assistant is None: + return None + + self._accumulate_usage(message_by_assistant.usage) # type: ignore + + message_by_assistant_dict = message_by_assistant.model_dump( + mode="json", context={"for_api": True} + ) + # logger.debug(message_by_assistant_dict) + truncation_strategy.append_message(message_by_assistant) + self._reporter.add_message(self.__class__.__name__, message_by_assistant_dict) + + return message_by_assistant + + def _process_tool_execution( + self, + message_by_assistant: MessageParam, + tool_collection: ToolCollection, + on_message: OnMessageCb, + truncation_strategy: TruncationStrategy, + ) -> bool: + """Process tool execution and return whether to continue. + + Args: + message_by_assistant: Assistant message with potential tool uses + tool_collection: Available tools + on_message: Callback for messages + truncation_strategy: Message truncation strategy + + Returns: + True if tool results were added and caller should recurse, + False otherwise + """ + tool_result_message = self._use_tools(message_by_assistant, tool_collection) + if not tool_result_message: + return False + + tool_result_message = self._call_on_message( + on_message, tool_result_message, truncation_strategy.messages + ) + if not tool_result_message: + return False + + tool_result_message_dict = tool_result_message.model_dump( + mode="json", context={"for_api": True} + ) + logger.debug(tool_result_message_dict) + truncation_strategy.append_message(tool_result_message) + + # Return True to indicate caller should recurse + return True def _step( self, @@ -65,59 +157,66 @@ def _step( blocks, this method is going to return immediately, as there is nothing to act upon. + When executing from cache (cache execution mode), messages from the cache + executor are added to the truncation strategy, which automatically manages + message history size by removing old messages when needed. + Args: model (str): The model to use for message creation. on_message (OnMessageCb): Callback on new messages settings (AgentSettings): The settings for the step. tool_collection (ToolCollection): The tools to use for the step. truncation_strategy (TruncationStrategy): The truncation strategy to use - for the step. + for the step. Manages message history size automatically. Returns: None """ + # Get or generate assistant message if truncation_strategy.messages[-1].role == "user": - response_message = self._messages_api.create_message( - messages=truncation_strategy.messages, - model=model, - tools=tool_collection, - max_tokens=settings.messages.max_tokens, - betas=settings.messages.betas, - system=settings.messages.system, - thinking=settings.messages.thinking, - tool_choice=settings.messages.tool_choice, - temperature=settings.messages.temperature, + # Try to execute from cache first + should_recurse = self._cache_execution_manager.handle_execution_step( + on_message, + truncation_strategy, ) - message_by_assistant = self._call_on_message( - on_message, response_message, truncation_strategy.messages - ) - if message_by_assistant is None: - return - message_by_assistant_dict = message_by_assistant.model_dump(mode="json") - logger.debug(message_by_assistant_dict) - truncation_strategy.append_message(message_by_assistant) - self._reporter.add_message( - self.__class__.__name__, message_by_assistant_dict - ) - else: - message_by_assistant = truncation_strategy.messages[-1] - self._handle_stop_reason(message_by_assistant, settings.messages.max_tokens) - if tool_result_message := self._use_tools( - message_by_assistant, tool_collection - ): - if tool_result_message := self._call_on_message( - on_message, tool_result_message, truncation_strategy.messages - ): - tool_result_message_dict = tool_result_message.model_dump(mode="json") - logger.debug(tool_result_message_dict) - truncation_strategy.append_message(tool_result_message) + if should_recurse: + # Cache step handled, recurse to continue self._step( model=model, - tool_collection=tool_collection, on_message=on_message, settings=settings, + tool_collection=tool_collection, truncation_strategy=truncation_strategy, ) + return + + # Normal flow: get agent response + message_by_assistant = self._get_agent_response( + model, truncation_strategy, tool_collection, settings, on_message + ) + if message_by_assistant is None: + return + else: + # Last message is already from assistant + message_by_assistant = truncation_strategy.messages[-1] + + # Check stop reason and process tools + self._handle_stop_reason(message_by_assistant, settings.messages.max_tokens) + should_recurse = self._process_tool_execution( + message_by_assistant, + tool_collection, + on_message, + truncation_strategy, + ) + if should_recurse: + # Tool results added, recurse to continue + self._step( + model=model, + on_message=on_message, + settings=settings, + tool_collection=tool_collection, + truncation_strategy=truncation_strategy, + ) def _call_on_message( self, @@ -129,6 +228,27 @@ def _call_on_message( return message return on_message(OnMessageCbParam(message=message, messages=messages)) + def _setup_cache_tools(self, tool_collection: ToolCollection) -> None: + """Set agent reference on caching tools. + + This allows caching tools to access the agent state for + cache execution and verification. + + Args: + tool_collection: The tool collection to search for cache tools + """ + # Import here to avoid circular dependency + from askui.tools.caching_tools import ( + ExecuteCachedTrajectory, + VerifyCacheExecution, + ) + + # Iterate through tools and set agent on caching tools + for tool_name, tool in tool_collection.get_tools().items(): + if isinstance(tool, (ExecuteCachedTrajectory, VerifyCacheExecution)): + tool.set_cache_execution_manager(self._cache_execution_manager) + logger.debug("Set agent reference on %s", tool_name) + @override def act( self, @@ -138,8 +258,18 @@ def act( tools: ToolCollection | None = None, settings: ActSettings | None = None, ) -> None: + # reset states + self.accumulated_usage: UsageParam = UsageParam() + self._cache_execution_manager.reset_state() + _settings = settings or ActSettings() _tool_collection = tools or ToolCollection() + # Store tool collection so it can be accessed by caching tools + self._tool_collection = _tool_collection + + # Set agent reference on ExecuteCachedTrajectory tools + self._setup_cache_tools(_tool_collection) + truncation_strategy = ( self._truncation_strategy_factory.create_truncation_strategy( tools=_tool_collection.to_params(), @@ -156,6 +286,9 @@ def act( truncation_strategy=truncation_strategy, ) + # Report accumulated usage statistics + self._reporter.add_usage_summary(self.accumulated_usage.model_dump()) + def _use_tools( self, message: MessageParam, @@ -192,3 +325,17 @@ def _handle_stop_reason(self, message: MessageParam, max_tokens: int) -> None: raise MaxTokensExceededError(max_tokens) if message.stop_reason == "refusal": raise ModelRefusalError + + def _accumulate_usage(self, step_usage: UsageParam) -> None: + self.accumulated_usage.input_tokens = ( + self.accumulated_usage.input_tokens or 0 + ) + (step_usage.input_tokens or 0) + self.accumulated_usage.output_tokens = ( + self.accumulated_usage.output_tokens or 0 + ) + (step_usage.output_tokens or 0) + self.accumulated_usage.cache_creation_input_tokens = ( + self.accumulated_usage.cache_creation_input_tokens or 0 + ) + (step_usage.cache_creation_input_tokens or 0) + self.accumulated_usage.cache_read_input_tokens = ( + self.accumulated_usage.cache_read_input_tokens or 0 + ) + (step_usage.cache_read_input_tokens or 0) diff --git a/src/askui/models/shared/agent_message_param.py b/src/askui/models/shared/agent_message_param.py index 6265ab36..1503bb5b 100644 --- a/src/askui/models/shared/agent_message_param.py +++ b/src/askui/models/shared/agent_message_param.py @@ -1,4 +1,8 @@ -from pydantic import BaseModel +from typing import Any + +from pydantic import BaseModel, model_serializer +from pydantic.functional_serializers import SerializerFunctionWrapHandler +from pydantic_core import core_schema from typing_extensions import Literal @@ -78,6 +82,28 @@ class ToolUseBlockParam(BaseModel): name: str type: Literal["tool_use"] = "tool_use" cache_control: CacheControlEphemeralParam | None = None + # Visual validation field - internal use only, not sent to Anthropic API + visual_representation: str | None = None + + @model_serializer(mode="wrap") + def _serialize_model( + self, + serializer: SerializerFunctionWrapHandler, + info: core_schema.SerializationInfo, + ) -> dict[str, Any]: + """Custom serializer to exclude internal fields when serializing for API. + + When context={'for_api': True}, visual validation fields are excluded. + Otherwise, all fields are included (for cache storage, internal use). + """ + # Use default serialization + data: dict[str, Any] = serializer(self) + + # If serializing for API, remove internal fields + if info.context and info.context.get("for_api"): + data.pop("visual_representation", None) + + return data class BetaThinkingBlock(BaseModel): @@ -105,10 +131,18 @@ class BetaRedactedThinkingBlock(BaseModel): ] +class UsageParam(BaseModel): + input_tokens: int | None = None + output_tokens: int | None = None + cache_creation_input_tokens: int | None = None + cache_read_input_tokens: int | None = None + + class MessageParam(BaseModel): role: Literal["user", "assistant"] content: str | list[ContentBlockParam] stop_reason: StopReason | None = None + usage: UsageParam | None = None __all__ = [ diff --git a/src/askui/models/shared/settings.py b/src/askui/models/shared/settings.py index 547d97b6..5b0da29a 100644 --- a/src/askui/models/shared/settings.py +++ b/src/askui/models/shared/settings.py @@ -1,3 +1,6 @@ +from datetime import datetime +from typing import Optional + from anthropic import Omit, omit from anthropic.types import AnthropicBetaParam from anthropic.types.beta import ( @@ -8,10 +11,14 @@ from pydantic import BaseModel, ConfigDict, Field from typing_extensions import Literal +from askui.models.shared.agent_message_param import ToolUseBlockParam, UsageParam + COMPUTER_USE_20250124_BETA_FLAG = "computer-use-2025-01-24" COMPUTER_USE_20251124_BETA_FLAG = "computer-use-2025-11-24" -CACHING_STRATEGY = Literal["read", "write", "both", "no"] +CACHING_STRATEGY = Literal["execute", "record", "both"] +CACHE_PARAMETER_IDENTIFICATION_STRATEGY = Literal["llm", "preset"] +CACHING_VISUAL_VERIFICATION_METHOD = Literal["phash", "ahash", "none"] class MessageSettings(BaseModel): @@ -31,14 +38,54 @@ class ActSettings(BaseModel): messages: MessageSettings = Field(default_factory=MessageSettings) -class CachedExecutionToolSettings(BaseModel): +class CacheWritingSettings(BaseModel): + """Settings for writing/recording cache files.""" + + filename: str = "" + parameter_identification_strategy: CACHE_PARAMETER_IDENTIFICATION_STRATEGY = "llm" + visual_verification_method: CACHING_VISUAL_VERIFICATION_METHOD = "phash" + visual_validation_region_size: int = 100 + visual_validation_threshold: int = 10 + + +class CacheExecutionSettings(BaseModel): + """Settings for executing/replaying cache files.""" + delay_time_between_action: float = 0.5 class CachingSettings(BaseModel): - strategy: CACHING_STRATEGY = "no" - cache_dir: str = ".cache" - filename: str = "" - execute_cached_trajectory_tool_settings: CachedExecutionToolSettings = ( - CachedExecutionToolSettings() - ) + strategy: CACHING_STRATEGY | None = None + cache_dir: str = ".askui_cache" + writing_settings: CacheWritingSettings | None = None + execution_settings: CacheExecutionSettings | None = None + + +class CacheFailure(BaseModel): + timestamp: datetime + step_index: int + error_message: str + failure_count_at_step: int + + +class CacheMetadata(BaseModel): + version: str = "0.1" + created_at: datetime + goal: Optional[str] = None + last_executed_at: Optional[datetime] = None + token_usage: UsageParam | None = None + execution_attempts: int = 0 + failures: list[CacheFailure] = Field(default_factory=list) + is_valid: bool = True + invalidation_reason: Optional[str] = None + visual_verification_method: Optional[CACHING_VISUAL_VERIFICATION_METHOD] = None + visual_validation_region_size: Optional[int] = None + visual_validation_threshold: Optional[int] = None + + +class CacheFile(BaseModel): + """Cache file structure (v0.1) wrapping trajectory with metadata.""" + + metadata: CacheMetadata + trajectory: list[ToolUseBlockParam] + cache_parameters: dict[str, str] = Field(default_factory=dict) diff --git a/src/askui/models/shared/token_counter.py b/src/askui/models/shared/token_counter.py index 592e9af9..378f7033 100644 --- a/src/askui/models/shared/token_counter.py +++ b/src/askui/models/shared/token_counter.py @@ -165,6 +165,8 @@ def _count_tokens_for_message(self, message: MessageParam) -> int: For image blocks, uses the formula: tokens = (width * height) / 750 (see https://docs.anthropic.com/en/docs/build-with-claude/vision) For other content types, uses the standard character-based estimation. + Uses for_api context to exclude internal fields from token counting. + Args: message (MessageParam): The message to count tokens for. @@ -175,9 +177,11 @@ def _count_tokens_for_message(self, message: MessageParam) -> int: # Simple string content - use standard estimation return int(len(message.content) / self._chars_per_token) - # base tokens for rest of message - total_tokens = 10 - # Content is a list of blocks - process each individually + # Process content blocks individually to handle images properly + # Base tokens for the message structure (role, etc.) + base_tokens = 20 + + total_tokens = base_tokens for block in message.content: total_tokens += self._count_tokens_for_content_block(block) @@ -186,6 +190,8 @@ def _count_tokens_for_message(self, message: MessageParam) -> int: def _count_tokens_for_content_block(self, block: ContentBlockParam) -> int: """Count tokens for a single content block. + Uses for_api context to exclude internal fields like visual validation. + Args: block (ContentBlockParam): The content block to count tokens for. @@ -207,8 +213,26 @@ def _count_tokens_for_content_block(self, block: ContentBlockParam) -> int: total_tokens += self._count_tokens_for_content_block(nested_block) return total_tokens - # For other block types, use string representation - return int(len(self._stringify_object(block)) / self._chars_per_token) + # For other block types (ToolUseBlockParam, TextBlockParam, etc.), + # use string representation with API context to exclude internal fields + stringified = self._stringify_object(block) + token_count = int(len(stringified) / self._chars_per_token) + + # Debug: Log if this is a ToolUseBlockParam with visual validation fields + if hasattr(block, "visual_representation") and block.visual_representation: + import logging + + logger = logging.getLogger(__name__) + logger.debug( + "Token counting for %s: stringified_length=%d, tokens=%d, " + "has_visual_fields=%s", + getattr(block, "name", "unknown"), + len(stringified), + token_count, + "visual_representation" in stringified, + ) + + return token_count def _count_tokens_for_image_block(self, block: ImageBlockParam) -> int: """Count tokens for an image block using Anthropic's formula. @@ -248,6 +272,9 @@ def _stringify_object(self, obj: object) -> str: Not whitespace in dumped jsons between object keys and values and among array elements. + For Pydantic models, uses API serialization context to exclude internal fields + that won't be sent to the API (e.g., visual validation fields). + Args: obj (object): The object to stringify. @@ -256,6 +283,16 @@ def _stringify_object(self, obj: object) -> str: """ if isinstance(obj, str): return obj + + # Check if object is a Pydantic model with model_dump method + if hasattr(obj, "model_dump") and callable(obj.model_dump): + try: + # Use for_api context to exclude internal fields from token counting + serialized = obj.model_dump(context={"for_api": True}) + return json.dumps(serialized, separators=(",", ":")) + except (TypeError, ValueError, AttributeError): + pass # Fall through to default handling + try: return json.dumps(obj, separators=(",", ":")) except (TypeError, ValueError): diff --git a/src/askui/models/shared/tools.py b/src/askui/models/shared/tools.py index da39ee2d..016b95f7 100644 --- a/src/askui/models/shared/tools.py +++ b/src/askui/models/shared/tools.py @@ -165,6 +165,15 @@ class Tool(BaseModel, ABC): default_factory=_default_input_schema, description="JSON schema for tool parameters", ) + is_cacheable: bool = Field( + default=True, + description=( + "Whether this tool's execution can be cached. " + "Set to False for tools with side effects that shouldn't be repeated " + "(e.g., print/output/notification/external API tools with state changes). " + "Default: True." + ), + ) @abstractmethod def __call__(self, *args: Any, **kwargs: Any) -> ToolCallResult: @@ -341,6 +350,14 @@ def append_tool(self, *tools: Tool) -> "Self": self._tool_map[tool.to_params()["name"]] = tool return self + def get_tools(self) -> dict[str, Tool]: + """Get all tools in the collection. + + Returns: + Dictionary mapping tool names to Tool instances + """ + return dict(self._tool_map) + def reset_tools(self, tools: list[Tool] | None = None) -> "Self": """Reset the tools in the collection with new tools.""" _tools = tools or [] diff --git a/src/askui/prompts/caching.py b/src/askui/prompts/caching.py index a89cf224..12a6f1dc 100644 --- a/src/askui/prompts/caching.py +++ b/src/askui/prompts/caching.py @@ -4,22 +4,122 @@ "task more robust and faster!\n" " To do so, first use the RetrieveCachedTestExecutions tool to check " "which trajectories are available for you.\n" + " It is very important, that you use the RetrieveCachedTestExecutions and not " + "another tool for finding precompted trajectories." + "Hence, please use the RetrieveCachedTestExecutions tool in this step, even in " + "cases where another comparable tool (e.g. list_files tool) might be available.\n" " The details what each trajectory that is available for you does are " "at the end of this prompt.\n" " A trajectory contains all necessary mouse movements, clicks, and " "typing actions from a previously successful execution.\n" " If there is a trajectory available for a step you need to take, " "always use it!\n" - " You can execute a trajectory with the ExecuteCachedExecution tool.\n" - " After a trajectory was executed, make sure to verify the results! " - "While it works most of the time, occasionally, the execution can be " - "(partly) incorrect. So make sure to verify if everything is filled out " - "as expected, and make corrections where necessary!\n" + "\n" + " EXECUTING TRAJECTORIES:\n" + " - Use ExecuteCachedTrajectory to execute a cached trajectory\n" + " - You will see all screenshots and results from the execution in " + "the message history\n" + " - After execution completes, verify the results are correct\n" + " - If execution fails partway, you'll see exactly where it failed " + "and can decide how to proceed\n" + "\n" + " CACHING_PARAMETERS:\n" + " - Trajectories may contain dynamic parameters like " + "{{current_date}} or {{user_name}}\n" + " - When executing a trajectory, check if it requires " + "parameter values\n" + " - Provide parameter values using the parameter_values " + "parameter as a dictionary\n" + " - Example: ExecuteCachedTrajectory(trajectory_file='test.json', " + "parameter_values={'current_date': '2025-12-11'})\n" + " - If required parameters are missing, execution will fail with " + "a clear error message\n" + "\n" + " NON-CACHEABLE STEPS:\n" + " - Some tools cannot be cached and require your direct execution " + "(e.g., print_debug, contextual decisions)\n" + " - When trajectory execution reaches a non-cacheable step, it will " + "pause and return control to you\n" + " - You'll receive a NEEDS_AGENT status with the current " + "step index\n" + " - Execute the non-cacheable step manually using your " + "regular tools\n" + " - After completing the non-cacheable step, continue the trajectory " + "using ExecuteCachedTrajectory with start_from_step_index\n" + "\n" + " CONTINUING TRAJECTORIES:\n" + " - Use ExecuteCachedTrajectory with start_from_step_index to resume " + "execution after handling a non-cacheable step\n" + " - Provide the same trajectory file and the step index where " + "execution should continue\n" + " - Example: ExecuteCachedTrajectory(trajectory_file='test.json', " + "start_from_step_index=5, parameter_values={...})\n" + " - The tool will execute remaining steps from that index onwards\n" + "\n" + " FAILURE HANDLING:\n" + " - If a trajectory fails during execution, you'll see the error " + "message and the step where it failed\n" + " - Analyze the failure: Was it due to UI changes, timing issues, " + "or incorrect state?\n" + " - Options for handling failures:\n" + " 1. Execute the remaining steps manually\n" + " 2. Fix the issue and retry from a specific step using " + "ExecuteCachedTrajectory with start_from_step_index\n" + " 3. Report that the cached trajectory is outdated and needs " + "re-recording\n" + "\n" + " BEST PRACTICES:\n" + " - Always verify results after trajectory execution completes\n" + " - While trajectories work most of the time, occasionally " + "execution can be partly incorrect\n" + " - Make corrections where necessary after cached execution\n" + " - if you need to make any corrections after a trajectory " + "execution, please mark the cached execution as failed\n" + " - If a trajectory consistently fails, it may be invalid and " + "should be re-recorded\n" " \n" " \n" " There are several trajectories available to you.\n" " Their filename is a unique testID.\n" - " If executed using the ExecuteCachedExecution tool, a trajectory will " - "automatically execute all necessary steps for the test with that id.\n" + " If executed using the ExecuteCachedTrajectory tool, a trajectory " + "will automatically execute all necessary steps for the test with " + "that id.\n" " \n" ) + +CACHING_PARAMETER_IDENTIFIER_SYSTEM_PROMPT = """You are analyzing UI automation \ +trajectories to identify values that should be parameterized as parameters. + +Identify values that are likely to change between executions, such as: +- Dates and timestamps (e.g., "2025-12-11", "10:30 AM", "2025-12-11T14:30:00Z") +- Usernames, emails, names (e.g., "john.doe", "test@example.com", "John Smith") +- Session IDs, tokens, UUIDs, API keys +- Dynamic text that references current state or time-sensitive information +- File paths with user-specific or time-specific components +- Temporary or generated identifiers + +DO NOT mark as parameters: +- UI element coordinates (x, y positions) +- Fixed button labels or static UI text +- Configuration values that don't change (e.g., timeouts, retry counts) +- Generic action names like "click", "type", "scroll" +- Tool names +- Boolean values or common constants + +For each parameter, provide: +1. A descriptive name in snake_case (e.g., "current_date", "user_email") +2. The actual value found in the trajectory +3. A brief description of what it represents + +Return your analysis as a JSON object with this structure: +{ + "parameters": [ + { + "name": "current_date", + "value": "2025-12-11", + "description": "Current date in YYYY-MM-DD format" + } + ] +} + +If no parameters are found, return an empty parameters array.""" diff --git a/src/askui/reporting.py b/src/askui/reporting.py index 1116f009..96408061 100644 --- a/src/askui/reporting.py +++ b/src/askui/reporting.py @@ -56,6 +56,21 @@ def add_message( """ raise NotImplementedError + @abstractmethod + def add_usage_summary(self, usage: dict[str, int | None]) -> None: + """Add usage statistics summary to the report. + + Called at the end of an act() execution with accumulated token usage. + + Args: + usage (dict[str, int | None]): Accumulated usage statistics containing: + - input_tokens: Total input tokens sent to API + - output_tokens: Total output tokens generated + - cache_creation_input_tokens: Tokens written to prompt cache + - cache_read_input_tokens: Tokens read from prompt cache + """ + raise NotImplementedError + @abstractmethod def generate(self) -> None: """Generates the final report. @@ -81,6 +96,10 @@ def add_message( ) -> None: pass + @override + def add_usage_summary(self, usage: dict[str, int | None]) -> None: + pass + @override def generate(self) -> None: pass @@ -120,6 +139,12 @@ def generate(self) -> None: for report in self._reporters: report.generate() + @override + def add_usage_summary(self, usage: dict[str, int | None]) -> None: + """Add usage summary to all reporters.""" + for reporter in self._reporters: + reporter.add_usage_summary(usage) + class SystemInfo(TypedDict): platform: str @@ -139,6 +164,7 @@ def __init__(self, report_dir: str = "reports") -> None: self.report_dir = Path(report_dir) self.messages: list[dict[str, Any]] = [] self.system_info = self._collect_system_info() + self.usage_summary: dict[str, int | None] | None = None def _collect_system_info(self) -> SystemInfo: """Collect system and Python information""" @@ -179,6 +205,11 @@ def add_message( } self.messages.append(message) + @override + def add_usage_summary(self, usage: dict[str, int | None]) -> None: + """Store usage summary for inclusion in the report.""" + self.usage_summary = usage + @override def generate(self) -> None: """Generate an HTML report file. @@ -684,42 +715,53 @@ def generate(self) -> None: -
-

Conversation Log

- - - - - - - {% for msg in messages %} - - - - - - {% endfor %} -
TimeRoleContent
{{ msg.timestamp.strftime('%H:%M:%S.%f')[:-3] }} UTC - - {{ msg.role }} - - - {% if msg.is_json %} -
-
{{ msg.content }}
-
- {% else %} - {{ msg.content }} - {% endif %} - {% for image in msg.images %} -
- Message image - {% endfor %} -
-
- + {% if usage_summary %} +

Token Usage

+ + {% if usage_summary.get('input_tokens') is not none %} + + + + + {% endif %} + {% if usage_summary.get('output_tokens') is not none %} + + + + + {% endif %} +
Input Tokens{{ "{:,}".format(usage_summary.get('input_tokens')) }}
Output Tokens{{ "{:,}".format(usage_summary.get('output_tokens')) }}
+ {% endif %} + +

Conversation Log

+ + + + + + + {% for msg in messages %} + + + + + + {% endfor %} +
TimeRoleContent
{{ msg.timestamp.strftime('%H:%M:%S') }}{{ msg.role }} + {% if msg.is_json %} +
+
{{ msg.content }}
+
+ {% else %} + {{ msg.content }} + {% endif %} + {% for image in msg.images %} +
+ Message image + {% endfor %} +
""" @@ -729,6 +771,7 @@ def generate(self) -> None: timestamp=datetime.now(tz=timezone.utc), messages=self.messages, system_info=self.system_info, + usage_summary=self.usage_summary, ) report_path = ( @@ -811,6 +854,10 @@ def add_message( attachment_type=self.allure.attachment_type.PNG, ) + @override + def add_usage_summary(self, usage: dict[str, int | None]) -> None: + """No-op for AllureReporter as usage is not part of Allure reports.""" + @override def generate(self) -> None: """No-op for AllureReporter as reports are generated in real-time.""" diff --git a/src/askui/tools/caching_tools.py b/src/askui/tools/caching_tools.py index 4abb3a55..bcba4ef6 100644 --- a/src/askui/tools/caching_tools.py +++ b/src/askui/tools/caching_tools.py @@ -1,13 +1,22 @@ +import json import logging -import time from pathlib import Path +from typing import TYPE_CHECKING from pydantic import validate_call from typing_extensions import override -from ..models.shared.settings import CachedExecutionToolSettings +from ..models.shared.settings import CacheExecutionSettings from ..models.shared.tools import Tool, ToolCollection -from ..utils.cache_writer import CacheWriter +from ..utils.cache_parameter_handler import CacheParameterHandler +from ..utils.caching.cache_execution_manager import CacheExecutionManager +from ..utils.caching.cache_manager import CacheManager +from ..utils.caching.cache_writer import CacheWriter + +if TYPE_CHECKING: + from ..models.shared.agent_message_param import ToolUseBlockParam + from ..models.shared.settings import CacheFile + from ..utils.trajectory_executor import TrajectoryExecutor logger = logging.getLogger() @@ -27,28 +36,93 @@ def __init__(self, cache_dir: str, trajectories_format: str = ".json") -> None: "replayed using the execute_trajectory_tool. Call this tool " "first to see which trajectories are available before " "executing one. The tool returns a list of file paths to " - "available trajectory files." + "available trajectory files.\n\n" + "By default, only valid (non-invalidated) caches are returned. " + "Set include_invalid=True to see all caches including those " + "marked as invalid due to repeated failures." ), + input_schema={ + "type": "object", + "properties": { + "include_invalid": { + "type": "boolean", + "description": ( + "Whether to include invalid/invalidated caches in " + "the results. Default is False (only show valid " + "caches)." + ), + "default": False, + }, + }, + "required": [], + }, ) self._cache_dir = Path(cache_dir) self._trajectories_format = trajectories_format @override @validate_call - def __call__(self) -> list[str]: # type: ignore + def __call__(self, include_invalid: bool = False) -> list[str]: # type: ignore + logger.info( + "Retrieving cached trajectories from %s (include_invalid=%s)", + self._cache_dir, + include_invalid, + ) + if not Path.is_dir(self._cache_dir): error_msg = f"Trajectories directory not found: {self._cache_dir}" logger.error(error_msg) raise FileNotFoundError(error_msg) - available = [ - str(f) + all_files = [ + f for f in self._cache_dir.iterdir() if str(f).endswith(self._trajectories_format) ] + logger.debug("Found %d total cache files", len(all_files)) + + if not include_invalid: + # Filter out invalid caches + valid_files = [] + invalid_count = 0 + unreadable_count = 0 + for f in all_files: + try: + cache_file = CacheWriter.read_cache_file(f) + if cache_file.metadata.is_valid: + valid_files.append(str(f)) + else: + invalid_count += 1 + logger.debug( + "Excluding invalid cache: %s (reason: %s)", + f.name, + cache_file.metadata.invalidation_reason, + ) + except Exception: # noqa: PERF203 + unreadable_count += 1 + logger.exception("Failed to read cache file %s", f.name) + # If we can't read it, exclude it + continue + available = valid_files + logger.info( + "Found %d valid cache(s), excluded %d invalid, %d unreadable", + len(valid_files), + invalid_count, + unreadable_count, + ) + else: + available = [str(f) for f in all_files] + logger.info("Retrieved %d cache file(s) (all included)", len(available)) if not available: - warning_msg = f"Warning: No trajectory files found in {self._cache_dir}" + if include_invalid: + warning_msg = f"Warning: No trajectory files found in {self._cache_dir}" + else: + warning_msg = ( + f"Warning: No valid trajectory files found in " + f"{self._cache_dir}. " + "Try include_invalid=True to see all caches." + ) logger.warning(warning_msg) return available @@ -56,25 +130,41 @@ def __call__(self) -> list[str]: # type: ignore class ExecuteCachedTrajectory(Tool): """ - Execute a predefined trajectory to fast-forward through UI interactions + Execute or continue a predefined trajectory to fast-forward through UI interactions """ - def __init__(self, settings: CachedExecutionToolSettings | None = None) -> None: + def __init__( + self, + toolbox: ToolCollection, + settings: CacheExecutionSettings | None = None, + ) -> None: super().__init__( name="execute_cached_executions_tool", description=( - "Execute a pre-recorded trajectory to automatically perform a " - "sequence of UI interactions. This tool replays mouse movements, " - "clicks, and typing actions from a previously successful execution.\n\n" + "Activate cache execution mode to replay a pre-recorded " + "trajectory. This tool sets up the agent to execute cached UI " + "interactions step-by-step.\n\n" "Before using this tool:\n" "1. Use retrieve_available_trajectories_tool to see which " "trajectory files are available\n" "2. Select the appropriate trajectory file path from the " "returned list\n" - "3. Pass the full file path to this tool\n\n" - "The trajectory will be executed step-by-step, and you should " - "verify the results afterward. Note: Trajectories may fail if " - "the UI state has changed since they were recorded." + "3. If the trajectory contains parameters (e.g., " + "{{current_date}}), provide values for them in the " + "parameter_values parameter\n" + "4. Pass the full file path to this tool\n\n" + "Cache parameters allow dynamic values to be injected during " + "execution. For example, if a trajectory types " + "'{{current_date}}', you must provide " + "parameter_values={'current_date': '2025-12-11'}.\n\n" + "To continue from a specific step (e.g., after manually " + "handling a non-cacheable step), use the start_from_step_index " + "parameter. By default, execution starts from the beginning " + "(step 0).\n\n" + "Once activated, the agent will execute cached steps " + "automatically. If a non-cacheable step is encountered, the " + "agent will be asked to handle it manually before resuming " + "cache execution." ), input_schema={ "type": "object", @@ -87,58 +177,771 @@ def __init__(self, settings: CachedExecutionToolSettings | None = None) -> None: "available files)" ), }, + "start_from_step_index": { + "type": "integer", + "description": ( + "Optional: The step index to start or resume " + "execution from (0-based). Use 0 (default) to start " + "from the beginning. Use a higher index to continue " + "from a specific step, e.g., after manually handling " + "a non-cacheable step." + ), + "default": 0, + }, + "parameter_values": { + "type": "object", + "description": ( + "Optional dictionary mapping parameter names to " + "their values. Required if the trajectory contains " + "parameters like {{variable}}. Example: " + "{'current_date': '2025-12-11', 'user_name': 'Alice'}" + ), + "additionalProperties": {"type": "string"}, + "default": {}, + }, }, "required": ["trajectory_file"], }, ) if not settings: - settings = CachedExecutionToolSettings() + settings = CacheExecutionSettings() self._settings = settings - - def set_toolbox(self, toolbox: ToolCollection) -> None: - """Set the AgentOS/AskUiControllerClient reference for executing actions.""" + self._cache_execution_manager: CacheExecutionManager | None = None self._toolbox = toolbox + def set_cache_execution_manager( + self, cache_execution_manager: CacheExecutionManager + ) -> None: + """Set the agent reference for cache execution mode activation. + + Args: + agent: The Agent instance that will execute the cached trajectory + """ + self._cache_execution_manager = cache_execution_manager + + def _validate_trajectory_file(self, trajectory_file: str) -> str | None: + """Validate that trajectory file exists. + + Args: + trajectory_file: Path to the trajectory file + + Returns: + Error message if validation fails, None otherwise + """ + if not Path(trajectory_file).is_file(): + error_msg = ( + f"Trajectory file not found: {trajectory_file}\n" + "Use retrieve_available_trajectories_tool to see " + "available files." + ) + logger.error(error_msg) + return error_msg + return None + + def _validate_step_index( + self, start_from_step_index: int, trajectory_length: int + ) -> str | None: + """Validate step index is within bounds. + + Args: + start_from_step_index: Index to start from + trajectory_length: Total number of steps in trajectory + + Returns: + Error message if validation fails, None otherwise + """ + logger.debug( + "Validating start_from_step_index=%d (trajectory has %d steps)", + start_from_step_index, + trajectory_length, + ) + if start_from_step_index < 0 or start_from_step_index >= trajectory_length: + error_msg = ( + f"Invalid start_from_step_index: {start_from_step_index}. " + f"Trajectory has {trajectory_length} steps " + f"(valid indices: 0-{trajectory_length - 1})." + ) + logger.error(error_msg) + return error_msg + return None + + def _validate_parameters( + self, + trajectory: list["ToolUseBlockParam"], + parameter_values: dict[str, str], + cache_parameters: dict[str, str], + ) -> str | None: + """Validate parameter values. + + Args: + trajectory: The cached trajectory + parameter_values: User-provided parameter values + cache_parameters: Parameters defined in cache file + + Returns: + Error message if validation fails, None otherwise + """ + logger.debug("Validating parameter values") + is_valid, missing = CacheParameterHandler.validate_parameters( + trajectory, parameter_values + ) + if not is_valid: + error_msg = ( + f"Missing required parameter values: {', '.join(missing)}\n" + f"The trajectory contains the following parameters: " + f"{', '.join(cache_parameters.keys())}\n" + f"Please provide values for all parameters in the " + f"parameter_values parameter." + ) + logger.error(error_msg) + return error_msg + return None + + def _create_executor( + self, + cache_file: "CacheFile", + parameter_values: dict[str, str], + start_from_step_index: int, + ) -> "TrajectoryExecutor": + """Create and configure trajectory executor. + + Args: + cache_file: The cache file to execute + parameter_values: Parameter values to use + start_from_step_index: Index to start execution from + + Returns: + Configured TrajectoryExecutor instance + """ + # Read visual validation settings ONLY from cache metadata + # Visual validation is only enabled if the cache was recorded with it + visual_validation_enabled = False + visual_hash_method = "phash" # Default (unused if validation disabled) + visual_validation_threshold = 10 # Default (unused if validation disabled) + visual_validation_region_size = 100 # Default (unused if validation disabled) + + if cache_file.metadata.visual_verification_method: + # Cache has visual validation metadata - use those exact settings + visual_validation_enabled = ( + cache_file.metadata.visual_verification_method != "none" + ) + visual_hash_method = cache_file.metadata.visual_verification_method + + if cache_file.metadata.visual_validation_threshold is not None: + visual_validation_threshold = ( + cache_file.metadata.visual_validation_threshold + ) + + if cache_file.metadata.visual_validation_region_size is not None: + visual_validation_region_size = ( + cache_file.metadata.visual_validation_region_size + ) + + logger.debug( + ( + "Visual validation enabled from cache metadata: " + "method=%s, threshold=%d, region_size=%d", + visual_hash_method, + visual_validation_threshold, + visual_validation_region_size, + ) + ) + else: + # Cache doesn't have visual validation metadata - don't validate + logger.debug( + ( + "Visual validation disabled: cache file has no visual " + "validation metadata" + ) + ) + + logger.debug( + "Creating TrajectoryExecutor with delay=%ss, visual_validation=%s", + self._settings.delay_time_between_action, + visual_validation_enabled, + ) + + # Import here to avoid circular dependency + from askui.utils.trajectory_executor import TrajectoryExecutor + + executor = TrajectoryExecutor( + trajectory=cache_file.trajectory, + toolbox=self._toolbox, + parameter_values=parameter_values, + delay_time=self._settings.delay_time_between_action, + visual_validation_enabled=visual_validation_enabled, + visual_validation_threshold=visual_validation_threshold, + visual_hash_method=visual_hash_method, + visual_validation_region_size=visual_validation_region_size, + ) + + # Set the starting position if continuing + if start_from_step_index > 0: + executor.current_step_index = start_from_step_index + logger.debug( + "Set executor start position to step %d", start_from_step_index + ) + + return executor + + def _format_success_message( + self, + trajectory_file: str, + trajectory_length: int, + start_from_step_index: int, + parameter_count: int, + ) -> str: + """Format success message. + + Args: + trajectory_file: Path to trajectory file + trajectory_length: Total steps in trajectory + start_from_step_index: Starting step index + parameter_count: Number of parameters used + + Returns: + Formatted success message + """ + if start_from_step_index == 0: + success_msg = ( + f"✓ Cache execution mode activated for " + f"{Path(trajectory_file).name}. " + f"Will execute {trajectory_length} cached steps." + ) + else: + remaining_steps = trajectory_length - start_from_step_index + success_msg = ( + f"✓ Cache execution mode activated, resuming from step " + f"{start_from_step_index}. " + f"Will execute {remaining_steps} remaining cached steps." + ) + + if parameter_count > 0: + success_msg += f" Using {parameter_count} parameter value(s)." + + return success_msg + @override @validate_call - def __call__(self, trajectory_file: str) -> str: - if not hasattr(self, "_toolbox"): - error_msg = "Toolbox not set. Call set_toolbox() first." + def __call__( + self, + trajectory_file: str, + start_from_step_index: int = 0, + parameter_values: dict[str, str] | None = None, + ) -> str: + """Activate cache execution mode for the agent. + + This method validates the cache file and sets up the agent to execute + cached steps. The actual execution happens in the agent's step loop. + + Returns: + Success message indicating cache mode has been activated + """ + if parameter_values is None: + parameter_values = {} + + logger.info( + "Activating cache execution mode: %s (start_from_step=%d)", + Path(trajectory_file).name, + start_from_step_index, + ) + + # Validate agent is set + if not self._cache_execution_manager: + error_msg = ( + "Cache Execution Manager not set. Call " + "set_cache_execution_manager() first." + ) logger.error(error_msg) raise RuntimeError(error_msg) + # Validate trajectory file exists + if error := self._validate_trajectory_file(trajectory_file): + return error + + # Load cache file + logger.debug("Loading cache file: %s", trajectory_file) + cache_file = CacheWriter.read_cache_file(Path(trajectory_file)) + + logger.debug( + "Cache loaded: %d steps, %d parameters, valid=%s", + len(cache_file.trajectory), + len(cache_file.cache_parameters), + cache_file.metadata.is_valid, + ) + + # Warn if cache is invalid + if not cache_file.metadata.is_valid: + warning_msg = ( + f"WARNING: Using invalid cache from " + f"{Path(trajectory_file).name}. " + f"Reason: {cache_file.metadata.invalidation_reason}. " + "This cache may not work correctly." + ) + logger.warning(warning_msg) + + # Validate step index + if error := self._validate_step_index( + start_from_step_index, len(cache_file.trajectory) + ): + return error + + # Validate parameters + if error := self._validate_parameters( + cache_file.trajectory, parameter_values, cache_file.cache_parameters + ): + return error + + # Create and configure executor + executor = self._create_executor( + cache_file, parameter_values, start_from_step_index + ) + + # Store executor and cache info in agent state + self._cache_execution_manager.activate_execution( + executor=executor, + cache_file=cache_file, + cache_file_path=trajectory_file, + ) + + # Format and return success message + success_msg = self._format_success_message( + trajectory_file, + len(cache_file.trajectory), + start_from_step_index, + len(parameter_values), + ) + logger.info(success_msg) + return success_msg + + +class InspectCacheMetadata(Tool): + """ + Inspect detailed metadata for a cached trajectory file + """ + + def __init__(self) -> None: + super().__init__( + name="inspect_cache_metadata_tool", + description=( + "Inspect and display detailed metadata for a cached trajectory file. " + "This tool shows information about:\n" + "- Cache version and creation timestamp\n" + "- Execution statistics (attempts, last execution time)\n" + "- Validity status and invalidation reason (if invalid)\n" + "- Failure history with timestamps and error messages\n" + "- Parameters and trajectory step count\n\n" + "Use this tool to debug cache issues or understand why a cache " + "might be failing or invalidated." + ), + input_schema={ + "type": "object", + "properties": { + "trajectory_file": { + "type": "string", + "description": ( + "Full path to the trajectory file to inspect. " + "Use retrieve_available_trajectories_tool to " + "find available files." + ), + }, + }, + "required": ["trajectory_file"], + }, + ) + + @override + @validate_call + def __call__(self, trajectory_file: str) -> str: + logger.info("Inspecting cache metadata: %s", Path(trajectory_file).name) + if not Path(trajectory_file).is_file(): error_msg = ( f"Trajectory file not found: {trajectory_file}\n" "Use retrieve_available_trajectories_tool to see available files." ) logger.error(error_msg) - raise FileNotFoundError(error_msg) + return error_msg + + try: + cache_file = CacheWriter.read_cache_file(Path(trajectory_file)) + except Exception: + error_msg = f"Failed to read cache file {Path(trajectory_file).name}" + logger.exception(error_msg) + return error_msg + + metadata = cache_file.metadata + logger.debug( + "Metadata loaded: version=%s, valid=%s, attempts=%d, failures=%d", + metadata.version, + metadata.is_valid, + metadata.execution_attempts, + len(metadata.failures), + ) - # Load and execute trajectory - trajectory = CacheWriter.read_cache_file(Path(trajectory_file)) - info_msg = f"Executing cached trajectory from {trajectory_file}" - logger.info(info_msg) - for step in trajectory: - if ( - "screenshot" in step.name - or step.name == "retrieve_available_trajectories_tool" - ): - continue - try: - self._toolbox.run([step]) - except Exception as e: - error_msg = f"An error occured during the cached execution: {e}" - logger.exception(error_msg) - return ( - f"An error occured while executing the trajectory from " - f"{trajectory_file}. Please verify the UI state and " - "continue without cache." + # Format the metadata into a readable string + lines = [ + "=== Cache Metadata ===", + f"File: {trajectory_file}", + "", + "--- Basic Info ---", + f"Version: {metadata.version}", + f"Created: {metadata.created_at}", + f"Last Executed: {metadata.last_executed_at or 'Never'}", + "", + "--- Execution Statistics ---", + f"Total Execution Attempts: {metadata.execution_attempts}", + f"Total Failures: {len(metadata.failures)}", + "", + "--- Validity Status ---", + f"Is Valid: {metadata.is_valid}", + ] + + if not metadata.is_valid: + lines.append(f"Invalidation Reason: {metadata.invalidation_reason}") + + lines.append("") + lines.append("--- Trajectory Info ---") + lines.append(f"Total Steps: {len(cache_file.trajectory)}") + lines.append(f"Parameters: {len(cache_file.cache_parameters)}") + if cache_file.cache_parameters: + lines.append( + f"Parameter Names: {', '.join(cache_file.cache_parameters.keys())}" + ) + + if metadata.failures: + lines.append("") + lines.append("--- Failure History ---") + for i, failure in enumerate(metadata.failures, 1): + lines.append(f"Failure {i}:") + lines.append(f" Timestamp: {failure.timestamp}") + lines.append(f" Step Index: {failure.step_index}") + lines.append( + f" Failure Count at Step: {failure.failure_count_at_step}" + ) + lines.append(f" Error: {failure.error_message}") + + return "\n".join(lines) + + +class RevalidateCache(Tool): + """ + Manually mark a cache as valid (reset invalidation) + """ + + def __init__(self) -> None: + super().__init__( + name="revalidate_cache_tool", + description=( + "Manually mark a cache as valid, resetting any previous invalidation. " + "Use this tool when:\n" + "- A cache was invalidated but the underlying issue has been fixed\n" + "- You want to give a previously failing cache another chance\n" + "- You've manually verified the cache should work now\n\n" + "This will:\n" + "- Set is_valid=True\n" + "- Clear the invalidation_reason\n" + "- Keep existing failure history (for debugging)\n" + "- Keep execution attempt counters\n\n" + "Note: The cache can still be auto-invalidated again if it " + "continues to fail." + ), + input_schema={ + "type": "object", + "properties": { + "trajectory_file": { + "type": "string", + "description": ( + "Full path to the trajectory file to revalidate. " + "Use retrieve_available_trajectories_tool with " + "include_invalid=True to find invalidated caches." + ), + }, + }, + "required": ["trajectory_file"], + }, + ) + + @override + @validate_call + def __call__(self, trajectory_file: str) -> str: + if not Path(trajectory_file).is_file(): + error_msg = ( + f"Trajectory file not found: {trajectory_file}\n" + "Use retrieve_available_trajectories_tool to see available files." + ) + logger.error(error_msg) + return error_msg + + try: + cache_file = CacheWriter.read_cache_file(Path(trajectory_file)) + except Exception: + error_msg = f"Failed to read cache file {trajectory_file}" + logger.exception(error_msg) + return error_msg + + # Mark cache as valid + cache_manager = CacheManager() + was_invalid = not cache_file.metadata.is_valid + previous_reason = cache_file.metadata.invalidation_reason + + cache_manager.mark_cache_valid(cache_file) + + # Write back to disk + try: + cache_path = Path(trajectory_file) + with cache_path.open("w") as f: + json.dump( + cache_file.model_dump(mode="json"), + f, + indent=2, + default=str, + ) + except Exception: + error_msg = f"Failed to write cache file {trajectory_file}" + logger.exception(error_msg) + return error_msg + + if was_invalid: + logger.info("Cache revalidated: %s", trajectory_file) + return ( + f"Successfully revalidated cache: {trajectory_file}\n" + f"Previous invalidation reason was: {previous_reason}\n" + "The cache is now marked as valid and can be used again." + ) + logger.info("Cache was already valid: %s", trajectory_file) + return f"Cache {trajectory_file} was already valid. No changes made." + + +class InvalidateCache(Tool): + """ + Manually mark a cache as invalid + """ + + def __init__(self) -> None: + super().__init__( + name="invalidate_cache_tool", + description=( + "Manually mark a cache as invalid with a custom reason. " + "Use this tool when:\n" + "- You've determined a cache is no longer reliable\n" + "- The UI has changed and the cached actions won't work\n" + "- You want to prevent automatic execution of a problematic cache\n\n" + "This will:\n" + "- Set is_valid=False\n" + "- Record your custom invalidation reason\n" + "- Keep all existing metadata (failures, execution attempts)\n" + "- Hide the cache from default trajectory listings\n\n" + "The cache can later be revalidated using revalidate_cache_tool " + "if the issue is resolved." + ), + input_schema={ + "type": "object", + "properties": { + "trajectory_file": { + "type": "string", + "description": ( + "Full path to the trajectory file to " + "invalidate. " + "Use retrieve_available_trajectories_tool to " + "find available files." + ), + }, + "reason": { + "type": "string", + "description": ( + "Reason for invalidating this cache. " + "Be specific about why " + "this cache should not be used " + "(e.g., 'UI changed - button moved', " + "'Workflow outdated', 'Replaced by new cache')." + ), + }, + }, + "required": ["trajectory_file", "reason"], + }, + ) + + @override + @validate_call + def __call__(self, trajectory_file: str, reason: str) -> str: + if not Path(trajectory_file).is_file(): + error_msg = ( + f"Trajectory file not found: {trajectory_file}\n" + "Use retrieve_available_trajectories_tool to see available files." + ) + logger.error(error_msg) + return error_msg + + try: + cache_file = CacheWriter.read_cache_file(Path(trajectory_file)) + except Exception: + error_msg = f"Failed to read cache file {trajectory_file}" + logger.exception(error_msg) + return error_msg + + # Mark cache as invalid + cache_manager = CacheManager() + was_valid = cache_file.metadata.is_valid + + cache_manager.invalidate_cache(cache_file, reason=reason) + + # Write back to disk + try: + cache_path = Path(trajectory_file) + with cache_path.open("w") as f: + json.dump( + cache_file.model_dump(mode="json"), + f, + indent=2, + default=str, ) - time.sleep(self._settings.delay_time_between_action) + except Exception: + error_msg = f"Failed to write cache file {trajectory_file}" + logger.exception(error_msg) + return error_msg - logger.info("Finished executing cached trajectory") + logger.info("Cache manually invalidated: %s", trajectory_file) + + if was_valid: + return ( + f"Successfully invalidated cache: {trajectory_file}\n" + f"Reason: {reason}\n" + "The cache will not appear in default trajectory listings. " + "Use revalidate_cache_tool to restore it if needed." + ) return ( - f"Successfully executed trajectory from {trajectory_file}. " - "Please verify the UI state." + f"Cache {trajectory_file} was already invalid.\n" + f"Updated invalidation reason to: {reason}" + ) + + +class VerifyCacheExecution(Tool): + """Tool for agent to explicitly report cache execution verification results.""" + + def __init__(self) -> None: + super().__init__( + name="verify_cache_execution", + description=( + "IMPORTANT: Call this tool immediately after reviewing a " + "cached trajectory execution.\n\n" + "Report whether the cached execution successfully achieved " + "the target system state. You MUST call this tool to complete " + "the cache verification process.\n\n" + "Set success=True if:\n" + "- The cached execution correctly achieved the intended goal\n" + "- The final state matches what was expected\n" + "- No corrections or additional actions were needed\n\n" + "Set success=False if:\n" + "- The execution did not achieve the target state\n" + "- You had to make corrections or perform additional actions\n" + "- The final state is incorrect or incomplete" + ), + input_schema={ + "type": "object", + "properties": { + "success": { + "type": "boolean", + "description": ( + "True if cached execution correctly " + "achieved target state, " + "False if execution was incorrect or " + "corrections were needed" + ), + }, + "verification_notes": { + "type": "string", + "description": ( + "Brief explanation of what you verified. " + "If success=False, describe what was " + "wrong and what corrections you made." + ), + }, + }, + "required": ["success", "verification_notes"], + }, + ) + self.is_cacheable = False # Verification is not cacheable + self._cache_execution_manager: CacheExecutionManager | None = None + + def set_cache_execution_manager( + self, cache_execution_manager: CacheExecutionManager + ) -> None: + """Set the agent reference for cache execution mode activation. + + Args: + agent: The Agent instance that will execute the cached trajectory + """ + self._cache_execution_manager = cache_execution_manager + + @override + @validate_call + def __call__(self, success: bool, verification_notes: str) -> str: + """Record cache verification result. + + Args: + success: Whether cache execution achieved target state + verification_notes: Explanation of verification result + + Returns: + Confirmation message + """ + logger.info( + "Cache verification reported: success=%s, notes=%s", + success, + verification_notes, ) + if not self._cache_execution_manager: + error_msg = ( + "Cache Execution Manager not set. Cannot record verification result." + ) + logger.error(error_msg) + return error_msg + + # Check if there's a cache file to update (more reliable than checking flag) + cache_file, cache_file_path = self._cache_execution_manager.get_cache_info() + if not (cache_file and cache_file_path): + warning_msg = ( + "No cache file to update. " + "Verification tool called without recent cache execution." + ) + logger.warning(warning_msg) + return warning_msg + + # Debug log if verification flag wasn't explicitly set + # (This can happen if verification is called directly without the flag, + # but we still proceed since we have the cache file) + if not self._cache_execution_manager.is_cache_verification_pending(): + logger.debug( + "Verification flag not set, but cache file exists. " + "This is normal for direct verification calls." + ) + + # Update cache metadata based on verification result + if success: + self._cache_execution_manager.update_metadata_on_completion(success=True) + result_msg = f"✓ Cache verification successful: {verification_notes}" + logger.info(result_msg) + else: + error_msg = ( + f"Cache execution did not lead to target system state: " + f"{verification_notes}" + ) + self._cache_execution_manager.update_metadata_on_failure( + step_index=-1, # -1 indicates verification failure + error_message=error_msg, + ) + result_msg = ( + f"✗ Cache verification failed: {verification_notes}\n\n" + "The cached trajectory did not achieve the target system " + "state correctly. You should now continue to complete the " + "task manually from the current state. Use your tools to " + "finish achieving the goal, taking into account what the " + "cache attempted to do and what corrections are needed." + ) + logger.warning(result_msg) + + # Clear verification flag and cache references after verification + self._cache_execution_manager.clear_cache_state() + + return result_msg diff --git a/src/askui/utils/cache_parameter_handler.py b/src/askui/utils/cache_parameter_handler.py new file mode 100644 index 00000000..b0eaaf27 --- /dev/null +++ b/src/askui/utils/cache_parameter_handler.py @@ -0,0 +1,493 @@ +"""Cache parameter handling for trajectory recording and execution. + +This module provides utilities for: +- Identifying dynamic values that should become parameters (recording phase) +- Validating and substituting parameter values (execution phase) + +Cache parameters use the {{parameter_name}} syntax and allow dynamic values +to be injected during cache execution. +""" + +import json +import logging +import re +from typing import Any + +from askui.models.shared.agent_message_param import MessageParam, ToolUseBlockParam +from askui.models.shared.messages_api import MessagesApi +from askui.prompts.caching import CACHING_PARAMETER_IDENTIFIER_SYSTEM_PROMPT + +logger = logging.getLogger(__name__) + +# Regex pattern for matching parameters: {{parameter_name}} +# Allows alphanumeric characters and underscores, must start with letter/underscore +CACHE_PARAMETER_PATTERN = r"\{\{([a-zA-Z_][a-zA-Z0-9_]*)\}\}" + + +class CacheParameterDefinition: + """Represents a cache parameter identified in a trajectory.""" + + def __init__(self, name: str, value: Any, description: str) -> None: + self.name = name + self.value = value + self.description = description + + def __repr__(self) -> str: + return f"CacheParameterDefinition(name={self.name}, value={self.value})" + + +class CacheParameterHandler: + """Handles all cache parameter operations for trajectory recording and execution.""" + + # ======================================================================== + # RECORDING PHASE: Parameter identification and templatization + # ======================================================================== + + @staticmethod + def identify_and_parameterize( + trajectory: list[ToolUseBlockParam], + goal: str | None, + identification_strategy: str, + messages_api: MessagesApi | None = None, + model: str | None = None, + ) -> tuple[str | None, list[ToolUseBlockParam], dict[str, str]]: + """Identify parameters and return parameterized trajectory + goal. + + This is the main entry point for the recording phase. It orchestrates + parameter identification and templatization of both trajectory and goal. + + Args: + trajectory: The trajectory to analyze and parameterize + goal: The goal text to parameterize (optional) + identification_strategy: "llm" for AI-based or "preset" for manual + api_provider: API provider for LLM calls (only used for "llm" strategy) + model: Model to use for LLM-based identification + + Returns: + Tuple of: + - Parameterized goal text (or None if no goal) + - Parameterized trajectory (with {{param}} syntax) + - Dict mapping parameter names to descriptions + """ + if identification_strategy == "llm" and trajectory and messages_api and model: + # Use LLM to identify parameters + try: + logger.info("Trying to extract parameters using the strategy 'llm'") + parameters_dict, parameter_definitions = ( + CacheParameterHandler._identify_parameters_with_llm( + trajectory, messages_api, model + ) + ) + + if parameter_definitions: + # Replace values with {{parameter}} syntax in trajectory + parameterized_trajectory = ( + CacheParameterHandler._replace_values_with_parameters( + trajectory, parameter_definitions + ) + ) + + # Apply same replacement to goal text + parameterized_goal = goal + if goal: + parameterized_goal = ( + CacheParameterHandler._apply_parameters_to_text( + goal, parameter_definitions + ) + ) + + n_parameters = len(parameter_definitions) + logger.info( + "Replaced %s parameter values in trajectory", n_parameters + ) + return parameterized_goal, parameterized_trajectory, parameters_dict + + else: # noqa: RET505 + # No parameters identified + logger.info("No parameters identified in trajectory") + return goal, trajectory, {} + + except Exception: + logger.exception( + "An error occurred while extracting parameters using the strategy" + "'llm'. Will use 'preset' strategy instead" + ) + + # Manual extraction (preset strategy) + logger.info("Extracting parameters using the strategy 'preset'") + parameter_names = CacheParameterHandler.extract_parameters(trajectory) + parameters_dict = { + name: f"Parameter for {name}" + for name in parameter_names # Generic desc + } + n_parameters = len(parameter_names) + logger.info("Extracted %s manual parameters from trajectory", n_parameters) + return goal, trajectory, parameters_dict + + @staticmethod + def _identify_parameters_with_llm( + trajectory: list[ToolUseBlockParam], + messages_api: MessagesApi, + model: str = "claude-sonnet-4-5-20250929", + ) -> tuple[dict[str, str], list[CacheParameterDefinition]]: + """Identify parameters in a trajectory using LLM analysis. + + Args: + trajectory: The trajectory to analyze (list of tool use blocks) + messages_api: Messages API instance for LLM calls + model: Model to use for analysis + + Returns: + Tuple of: + - Dict mapping parameter names to descriptions + - List of CacheParameterDefinition objects with name, value, and description + """ + if not trajectory: + logger.debug("Empty trajectory provided, skipping parameter identification") + return {}, [] + + logger.info( + "Starting parameter identification for trajectory with %s steps", + len(trajectory), + ) + + # Convert trajectory to serializable format for analysis + trajectory_data = [tool.model_dump(mode="json") for tool in trajectory] + logger.debug("Converted %s tool blocks to JSON format", len(trajectory_data)) + + user_message = ( + "Analyze this UI automation trajectory and identify all values that " + "should be parameters:\n\n" + f"```json\n{json.dumps(trajectory_data, indent=2)}\n```\n\n" + "Return only the JSON object with identified parameters. " + "Be thorough but conservative - only mark values that are clearly " + "dynamic or time-sensitive." + ) + + response_text = "" # Initialize for error logging + try: + # Make single API call + logger.debug("Calling LLM (%s) to analyze trajectory for parameters", model) + response = messages_api.create_message( + messages=[MessageParam(role="user", content=user_message)], + model=model, + system=CACHING_PARAMETER_IDENTIFIER_SYSTEM_PROMPT, + max_tokens=4096, + temperature=0.0, # Deterministic for analysis + ) + logger.debug("Received response from LLM") + + # Extract text from response + if isinstance(response.content, list): + response_text = next( + ( + block.text + for block in response.content + if hasattr(block, "text") + ), + "", + ) + else: + response_text = str(response.content) + + # Parse the JSON response + logger.debug("Parsing LLM response to extract parameter definitions") + # Handle markdown code blocks if present + if "```json" in response_text: + logger.debug("Removing JSON markdown code block wrapper from response") + response_text = ( + response_text.split("```json")[1].split("```")[0].strip() + ) + elif "```" in response_text: + logger.debug("Removing code block wrapper from response") + response_text = response_text.split("```")[1].split("```")[0].strip() + + parameter_data = json.loads(response_text) + logger.debug( + "Successfully parsed JSON response with %s parameters", + len(parameter_data.get("parameters", [])), + ) + + # Convert to our data structures + parameter_definitions = [ + CacheParameterDefinition( + name=p["name"], value=p["value"], description=p["description"] + ) + for p in parameter_data.get("parameters", []) + ] + + parameters_dict = {p.name: p.description for p in parameter_definitions} + + if parameter_definitions: + logger.info( + "Successfully identified %s parameters in trajectory", + len(parameter_definitions), + ) + for p in parameter_definitions: + logger.debug(" - %s: %s (%s)", p.name, p.value, p.description) + else: + logger.info( + "No parameters identified in trajectory " + "(this is normal for trajectories with only static values)" + ) + + except json.JSONDecodeError as e: + logger.warning( + "Failed to parse LLM response as JSON: %s. " + "Falling back to empty parameter list.", + e, + extra={"response_text": response_text[:500]}, # Log first 500 chars + ) + return {}, [] + except Exception as e: # noqa: BLE001 + logger.warning( + "Failed to identify parameters with LLM: %s. " + "Falling back to empty parameter list.", + e, + exc_info=True, + ) + return {}, [] + else: + return parameters_dict, parameter_definitions + + @staticmethod + def _replace_values_with_parameters( + trajectory: list[ToolUseBlockParam], + parameter_definitions: list[CacheParameterDefinition], + ) -> list[ToolUseBlockParam]: + """Replace actual values in trajectory with {{parameter_name}} syntax. + + This is the reverse of substitute_parameters - it takes identified values + and replaces them with parameter syntax for saving to cache. + + Args: + trajectory: The trajectory to templatize + parameter_definitions: List of CacheParameterDefinition objects with + name and value attributes + + Returns: + New trajectory with values replaced by parameters + """ + # Build replacement map: value -> parameter name + replacements = { + str(p.value): f"{{{{{p.name}}}}}" for p in parameter_definitions + } + + # Apply replacements to each tool block + parameterized_trajectory = [] + for tool_block in trajectory: + parameterized_input = CacheParameterHandler._replace_values_in_value( + tool_block.input, replacements + ) + + parameterized_trajectory.append( + ToolUseBlockParam( + id=tool_block.id, + name=tool_block.name, + input=parameterized_input, + type=tool_block.type, + cache_control=tool_block.cache_control, + visual_representation=tool_block.visual_representation, + ) + ) + + return parameterized_trajectory + + @staticmethod + def _apply_parameters_to_text( + text: str, parameter_definitions: list[CacheParameterDefinition] + ) -> str: + """Apply parameter replacement to a text string (e.g., goal). + + Args: + text: The text to parameterize + parameter_definitions: List of parameter definitions + + Returns: + Text with values replaced by {{parameter}} syntax + """ + # Build replacement map: value -> parameter syntax + replacements = { + str(p.value): f"{{{{{p.name}}}}}" for p in parameter_definitions + } + # Sort by length descending to replace longer matches first + result = text + for actual_value in sorted(replacements.keys(), key=len, reverse=True): + if actual_value in result: + result = result.replace(actual_value, replacements[actual_value]) + return result + + @staticmethod + def _replace_values_in_value(value: Any, replacements: dict[str, str]) -> Any: + """Recursively replace actual values with parameter syntax. + + Args: + value: Any value (str, dict, list, etc.) to process + replacements: Dict mapping actual values to parameter syntax + + Returns: + New value with replacements applied + """ + if isinstance(value, str): + # Replace exact matches and substring matches + result = value + # Sort by length descending to replace longer matches first + # This prevents partial replacements + for actual_value in sorted(replacements.keys(), key=len, reverse=True): + if actual_value in result: + result = result.replace(actual_value, replacements[actual_value]) + return result + if isinstance(value, dict): + # Recursively replace in dict values + return { + k: CacheParameterHandler._replace_values_in_value(v, replacements) + for k, v in value.items() + } + if isinstance(value, list): + # Recursively replace in list items + return [ + CacheParameterHandler._replace_values_in_value(item, replacements) + for item in value + ] + # For non-string types, check if the value matches exactly + str_value = str(value) + if str_value in replacements: + # Return the parameter as a string + return replacements[str_value] + return value + + # ======================================================================== + # EXECUTION PHASE: Parameter extraction, validation, and substitution + # ======================================================================== + + @staticmethod + def extract_parameters(trajectory: list[ToolUseBlockParam]) -> set[str]: + """Extract all parameter names from a trajectory. + + Scans all tool inputs for {{parameter_name}} patterns and returns + a set of unique parameter names. + + Args: + trajectory: List of tool use blocks to scan + + Returns: + Set of unique parameter names found in the trajectory + """ + parameters: set[str] = set() + + for step in trajectory: + # Recursively find parameters in the input object + parameters.update(CacheParameterHandler._extract_from_value(step.input)) + + return parameters + + @staticmethod + def _extract_from_value(value: Any) -> set[str]: + """Recursively extract parameters from a value. + + Args: + value: Any value (str, dict, list, etc.) to search for parameters + + Returns: + Set of parameter names found + """ + parameters: set[str] = set() + + if isinstance(value, str): + # Find all matches in the string + matches = re.finditer(CACHE_PARAMETER_PATTERN, value) + parameters.update(match.group(1) for match in matches) + elif isinstance(value, dict): + # Recursively search dict values + for v in value.values(): + parameters.update(CacheParameterHandler._extract_from_value(v)) + elif isinstance(value, list): + # Recursively search list items + for item in value: + parameters.update(CacheParameterHandler._extract_from_value(item)) + + return parameters + + @staticmethod + def validate_parameters( + trajectory: list[ToolUseBlockParam], provided_values: dict[str, str] + ) -> tuple[bool, list[str]]: + """Validate that all required parameters have values. + + Args: + trajectory: List of tool use blocks containing parameters + provided_values: Dict of parameter names to their values + + Returns: + Tuple of (is_valid, missing_parameters) + - is_valid: True if all parameters have values, False otherwise + - missing_parameters: List of parameter names that are missing values + """ + required_parameters = CacheParameterHandler.extract_parameters(trajectory) + missing = [name for name in required_parameters if name not in provided_values] + + return len(missing) == 0, missing + + @staticmethod + def substitute_parameters( + tool_block: ToolUseBlockParam, parameter_values: dict[str, str] + ) -> ToolUseBlockParam: + """Replace parameters in a tool block with actual values. + + Creates a new ToolUseBlockParam with all {{parameter}} occurrences + replaced with their corresponding values from parameter_values. + + Args: + tool_block: The tool use block containing parameters + parameter_values: Dict mapping parameter names to replacement values + + Returns: + New ToolUseBlockParam with parameters substituted + """ + # Deep copy the input and substitute parameters + substituted_input = CacheParameterHandler._substitute_in_value( + tool_block.input, parameter_values + ) + + # Create new ToolUseBlockParam with substituted values + return ToolUseBlockParam( + id=tool_block.id, + name=tool_block.name, + input=substituted_input, + type=tool_block.type, + cache_control=tool_block.cache_control, + ) + + @staticmethod + def _substitute_in_value(value: Any, parameter_values: dict[str, str]) -> Any: + """Recursively substitute parameters in a value. + + Args: + value: Any value (str, dict, list, etc.) containing parameters + parameter_values: Dict of parameter names to replacement values + + Returns: + New value with parameters substituted + """ + if isinstance(value, str): + # Replace all parameters in the string using literal string replacement + # This avoids regex interpretation issues with backslashes in values + result = value + for name, replacement in parameter_values.items(): + placeholder = f"{{{{{name}}}}}" # Creates "{{parameter_name}}" + result = result.replace(placeholder, replacement) + return result + if isinstance(value, dict): + # Recursively substitute in dict values + return { + k: CacheParameterHandler._substitute_in_value(v, parameter_values) + for k, v in value.items() + } + if isinstance(value, list): + # Recursively substitute in list items + return [ + CacheParameterHandler._substitute_in_value(item, parameter_values) + for item in value + ] + # Return other types as-is + return value diff --git a/src/askui/utils/cache_writer.py b/src/askui/utils/cache_writer.py deleted file mode 100644 index 36508c73..00000000 --- a/src/askui/utils/cache_writer.py +++ /dev/null @@ -1,70 +0,0 @@ -import json -import logging -from datetime import datetime, timezone -from pathlib import Path - -from askui.models.shared.agent_message_param import MessageParam, ToolUseBlockParam -from askui.models.shared.agent_on_message_cb import OnMessageCbParam - -logger = logging.getLogger(__name__) - - -class CacheWriter: - def __init__(self, cache_dir: str = ".cache", file_name: str = "") -> None: - self.cache_dir = Path(cache_dir) - self.cache_dir.mkdir(exist_ok=True) - self.messages: list[ToolUseBlockParam] = [] - if file_name and not file_name.endswith(".json"): - file_name += ".json" - self.file_name = file_name - self.was_cached_execution = False - - def add_message_cb(self, param: OnMessageCbParam) -> MessageParam: - """Add a message to cache.""" - if param.message.role == "assistant": - contents = param.message.content - if isinstance(contents, list): - for content in contents: - if isinstance(content, ToolUseBlockParam): - self.messages.append(content) - if content.name == "execute_cached_executions_tool": - self.was_cached_execution = True - if param.message.stop_reason == "end_turn": - self.generate() - - return param.message - - def set_file_name(self, file_name: str) -> None: - if not file_name.endswith(".json"): - file_name += ".json" - self.file_name = file_name - - def reset(self, file_name: str = "") -> None: - self.messages = [] - if file_name and not file_name.endswith(".json"): - file_name += ".json" - self.file_name = file_name - self.was_cached_execution = False - - def generate(self) -> None: - if self.was_cached_execution: - logger.info("Will not write cache file as this was a cached execution") - return - if not self.file_name: - self.file_name = ( - f"cached_trajectory_{datetime.now(tz=timezone.utc):%Y%m%d%H%M%S%f}.json" - ) - cache_file_path = self.cache_dir / self.file_name - - messages_json = [m.model_dump() for m in self.messages] - with cache_file_path.open("w", encoding="utf-8") as f: - json.dump(messages_json, f, indent=4) - info_msg = f"Cache File written at {str(cache_file_path)}" - logger.info(info_msg) - self.reset() - - @staticmethod - def read_cache_file(cache_file_path: Path) -> list[ToolUseBlockParam]: - with cache_file_path.open("r", encoding="utf-8") as f: - raw_trajectory = json.load(f) - return [ToolUseBlockParam(**step) for step in raw_trajectory] diff --git a/src/askui/utils/caching/__init__.py b/src/askui/utils/caching/__init__.py new file mode 100644 index 00000000..be6e4b17 --- /dev/null +++ b/src/askui/utils/caching/__init__.py @@ -0,0 +1 @@ +"""Caching utilities for trajectory execution.""" diff --git a/src/askui/utils/caching/cache_execution_manager.py b/src/askui/utils/caching/cache_execution_manager.py new file mode 100644 index 00000000..765bd2a2 --- /dev/null +++ b/src/askui/utils/caching/cache_execution_manager.py @@ -0,0 +1,357 @@ +"""Manager for cache execution flow and state.""" + +import json +import logging +from pathlib import Path +from typing import TYPE_CHECKING + +from askui.models.shared.agent_message_param import MessageParam, TextBlockParam +from askui.models.shared.agent_on_message_cb import OnMessageCb, OnMessageCbParam +from askui.models.shared.truncation_strategies import TruncationStrategy +from askui.reporting import Reporter +from askui.utils.caching.cache_manager import CacheManager +from askui.utils.trajectory_executor import ExecutionResult + +if TYPE_CHECKING: + from askui.models.shared.settings import CacheFile + from askui.utils.trajectory_executor import TrajectoryExecutor + +logger = logging.getLogger(__name__) + + +class CacheExecutionManager: + """Manages cache execution flow, state, and metadata updates. + + This class encapsulates all cache-related logic, keeping the Agent class + focused on conversation orchestration. + """ + + def __init__(self, reporter: Reporter) -> None: + """Initialize cache execution manager. + + Args: + reporter: Reporter for logging messages and actions + """ + self._reporter = reporter + # Cache execution state + self._executing_from_cache: bool = False + self._cache_executor: "TrajectoryExecutor | None" = None + self._cache_file_path: str | None = None + self._cache_file: "CacheFile | None" = None + # Track cache verification after execution completes + self._cache_verification_pending: bool = False + + def reset_state(self) -> None: + """Reset cache execution state.""" + self._executing_from_cache = False + self._cache_executor = None + self._cache_file_path = None + self._cache_file = None + self._cache_verification_pending = False + logger.debug("Reset cache execution state") + + def activate_execution( + self, + executor: "TrajectoryExecutor", + cache_file: "CacheFile", + cache_file_path: str, + ) -> None: + """Activate cache execution mode. + + Args: + executor: The trajectory executor to use + cache_file: The cache file being executed + cache_file_path: Path to the cache file + """ + self._cache_executor = executor + self._cache_file = cache_file + self._cache_file_path = cache_file_path + self._executing_from_cache = True + + def get_cache_info(self) -> tuple["CacheFile | None", str | None]: + """Get current cache file and path. + + Returns: + Tuple of (cache_file, cache_file_path) + """ + return (self._cache_file, self._cache_file_path) + + def is_cache_verification_pending(self) -> bool: + """Check if cache verification is pending. + + Returns: + True if verification is pending + """ + return self._cache_verification_pending + + def clear_cache_state(self) -> None: + """Clear cache execution state after verification.""" + self._cache_verification_pending = False + self._cache_file = None + self._cache_file_path = None + + def handle_execution_step( + self, + on_message: OnMessageCb, + truncation_strategy: TruncationStrategy, + ) -> bool: + """Handle cache execution step. + + Args: + on_message: Callback for messages + truncation_strategy: Message truncation strategy + agent_class_name: Name of agent class for reporting + + Returns: + True if cache step was handled and caller should recurse, + False if should continue with normal flow + """ + if not (self._executing_from_cache and self._cache_executor): + return False + + logger.debug("Executing next step from cache") + result: ExecutionResult = self._cache_executor.execute_next_step() + + if result.status == "SUCCESS": + return self._handle_cache_success( + result, + on_message, + truncation_strategy, + ) + if result.status == "NEEDS_AGENT": + return self._handle_cache_needs_agent( + result, + on_message, + truncation_strategy, + ) + if result.status == "COMPLETED": + return self._handle_cache_completed(truncation_strategy) + # result.status == "FAILED" + return self._handle_cache_failed(result) + + def _handle_cache_success( + self, + result: ExecutionResult, + on_message: OnMessageCb, + truncation_strategy: TruncationStrategy, + ) -> bool: + """Handle successful cache step execution. + + Returns: + True if messages were added and caller should recurse, + False otherwise + """ + if len(result.message_history) < 2: + return False + + assistant_msg = result.message_history[-2] + user_msg = result.message_history[-1] + + # Add assistant message (tool use) + message_by_assistant = self._call_on_message( + on_message, assistant_msg, truncation_strategy.messages + ) + if message_by_assistant is None: + return True + + truncation_strategy.append_message(message_by_assistant) + self._reporter.add_message( + self.__class__.__name__, message_by_assistant.model_dump(mode="json") + ) + + # Add user message (tool result) + user_msg_processed = self._call_on_message( + on_message, user_msg, truncation_strategy.messages + ) + if user_msg_processed is None: + return True + + truncation_strategy.append_message(user_msg_processed) + + # Return True to indicate caller should recurse + return True + + def _handle_cache_needs_agent( + self, + result: ExecutionResult, + on_message: OnMessageCb, + truncation_strategy: TruncationStrategy, + ) -> bool: + """Handle cache execution pausing for non-cacheable tool. + + Injects a user message explaining that cache execution paused and + what the agent needs to execute next. + + Returns: + False to indicate normal agent flow should continue + """ + logger.info( + "Paused cache execution at step %d " + "(non-cacheable tool - agent will handle this step)", + result.step_index, + ) + self._executing_from_cache = False + + # Get the tool that needs to be executed + tool_to_execute = result.tool_result # This is the ToolUseBlockParam + + # Create a user message explaining what needs to be done + if tool_to_execute: + instruction_message = MessageParam( + role="user", + content=[ + TextBlockParam( + type="text", + text=( + f"Cache execution paused at step " + f"{result.step_index}. " + f"The previous steps were executed successfully " + f"from cache. " + f"The next step requires the " + f"'{tool_to_execute.name}' tool, " + f"which cannot be executed from cache. " + f"Please execute this tool with the necessary " + f"parameters." + ), + ) + ], + ) + + # Add the instruction message to truncation strategy + instruction_msg = self._call_on_message( + on_message, instruction_message, truncation_strategy.messages + ) + if instruction_msg: + truncation_strategy.append_message(instruction_msg) + + return False # Fall through to normal agent API call + + def _handle_cache_completed(self, truncation_strategy: TruncationStrategy) -> bool: + """Handle cache execution completion.""" + logger.info( + "✓ Cache trajectory execution completed - requesting agent verification" + ) + self._executing_from_cache = False + self._cache_verification_pending = True + + # Inject verification request message + verification_request = MessageParam( + role="user", + content=[ + TextBlockParam( + type="text", + text=( + "[CACHE EXECUTION COMPLETED]\n\n" + "The CacheExecutor has automatically executed" + f" all steps from the cached trajectory" + f" '{self._cache_file_path}'. All previous tool calls in this" + f" conversation were replayed from cache, not performed by the" + f" agent.\n\n Please verify if the cached execution correctly" + " achieved the target system state using the" + " verify_cache_execution tool." + ), + ) + ], + ) + truncation_strategy.append_message(verification_request) + self._reporter.add_message( + self.__class__.__name__, verification_request.model_dump(mode="json") + ) + logger.debug("Injected cache verification request message") + return False # Fall through to let agent verify execution + + def _handle_cache_failed(self, result: ExecutionResult) -> bool: + """Handle cache execution failure.""" + logger.error( + "✗ Cache execution failed at step %d: %s", + result.step_index, + result.error_message, + ) + self._executing_from_cache = False + + # Update cache metadata + if self._cache_file and self._cache_file_path: + self.update_metadata_on_failure( + step_index=result.step_index, + error_message=result.error_message or "Unknown error", + ) + + return False # Fall through to let agent continue + + def _call_on_message( + self, + on_message: OnMessageCb | None, + message: MessageParam, + messages: list[MessageParam], + ) -> MessageParam | None: + """Call on_message callback if provided.""" + if on_message is None: + return message + + return on_message(OnMessageCbParam(message=message, messages=messages)) + + def update_metadata_on_completion(self, success: bool) -> None: + """Update cache metadata after execution completion. + + Args: + success: Whether the execution was successful + """ + if not self._cache_file or not self._cache_file_path: + return + + try: + cache_manager = CacheManager() + cache_manager.record_execution_attempt(self._cache_file, success=success) + + # Write updated metadata back to disk + cache_path = Path(self._cache_file_path) + with cache_path.open("w") as f: + json.dump( + self._cache_file.model_dump(mode="json"), + f, + indent=2, + default=str, + ) + logger.info("Updated cache metadata: %s", cache_path.name) + except Exception: + logger.exception("Failed to update cache metadata") + + def update_metadata_on_failure(self, step_index: int, error_message: str) -> None: + """Update cache metadata after execution failure. + + Args: + step_index: The step index where failure occurred + error_message: The error message + """ + if not self._cache_file or not self._cache_file_path: + return + + try: + cache_manager = CacheManager() + cache_manager.record_execution_attempt(self._cache_file, success=False) + cache_manager.record_step_failure( + self._cache_file, + step_index=step_index, + error_message=error_message, + ) + + # Check if cache should be invalidated + should_inv, reason = cache_manager.should_invalidate( + self._cache_file, step_index=step_index + ) + if should_inv and reason: + logger.warning("Cache invalidated: %s", reason) + cache_manager.invalidate_cache(self._cache_file, reason=reason) + + # Write updated metadata back to disk + cache_path = Path(self._cache_file_path) + with cache_path.open("w") as f: + json.dump( + self._cache_file.model_dump(mode="json"), + f, + indent=2, + default=str, + ) + logger.debug("Updated cache metadata after failure: %s", cache_path.name) + except Exception: + logger.exception("Failed to update cache metadata") diff --git a/src/askui/utils/caching/cache_manager.py b/src/askui/utils/caching/cache_manager.py new file mode 100644 index 00000000..e874bde0 --- /dev/null +++ b/src/askui/utils/caching/cache_manager.py @@ -0,0 +1,73 @@ +"""Cache management for tracking execution and invalidation.""" + +from datetime import datetime, timezone +from typing import Optional + +from askui.models.shared.settings import CacheFailure, CacheFile +from askui.utils.caching.cache_validator import ( + CacheValidator, + CompositeCacheValidator, + StaleCacheValidator, + StepFailureCountValidator, + TotalFailureRateValidator, +) + + +class CacheManager: + """Manages cache metadata updates and validation using configurable validators.""" + + def __init__(self, validators: Optional[list[CacheValidator]] = None): + if validators is None: + self.validators = CompositeCacheValidator( + [ + StepFailureCountValidator(max_failures_per_step=3), + TotalFailureRateValidator(min_attempts=10, max_failure_rate=0.5), + StaleCacheValidator(max_age_days=30), + ] + ) + else: + self.validators = CompositeCacheValidator(validators) + + def record_execution_attempt( + self, + cache_file: CacheFile, + success: bool, + failure_info: Optional[CacheFailure] = None, + ) -> None: + cache_file.metadata.execution_attempts += 1 + if success: + cache_file.metadata.last_executed_at = datetime.now(tz=timezone.utc) + elif failure_info: + cache_file.metadata.failures.append(failure_info) + + def record_step_failure( + self, cache_file: CacheFile, step_index: int, error_message: str + ) -> None: + failure = CacheFailure( + timestamp=datetime.now(tz=timezone.utc), + step_index=step_index, + error_message=error_message, + failure_count_at_step=self.get_failure_count_for_step( + cache_file, step_index + ) + + 1, + ) + cache_file.metadata.failures.append(failure) + + def should_invalidate( + self, cache_file: CacheFile, step_index: Optional[int] = None + ) -> tuple[bool, Optional[str]]: + return self.validators.should_invalidate(cache_file, step_index) + + def invalidate_cache(self, cache_file: CacheFile, reason: str) -> None: + cache_file.metadata.is_valid = False + cache_file.metadata.invalidation_reason = reason + + def mark_cache_valid(self, cache_file: CacheFile) -> None: + cache_file.metadata.is_valid = True + cache_file.metadata.invalidation_reason = None + + def get_failure_count_for_step(self, cache_file: CacheFile, step_index: int) -> int: + return sum( + 1 for f in cache_file.metadata.failures if f.step_index == step_index + ) diff --git a/src/askui/utils/caching/cache_validator.py b/src/askui/utils/caching/cache_validator.py new file mode 100644 index 00000000..5b113792 --- /dev/null +++ b/src/askui/utils/caching/cache_validator.py @@ -0,0 +1,242 @@ +"""Cache validation strategies for automatic cache invalidation. + +This module provides an extensible validator pattern that allows users to +define custom cache invalidation logic. The system includes built-in validators +for common scenarios like step failure counts, overall failure rates, and stale caches. +""" + +from abc import ABC, abstractmethod +from datetime import datetime, timedelta, timezone +from typing import Optional + +from askui.models.shared.settings import CacheFile + + +class CacheValidator(ABC): + """Abstract base class for cache validation strategies. + + Users can implement custom validators by subclassing this and implementing + the should_invalidate method. + """ + + @abstractmethod + def should_invalidate( + self, cache_file: CacheFile, step_index: Optional[int] = None + ) -> tuple[bool, Optional[str]]: + """Check if cache should be invalidated. + + Args: + cache_file: The cache file with metadata and trajectory + step_index: Optional step index where failure occurred + + Returns: + Tuple of (should_invalidate: bool, reason: Optional[str]) + """ + + @abstractmethod + def get_name(self) -> str: + """Return validator name for logging/debugging.""" + + +class CompositeCacheValidator(CacheValidator): + """Composite validator that combines multiple validation strategies. + + Invalidates cache if ANY of the validators returns True. + Users can add custom validators via add_validator(). + """ + + def __init__(self, validators: Optional[list[CacheValidator]] = None): + """Initialize composite validator. + + Args: + validators: Optional list of validators to include + """ + self.validators: list[CacheValidator] = validators or [] + + def add_validator(self, validator: CacheValidator) -> None: + """Add a validator to the composite. + + Args: + validator: The validator to add + """ + self.validators.append(validator) + + def should_invalidate( + self, cache_file: CacheFile, step_index: Optional[int] = None + ) -> tuple[bool, Optional[str]]: + """Check all validators, invalidate if any returns True. + + Args: + cache_file: The cache file with metadata and trajectory + step_index: Optional step index where failure occurred + + Returns: + Tuple of (should_invalidate: bool, reason: Optional[str]) + If multiple validators trigger, reasons are combined with "; " + """ + reasons = [] + for validator in self.validators: + should_inv, reason = validator.should_invalidate(cache_file, step_index) + if should_inv and reason: + reasons.append(f"{validator.get_name()}: {reason}") + + if reasons: + return True, "; ".join(reasons) + return False, None + + def get_name(self) -> str: + """Return validator name.""" + return "CompositeValidator" + + +# Built-in validators + + +class StepFailureCountValidator(CacheValidator): + """Invalidate if same step fails too many times. + + This validator counts how many times a specific step has failed + and invalidates the cache if it exceeds the threshold. + """ + + def __init__(self, max_failures_per_step: int = 3): + """Initialize validator. + + Args: + max_failures_per_step: Maximum number of failures allowed per step + """ + self.max_failures_per_step = max_failures_per_step + + def should_invalidate( + self, cache_file: CacheFile, step_index: Optional[int] = None + ) -> tuple[bool, Optional[str]]: + """Check if step has failed too many times. + + Args: + cache_file: The cache file with metadata and trajectory + step_index: The step index to check (required for this validator) + + Returns: + Tuple of (should_invalidate: bool, reason: Optional[str]) + """ + if step_index is None: + return False, None + + # Count failures at this specific step + failures_at_step = sum( + 1 for f in cache_file.metadata.failures if f.step_index == step_index + ) + + if failures_at_step >= self.max_failures_per_step: + return ( + True, + f"Step {step_index} failed {failures_at_step} times " + f"(max: {self.max_failures_per_step})", + ) + return False, None + + def get_name(self) -> str: + """Return validator name.""" + return "StepFailureCount" + + +class TotalFailureRateValidator(CacheValidator): + """Invalidate if overall failure rate is too high. + + This validator calculates the ratio of failures to execution attempts + and invalidates if the rate exceeds the threshold after a minimum + number of attempts. + """ + + def __init__(self, min_attempts: int = 10, max_failure_rate: float = 0.5): + """Initialize validator. + + Args: + min_attempts: Minimum execution attempts before checking rate + max_failure_rate: Maximum acceptable failure rate (0.0 to 1.0) + """ + self.min_attempts = min_attempts + self.max_failure_rate = max_failure_rate + + def should_invalidate( + self, cache_file: CacheFile, _step_index: Optional[int] = None + ) -> tuple[bool, Optional[str]]: + """Check if overall failure rate is too high. + + Args: + cache_file: The cache file with metadata and trajectory + _step_index: Unused for this validator + + Returns: + Tuple of (should_invalidate: bool, reason: Optional[str]) + """ + attempts = cache_file.metadata.execution_attempts + if attempts < self.min_attempts: + return False, None + + failures = len(cache_file.metadata.failures) + failure_rate = failures / attempts if attempts > 0 else 0.0 + + if failure_rate > self.max_failure_rate: + return ( + True, + f"Failure rate {failure_rate:.1%} exceeds " + f"{self.max_failure_rate:.1%} after {attempts} attempts", + ) + return False, None + + def get_name(self) -> str: + """Return validator name.""" + return "TotalFailureRate" + + +class StaleCacheValidator(CacheValidator): + """Invalidate if cache is old and has failures. + + This validator checks if a cache hasn't been successfully executed + in a long time AND has failures. Caches without failures are not + considered stale regardless of age. + """ + + def __init__(self, max_age_days: int = 30): + """Initialize validator. + + Args: + max_age_days: Maximum age in days for cache with failures + """ + self.max_age_days = max_age_days + + def should_invalidate( + self, cache_file: CacheFile, _step_index: Optional[int] = None + ) -> tuple[bool, Optional[str]]: + """Check if cache is stale (old + has failures). + + Args: + cache_file: The cache file with metadata and trajectory + _step_index: Unused for this validator + + Returns: + Tuple of (should_invalidate: bool, reason: Optional[str]) + """ + if not cache_file.metadata.last_executed_at: + return False, None + + if not cache_file.metadata.failures: + return False, None # No failures, age doesn't matter + + # Ensure last_executed_at is timezone-aware + last_executed = cache_file.metadata.last_executed_at + if last_executed.tzinfo is None: + last_executed = last_executed.replace(tzinfo=timezone.utc) + + age = datetime.now(tz=timezone.utc) - last_executed + if age > timedelta(days=self.max_age_days): + return ( + True, + f"Cache not successfully executed in {age.days} days and has failures", + ) + return False, None + + def get_name(self) -> str: + """Return validator name.""" + return "StaleCache" diff --git a/src/askui/utils/caching/cache_writer.py b/src/askui/utils/caching/cache_writer.py new file mode 100644 index 00000000..40fa47ef --- /dev/null +++ b/src/askui/utils/caching/cache_writer.py @@ -0,0 +1,497 @@ +import json +import logging +from datetime import datetime, timezone +from pathlib import Path +from typing import TYPE_CHECKING + +from PIL import Image + +from askui.models.model_router import ModelRouter +from askui.models.shared.agent_message_param import ( + MessageParam, + ToolUseBlockParam, + UsageParam, +) +from askui.models.shared.agent_on_message_cb import OnMessageCbParam +from askui.models.shared.facade import ModelFacade +from askui.models.shared.settings import ( + CacheFile, + CacheMetadata, + CacheWritingSettings, +) +from askui.models.shared.tools import ToolCollection +from askui.utils.cache_parameter_handler import CacheParameterHandler +from askui.utils.visual_validation import ( + compute_ahash, + compute_phash, + extract_region, + get_validation_coordinate, + should_validate_step, +) + +if TYPE_CHECKING: + from askui.models.models import ActModel + +logger = logging.getLogger(__name__) + + +class CacheWriter: + def __init__( + self, + cache_dir: str = ".cache", + cache_writing_settings: CacheWritingSettings | None = None, + toolbox: ToolCollection | None = None, + goal: str | None = None, + model_router: ModelRouter | None = None, + model: str | None = None, + ) -> None: + self.cache_dir = Path(cache_dir) + self.cache_dir.mkdir(exist_ok=True) + self.messages: list[ToolUseBlockParam] = [] + + # Use default settings if not provided + self._cache_writing_settings = cache_writing_settings or CacheWritingSettings() + + # Extract file_name from settings + file_name = self._cache_writing_settings.filename + if file_name and not file_name.endswith(".json"): + file_name += ".json" + self.file_name = file_name + + self.was_cached_execution = False + self._goal = goal + self._model_router = model_router + self._model = model + self._toolbox: ToolCollection | None = None + self._accumulated_usage = UsageParam() + + # Extract visual verification settings from cache_writing_settings + self._visual_verification_method = ( + self._cache_writing_settings.visual_verification_method + ) + self._visual_validation_region_size = ( + self._cache_writing_settings.visual_validation_region_size + ) + self._visual_validation_threshold = ( + self._cache_writing_settings.visual_validation_threshold + ) + + # Set toolbox for cache writer so it can check which tools are cacheable + self._toolbox = toolbox + + def add_message_cb(self, param: OnMessageCbParam) -> MessageParam: + """Add a message to cache and accumulate usage statistics.""" + if param.message.role == "assistant": + contents = param.message.content + if isinstance(contents, list): + for content in contents: + if isinstance(content, ToolUseBlockParam): + # Detect if we're starting a cached execution + if content.name == "execute_cached_executions_tool": + self.was_cached_execution = True + + # Enhance with visual validation if applicable + # (skip during cached execution) + enhanced_content = self._enhance_with_visual_validation(content) + self.messages.append(enhanced_content) + + # Accumulate usage from assistant messages + if param.message.usage: + self._accumulate_usage(param.message.usage) + + if param.message.stop_reason == "end_turn": + self.generate() + + return param.message + + def set_file_name(self, file_name: str) -> None: + if not file_name.endswith(".json"): + file_name += ".json" + self.file_name = file_name + + def reset(self, file_name: str = "") -> None: + self.messages = [] + if file_name and not file_name.endswith(".json"): + file_name += ".json" + self.file_name = file_name + self.was_cached_execution = False + self._accumulated_usage = UsageParam() + + def generate(self) -> None: + if self.was_cached_execution: + logger.info("Will not write cache file as this was a cached execution") + return + + if not self.file_name: + self.file_name = ( + f"cached_trajectory_{datetime.now(tz=timezone.utc):%Y%m%d%H%M%S%f}.json" + ) + + cache_file_path = self.cache_dir / self.file_name + + goal_to_save, trajectory_to_save, parameters_dict = ( + self._parameterize_trajectory() + ) + + if self._toolbox is not None: + trajectory_to_save = self._blank_non_cacheable_tool_inputs( + trajectory_to_save + ) + else: + logger.info("No toolbox set, skipping non-cacheable tool input blanking") + + self._generate_cache_file( + goal_to_save, trajectory_to_save, parameters_dict, cache_file_path + ) + self.reset() + + def _parameterize_trajectory( + self, + ) -> tuple[str | None, list[ToolUseBlockParam], dict[str, str]]: + """Identify parameters and return parameterized trajectory + goal.""" + identification_strategy = "preset" + messages_api = None + model = None + + if self._cache_writing_settings.parameter_identification_strategy == "llm": + if self._model_router and self._model: + try: + _get_model: tuple[ActModel, str] = self._model_router._get_model( # noqa: SLF001 + self._model, "act" + ) + if isinstance(_get_model[0], ModelFacade): + act_model: ActModel = _get_model[0]._act_model # noqa: SLF001 + else: + act_model = _get_model[0] + model_name: str = _get_model[1] + if hasattr(act_model, "_messages_api"): + messages_api = act_model._messages_api # noqa: SLF001 + identification_strategy = "llm" + model = model_name + except Exception: + logger.exception( + "Using 'llm' for parameter identification caused an exception." + "Will use 'preset' strategy instead" + ) + + return CacheParameterHandler.identify_and_parameterize( + trajectory=self.messages, + goal=self._goal, + identification_strategy=identification_strategy, + messages_api=messages_api, + model=model, + ) + + def _blank_non_cacheable_tool_inputs( + self, trajectory: list[ToolUseBlockParam] + ) -> list[ToolUseBlockParam]: + """Blank out input fields for non-cacheable tools to save space. + + For tools marked as is_cacheable=False, we replace their input with an + empty dict since we won't be executing them from cache anyway. + + Args: + trajectory: The trajectory to process + + Returns: + New trajectory with non-cacheable tool inputs blanked out + """ + if self._toolbox is None: + return trajectory + + blanked_count = 0 + result: list[ToolUseBlockParam] = [] + for tool_block in trajectory: + # Check if this tool is cacheable + tool = self._toolbox.get_tools().get(tool_block.name) + + # If tool is not cacheable, blank out its input + if tool is not None and not tool.is_cacheable: + logger.debug( + "Blanking input for non-cacheable tool: %s", tool_block.name + ) + blanked_count += 1 + result.append( + ToolUseBlockParam( + id=tool_block.id, + name=tool_block.name, + input={}, # Blank out the input + type=tool_block.type, + cache_control=tool_block.cache_control, + ) + ) + else: + # Keep the tool block as-is + result.append(tool_block) + + if blanked_count > 0: + logger.info( + "Blanked inputs for %s non-cacheable tool(s) to save space", + blanked_count, + ) + + return result + + def _generate_cache_file( + self, + goal_to_save: str | None, + trajectory_to_save: list[ToolUseBlockParam], + parameters_dict: dict[str, str], + cache_file_path: Path, + ) -> None: + cache_file = CacheFile( + metadata=CacheMetadata( + version="0.1", + created_at=datetime.now(tz=timezone.utc), + goal=goal_to_save, + token_usage=self._accumulated_usage, + visual_verification_method=self._visual_verification_method, + visual_validation_region_size=self._visual_validation_region_size, + visual_validation_threshold=self._visual_validation_threshold, + ), + trajectory=trajectory_to_save, + cache_parameters=parameters_dict, + ) + + with cache_file_path.open("w", encoding="utf-8") as f: + json.dump(cache_file.model_dump(mode="json"), f, indent=4) + logger.info("Cache file successfully written: %s ", cache_file_path) + + def _enhance_with_visual_validation( # noqa: C901 + self, tool_block: ToolUseBlockParam + ) -> ToolUseBlockParam: + """Enhance ToolUseBlockParam with visual validation data if applicable. + + Args: + tool_block: The tool use block to potentially enhance + + Returns: + Enhanced ToolUseBlockParam with visual validation data, or original if N/A + """ + # Skip if we're in a cached execution (recording disabled during replay) + if self.was_cached_execution: + return tool_block + + # Skip if visual verification is disabled + if self._visual_verification_method == "none": + logger.debug( + "Visual validation skipped for %s: method='none'", tool_block.name + ) + return tool_block + + # Skip if no toolbox available + if self._toolbox is None: + logger.warning( + "Visual validation skipped for %s: no toolbox available", + tool_block.name, + ) + return tool_block + + # Check if this tool input should be validated + action = None + if isinstance(tool_block.input, dict): + action = tool_block.input.get("action") + + if not should_validate_step(tool_block.name, action): + logger.debug( + "Visual validation skipped for %s action=%s: not a validatable action", + tool_block.name, + action, + ) + return tool_block + + # Get validation coordinate + if not isinstance(tool_block.input, dict): + logger.debug("Visual validation skipped: input is not a dict") + return tool_block + + coordinate = get_validation_coordinate(tool_block.input) + if coordinate is None: + logger.debug( + "Visual validation skipped for %s action=%s: no coordinate found", + tool_block.name, + action, + ) + return tool_block + + # Capture current screenshot and compute hash + try: + screenshot = self._capture_screenshot() + if screenshot is None: + logger.warning( + "Visual validation skipped for %s action=%s: " + "screenshot capture failed", + tool_block.name, + action, + ) + return tool_block + + # Extract region around coordinate + region = extract_region( + screenshot, coordinate, size=self._visual_validation_region_size + ) + + # Compute hash based on method + if self._visual_verification_method == "phash": + visual_hash = compute_phash(region) + else: # ahash (none was already handled earlier) + visual_hash = compute_ahash(region) + + # Create enhanced ToolUseBlockParam with visual validation data + enhanced = ToolUseBlockParam( + id=tool_block.id, + name=tool_block.name, + input=tool_block.input, + type=tool_block.type, + cache_control=tool_block.cache_control, + visual_representation=visual_hash, + ) + + logger.info( + "✓ Visual validation added to %s action=%s at coordinate %s " + "(hash=%s...)", + tool_block.name, + action, + coordinate, + visual_hash[:16], + ) + return enhanced # noqa: TRY300 + + except Exception as e: # noqa: BLE001 + logger.warning( + "Visual validation skipped for %s action=%s: " + "error during enhancement: %s", + tool_block.name, + action, + str(e), + ) + # Fall through to return original tool_block + + return tool_block + + def _capture_screenshot(self) -> Image.Image | None: + """Capture current screenshot using the computer tool. + + Returns: + PIL Image or None if screenshot capture fails + """ + if self._toolbox is None: + logger.warning("Cannot capture screenshot: toolbox is None") + return None + + # Get the computer tool from the toolbox + tools = self._toolbox.get_tools() + computer_tool = tools.get("computer") + + if computer_tool is None: + logger.warning( + "Cannot capture screenshot: computer tool not found in toolbox. " + "Available tools: %s", + list(tools.keys()), + ) + return None + + # Call the screenshot action + try: + # Try to call _screenshot() method directly if available + if hasattr(computer_tool, "_screenshot"): + result = computer_tool._screenshot() # noqa: SLF001 + if isinstance(result, Image.Image): + logger.debug("Screenshot captured successfully via _screenshot()") + return result + + # Fallback to calling via __call__ with action parameter + result = computer_tool(action="screenshot") + if isinstance(result, Image.Image): + logger.debug("Screenshot captured successfully via __call__") + return result + + logger.warning( + "Screenshot action did not return an Image. Type: %s, Value: %s", + type(result).__name__, + str(result)[:100], + ) + except Exception as e: # noqa: BLE001 + logger.warning( + "Error capturing screenshot for visual validation: %s: %s", + type(e).__name__, + str(e), + ) + + return None + + def _accumulate_usage(self, step_usage: UsageParam) -> None: + """Accumulate usage statistics from a single API call. + + Args: + step_usage: Usage from a single message + """ + self._accumulated_usage.input_tokens = ( + self._accumulated_usage.input_tokens or 0 + ) + (step_usage.input_tokens or 0) + self._accumulated_usage.output_tokens = ( + self._accumulated_usage.output_tokens or 0 + ) + (step_usage.output_tokens or 0) + self._accumulated_usage.cache_creation_input_tokens = ( + self._accumulated_usage.cache_creation_input_tokens or 0 + ) + (step_usage.cache_creation_input_tokens or 0) + self._accumulated_usage.cache_read_input_tokens = ( + self._accumulated_usage.cache_read_input_tokens or 0 + ) + (step_usage.cache_read_input_tokens or 0) + + @staticmethod + def read_cache_file(cache_file_path: Path) -> CacheFile: + """Read cache file with backward compatibility for v0.0 format. + + Returns: + CacheFile object with metadata and trajectory + """ + logger.debug("Reading cache file: %s", cache_file_path) + with cache_file_path.open("r", encoding="utf-8") as f: + raw_data = json.load(f) + + # Detect format version + if isinstance(raw_data, list): + # v0.0 format: just a list of tool use blocks + logger.info( + "Detected v0.0 cache format in %s, migrating to v0.1", + cache_file_path.name, + ) + trajectory = [ToolUseBlockParam(**step) for step in raw_data] + # Create default metadata for v0.0 files (migrated to v0.1 format) + cache_file = CacheFile( + metadata=CacheMetadata( + version="0.1", # Migrated from v0.0 to v0.1 format + created_at=datetime.fromtimestamp( + cache_file_path.stat().st_ctime, tz=timezone.utc + ), + ), + trajectory=trajectory, + cache_parameters={}, + ) + logger.info( + "Successfully loaded and migrated v0.0 cache: %s steps, 0 parameters", + len(trajectory), + ) + return cache_file + if isinstance(raw_data, dict) and "metadata" in raw_data: + # v0.1 format: structured with metadata + cache_file = CacheFile(**raw_data) + logger.info( + "Successfully loaded v0.1 cache: %s steps, %s parameters", + len(cache_file.trajectory), + len(cache_file.cache_parameters), + ) + if cache_file.metadata.goal: + logger.debug("Cache goal: %s", cache_file.metadata.goal) + return cache_file + logger.error( + "Unknown cache file format in %s. " + "Expected either a list (v0.0) or dict with 'metadata' key (v0.1).", + cache_file_path.name, + ) + msg = ( + f"Unknown cache file format in {cache_file_path}. " + "Expected either a list (v0.0) or dict with 'metadata' key (v0.1)." + ) + raise ValueError(msg) diff --git a/src/askui/utils/trajectory_executor.py b/src/askui/utils/trajectory_executor.py new file mode 100644 index 00000000..a0400b9c --- /dev/null +++ b/src/askui/utils/trajectory_executor.py @@ -0,0 +1,414 @@ +"""Trajectory executor for step-by-step cache execution. + +This module provides the TrajectoryExecutor class that enables controlled +execution of cached trajectories with support for pausing at non-cacheable +steps, error handling, and agent intervention. +""" + +import logging +import time +from typing import Any, Optional + +from PIL import Image +from pydantic import BaseModel, Field +from typing_extensions import Literal + +from askui.models.shared.agent_message_param import ( + MessageParam, + ToolUseBlockParam, +) +from askui.models.shared.tools import ToolCollection +from askui.utils.cache_parameter_handler import CacheParameterHandler +from askui.utils.visual_validation import ( + extract_region, + get_validation_coordinate, + validate_visual_hash, +) + +logger = logging.getLogger(__name__) + + +class ExecutionResult(BaseModel): + """Result of executing a single step in a trajectory. + + Attributes: + status: Execution status (SUCCESS, FAILED, NEEDS_AGENT, COMPLETED) + step_index: Index of the step that was executed + tool_result: The ToolResultBlockParam returned by the tool (if any), + preserving proper data types like ImageBlockParam for screenshots + error_message: Error message if execution failed + screenshots_taken: List of screenshots captured during this step + message_history: List of MessageParam representing the conversation history, + with proper content types (ImageBlockParam, TextBlockParam, etc.) + """ + + status: Literal["SUCCESS", "FAILED", "NEEDS_AGENT", "COMPLETED"] + step_index: int + tool_result: Optional[Any] = None + error_message: Optional[str] = None + screenshots_taken: list[Any] = Field(default_factory=list) + message_history: list[MessageParam] = Field(default_factory=list) + + +class TrajectoryExecutor: + """Executes cached trajectories step-by-step with control flow. + + Supports pausing at non-cacheable steps, cache_parameter substitution, + and collecting execution results for the agent to review. + """ + + def __init__( + self, + trajectory: list[ToolUseBlockParam], + toolbox: ToolCollection, + parameter_values: dict[str, str] | None = None, + delay_time: float = 0.5, + visual_validation_enabled: bool = False, + visual_validation_threshold: int = 10, + visual_hash_method: str = "phash", + visual_validation_region_size: int = 100, + ): + """Initialize the trajectory executor. + + Args: + trajectory: List of tool use blocks to execute + toolbox: ToolCollection for executing tools + parameter_values: Dict of parameter names to values + delay_time: Seconds to wait between step executions + visual_validation_enabled: Enable visual validation + visual_validation_threshold: Hamming distance threshold (0-64) + visual_hash_method: Hash method to use ('phash' or 'ahash') + visual_validation_region_size: Size of square region to extract (in pixels) + """ + self.trajectory = trajectory + self.toolbox = toolbox + self.parameter_values = parameter_values or {} + self.delay_time = delay_time + self.visual_validation_enabled = visual_validation_enabled + self.visual_validation_threshold = visual_validation_threshold + self.visual_hash_method = visual_hash_method + self.visual_validation_region_size = visual_validation_region_size + self.current_step_index = 0 + self.message_history: list[MessageParam] = [] + + def execute_next_step(self) -> ExecutionResult: + """Execute the next step in the trajectory. + + Returns: + ExecutionResult with status and details of the execution + + The method will: + 1. Check if there are more steps to execute + 2. Check if the step should be skipped (screenshots, retrieval tools) + 3. Check if the step is non-cacheable (needs agent) + 4. Substitute parameters + 5. Execute the tool and build messages with proper data types + 6. Return result with updated message history + + Note: Tool results are preserved with their proper data types (e.g., + ImageBlockParam for screenshots) and added to message history. The + agent's truncation strategy will manage message history size. + """ + # Check if we've completed all steps + if self.current_step_index >= len(self.trajectory): + return ExecutionResult( + status="COMPLETED", + step_index=self.current_step_index - 1, + message_history=self.message_history, + ) + + step = self.trajectory[self.current_step_index] + step_index = self.current_step_index + + # Check if step should be skipped + if self._should_skip_step(step): + logger.debug("Skipping step %d: %s", step_index, step.name) + self.current_step_index += 1 + # Recursively execute next step + return self.execute_next_step() + + # Check if step needs agent intervention (non-cacheable) + if self.should_pause_for_agent(step): + logger.info( + "Pausing at step %d: %s (non-cacheable tool)", + step_index, + step.name, + ) + # Return result with current tool step info for the agent to handle + # Note: We don't add any messages here - the cache manager will + # inject a user message explaining what needs to be done + return ExecutionResult( + status="NEEDS_AGENT", + step_index=step_index, + message_history=self.message_history.copy(), + tool_result=step, # Pass the tool use block for reference + ) + + # Visual validation: verify UI state matches cached expectations + # Compares stored visual hash with current screen region + if self.visual_validation_enabled: + is_valid, error_msg = self.validate_step_visually(step) + if not is_valid: + logger.warning( + "Visual validation failed at step %d: %s. " + "Handing execution back to agent.", + step_index, + error_msg, + ) + return ExecutionResult( + status="FAILED", + step_index=step_index, + error_message=error_msg, + message_history=self.message_history.copy(), + ) + + # Substitute parameters + substituted_step = CacheParameterHandler.substitute_parameters( + step, self.parameter_values + ) + + # Execute the tool + try: + logger.debug("Executing step %d: %s", step_index, step.name) + + # Add assistant message (tool use) to history + assistant_message = MessageParam( + role="assistant", + content=[substituted_step], + ) + self.message_history.append(assistant_message) + + # Execute the tool + tool_results = self.toolbox.run([substituted_step]) + + # toolbox.run() returns a list of content blocks + # (ToolResultBlockParam, etc.) We use these directly without + # converting to strings - this preserves proper data types like + # ImageBlockParam + + # Add user message (tool result) to history + user_message = MessageParam( + role="user", + content=tool_results if tool_results else [], + ) + self.message_history.append(user_message) + + # Move to next step + self.current_step_index += 1 + + # Add delay between actions + if self.current_step_index < len(self.trajectory): + time.sleep(self.delay_time) + + return ExecutionResult( + status="SUCCESS", + step_index=step_index, + tool_result=tool_results[0] if tool_results else None, + message_history=self.message_history.copy(), + ) + + except Exception as e: + logger.exception("Error executing step %d: %s", step_index, step.name) + return ExecutionResult( + status="FAILED", + step_index=step_index, + error_message=str(e), + message_history=self.message_history.copy(), + ) + + def execute_all(self) -> list[ExecutionResult]: + """Execute all steps in the trajectory until completion or pause. + + Returns: + List of ExecutionResult for all executed steps + + Execution stops when: + - All steps are completed + - A step fails + - A non-cacheable step is encountered + """ + results: list[ExecutionResult] = [] + + while True: + result = self.execute_next_step() + results.append(result) + + # Stop if we've completed, failed, or need agent + if result.status in ["COMPLETED", "FAILED", "NEEDS_AGENT"]: + break + + return results + + def should_pause_for_agent(self, step: ToolUseBlockParam) -> bool: + """Check if execution should pause for agent intervention. + + Args: + step: The tool use block to check + + Returns: + True if agent should execute this step, False if it can be cached + + Currently checks if the tool is marked as non-cacheable. + """ + # Get the tool from toolbox + tool = self.toolbox._tool_map.get(step.name) # noqa: SLF001 + + if tool is None: + # Tool not found in regular tools, might be MCP tool + # For now, assume MCP tools are cacheable + return False + + # Check if tool is marked as non-cacheable + return not tool.is_cacheable + + def get_current_step_index(self) -> int: + """Get the index of the current step. + + Returns: + Current step index + """ + return self.current_step_index + + def get_remaining_trajectory(self) -> list[ToolUseBlockParam]: + """Get the remaining steps in the trajectory. + + Returns: + List of tool use blocks that haven't been executed yet + """ + return self.trajectory[self.current_step_index :] + + def skip_current_step(self) -> None: + """Skip the current step and move to the next one. + + Useful when the agent manually executes a non-cacheable step. + """ + if self.current_step_index < len(self.trajectory): + self.current_step_index += 1 + + def _should_skip_step(self, _step: ToolUseBlockParam) -> bool: + """Check if a step should be skipped during execution. + + Args: + step: The tool use block to check + + Returns: + True if step should be skipped, False otherwise + + Note: As of v0.1, no steps are skipped. All tools in the trajectory + are executed, including screenshots and trajectory retrieval tools. + """ + return False + + def validate_step_visually( + self, step: ToolUseBlockParam, current_screenshot: Any = None + ) -> tuple[bool, str | None]: + """Validate cached steps using visual hash comparison. + + Compares the current UI state against the stored visual hash to detect + if the UI has changed significantly since the trajectory was recorded. + + Args: + step: The trajectory step to validate + current_screenshot: Optional current screen capture (will capture if None) + + Returns: + Tuple of (is_valid: bool, error_message: str | None) + - (True, None) if validation passes or is disabled + - (False, error_msg) if validation fails + """ + # Skip validation if disabled + if not self.visual_validation_enabled: + return True, None + + # Skip if no visual representation stored (implies no validation needed) + if step.visual_representation is None: + return True, None + + # Get coordinate for validation + if not isinstance(step.input, dict): + return True, None + + coordinate = get_validation_coordinate(step.input) + if coordinate is None: + logger.debug( + "Could not extract coordinate from step %d for visual validation", + self.current_step_index, + ) + return True, None + + # Capture current screenshot if not provided + if current_screenshot is None: + current_screenshot = self._capture_screenshot() + if current_screenshot is None: + logger.warning( + "Could not capture screenshot for visual validation at step %d", + self.current_step_index, + ) + # Unable to validate, but don't fail execution + return True, None + + # Extract region around coordinate + try: + region = extract_region( + current_screenshot, coordinate, size=self.visual_validation_region_size + ) + except Exception as e: # noqa: BLE001 + logger.warning( + "Error extracting region for visual validation at step %d: %s", + self.current_step_index, + e, + ) + return True, None + + # Validate hash + is_valid, error_msg, _distance = validate_visual_hash( + stored_hash=step.visual_representation, + current_image=region, + threshold=self.visual_validation_threshold, + hash_method=self.visual_hash_method, + ) + + # Only log if validation fails + if not is_valid: + logger.warning( + "Visual validation failed at step %d: %s", + self.current_step_index, + error_msg, + ) + + return is_valid, error_msg + + def _capture_screenshot(self) -> Image.Image | None: + """Capture current screenshot using the computer tool. + + Returns: + PIL Image or None if screenshot capture fails + """ + # Get the computer tool from toolbox + tools = self.toolbox.get_tools() + computer_tool = tools.get("computer") + + if computer_tool is None: + logger.debug("Computer tool not found in toolbox") + return None + + # Call the screenshot action + try: + # Try to call _screenshot() method directly if available + if hasattr(computer_tool, "_screenshot"): + result = computer_tool._screenshot() # noqa: SLF001 + if isinstance(result, Image.Image): + return result + + # Fallback to calling via __call__ with action parameter + result = computer_tool(action="screenshot") + if isinstance(result, Image.Image): + return result + + logger.warning( + "Screenshot action did not return an Image: %s", type(result) + ) + return None # noqa: TRY300 + except Exception: + logger.exception("Error capturing screenshot") + return None diff --git a/src/askui/utils/visual_validation.py b/src/askui/utils/visual_validation.py new file mode 100644 index 00000000..041807d5 --- /dev/null +++ b/src/askui/utils/visual_validation.py @@ -0,0 +1,312 @@ +"""Visual validation utilities for cache execution. + +This module provides utilities for computing and comparing visual hashes +to validate UI state before executing cached trajectory steps. +""" + +import logging +from typing import Any + +import numpy as np +from numpy.typing import NDArray +from PIL import Image + +logger = logging.getLogger(__name__) + + +def compute_phash(image: Image.Image, hash_size: int = 8) -> str: + """Compute perceptual hash (pHash) of an image. + + pHash is robust to minor changes (compression, scaling, lighting) while + being sensitive to structural changes (moved buttons, different layouts). + + Args: + image: PIL Image to hash + hash_size: Size of the hash (default 8 = 64-bit hash) + + Returns: + Hexadecimal string representation of the hash + + The algorithm: + 1. Resize image to hash_size x hash_size + 2. Convert to grayscale + 3. Compute DCT (Discrete Cosine Transform) + 4. Extract top-left 8x8 DCT coefficients + 5. Compute median of coefficients + 6. Create binary hash: 1 if coeff > median, else 0 + """ + # Resize to hash_size x hash_size + resized = image.resize((hash_size, hash_size), Image.Resampling.LANCZOS) + + # Convert to grayscale + gray = resized.convert("L") + + # Convert to numpy array + pixels = np.array(gray, dtype=np.float32) + + # Compute DCT (using a simple 2D DCT approximation) + # For production, consider using scipy.fftpack.dct + dct = _dct_2d(pixels) + + # Extract top-left coefficients (excluding DC component at [0,0]) + dct_low = dct[:hash_size, :hash_size] + + # Compute median + median = np.median(dct_low) + + # Create binary hash + diff = dct_low > median + + # Convert to hexadecimal string + hash_bytes = _binary_array_to_bytes(diff.flatten()) + return hash_bytes.hex() + + +def compute_ahash(image: Image.Image, hash_size: int = 8) -> str: + """Compute average hash (aHash) of an image. + + aHash is a simpler but faster alternative to pHash. It's less robust + to transformations but still useful for basic visual validation. + + Args: + image: PIL Image to hash + hash_size: Size of the hash (default 8 = 64-bit hash) + + Returns: + Hexadecimal string representation of the hash + + The algorithm: + 1. Resize image to hash_size x hash_size + 2. Convert to grayscale + 3. Compute mean pixel value + 4. Create binary hash: 1 if pixel > mean, else 0 + """ + # Resize to hash_size x hash_size + resized = image.resize((hash_size, hash_size), Image.Resampling.LANCZOS) + + # Convert to grayscale + gray = resized.convert("L") + + # Convert to numpy array + pixels = np.array(gray, dtype=np.float32) + + # Compute mean + mean = pixels.mean() + + # Create binary hash + diff = pixels > mean + + # Convert to hexadecimal string + hash_bytes = _binary_array_to_bytes(diff.flatten()) + return hash_bytes.hex() + + +def hamming_distance(hash1: str, hash2: str) -> int: + """Compute Hamming distance between two hash strings. + + The Hamming distance is the number of bit positions where the two + hashes differ. A distance of 0 means identical hashes, while 64 + means completely different (for 64-bit hashes). + + Args: + hash1: First hash (hexadecimal string) + hash2: Second hash (hexadecimal string) + + Returns: + Number of differing bits (0-64 for 64-bit hashes) + + Raises: + ValueError: If hashes have different lengths + """ + if len(hash1) != len(hash2): + msg = f"Hash lengths differ: {len(hash1)} vs {len(hash2)}" + raise ValueError(msg) + + # Convert hex strings to integers and XOR them + # XOR will have 1s where bits differ + xor_result = int(hash1, 16) ^ int(hash2, 16) + + # Count number of 1s (differing bits) + return (xor_result).bit_count() + + +def extract_region( + image: Image.Image, center: tuple[int, int], size: int = 100 +) -> Image.Image: + """Extract a square region from an image centered at given coordinates. + + Args: + image: Source image + center: (x, y) coordinates of region center + size: Size of the square region in pixels + + Returns: + Cropped image region + + The region is clipped to image boundaries if necessary. + """ + x, y = center + half_size = size // 2 + + # Calculate bounds, clipping to image boundaries + left = max(0, x - half_size) + top = max(0, y - half_size) + right = min(image.width, x + half_size) + bottom = min(image.height, y + half_size) + + # Crop and return + return image.crop((left, top, right, bottom)) + + +def validate_visual_hash( + stored_hash: str, + current_image: Image.Image, + threshold: int = 10, + hash_method: str = "phash", +) -> tuple[bool, str | None, int]: + """Validate that current image matches stored visual hash. + + Args: + stored_hash: The hash stored in the cache + current_image: Current screenshot region + threshold: Maximum Hamming distance to accept (0-64) + - 0-5: Nearly identical (strict validation) + - 6-10: Very similar (recommended default) + - 11-15: Similar (lenient) + - 16+: Different (validation should fail) + hash_method: Hash method to use ('phash' or 'ahash') + + Returns: + Tuple of (is_valid, error_message, distance) + - is_valid: True if validation passes + - error_message: None if valid, error description if invalid + - distance: Hamming distance between hashes + """ + # Compute current hash + if hash_method == "phash": + current_hash = compute_phash(current_image) + elif hash_method == "ahash": + current_hash = compute_ahash(current_image) + else: + return False, f"Unknown hash method: {hash_method}", -1 + + # Compare hashes + try: + distance = hamming_distance(stored_hash, current_hash) + except ValueError as e: + return False, f"Hash comparison failed: {e}", -1 + + # Validate against threshold + if distance <= threshold: + return True, None, distance + + error_msg = ( + f"Visual validation failed: UI region changed significantly " + f"(Hamming distance: {distance} > threshold: {threshold})" + ) + return False, error_msg, distance + + +def should_validate_step(tool_name: str, action: str | None = None) -> bool: + """Determine if a tool step should have visual validation. + + Args: + tool_name: Name of the tool + action: Action type (for computer tool) + + Returns: + True if step should be validated + + Steps that should be validated: + - Mouse clicks (left_click, right_click, double_click, middle_click) + - Type actions (verify input field hasn't moved) + - Key presses targeting specific UI elements + """ + # Computer tool with click or type actions + if tool_name == "computer": + if action in [ + "left_click", + "right_click", + "double_click", + "middle_click", + "type", + "key", + ]: + return True + + # Other tools that interact with specific UI regions + # Add more as needed + return False + + +def get_validation_coordinate(tool_input: dict[str, Any]) -> tuple[int, int] | None: + """Extract the coordinate for visual validation from tool input. + + Args: + tool_input: Tool input dictionary + + Returns: + (x, y) coordinate tuple or None if not applicable + + For click actions, returns the click coordinate. + For type actions, returns the coordinate of the text input field. + """ + # Computer tool coordinates + if "coordinate" in tool_input: + coord = tool_input["coordinate"] + if isinstance(coord, list) and len(coord) == 2: + return (int(coord[0]), int(coord[1])) + + return None + + +# Private helper functions + + +def _dct_2d(image_array: NDArray[np.float32]) -> NDArray[np.complex128]: + """Compute 2D Discrete Cosine Transform. + + This is a simplified implementation. For production use, consider + using scipy.fftpack.dct for better performance and accuracy. + + Args: + image_array: 2D numpy array + + Returns: + 2D DCT coefficients + """ + # Using a simple DCT approximation via FFT + # For production, use: from scipy.fftpack import dct + # return dct(dct(image_array.T, norm='ortho').T, norm='ortho') + + # Simplified approach: use numpy's FFT and take real part + fft = np.fft.fft2(image_array) + # Take absolute value and use as approximation + # Note: This is not a true DCT, but works for hash purposes + return np.abs(fft) + + +def _binary_array_to_bytes(binary_array: NDArray[np.bool_]) -> bytes: + """Convert binary numpy array to bytes. + + Args: + binary_array: 1D array of boolean values + + Returns: + Bytes representation + """ + # Convert to string of 0s and 1s + binary_string = "".join("1" if b else "0" for b in binary_array) + + # Pad to multiple of 8 + padding = 8 - (len(binary_string) % 8) + if padding != 8: + binary_string += "0" * padding + + # Convert to bytes + byte_array = bytearray() + for i in range(0, len(binary_string), 8): + byte = binary_string[i : i + 8] + byte_array.append(int(byte, 2)) + + return bytes(byte_array) diff --git a/tests/e2e/agent/test_act_caching.py b/tests/e2e/agent/test_act_caching.py index 70652b43..1166826a 100644 --- a/tests/e2e/agent/test_act_caching.py +++ b/tests/e2e/agent/test_act_caching.py @@ -9,7 +9,11 @@ from askui.agent import VisionAgent from askui.models.shared.agent_message_param import MessageParam from askui.models.shared.agent_on_message_cb import OnMessageCbParam -from askui.models.shared.settings import CachedExecutionToolSettings, CachingSettings +from askui.models.shared.settings import ( + CacheExecutionSettings, + CacheWritingSettings, + CachingSettings, +) def test_act_with_caching_strategy_read(vision_agent: VisionAgent) -> None: @@ -24,7 +28,7 @@ def test_act_with_caching_strategy_read(vision_agent: VisionAgent) -> None: vision_agent.act( goal="Tell me a joke", caching_settings=CachingSettings( - strategy="read", + strategy="execute", cache_dir=str(cache_dir), ), ) @@ -41,9 +45,11 @@ def test_act_with_caching_strategy_write(vision_agent: VisionAgent) -> None: vision_agent.act( goal="Tell me a joke", caching_settings=CachingSettings( - strategy="write", + strategy="record", cache_dir=str(cache_dir), - filename=cache_filename, + writing_settings=CacheWritingSettings( + filename=cache_filename, + ), ), ) @@ -68,7 +74,9 @@ def test_act_with_caching_strategy_both(vision_agent: VisionAgent) -> None: caching_settings=CachingSettings( strategy="both", cache_dir=str(cache_dir), - filename=cache_filename, + writing_settings=CacheWritingSettings( + filename=cache_filename, + ), ), ) @@ -86,8 +94,7 @@ def test_act_with_caching_strategy_no(vision_agent: VisionAgent) -> None: vision_agent.act( goal="Tell me a joke", caching_settings=CachingSettings( - strategy="no", - cache_dir=str(cache_dir), + strategy=None, ), ) @@ -106,9 +113,11 @@ def test_act_with_custom_cache_dir_and_filename(vision_agent: VisionAgent) -> No vision_agent.act( goal="Tell me a joke", caching_settings=CachingSettings( - strategy="write", + strategy="record", cache_dir=str(custom_cache_dir), - filename=custom_filename, + writing_settings=CacheWritingSettings( + filename=custom_filename, + ), ), ) @@ -132,7 +141,7 @@ def dummy_callback(param: OnMessageCbParam) -> MessageParam: vision_agent.act( goal="Tell me a joke", caching_settings=CachingSettings( - strategy="write", + strategy="record", cache_dir=str(temp_dir), ), on_message=dummy_callback, @@ -170,9 +179,11 @@ def test_cache_file_contains_tool_use_blocks(vision_agent: VisionAgent) -> None: vision_agent.act( goal="Tell me a joke", caching_settings=CachingSettings( - strategy="write", + strategy="record", cache_dir=str(cache_dir), - filename=cache_filename, + writing_settings=CacheWritingSettings( + filename=cache_filename, + ), ), ) @@ -205,13 +216,14 @@ def test_act_with_custom_cached_execution_tool_settings( cache_file.write_text("[]", encoding="utf-8") # Act with custom execution tool settings - custom_settings = CachedExecutionToolSettings(delay_time_between_action=2.0) vision_agent.act( goal="Tell me a joke", caching_settings=CachingSettings( - strategy="read", + strategy="record", cache_dir=str(cache_dir), - execute_cached_trajectory_tool_settings=custom_settings, + execution_settings=CacheExecutionSettings( + delay_time_between_action=2.0, + ), ), ) diff --git a/tests/unit/models/test_messages_api_field_stripping.py b/tests/unit/models/test_messages_api_field_stripping.py new file mode 100644 index 00000000..464305ab --- /dev/null +++ b/tests/unit/models/test_messages_api_field_stripping.py @@ -0,0 +1,101 @@ +"""Tests for Pydantic-based context-aware serialization of internal fields.""" + +from askui.models.shared.agent_message_param import MessageParam, ToolUseBlockParam + + +def test_tool_use_block_includes_fields_by_default() -> None: + """Test that visual validation fields are included in normal serialization.""" + tool_block = ToolUseBlockParam( + id="test_id", + name="computer", + input={"action": "left_click", "coordinate": [100, 200]}, + type="tool_use", + visual_representation="abc123", + ) + + # Default serialization includes all fields + serialized = tool_block.model_dump() + + assert serialized["visual_representation"] == "abc123" + assert serialized["id"] == "test_id" + assert serialized["name"] == "computer" + + +def test_tool_use_block_excludes_fields_for_api() -> None: + """Test that visual validation fields are excluded when for_api context is set.""" + tool_block = ToolUseBlockParam( + id="test_id", + name="computer", + input={"action": "left_click", "coordinate": [100, 200]}, + type="tool_use", + visual_representation="abc123", + ) + + # Serialization with for_api context excludes internal fields + serialized = tool_block.model_dump(context={"for_api": True}) + + # Internal fields should be excluded + assert "visual_representation" not in serialized + + # Other fields should remain + assert serialized["id"] == "test_id" + assert serialized["name"] == "computer" + assert serialized["input"] == {"action": "left_click", "coordinate": [100, 200]} + + +def test_tool_use_block_without_visual_validation() -> None: + """Test serialization of tool block without visual validation fields.""" + tool_block = ToolUseBlockParam( + id="test_id", + name="computer", + input={"action": "screenshot"}, + type="tool_use", + ) + + # Both modes should work fine + normal = tool_block.model_dump() + for_api = tool_block.model_dump(context={"for_api": True}) + + # Should not have visual representation field in either case (or it should be None) + assert ( + "visual_representation" not in normal or normal["visual_representation"] is None + ) + assert "visual_representation" not in for_api + + +def test_message_with_tool_use_context_propagation() -> None: + """Test that context propagates through nested MessageParam serialization.""" + tool_block = ToolUseBlockParam( + id="test_id", + name="computer", + input={"action": "left_click", "coordinate": [100, 200]}, + type="tool_use", + visual_representation="abc123", + ) + + message = MessageParam(role="assistant", content=[tool_block]) + + # Normal dump includes fields + normal = message.model_dump() + assert normal["content"][0]["visual_representation"] == "abc123" + + # API dump excludes fields + for_api = message.model_dump(context={"for_api": True}) + assert "visual_representation" not in for_api["content"][0] + + +def test_cache_storage_includes_all_fields() -> None: + """Test that cache storage (mode='json') includes all fields.""" + tool_block = ToolUseBlockParam( + id="test_id", + name="computer", + input={"action": "left_click", "coordinate": [100, 200]}, + type="tool_use", + visual_representation="abc123", + ) + + # Cache storage uses mode='json' without for_api context + cache_dump = tool_block.model_dump(mode="json") + + # Should include all fields for cache storage + assert cache_dump["visual_representation"] == "abc123" diff --git a/tests/unit/models/test_token_counting_visual_validation.py b/tests/unit/models/test_token_counting_visual_validation.py new file mode 100644 index 00000000..f97de19d --- /dev/null +++ b/tests/unit/models/test_token_counting_visual_validation.py @@ -0,0 +1,96 @@ +"""Test that token counting excludes visual validation fields.""" + +from askui.models.shared.agent_message_param import MessageParam, ToolUseBlockParam +from askui.models.shared.token_counter import SimpleTokenCounter + + +def test_token_counting_excludes_visual_validation_fields() -> None: + """Verify that visual validation fields don't inflate token counts.""" + # Create two identical tool blocks, one with and one without visual validation + tool_block_without = ToolUseBlockParam( + id="test_id", + name="computer", + input={"action": "left_click", "coordinate": [100, 200]}, + type="tool_use", + ) + + tool_block_with = ToolUseBlockParam( + id="test_id", + name="computer", + input={"action": "left_click", "coordinate": [100, 200]}, + type="tool_use", + visual_representation="a" * 1000, # 1000 character hash + ) + + # Create messages + msg_without = MessageParam(role="assistant", content=[tool_block_without]) + msg_with = MessageParam(role="assistant", content=[tool_block_with]) + + # Count tokens + counter = SimpleTokenCounter() + counts_without = counter.count_tokens(messages=[msg_without]) + counts_with = counter.count_tokens(messages=[msg_with]) + + # Token counts should be identical (visual fields excluded from counting) + assert counts_without.total == counts_with.total, ( + f"Token counts differ: {counts_without.total} vs {counts_with.total}. " + "Visual validation fields should be excluded from token counting." + ) + + +def test_token_counter_uses_api_context() -> None: + """Verify that token counter uses for_api context when stringifying objects.""" + tool_block = ToolUseBlockParam( + id="test_id", + name="computer", + input={"action": "left_click", "coordinate": [100, 200]}, + type="tool_use", + visual_representation="hash123", + ) + + counter = SimpleTokenCounter() + + # Stringify the object (as token counter does internally) + stringified = counter._stringify_object(tool_block) + + # Should not contain visual validation fields + assert "visual_representation" not in stringified + assert "hash123" not in stringified + + # Should contain regular fields + assert "test_id" in stringified + assert "computer" in stringified + + +def test_token_counting_with_multiple_tool_blocks() -> None: + """Test token counting with multiple tool blocks in one message.""" + blocks = [ + ToolUseBlockParam( + id=f"id_{i}", + name="computer", + input={"action": "left_click", "coordinate": [i * 100, i * 100]}, + type="tool_use", + visual_representation="x" * 500, # Large hash + ) + for i in range(5) + ] + + blocks_without_validation = [ + ToolUseBlockParam( + id=f"id_{i}", + name="computer", + input={"action": "left_click", "coordinate": [i * 100, i * 100]}, + type="tool_use", + ) + for i in range(5) + ] + + msg_with = MessageParam(role="assistant", content=blocks) + msg_without = MessageParam(role="assistant", content=blocks_without_validation) + + counter = SimpleTokenCounter() + counts_with = counter.count_tokens(messages=[msg_with]) + counts_without = counter.count_tokens(messages=[msg_without]) + + # Should have same token count + assert counts_with.total == counts_without.total diff --git a/tests/unit/tools/test_caching_tools.py b/tests/unit/tools/test_caching_tools.py index a4404114..7c9e69eb 100644 --- a/tests/unit/tools/test_caching_tools.py +++ b/tests/unit/tools/test_caching_tools.py @@ -2,18 +2,24 @@ import json import tempfile +from datetime import datetime, timezone from pathlib import Path from typing import Any -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock import pytest -from askui.models.shared.settings import CachedExecutionToolSettings +from askui.models.shared.settings import CacheExecutionSettings from askui.models.shared.tools import ToolCollection from askui.tools.caching_tools import ( ExecuteCachedTrajectory, RetrieveCachedTestExecutions, ) +from askui.utils.caching.cache_execution_manager import CacheExecutionManager + +# ============================================================================ +# RetrieveCachedTestExecutions Tests (unchanged from before) +# ============================================================================ def test_retrieve_cached_test_executions_lists_json_files() -> None: @@ -21,9 +27,22 @@ def test_retrieve_cached_test_executions_lists_json_files() -> None: with tempfile.TemporaryDirectory() as temp_dir: cache_dir = Path(temp_dir) - # Create some cache files - (cache_dir / "cache1.json").write_text("{}", encoding="utf-8") - (cache_dir / "cache2.json").write_text("{}", encoding="utf-8") + # Create valid cache files with v0.1 format + cache_data = { + "metadata": { + "version": "0.1", + "created_at": datetime.now(tz=timezone.utc).isoformat(), + "execution_attempts": 0, + "last_executed_at": None, + "failures": [], + "is_valid": True, + "invalidation_reason": None, + }, + "trajectory": [], + "cache_parameters": {}, + } + (cache_dir / "cache1.json").write_text(json.dumps(cache_data), encoding="utf-8") + (cache_dir / "cache2.json").write_text(json.dumps(cache_data), encoding="utf-8") (cache_dir / "not_cache.txt").write_text("text", encoding="utf-8") tool = RetrieveCachedTestExecutions(cache_dir=str(cache_dir)) @@ -59,9 +78,22 @@ def test_retrieve_cached_test_executions_respects_custom_format() -> None: with tempfile.TemporaryDirectory() as temp_dir: cache_dir = Path(temp_dir) - # Create files with different extensions - (cache_dir / "cache1.json").write_text("{}", encoding="utf-8") - (cache_dir / "cache2.traj").write_text("{}", encoding="utf-8") + # Create valid cache files with different extensions + cache_data = { + "metadata": { + "version": "0.1", + "created_at": datetime.now(tz=timezone.utc).isoformat(), + "execution_attempts": 0, + "last_executed_at": None, + "failures": [], + "is_valid": True, + "invalidation_reason": None, + }, + "trajectory": [], + "cache_parameters": {}, + } + (cache_dir / "cache1.json").write_text(json.dumps(cache_data), encoding="utf-8") + (cache_dir / "cache2.traj").write_text(json.dumps(cache_data), encoding="utf-8") # Default format (.json) tool_json = RetrieveCachedTestExecutions( @@ -80,254 +112,946 @@ def test_retrieve_cached_test_executions_respects_custom_format() -> None: assert "cache2.traj" in result_traj[0] -def test_execute_cached_execution_initializes_without_toolbox() -> None: - """Test that ExecuteCachedExecution can be initialized without toolbox.""" - tool = ExecuteCachedTrajectory() +def test_retrieve_caches_filters_invalid_by_default(tmp_path: Path) -> None: + """Test that RetrieveCachedTestExecutions filters out invalid caches by default.""" + cache_dir = tmp_path / "cache" + cache_dir.mkdir() + + # Create a valid cache + valid_cache = cache_dir / "valid.json" + valid_data = { + "metadata": { + "version": "0.1", + "created_at": datetime.now(tz=timezone.utc).isoformat(), + "execution_attempts": 0, + "failures": [], + "is_valid": True, + "invalidation_reason": None, + }, + "trajectory": [], + "cache_parameters": {}, + } + with valid_cache.open("w") as f: + json.dump(valid_data, f) + + # Create an invalid cache + invalid_cache = cache_dir / "invalid.json" + invalid_data = { + "metadata": { + "version": "0.1", + "created_at": datetime.now(tz=timezone.utc).isoformat(), + "execution_attempts": 10, + "failures": [], + "is_valid": False, + "invalidation_reason": "Too many failures", + }, + "trajectory": [], + "cache_parameters": {}, + } + with invalid_cache.open("w") as f: + json.dump(invalid_data, f) + + tool = RetrieveCachedTestExecutions(cache_dir=str(cache_dir)) + + # Should only return valid cache + results = tool() + assert len(results) == 1 + assert str(valid_cache) in results + assert str(invalid_cache) not in results + + +def test_retrieve_caches_includes_invalid_when_requested(tmp_path: Path) -> None: + """Test that RetrieveCachedTestExecutions includes invalid caches when requested.""" + cache_dir = tmp_path / "cache" + cache_dir.mkdir() + + # Create a valid cache + valid_cache = cache_dir / "valid.json" + valid_data = { + "metadata": { + "version": "0.1", + "created_at": datetime.now(tz=timezone.utc).isoformat(), + "execution_attempts": 0, + "failures": [], + "is_valid": True, + "invalidation_reason": None, + }, + "trajectory": [], + "cache_parameters": {}, + } + with valid_cache.open("w") as f: + json.dump(valid_data, f) + + # Create an invalid cache + invalid_cache = cache_dir / "invalid.json" + invalid_data = { + "metadata": { + "version": "0.1", + "created_at": datetime.now(tz=timezone.utc).isoformat(), + "execution_attempts": 10, + "failures": [], + "is_valid": False, + "invalidation_reason": "Too many failures", + }, + "trajectory": [], + "cache_parameters": {}, + } + with invalid_cache.open("w") as f: + json.dump(invalid_data, f) + + tool = RetrieveCachedTestExecutions(cache_dir=str(cache_dir)) + + # Should return both caches when include_invalid=True + results = tool(include_invalid=True) + assert len(results) == 2 + + +# ============================================================================ +# ExecuteCachedTrajectory Tests (refactored for new behavior) +# ============================================================================ + + +def test_execute_cached_execution_initializes_with_toolbox() -> None: + """Test that ExecuteCachedTrajectory can be initialized with toolbox.""" + mock_toolbox = MagicMock(spec=ToolCollection) + tool = ExecuteCachedTrajectory(toolbox=mock_toolbox) assert tool.name == "execute_cached_executions_tool" + assert tool._toolbox is mock_toolbox # noqa: SLF001 + assert tool._cache_execution_manager is None # noqa: SLF001 + +def test_execute_cached_execution_raises_error_without_cache_manager() -> None: + """Test that ExecuteCachedTrajectory raises error when cache manager not set.""" + with tempfile.TemporaryDirectory() as temp_dir: + cache_file = Path(temp_dir) / "test.json" + cache_data = { + "metadata": { + "version": "0.1", + "created_at": datetime.now(tz=timezone.utc).isoformat(), + "execution_attempts": 0, + "last_executed_at": None, + "failures": [], + "is_valid": True, + "invalidation_reason": None, + }, + "trajectory": [], + "cache_parameters": {}, + } + cache_file.write_text(json.dumps(cache_data), encoding="utf-8") -def test_execute_cached_execution_raises_error_without_toolbox() -> None: - """Test that ExecuteCachedExecution raises error when toolbox not set.""" - tool = ExecuteCachedTrajectory() + mock_toolbox = MagicMock(spec=ToolCollection) + tool = ExecuteCachedTrajectory(toolbox=mock_toolbox) - with pytest.raises(RuntimeError, match="Toolbox not set"): - tool(trajectory_file="some_file.json") + with pytest.raises(RuntimeError, match="Cache Execution Manager not set"): + tool(trajectory_file=str(cache_file)) -def test_execute_cached_execution_raises_error_when_file_not_found() -> None: - """Test that ExecuteCachedExecution raises error if trajectory file doesn't exist""" - tool = ExecuteCachedTrajectory() +def test_execute_cached_execution_returns_error_when_file_not_found() -> None: + """Test that ExecuteCachedTrajectory returns error message if file doesn't exist.""" mock_toolbox = MagicMock(spec=ToolCollection) - tool.set_toolbox(mock_toolbox) + tool = ExecuteCachedTrajectory(toolbox=mock_toolbox) + mock_cache_manager = MagicMock(spec=CacheExecutionManager) + tool.set_cache_execution_manager(mock_cache_manager) - with pytest.raises(FileNotFoundError, match="Trajectory file not found"): - tool(trajectory_file="/non/existent/file.json") + result = tool(trajectory_file="/non/existent/file.json") + # New behavior: returns error message string instead of raising + assert isinstance(result, str) + assert "Trajectory file not found" in result -def test_execute_cached_execution_executes_trajectory() -> None: - """Test that ExecuteCachedExecution executes tools from trajectory file.""" + +def test_execute_cached_execution_activates_cache_mode() -> None: + """Test that ExecuteCachedTrajectory activates cache mode in the agent.""" with tempfile.TemporaryDirectory() as temp_dir: cache_file = Path(temp_dir) / "test_trajectory.json" - # Create a trajectory file - trajectory: list[dict[str, Any]] = [ - { - "id": "tool1", - "name": "click_tool", - "input": {"x": 100, "y": 200}, - "type": "tool_use", + # Create a trajectory file with v0.1 format + cache_data = { + "metadata": { + "version": "0.1", + "created_at": datetime.now(tz=timezone.utc).isoformat(), + "execution_attempts": 0, + "last_executed_at": None, + "failures": [], + "is_valid": True, + "invalidation_reason": None, }, - { - "id": "tool2", - "name": "type_tool", - "input": {"text": "hello"}, - "type": "tool_use", + "trajectory": [ + { + "id": "tool1", + "name": "click_tool", + "input": {"x": 100, "y": 200}, + "type": "tool_use", + }, + { + "id": "tool2", + "name": "type_tool", + "input": {"text": "hello"}, + "type": "tool_use", + }, + ], + "cache_parameters": {}, + } + + with cache_file.open("w", encoding="utf-8") as f: + json.dump(cache_data, f) + + # Create mock agent with toolbox + mock_cache_manager = MagicMock(spec=CacheExecutionManager) + mock_toolbox = MagicMock(spec=ToolCollection) + mock_toolbox._tool_map = {} + + # Create and configure tool + tool = ExecuteCachedTrajectory(toolbox=mock_toolbox) + tool.set_cache_execution_manager(mock_cache_manager) + + # Call the tool + result = tool(trajectory_file=str(cache_file)) + + # Verify return type is string + assert isinstance(result, str) + assert "✓ Cache execution mode activated" in result + assert "2 cached steps" in result + + # Verify activate_execution was called on the cache manager + mock_cache_manager.activate_execution.assert_called_once() + # Verify the cache file path was passed correctly + call_args = mock_cache_manager.activate_execution.call_args + assert call_args.kwargs["cache_file_path"] == str(cache_file) + + +def test_execute_cached_execution_works_with_toolbox() -> None: + """Test that ExecuteCachedTrajectory works with toolbox provided.""" + with tempfile.TemporaryDirectory() as temp_dir: + cache_file = Path(temp_dir) / "test_trajectory.json" + + cache_data = { + "metadata": { + "version": "0.1", + "created_at": datetime.now(tz=timezone.utc).isoformat(), + "execution_attempts": 0, + "last_executed_at": None, + "failures": [], + "is_valid": True, + "invalidation_reason": None, }, - ] + "trajectory": [ + { + "id": "tool1", + "name": "test_tool", + "input": {}, + "type": "tool_use", + } + ], + "cache_parameters": {}, + } with cache_file.open("w", encoding="utf-8") as f: - json.dump(trajectory, f) + json.dump(cache_data, f) - # Execute the trajectory - tool = ExecuteCachedTrajectory() + # Create mock agent + mock_cache_manager = MagicMock(spec=CacheExecutionManager) + + # Create tool with toolbox mock_toolbox = MagicMock(spec=ToolCollection) - tool.set_toolbox(mock_toolbox) + mock_toolbox._tool_map = {} + tool = ExecuteCachedTrajectory(toolbox=mock_toolbox) + tool.set_cache_execution_manager(mock_cache_manager) result = tool(trajectory_file=str(cache_file)) - # Verify success message - assert "Successfully executed trajectory" in result - # Verify toolbox.run was called for each tool (2 calls) - assert mock_toolbox.run.call_count == 2 + # Should succeed using the toolbox + assert isinstance(result, str) + assert "✓ Cache execution mode activated" in result + + +def test_execute_cached_execution_set_cache_manager() -> None: + """Test that set_cache_execution_manager properly sets reference.""" + mock_toolbox = MagicMock(spec=ToolCollection) + tool = ExecuteCachedTrajectory(toolbox=mock_toolbox) + mock_cache_manager = MagicMock(spec=CacheExecutionManager) + + tool.set_cache_execution_manager(mock_cache_manager) + + assert tool._cache_execution_manager == mock_cache_manager # noqa: SLF001 + assert tool._toolbox == mock_toolbox # noqa: SLF001 + + +def test_execute_cached_execution_initializes_with_default_settings() -> None: + """Test that ExecuteCachedTrajectory uses default settings when none provided.""" + mock_toolbox = MagicMock(spec=ToolCollection) + tool = ExecuteCachedTrajectory(toolbox=mock_toolbox) + + # Should have default settings initialized + assert hasattr(tool, "_settings") + assert tool._settings.delay_time_between_action == 0.5 # noqa: SLF001 + + +def test_execute_cached_execution_initializes_with_custom_settings() -> None: + """Test that ExecuteCachedTrajectory accepts custom settings.""" + mock_toolbox = MagicMock(spec=ToolCollection) + custom_settings = CacheExecutionSettings(delay_time_between_action=1.0) + tool = ExecuteCachedTrajectory(toolbox=mock_toolbox, settings=custom_settings) + + # Should have custom settings initialized + assert hasattr(tool, "_settings") + assert tool._settings.delay_time_between_action == 1.0 # noqa: SLF001 -def test_execute_cached_execution_skips_screenshot_tools() -> None: - """Test that ExecuteCachedExecution skips screenshot-related tools.""" +def test_execute_cached_execution_with_parameters() -> None: + """Test that ExecuteCachedTrajectory validates parameters.""" with tempfile.TemporaryDirectory() as temp_dir: cache_file = Path(temp_dir) / "test_trajectory.json" - # Create a trajectory with screenshot tools - trajectory: list[dict[str, Any]] = [ - { - "id": "tool1", - "name": "screenshot", - "input": {}, - "type": "tool_use", + # Create a v0.1 cache file with parameters + cache_data = { + "metadata": { + "version": "0.1", + "created_at": "2025-12-11T10:00:00Z", + "last_executed_at": None, + "execution_attempts": 0, + "failures": [], + "is_valid": True, + "invalidation_reason": None, }, - { - "id": "tool2", - "name": "click_tool", - "input": {"x": 100, "y": 200}, - "type": "tool_use", + "trajectory": [ + { + "id": "tool1", + "name": "type_tool", + "input": {"text": "Today is {{current_date}}"}, + "type": "tool_use", + }, + ], + "cache_parameters": { + "current_date": "Current date", }, - { - "id": "tool3", - "name": "retrieve_available_trajectories_tool", - "input": {}, - "type": "tool_use", + } + + with cache_file.open("w", encoding="utf-8") as f: + json.dump(cache_data, f) + + # Create mock agent + mock_cache_manager = MagicMock(spec=CacheExecutionManager) + mock_toolbox = MagicMock(spec=ToolCollection) + mock_toolbox._tool_map = {} + + tool = ExecuteCachedTrajectory(toolbox=mock_toolbox) + tool.set_cache_execution_manager(mock_cache_manager) + + result = tool( + trajectory_file=str(cache_file), + parameter_values={"current_date": "2025-12-11"}, + ) + + # Verify success + assert isinstance(result, str) + assert "✓ Cache execution mode activated" in result + assert "1 parameter value" in result + + +def test_execute_cached_execution_missing_parameters() -> None: + """Test that ExecuteCachedTrajectory returns error for missing parameters.""" + with tempfile.TemporaryDirectory() as temp_dir: + cache_file = Path(temp_dir) / "test_trajectory.json" + + # Create a v0.1 cache file with parameters + cache_data = { + "metadata": { + "version": "0.1", + "created_at": "2025-12-11T10:00:00Z", + "last_executed_at": None, + "execution_attempts": 0, + "failures": [], + "is_valid": True, + "invalidation_reason": None, }, - ] + "trajectory": [ + { + "id": "tool1", + "name": "type_tool", + "input": {"text": "Date: {{current_date}}, User: {{user_name}}"}, + "type": "tool_use", + } + ], + "cache_parameters": { + "current_date": "Current date", + "user_name": "User name", + }, + } with cache_file.open("w", encoding="utf-8") as f: - json.dump(trajectory, f) + json.dump(cache_data, f) - # Execute the trajectory - tool = ExecuteCachedTrajectory() + # Create mock agent + mock_cache_manager = MagicMock(spec=CacheExecutionManager) mock_toolbox = MagicMock(spec=ToolCollection) - tool.set_toolbox(mock_toolbox) + + tool = ExecuteCachedTrajectory(toolbox=mock_toolbox) + tool.set_cache_execution_manager(mock_cache_manager) result = tool(trajectory_file=str(cache_file)) - # Verify only click_tool was executed (screenshot and retrieve tools skipped) - assert mock_toolbox.run.call_count == 1 - assert "Successfully executed trajectory" in result + # Verify error message + assert isinstance(result, str) + assert "Missing required parameter values" in result + assert "current_date" in result + assert "user_name" in result -def test_execute_cached_execution_handles_errors_gracefully() -> None: - """Test that ExecuteCachedExecution handles errors during execution.""" +def test_execute_cached_execution_no_parameters_backward_compat() -> None: + """Test backward compatibility: trajectories without parameters work fine.""" with tempfile.TemporaryDirectory() as temp_dir: cache_file = Path(temp_dir) / "test_trajectory.json" - # Create a trajectory + # Create a v0.0 cache file (old format, no parameters) trajectory: list[dict[str, Any]] = [ { "id": "tool1", - "name": "failing_tool", - "input": {}, + "name": "click_tool", + "input": {"x": 100, "y": 200}, "type": "tool_use", - }, + } ] with cache_file.open("w", encoding="utf-8") as f: json.dump(trajectory, f) - # Execute the trajectory with a failing tool - tool = ExecuteCachedTrajectory() + # Create mock agent + mock_cache_manager = MagicMock(spec=CacheExecutionManager) mock_toolbox = MagicMock(spec=ToolCollection) - mock_toolbox.run.side_effect = Exception("Tool execution failed") - tool.set_toolbox(mock_toolbox) + mock_toolbox._tool_map = {} + + tool = ExecuteCachedTrajectory(toolbox=mock_toolbox) + tool.set_cache_execution_manager(mock_cache_manager) result = tool(trajectory_file=str(cache_file)) - # Verify error message - assert "error occured" in result.lower() - assert "verify the UI state" in result + # Verify success + assert isinstance(result, str) + assert "✓ Cache execution mode activated" in result -def test_execute_cached_execution_set_toolbox() -> None: - """Test that set_toolbox properly sets the toolbox reference.""" - tool = ExecuteCachedTrajectory() - mock_toolbox = MagicMock(spec=ToolCollection) +def test_continue_cached_trajectory_from_middle() -> None: + """Test continuing execution from middle of trajectory.""" + with tempfile.TemporaryDirectory() as temp_dir: + cache_file = Path(temp_dir) / "test_trajectory.json" - tool.set_toolbox(mock_toolbox) + # Create a trajectory with 5 steps + cache_data = { + "metadata": { + "version": "0.1", + "created_at": datetime.now(tz=timezone.utc).isoformat(), + "execution_attempts": 0, + "last_executed_at": None, + "failures": [], + "is_valid": True, + "invalidation_reason": None, + }, + "trajectory": [ + {"id": "1", "name": "tool1", "input": {}, "type": "tool_use"}, + {"id": "2", "name": "tool2", "input": {}, "type": "tool_use"}, + {"id": "3", "name": "tool3", "input": {}, "type": "tool_use"}, + {"id": "4", "name": "tool4", "input": {}, "type": "tool_use"}, + {"id": "5", "name": "tool5", "input": {}, "type": "tool_use"}, + ], + "cache_parameters": {}, + } - # After setting toolbox, should be able to access it - assert hasattr(tool, "_toolbox") - assert tool._toolbox == mock_toolbox + with cache_file.open("w", encoding="utf-8") as f: + json.dump(cache_data, f) + # Create mock agent + mock_cache_manager = MagicMock(spec=CacheExecutionManager) + mock_toolbox = MagicMock(spec=ToolCollection) + mock_toolbox._tool_map = {} -def test_execute_cached_execution_initializes_with_default_settings() -> None: - """Test that ExecuteCachedTrajectory uses default settings when none provided.""" - tool = ExecuteCachedTrajectory() + tool = ExecuteCachedTrajectory(toolbox=mock_toolbox) + tool.set_cache_execution_manager(mock_cache_manager) - # Should have default settings initialized - assert hasattr(tool, "_settings") + result = tool(trajectory_file=str(cache_file), start_from_step_index=2) + # Verify success message + assert isinstance(result, str) + assert "✓ Cache execution mode activated" in result + assert "resuming from step 2" in result + assert "3 remaining cached steps" in result -def test_execute_cached_execution_initializes_with_custom_settings() -> None: - """Test that ExecuteCachedTrajectory accepts custom settings.""" - custom_settings = CachedExecutionToolSettings(delay_time_between_action=1.0) - tool = ExecuteCachedTrajectory(settings=custom_settings) - # Should have custom settings initialized - assert hasattr(tool, "_settings") +def test_continue_cached_trajectory_invalid_step_index_negative() -> None: + """Test that negative step index returns error.""" + with tempfile.TemporaryDirectory() as temp_dir: + cache_file = Path(temp_dir) / "test_trajectory.json" + cache_data = { + "metadata": { + "version": "0.1", + "created_at": datetime.now(tz=timezone.utc).isoformat(), + "execution_attempts": 0, + "last_executed_at": None, + "failures": [], + "is_valid": True, + "invalidation_reason": None, + }, + "trajectory": [ + {"id": "1", "name": "tool1", "input": {}, "type": "tool_use"}, + ], + "cache_parameters": {}, + } + + with cache_file.open("w", encoding="utf-8") as f: + json.dump(cache_data, f) + + mock_cache_manager = MagicMock(spec=CacheExecutionManager) + mock_toolbox = MagicMock(spec=ToolCollection) + mock_toolbox._tool_map = {} -def test_execute_cached_execution_uses_delay_time_between_actions() -> None: - """Test that ExecuteCachedTrajectory uses the configured delay time.""" + tool = ExecuteCachedTrajectory(toolbox=mock_toolbox) + tool.set_cache_execution_manager(mock_cache_manager) + + result = tool(trajectory_file=str(cache_file), start_from_step_index=-1) + + # Verify error message + assert isinstance(result, str) + assert "Invalid start_from_step_index" in result + + +def test_continue_cached_trajectory_invalid_step_index_too_large() -> None: + """Test that step index beyond trajectory length returns error.""" with tempfile.TemporaryDirectory() as temp_dir: cache_file = Path(temp_dir) / "test_trajectory.json" - # Create a trajectory with 3 actions - trajectory: list[dict[str, Any]] = [ - { - "id": "tool1", - "name": "click_tool", - "input": {"x": 100, "y": 200}, - "type": "tool_use", - }, - { - "id": "tool2", - "name": "type_tool", - "input": {"text": "hello"}, - "type": "tool_use", + cache_data = { + "metadata": { + "version": "0.1", + "created_at": datetime.now(tz=timezone.utc).isoformat(), + "execution_attempts": 0, + "last_executed_at": None, + "failures": [], + "is_valid": True, + "invalidation_reason": None, }, - { - "id": "tool3", - "name": "move_tool", - "input": {"x": 300, "y": 400}, - "type": "tool_use", - }, - ] + "trajectory": [ + {"id": "1", "name": "tool1", "input": {}, "type": "tool_use"}, + {"id": "2", "name": "tool2", "input": {}, "type": "tool_use"}, + ], + "cache_parameters": {}, + } with cache_file.open("w", encoding="utf-8") as f: - json.dump(trajectory, f) + json.dump(cache_data, f) - # Execute with custom delay time - custom_settings = CachedExecutionToolSettings(delay_time_between_action=0.1) - tool = ExecuteCachedTrajectory(settings=custom_settings) + mock_cache_manager = MagicMock(spec=CacheExecutionManager) mock_toolbox = MagicMock(spec=ToolCollection) - tool.set_toolbox(mock_toolbox) + mock_toolbox._tool_map = {} + + tool = ExecuteCachedTrajectory(toolbox=mock_toolbox) + tool.set_cache_execution_manager(mock_cache_manager) - # Mock time.sleep to verify it's called with correct delay - with patch("time.sleep") as mock_sleep: - result = tool(trajectory_file=str(cache_file)) + result = tool(trajectory_file=str(cache_file), start_from_step_index=5) - # Verify success - assert "Successfully executed trajectory" in result - # Verify sleep was called 3 times (once after each action) - assert mock_sleep.call_count == 3 - # Verify it was called with the configured delay time - for call in mock_sleep.call_args_list: - assert call[0][0] == 0.1 + # Verify error message + assert isinstance(result, str) + assert "Invalid start_from_step_index" in result + assert "valid indices: 0-1" in result -def test_execute_cached_execution_default_delay_time() -> None: - """Test that ExecuteCachedTrajectory uses default delay time of 0.5s.""" +def test_continue_cached_trajectory_with_parameters() -> None: + """Test continuing execution with parameter substitution.""" with tempfile.TemporaryDirectory() as temp_dir: cache_file = Path(temp_dir) / "test_trajectory.json" - # Create a trajectory with 2 actions - trajectory: list[dict[str, Any]] = [ - { - "id": "tool1", - "name": "click_tool", - "input": {"x": 100, "y": 200}, - "type": "tool_use", + # Create a v0.1 cache file with parameters + cache_data = { + "metadata": { + "version": "0.1", + "created_at": "2025-12-11T10:00:00Z", + "last_executed_at": None, + "execution_attempts": 0, + "failures": [], + "is_valid": True, + "invalidation_reason": None, }, - { - "id": "tool2", - "name": "type_tool", - "input": {"text": "hello"}, - "type": "tool_use", + "trajectory": [ + { + "id": "1", + "name": "tool1", + "input": {"text": "Step 1"}, + "type": "tool_use", + }, + { + "id": "2", + "name": "tool2", + "input": {"text": "Date: {{current_date}}"}, + "type": "tool_use", + }, + { + "id": "3", + "name": "tool3", + "input": {"text": "User: {{user_name}}"}, + "type": "tool_use", + }, + ], + "cache_parameters": { + "current_date": "Current date", + "user_name": "User name", }, - ] + } with cache_file.open("w", encoding="utf-8") as f: - json.dump(trajectory, f) + json.dump(cache_data, f) - # Execute with default settings - tool = ExecuteCachedTrajectory() + # Create mock agent + mock_cache_manager = MagicMock(spec=CacheExecutionManager) mock_toolbox = MagicMock(spec=ToolCollection) - tool.set_toolbox(mock_toolbox) - - # Mock time.sleep to verify default delay is used - with patch("time.sleep") as mock_sleep: - result = tool(trajectory_file=str(cache_file)) - - # Verify success - assert "Successfully executed trajectory" in result - # Verify sleep was called with default delay of 0.5s - assert mock_sleep.call_count == 2 - for call in mock_sleep.call_args_list: - assert call[0][0] == 0.5 + mock_toolbox._tool_map = {} + + tool = ExecuteCachedTrajectory(toolbox=mock_toolbox) + tool.set_cache_execution_manager(mock_cache_manager) + + result = tool( + trajectory_file=str(cache_file), + start_from_step_index=1, + parameter_values={"current_date": "2025-12-11", "user_name": "Alice"}, + ) + + # Verify success + assert isinstance(result, str) + assert "✓ Cache execution mode activated" in result + assert "resuming from step 1" in result + + +def test_execute_cached_trajectory_warns_if_invalid( + tmp_path: Path, caplog: pytest.LogCaptureFixture +) -> None: + """Test that ExecuteCachedTrajectory warns when activating with invalid cache.""" + import logging + + caplog.set_level(logging.WARNING) + + cache_file = tmp_path / "test.json" + cache_data = { + "metadata": { + "version": "0.1", + "created_at": datetime.now(tz=timezone.utc).isoformat(), + "execution_attempts": 10, + "last_executed_at": None, + "failures": [], + "is_valid": False, + "invalidation_reason": "Cache marked invalid for testing", + }, + "trajectory": [ + {"id": "1", "name": "click", "input": {"x": 100}, "type": "tool_use"}, + ], + "cache_parameters": {}, + } + with cache_file.open("w") as f: + json.dump(cache_data, f) + + # Create mock agent + mock_cache_manager = MagicMock(spec=CacheExecutionManager) + mock_toolbox = MagicMock(spec=ToolCollection) + mock_toolbox._tool_map = {} + + tool = ExecuteCachedTrajectory(toolbox=mock_toolbox) + tool.set_cache_execution_manager(mock_cache_manager) + + result = tool(trajectory_file=str(cache_file)) + + # Should still activate but log warning + assert isinstance(result, str) + assert "✓ Cache execution mode activated" in result + + # Verify warning was logged + assert any("WARNING" in record.levelname for record in caplog.records) + assert any("invalid cache" in record.message.lower() for record in caplog.records) + + +# ============================================================================ +# InspectCacheMetadata Tests (unchanged from before) +# ============================================================================ + + +def test_inspect_cache_metadata_shows_basic_info(tmp_path: Path) -> None: + """Test that InspectCacheMetadata displays basic cache information.""" + from askui.tools.caching_tools import InspectCacheMetadata + + cache_file = tmp_path / "test.json" + cache_data = { + "metadata": { + "version": "0.1", + "created_at": datetime.now(tz=timezone.utc).isoformat(), + "execution_attempts": 5, + "last_executed_at": datetime.now(tz=timezone.utc).isoformat(), + "failures": [], + "is_valid": True, + "invalidation_reason": None, + }, + "trajectory": [ + {"id": "1", "name": "click", "input": {"x": 100}, "type": "tool_use"}, + {"id": "2", "name": "type", "input": {"text": "test"}, "type": "tool_use"}, + ], + "cache_parameters": {"current_date": "{{current_date}}"}, + } + with cache_file.open("w") as f: + json.dump(cache_data, f) + + tool = InspectCacheMetadata() + result = tool(trajectory_file=str(cache_file)) + + # Verify output contains key information + assert "=== Cache Metadata ===" in result + assert "Version: 0.1" in result + assert "Total Execution Attempts: 5" in result + assert "Is Valid: True" in result + assert "Total Steps: 2" in result + assert "Parameters: 1" in result + assert "current_date" in result + + +def test_inspect_cache_metadata_shows_failures(tmp_path: Path) -> None: + """Test that InspectCacheMetadata displays failure history.""" + from askui.tools.caching_tools import InspectCacheMetadata + + cache_file = tmp_path / "test.json" + cache_data = { + "metadata": { + "version": "0.1", + "created_at": datetime.now(tz=timezone.utc).isoformat(), + "execution_attempts": 3, + "last_executed_at": None, + "failures": [ + { + "timestamp": datetime.now(tz=timezone.utc).isoformat(), + "step_index": 1, + "error_message": "Click failed", + "failure_count_at_step": 1, + }, + { + "timestamp": datetime.now(tz=timezone.utc).isoformat(), + "step_index": 1, + "error_message": "Click failed again", + "failure_count_at_step": 2, + }, + ], + "is_valid": False, + "invalidation_reason": "Too many failures at step 1", + }, + "trajectory": [ + {"id": "1", "name": "click", "input": {"x": 100}, "type": "tool_use"}, + ], + "cache_parameters": {}, + } + with cache_file.open("w") as f: + json.dump(cache_data, f) + + tool = InspectCacheMetadata() + result = tool(trajectory_file=str(cache_file)) + + # Verify failure information + assert "--- Failure History ---" in result + assert "Failure 1:" in result + assert "Failure 2:" in result + assert "Step Index: 1" in result + assert "Click failed" in result + assert "Is Valid: False" in result + assert "Invalidation Reason: Too many failures at step 1" in result + + +def test_inspect_cache_metadata_file_not_found() -> None: + """Test that InspectCacheMetadata handles missing files.""" + from askui.tools.caching_tools import InspectCacheMetadata + + tool = InspectCacheMetadata() + result = tool(trajectory_file="/nonexistent/file.json") + + assert "Trajectory file not found" in result + + +# ============================================================================ +# RevalidateCache Tests (unchanged from before) +# ============================================================================ + + +def test_revalidate_cache_marks_invalid_as_valid(tmp_path: Path) -> None: + """Test that RevalidateCache marks invalid cache as valid.""" + from askui.tools.caching_tools import RevalidateCache + + cache_file = tmp_path / "test.json" + cache_data = { + "metadata": { + "version": "0.1", + "created_at": datetime.now(tz=timezone.utc).isoformat(), + "execution_attempts": 3, + "last_executed_at": None, + "failures": [ + { + "timestamp": datetime.now(tz=timezone.utc).isoformat(), + "step_index": 1, + "error_message": "Error", + "failure_count_at_step": 1, + } + ], + "is_valid": False, + "invalidation_reason": "Manual invalidation", + }, + "trajectory": [ + {"id": "1", "name": "click", "input": {"x": 100}, "type": "tool_use"}, + ], + "cache_parameters": {}, + } + with cache_file.open("w") as f: + json.dump(cache_data, f) + + tool = RevalidateCache() + result = tool(trajectory_file=str(cache_file)) + + # Verify success message + assert "Successfully revalidated" in result + assert "Manual invalidation" in result + + # Read updated cache file + with cache_file.open("r") as f: + updated_data = json.load(f) + + # Verify cache is now valid + assert updated_data["metadata"]["is_valid"] is True + assert updated_data["metadata"]["invalidation_reason"] is None + # Failure history should still be there + assert len(updated_data["metadata"]["failures"]) == 1 + + +def test_revalidate_cache_already_valid(tmp_path: Path) -> None: + """Test that RevalidateCache handles already valid cache.""" + from askui.tools.caching_tools import RevalidateCache + + cache_file = tmp_path / "test.json" + cache_data = { + "metadata": { + "version": "0.1", + "created_at": datetime.now(tz=timezone.utc).isoformat(), + "execution_attempts": 0, + "last_executed_at": None, + "failures": [], + "is_valid": True, + "invalidation_reason": None, + }, + "trajectory": [ + {"id": "1", "name": "click", "input": {"x": 100}, "type": "tool_use"}, + ], + "cache_parameters": {}, + } + with cache_file.open("w") as f: + json.dump(cache_data, f) + + tool = RevalidateCache() + result = tool(trajectory_file=str(cache_file)) + + # Verify message indicates already valid + assert "already valid" in result + assert "No changes made" in result + + +def test_revalidate_cache_file_not_found() -> None: + """Test that RevalidateCache handles missing files.""" + from askui.tools.caching_tools import RevalidateCache + + tool = RevalidateCache() + result = tool(trajectory_file="/nonexistent/file.json") + + assert "Trajectory file not found" in result + + +# ============================================================================ +# InvalidateCache Tests (unchanged from before) +# ============================================================================ + + +def test_invalidate_cache_marks_valid_as_invalid(tmp_path: Path) -> None: + """Test that InvalidateCache marks valid cache as invalid.""" + from askui.tools.caching_tools import InvalidateCache + + cache_file = tmp_path / "test.json" + cache_data = { + "metadata": { + "version": "0.1", + "created_at": datetime.now(tz=timezone.utc).isoformat(), + "execution_attempts": 2, + "last_executed_at": datetime.now(tz=timezone.utc).isoformat(), + "failures": [], + "is_valid": True, + "invalidation_reason": None, + }, + "trajectory": [ + {"id": "1", "name": "click", "input": {"x": 100}, "type": "tool_use"}, + ], + "cache_parameters": {}, + } + with cache_file.open("w") as f: + json.dump(cache_data, f) + + tool = InvalidateCache() + result = tool(trajectory_file=str(cache_file), reason="UI changed - button moved") + + # Verify success message + assert "Successfully invalidated" in result + assert "UI changed - button moved" in result + + # Read updated cache file + with cache_file.open("r") as f: + updated_data = json.load(f) + + # Verify cache is now invalid + assert updated_data["metadata"]["is_valid"] is False + assert ( + updated_data["metadata"]["invalidation_reason"] == "UI changed - button moved" + ) + # Other metadata should be preserved + assert updated_data["metadata"]["execution_attempts"] == 2 + + +def test_invalidate_cache_updates_reason_if_already_invalid(tmp_path: Path) -> None: + """Test that InvalidateCache updates reason if already invalid.""" + from askui.tools.caching_tools import InvalidateCache + + cache_file = tmp_path / "test.json" + cache_data = { + "metadata": { + "version": "0.1", + "created_at": datetime.now(tz=timezone.utc).isoformat(), + "execution_attempts": 0, + "last_executed_at": None, + "failures": [], + "is_valid": False, + "invalidation_reason": "Old reason", + }, + "trajectory": [ + {"id": "1", "name": "click", "input": {"x": 100}, "type": "tool_use"}, + ], + "cache_parameters": {}, + } + with cache_file.open("w") as f: + json.dump(cache_data, f) + + tool = InvalidateCache() + result = tool(trajectory_file=str(cache_file), reason="New reason") + + # Verify message indicates update + assert "already invalid" in result + assert "Updated invalidation reason to: New reason" in result + + # Read updated cache file + with cache_file.open("r") as f: + updated_data = json.load(f) + + # Verify reason was updated + assert updated_data["metadata"]["invalidation_reason"] == "New reason" + + +def test_invalidate_cache_file_not_found() -> None: + """Test that InvalidateCache handles missing files.""" + from askui.tools.caching_tools import InvalidateCache + + tool = InvalidateCache() + result = tool(trajectory_file="/nonexistent/file.json", reason="Test") + + assert "Trajectory file not found" in result diff --git a/tests/unit/utils/test_cache_manager.py b/tests/unit/utils/test_cache_manager.py new file mode 100644 index 00000000..9570a04d --- /dev/null +++ b/tests/unit/utils/test_cache_manager.py @@ -0,0 +1,409 @@ +"""Tests for cache manager.""" + +from datetime import datetime, timezone +from unittest.mock import MagicMock + +import pytest + +from askui.models.shared.agent_message_param import ToolUseBlockParam +from askui.models.shared.settings import CacheFailure, CacheFile, CacheMetadata +from askui.utils.caching.cache_manager import CacheManager +from askui.utils.caching.cache_validator import ( + CacheValidator, + CompositeCacheValidator, + StepFailureCountValidator, +) + + +@pytest.fixture +def sample_cache_file() -> CacheFile: + """Create a sample cache file for testing.""" + return CacheFile( + metadata=CacheMetadata( + version="0.1", + created_at=datetime.now(tz=timezone.utc), + execution_attempts=0, + is_valid=True, + ), + trajectory=[ + ToolUseBlockParam(id="1", name="click", input={"x": 100}, type="tool_use"), + ToolUseBlockParam( + id="2", name="type", input={"text": "test"}, type="tool_use" + ), + ], + cache_parameters={}, + ) + + +# Initialization Tests + + +def test_cache_manager_default_initialization() -> None: + """Test cache manager initializes with default validator.""" + manager = CacheManager() + assert manager.validators is not None + assert isinstance(manager.validators, CompositeCacheValidator) + assert len(manager.validators.validators) == 3 # 3 built-in validators + + +def test_cache_manager_custom_validator() -> None: + """Test cache manager with custom validator.""" + custom_validator = StepFailureCountValidator(max_failures_per_step=5) + manager = CacheManager(validators=[custom_validator]) + assert manager.validators.validators[0] is custom_validator + + +# Record Execution Attempt Tests + + +def test_record_execution_attempt_success(sample_cache_file: CacheFile) -> None: + """Test recording successful execution attempt.""" + manager = CacheManager() + initial_attempts = sample_cache_file.metadata.execution_attempts + initial_last_executed = sample_cache_file.metadata.last_executed_at + + manager.record_execution_attempt(sample_cache_file, success=True) + + assert sample_cache_file.metadata.execution_attempts == initial_attempts + 1 + assert sample_cache_file.metadata.last_executed_at is not None + assert sample_cache_file.metadata.last_executed_at != initial_last_executed + + +def test_record_execution_attempt_failure_with_info( + sample_cache_file: CacheFile, +) -> None: + """Test recording failed execution attempt with failure info.""" + manager = CacheManager() + initial_attempts = sample_cache_file.metadata.execution_attempts + initial_failures = len(sample_cache_file.metadata.failures) + + failure_info = CacheFailure( + timestamp=datetime.now(tz=timezone.utc), + step_index=1, + error_message="Test error", + failure_count_at_step=1, + ) + + manager.record_execution_attempt( + sample_cache_file, success=False, failure_info=failure_info + ) + + assert sample_cache_file.metadata.execution_attempts == initial_attempts + 1 + assert len(sample_cache_file.metadata.failures) == initial_failures + 1 + assert sample_cache_file.metadata.failures[-1] == failure_info + + +def test_record_execution_attempt_failure_without_info( + sample_cache_file: CacheFile, +) -> None: + """Test recording failed execution attempt without failure info.""" + manager = CacheManager() + initial_attempts = sample_cache_file.metadata.execution_attempts + initial_failures = len(sample_cache_file.metadata.failures) + + manager.record_execution_attempt( + sample_cache_file, success=False, failure_info=None + ) + + assert sample_cache_file.metadata.execution_attempts == initial_attempts + 1 + assert ( + len(sample_cache_file.metadata.failures) == initial_failures + ) # No new failure added + + +# Record Step Failure Tests + + +def test_record_step_failure_first_failure(sample_cache_file: CacheFile) -> None: + """Test recording the first failure at a step.""" + manager = CacheManager() + + manager.record_step_failure( + sample_cache_file, step_index=1, error_message="First error" + ) + + assert len(sample_cache_file.metadata.failures) == 1 + failure = sample_cache_file.metadata.failures[0] + assert failure.step_index == 1 + assert failure.error_message == "First error" + assert failure.failure_count_at_step == 1 + + +def test_record_step_failure_multiple_at_same_step( + sample_cache_file: CacheFile, +) -> None: + """Test recording multiple failures at the same step.""" + manager = CacheManager() + + manager.record_step_failure( + sample_cache_file, step_index=1, error_message="Error 1" + ) + manager.record_step_failure( + sample_cache_file, step_index=1, error_message="Error 2" + ) + manager.record_step_failure( + sample_cache_file, step_index=1, error_message="Error 3" + ) + + assert len(sample_cache_file.metadata.failures) == 3 + assert sample_cache_file.metadata.failures[0].failure_count_at_step == 1 + assert sample_cache_file.metadata.failures[1].failure_count_at_step == 2 + assert sample_cache_file.metadata.failures[2].failure_count_at_step == 3 + + +def test_record_step_failure_different_steps(sample_cache_file: CacheFile) -> None: + """Test recording failures at different steps.""" + manager = CacheManager() + + manager.record_step_failure( + sample_cache_file, step_index=1, error_message="Error at step 1" + ) + manager.record_step_failure( + sample_cache_file, step_index=2, error_message="Error at step 2" + ) + manager.record_step_failure( + sample_cache_file, step_index=1, error_message="Another at step 1" + ) + + assert len(sample_cache_file.metadata.failures) == 3 + + step_1_failures = [ + f for f in sample_cache_file.metadata.failures if f.step_index == 1 + ] + step_2_failures = [ + f for f in sample_cache_file.metadata.failures if f.step_index == 2 + ] + + assert len(step_1_failures) == 2 + assert len(step_2_failures) == 1 + assert step_1_failures[1].failure_count_at_step == 2 # Second failure at step 1 + + +# Should Invalidate Tests + + +def test_should_invalidate_delegates_to_validator(sample_cache_file: CacheFile) -> None: + """Test that should_invalidate delegates to the validator.""" + mock_validator = MagicMock(spec=CacheValidator) + mock_validator.get_name.return_value = "Mock Validator" + mock_validator.should_invalidate.return_value = (True, "Test reason") + + manager = CacheManager(validators=[mock_validator]) + should_inv, reason = manager.should_invalidate(sample_cache_file, step_index=1) + + assert should_inv is True + assert reason == "Mock Validator: Test reason" + mock_validator.should_invalidate.assert_called_once_with(sample_cache_file, 1) + + +def test_should_invalidate_with_default_validator(sample_cache_file: CacheFile) -> None: + """Test should_invalidate with default built-in validators.""" + manager = CacheManager() + + # Add failures that exceed default thresholds + sample_cache_file.metadata.execution_attempts = 10 + sample_cache_file.metadata.failures = [ + CacheFailure( + timestamp=datetime.now(tz=timezone.utc), + step_index=1, + error_message=f"Error {i}", + failure_count_at_step=i + 1, + ) + for i in range(6) + ] # 6/10 = 60% failure rate (exceeds default 50%) + + should_inv, reason = manager.should_invalidate(sample_cache_file) + assert should_inv is True + assert "Failure rate" in reason # type: ignore[operator] + + +# Invalidate Cache Tests + + +def test_invalidate_cache(sample_cache_file: CacheFile) -> None: + """Test marking cache as invalid.""" + manager = CacheManager() + assert sample_cache_file.metadata.is_valid is True + assert sample_cache_file.metadata.invalidation_reason is None + + manager.invalidate_cache(sample_cache_file, reason="Test invalidation") + + assert sample_cache_file.metadata.is_valid is False + assert sample_cache_file.metadata.invalidation_reason == "Test invalidation" # type: ignore[unreachable] + + +def test_invalidate_cache_multiple_times(sample_cache_file: CacheFile) -> None: + """Test invalidating cache multiple times updates reason.""" + manager = CacheManager() + + manager.invalidate_cache(sample_cache_file, reason="First reason") + assert sample_cache_file.metadata.invalidation_reason == "First reason" + + manager.invalidate_cache(sample_cache_file, reason="Second reason") + assert sample_cache_file.metadata.invalidation_reason == "Second reason" + + +# Mark Cache Valid Tests + + +def test_mark_cache_valid(sample_cache_file: CacheFile) -> None: + """Test marking cache as valid.""" + manager = CacheManager() + + # First invalidate + sample_cache_file.metadata.is_valid = False + sample_cache_file.metadata.invalidation_reason = "Was invalid" + + # Then mark valid + manager.mark_cache_valid(sample_cache_file) + + assert sample_cache_file.metadata.is_valid is True + assert sample_cache_file.metadata.invalidation_reason is None + + +def test_mark_cache_valid_already_valid(sample_cache_file: CacheFile) -> None: + """Test marking already valid cache as valid.""" + manager = CacheManager() + assert sample_cache_file.metadata.is_valid is True + + manager.mark_cache_valid(sample_cache_file) + + assert sample_cache_file.metadata.is_valid is True + assert sample_cache_file.metadata.invalidation_reason is None + + +# Get Failure Count for Step Tests + + +def test_get_failure_count_for_step_no_failures(sample_cache_file: CacheFile) -> None: + """Test getting failure count when no failures exist.""" + manager = CacheManager() + + count = manager.get_failure_count_for_step(sample_cache_file, step_index=1) + assert count == 0 + + +def test_get_failure_count_for_step_with_failures(sample_cache_file: CacheFile) -> None: + """Test getting failure count for specific step.""" + manager = CacheManager() + + sample_cache_file.metadata.failures = [ + CacheFailure( + timestamp=datetime.now(tz=timezone.utc), + step_index=1, + error_message="Error 1", + failure_count_at_step=1, + ), + CacheFailure( + timestamp=datetime.now(tz=timezone.utc), + step_index=2, + error_message="Error 2", + failure_count_at_step=1, + ), + CacheFailure( + timestamp=datetime.now(tz=timezone.utc), + step_index=1, + error_message="Error 3", + failure_count_at_step=2, + ), + ] + + count_step_1 = manager.get_failure_count_for_step(sample_cache_file, step_index=1) + count_step_2 = manager.get_failure_count_for_step(sample_cache_file, step_index=2) + + assert count_step_1 == 2 + assert count_step_2 == 1 + + +def test_get_failure_count_for_step_nonexistent_step( + sample_cache_file: CacheFile, +) -> None: + """Test getting failure count for step that hasn't failed.""" + manager = CacheManager() + + sample_cache_file.metadata.failures = [ + CacheFailure( + timestamp=datetime.now(tz=timezone.utc), + step_index=1, + error_message="Error", + failure_count_at_step=1, + ) + ] + + count = manager.get_failure_count_for_step(sample_cache_file, step_index=99) + assert count == 0 + + +# Integration Tests + + +def test_full_workflow_with_failure_detection(sample_cache_file: CacheFile) -> None: + """Test complete workflow: record failures, detect threshold, invalidate.""" + manager = CacheManager() + + # Record 3 failures at step 1 (default threshold is 3) + for i in range(3): + manager.record_step_failure( + sample_cache_file, step_index=1, error_message=f"Error {i + 1}" + ) + + # Check if should invalidate + should_inv, reason = manager.should_invalidate(sample_cache_file, step_index=1) + assert should_inv is True + assert "Step 1 failed 3 times" in reason # type: ignore[operator] + + # Invalidate + manager.invalidate_cache(sample_cache_file, reason=reason) # type: ignore[arg-type] + assert sample_cache_file.metadata.is_valid is False + + +def test_full_workflow_below_threshold(sample_cache_file: CacheFile) -> None: + """Test workflow where failures don't reach threshold.""" + manager = CacheManager() + + # Record 2 failures at step 1 (below default threshold of 3) + for i in range(2): + manager.record_step_failure( + sample_cache_file, step_index=1, error_message=f"Error {i + 1}" + ) + + # Check if should invalidate + should_inv, reason = manager.should_invalidate(sample_cache_file, step_index=1) + assert should_inv is False + + # Cache should still be valid + assert sample_cache_file.metadata.is_valid is True + + +def test_workflow_with_custom_validator(sample_cache_file: CacheFile) -> None: + """Test workflow with custom validator with lower threshold.""" + # Custom validator with lower threshold + custom_validator = [StepFailureCountValidator(max_failures_per_step=2)] + manager = CacheManager(validators=custom_validator) # type: ignore[arg-type] + + # Record 2 failures (enough to trigger custom validator) + for i in range(2): + manager.record_step_failure( + sample_cache_file, step_index=1, error_message=f"Error {i + 1}" + ) + + should_inv, reason = manager.should_invalidate(sample_cache_file, step_index=1) + assert should_inv is True + assert "Step 1 failed 2 times" in reason # type: ignore[operator] + + +def test_workflow_successful_execution_updates_timestamp( + sample_cache_file: CacheFile, +) -> None: + """Test that successful execution updates last_executed_at.""" + manager = CacheManager() + + # Record some failures first + manager.record_step_failure(sample_cache_file, step_index=1, error_message="Error") + assert sample_cache_file.metadata.last_executed_at is None + + # Record successful execution + manager.record_execution_attempt(sample_cache_file, success=True) + + assert sample_cache_file.metadata.last_executed_at is not None + assert sample_cache_file.metadata.execution_attempts == 1 # type: ignore[unreachable] diff --git a/tests/unit/utils/test_cache_parameter_handler.py b/tests/unit/utils/test_cache_parameter_handler.py new file mode 100644 index 00000000..287409e7 --- /dev/null +++ b/tests/unit/utils/test_cache_parameter_handler.py @@ -0,0 +1,453 @@ +"""Unit tests for CacheParameterHandler.""" + +import pytest + +from askui.models.shared.agent_message_param import ToolUseBlockParam +from askui.utils.cache_parameter_handler import ( + CACHE_PARAMETER_PATTERN, + CacheParameterHandler, +) + + +def test_parameter_pattern_matches_valid_parameters() -> None: + """Test that the regex pattern matches valid parameter syntax.""" + import re + + valid_parameters = [ + "{{variable}}", + "{{current_date}}", + "{{user_name}}", + "{{_private}}", + "{{VAR123}}", + ] + + for parameter in valid_parameters: + match = re.search(CACHE_PARAMETER_PATTERN, parameter) + assert match is not None, f"Should match valid parameter: {parameter}" + + +def test_parameter_pattern_does_not_match_invalid() -> None: + """Test that the regex pattern rejects invalid parameter syntax.""" + import re + + invalid_parameters = [ + "{{123invalid}}", # Starts with number + "{{var-name}}", # Contains hyphen + "{{var name}}", # Contains space + "{single}", # Single braces + "{{}}", # Empty + ] + + for parameter in invalid_parameters: + match = re.search(CACHE_PARAMETER_PATTERN, parameter) + if match and match.group(0) == parameter: + pytest.fail(f"Should not match invalid parameter: {parameter}") + + +def test_extract_parameters_from_simple_string() -> None: + """Test extracting parameters from a simple string input.""" + trajectory = [ + ToolUseBlockParam( + id="1", + name="computer", + input={"action": "type", "text": "Today is {{current_date}}"}, + type="tool_use", + ) + ] + + parameters = CacheParameterHandler.extract_parameters(trajectory) + assert parameters == {"current_date"} + + +def test_extract_parameters_multiple_in_one_string() -> None: + """Test extracting multiple parameters from one string.""" + trajectory = [ + ToolUseBlockParam( + id="1", + name="computer", + input={ + "action": "type", + "text": "Hello {{user_name}}, today is {{current_date}}", + }, + type="tool_use", + ) + ] + + parameters = CacheParameterHandler.extract_parameters(trajectory) + assert parameters == {"user_name", "current_date"} + + +def test_extract_parameters_from_nested_dict() -> None: + """Test extracting parameters from nested dictionary structures.""" + trajectory = [ + ToolUseBlockParam( + id="1", + name="complex_tool", + input={ + "outer": {"inner": {"text": "Value is {{nested_var}}"}}, + "another": "{{another_var}}", + }, + type="tool_use", + ) + ] + + parameters = CacheParameterHandler.extract_parameters(trajectory) + assert parameters == {"nested_var", "another_var"} + + +def test_extract_parameters_from_list() -> None: + """Test extracting parameters from lists in input.""" + trajectory = [ + ToolUseBlockParam( + id="1", + name="tool", + input={ + "items": [ + "{{item1}}", + "{{item2}}", + {"nested": "{{item3}}"}, + ] + }, + type="tool_use", + ) + ] + + parameters = CacheParameterHandler.extract_parameters(trajectory) + assert parameters == {"item1", "item2", "item3"} + + +def test_extract_parameters_no_parameters() -> None: + """Test that extracting from trajectory without parameters returns empty set.""" + trajectory = [ + ToolUseBlockParam( + id="1", + name="computer", + input={"action": "click", "coordinate": [100, 200]}, + type="tool_use", + ) + ] + + parameters = CacheParameterHandler.extract_parameters(trajectory) + assert parameters == set() + + +def test_extract_parameters_from_multiple_steps() -> None: + """Test extracting parameters from multiple trajectory steps.""" + trajectory = [ + ToolUseBlockParam( + id="1", + name="tool1", + input={"text": "{{var1}}"}, + type="tool_use", + ), + ToolUseBlockParam( + id="2", + name="tool2", + input={"text": "{{var2}}"}, + type="tool_use", + ), + ToolUseBlockParam( + id="3", + name="tool3", + input={"text": "{{var1}}"}, # Duplicate + type="tool_use", + ), + ] + + parameters = CacheParameterHandler.extract_parameters(trajectory) + assert parameters == {"var1", "var2"} # No duplicates + + +def test_validate_parameters_all_provided() -> None: + """Test validation passes when all parameters have values.""" + trajectory = [ + ToolUseBlockParam( + id="1", + name="tool", + input={"text": "{{var1}} and {{var2}}"}, + type="tool_use", + ) + ] + + is_valid, missing = CacheParameterHandler.validate_parameters( + trajectory, {"var1": "value1", "var2": "value2"} + ) + + assert is_valid is True + assert missing == [] + + +def test_validate_parameters_missing_some() -> None: + """Test validation fails when some parameters are missing.""" + trajectory = [ + ToolUseBlockParam( + id="1", + name="tool", + input={"text": "{{var1}} and {{var2}} and {{var3}}"}, + type="tool_use", + ) + ] + + is_valid, missing = CacheParameterHandler.validate_parameters( + trajectory, {"var1": "value1"} + ) + + assert is_valid is False + assert set(missing) == {"var2", "var3"} + + +def test_validate_parameters_extra_values_ok() -> None: + """Test validation passes when extra values are provided (they're ignored).""" + trajectory = [ + ToolUseBlockParam( + id="1", + name="tool", + input={"text": "{{var1}}"}, + type="tool_use", + ) + ] + + is_valid, missing = CacheParameterHandler.validate_parameters( + trajectory, {"var1": "value1", "extra_var": "extra_value"} + ) + + assert is_valid is True + assert missing == [] + + +def test_validate_parameters_no_parameters() -> None: + """Test validation passes when trajectory has no parameters.""" + trajectory = [ + ToolUseBlockParam( + id="1", + name="tool", + input={"text": "No parameters here"}, + type="tool_use", + ) + ] + + is_valid, missing = CacheParameterHandler.validate_parameters(trajectory, {}) + + assert is_valid is True + assert missing == [] + + +def test_substitute_parameters_simple_string() -> None: + """Test substituting parameters in a simple string.""" + tool_block = ToolUseBlockParam( + id="1", + name="computer", + input={"action": "type", "text": "Today is {{current_date}}"}, + type="tool_use", + ) + + result = CacheParameterHandler.substitute_parameters( + tool_block, {"current_date": "2025-12-11"} + ) + + assert result.input["text"] == "Today is 2025-12-11" # type: ignore[index] + assert result.id == tool_block.id + assert result.name == tool_block.name + + +def test_substitute_parameters_multiple() -> None: + """Test substituting multiple parameters in one string.""" + tool_block = ToolUseBlockParam( + id="1", + name="computer", + input={ + "action": "type", + "text": "Hello {{user_name}}, date is {{current_date}}", + }, + type="tool_use", + ) + + result = CacheParameterHandler.substitute_parameters( + tool_block, {"user_name": "Alice", "current_date": "2025-12-11"} + ) + + assert result.input["text"] == "Hello Alice, date is 2025-12-11" # type: ignore[index] + + +def test_substitute_parameters_nested_dict() -> None: + """Test substituting parameters in nested dictionaries.""" + tool_block = ToolUseBlockParam( + id="1", + name="tool", + input={ + "outer": {"inner": {"text": "Value: {{var1}}"}}, + "another": "{{var2}}", + }, + type="tool_use", + ) + + result = CacheParameterHandler.substitute_parameters( + tool_block, {"var1": "value1", "var2": "value2"} + ) + + assert result.input["outer"]["inner"]["text"] == "Value: value1" # type: ignore[index] + assert result.input["another"] == "value2" # type: ignore[index] + + +def test_substitute_parameters_in_list() -> None: + """Test substituting parameters in lists.""" + tool_block = ToolUseBlockParam( + id="1", + name="tool", + input={"items": ["{{item1}}", "static", {"nested": "{{item2}}"}]}, + type="tool_use", + ) + + result = CacheParameterHandler.substitute_parameters( + tool_block, {"item1": "value1", "item2": "value2"} + ) + + assert result.input["items"][0] == "value1" # type: ignore[index] + assert result.input["items"][1] == "static" # type: ignore[index] + assert result.input["items"][2]["nested"] == "value2" # type: ignore[index] + + +def test_substitute_parameters_no_change_if_no_parameters() -> None: + """Test that substitution doesn't change input without parameters.""" + tool_block = ToolUseBlockParam( + id="1", + name="computer", + input={"action": "click", "coordinate": [100, 200]}, + type="tool_use", + ) + + result = CacheParameterHandler.substitute_parameters(tool_block, {}) + + assert result.input == tool_block.input + + +def test_substitute_parameters_partial_substitution() -> None: + """Test that only provided parameters are substituted.""" + tool_block = ToolUseBlockParam( + id="1", + name="tool", + input={"text": "{{var1}} and {{var2}}"}, + type="tool_use", + ) + + result = CacheParameterHandler.substitute_parameters(tool_block, {"var1": "value1"}) + + assert result.input["text"] == "value1 and {{var2}}" # type: ignore[index] + + +def test_substitute_parameters_preserves_original() -> None: + """Test that substitution creates a new object, doesn't modify original.""" + tool_block = ToolUseBlockParam( + id="1", + name="tool", + input={"text": "{{var1}}"}, + type="tool_use", + ) + + original_input = tool_block.input.copy() # type: ignore[attr-defined] + CacheParameterHandler.substitute_parameters(tool_block, {"var1": "value1"}) + + # Original should be unchanged + assert tool_block.input == original_input + + +def test_substitute_parameters_with_special_characters() -> None: + """Test substitution with values containing special regex characters.""" + tool_block = ToolUseBlockParam( + id="1", + name="tool", + input={"text": "Pattern: {{pattern}}"}, + type="tool_use", + ) + + # Value contains regex special characters + result = CacheParameterHandler.substitute_parameters( + tool_block, {"pattern": r".*[test]$"} + ) + + assert result.input["text"] == r"Pattern: .*[test]$" # type: ignore[index] + + +def test_substitute_parameters_with_backslashes() -> None: + """Test substitution with values containing backslashes (e.g., Windows paths). + + This test reveals the bug where backslashes in parameter values cause + re.PatternError because re.sub() treats the replacement string as a + regex replacement pattern, not a literal string. + """ + tool_block = ToolUseBlockParam( + id="1", + name="tool", + input={"text": "Open file at {{file_path}}"}, + type="tool_use", + ) + + # Value contains backslashes (like Windows paths) + # Previously raised: re.PatternError: bad escape \A at position 2 + result = CacheParameterHandler.substitute_parameters( + tool_block, {"file_path": "C:\\AskUI\\test"} + ) + + assert result.input["text"] == "Open file at C:\\AskUI\\test" # type: ignore[index] + + +def test_substitute_parameters_with_various_backslash_sequences() -> None: + """Test substitution with various backslash escape sequences. + + Tests multiple scenarios where backslashes could cause issues: + - Windows UNC paths + - Regex patterns as values + - Various escape sequences that could be misinterpreted + """ + # Test case 1: Windows path with \D + tool_block = ToolUseBlockParam( + id="1", + name="tool", + input={"text": "Save to {{path}}"}, + type="tool_use", + ) + + result = CacheParameterHandler.substitute_parameters( + tool_block, {"path": "D:\\Data\\file.txt"} + ) + assert result.input["text"] == "Save to D:\\Data\\file.txt" # type: ignore[index] + + # Test case 2: UNC path + tool_block2 = ToolUseBlockParam( + id="2", + name="tool", + input={"text": "Network path: {{unc_path}}"}, + type="tool_use", + ) + + result2 = CacheParameterHandler.substitute_parameters( + tool_block2, {"unc_path": "\\\\server\\share\\folder"} + ) + assert result2.input["text"] == "Network path: \\\\server\\share\\folder" # type: ignore[index] + + # Test case 3: Regex pattern as value (with backslashes) + tool_block3 = ToolUseBlockParam( + id="3", + name="tool", + input={"text": "Match pattern {{regex}}"}, + type="tool_use", + ) + + result3 = CacheParameterHandler.substitute_parameters( + tool_block3, {"regex": "\\d+\\s+\\w+"} + ) + assert result3.input["text"] == "Match pattern \\d+\\s+\\w+" # type: ignore[index] + + +def test_substitute_parameters_same_parameter_multiple_times() -> None: + """Test substituting the same parameter appearing multiple times.""" + tool_block = ToolUseBlockParam( + id="1", + name="tool", + input={"text": "{{var}} is {{var}} is {{var}}"}, + type="tool_use", + ) + + result = CacheParameterHandler.substitute_parameters(tool_block, {"var": "value"}) + + assert result.input["text"] == "value is value is value" # type: ignore[index] diff --git a/tests/unit/utils/test_cache_validator.py b/tests/unit/utils/test_cache_validator.py new file mode 100644 index 00000000..926ec251 --- /dev/null +++ b/tests/unit/utils/test_cache_validator.py @@ -0,0 +1,520 @@ +"""Tests for cache validation strategies.""" + +from datetime import datetime, timedelta, timezone + +import pytest + +from askui.models.shared.agent_message_param import ToolUseBlockParam +from askui.models.shared.settings import CacheFailure, CacheFile, CacheMetadata +from askui.utils.caching.cache_validator import ( + CacheValidator, + CompositeCacheValidator, + StaleCacheValidator, + StepFailureCountValidator, + TotalFailureRateValidator, +) + + +@pytest.fixture +def sample_cache_file() -> CacheFile: + """Create a sample cache file for testing.""" + return CacheFile( + metadata=CacheMetadata( + version="0.1", + created_at=datetime.now(tz=timezone.utc), + execution_attempts=0, + is_valid=True, + ), + trajectory=[ + ToolUseBlockParam(id="1", name="click", input={"x": 100}, type="tool_use"), + ToolUseBlockParam( + id="2", name="type", input={"text": "test"}, type="tool_use" + ), + ], + cache_parameters={}, + ) + + +# StepFailureCountValidator Tests + + +def test_step_failure_count_validator_below_threshold( + sample_cache_file: CacheFile, +) -> None: + """Test validator does not invalidate when failures are below threshold.""" + validator = StepFailureCountValidator(max_failures_per_step=3) + + # Add 2 failures at step 1 + sample_cache_file.metadata.failures = [ + CacheFailure( + timestamp=datetime.now(tz=timezone.utc), + step_index=1, + error_message="Error 1", + failure_count_at_step=1, + ), + CacheFailure( + timestamp=datetime.now(tz=timezone.utc), + step_index=1, + error_message="Error 2", + failure_count_at_step=2, + ), + ] + + should_inv, reason = validator.should_invalidate(sample_cache_file, step_index=1) + assert should_inv is False + assert reason is None + + +def test_step_failure_count_validator_at_threshold( + sample_cache_file: CacheFile, +) -> None: + """Test validator invalidates when failures reach threshold.""" + validator = StepFailureCountValidator(max_failures_per_step=3) + + # Add 3 failures at step 1 + sample_cache_file.metadata.failures = [ + CacheFailure( + timestamp=datetime.now(tz=timezone.utc), + step_index=1, + error_message=f"Error {i}", + failure_count_at_step=i + 1, + ) + for i in range(3) + ] + + should_inv, reason = validator.should_invalidate(sample_cache_file, step_index=1) + assert should_inv is True + assert "Step 1 failed 3 times" in reason # type: ignore[operator] + + +def test_step_failure_count_validator_different_steps( + sample_cache_file: CacheFile, +) -> None: + """Test validator only counts failures for specific step.""" + validator = StepFailureCountValidator(max_failures_per_step=3) + + # Add failures at different steps + sample_cache_file.metadata.failures = [ + CacheFailure( + timestamp=datetime.now(tz=timezone.utc), + step_index=1, + error_message="Error at step 1", + failure_count_at_step=1, + ), + CacheFailure( + timestamp=datetime.now(tz=timezone.utc), + step_index=2, + error_message="Error at step 2", + failure_count_at_step=1, + ), + CacheFailure( + timestamp=datetime.now(tz=timezone.utc), + step_index=1, + error_message="Error at step 1 again", + failure_count_at_step=2, + ), + ] + + # Step 1 has 2 failures (below threshold) + should_inv, reason = validator.should_invalidate(sample_cache_file, step_index=1) + assert should_inv is False + + # Step 2 has 1 failure (below threshold) + should_inv, reason = validator.should_invalidate(sample_cache_file, step_index=2) + assert should_inv is False + + +def test_step_failure_count_validator_no_step_index( + sample_cache_file: CacheFile, +) -> None: + """Test validator returns False when no step_index provided.""" + validator = StepFailureCountValidator(max_failures_per_step=3) + + should_inv, reason = validator.should_invalidate(sample_cache_file, step_index=None) + assert should_inv is False + assert reason is None + + +def test_step_failure_count_validator_name() -> None: + """Test validator returns correct name.""" + validator = StepFailureCountValidator() + assert validator.get_name() == "StepFailureCount" + + +# TotalFailureRateValidator Tests + + +def test_total_failure_rate_validator_below_min_attempts( + sample_cache_file: CacheFile, +) -> None: + """Test validator does not check rate below minimum attempts.""" + validator = TotalFailureRateValidator(min_attempts=10, max_failure_rate=0.5) + + sample_cache_file.metadata.execution_attempts = 5 + sample_cache_file.metadata.failures = [ + CacheFailure( + timestamp=datetime.now(tz=timezone.utc), + step_index=1, + error_message="Error", + failure_count_at_step=1, + ) + for _ in range(4) + ] # 4/5 = 80% failure rate + + should_inv, reason = validator.should_invalidate(sample_cache_file) + assert should_inv is False # Too few attempts to judge + + +def test_total_failure_rate_validator_above_threshold( + sample_cache_file: CacheFile, +) -> None: + """Test validator invalidates when failure rate exceeds threshold.""" + validator = TotalFailureRateValidator(min_attempts=10, max_failure_rate=0.5) + + sample_cache_file.metadata.execution_attempts = 10 + sample_cache_file.metadata.failures = [ + CacheFailure( + timestamp=datetime.now(tz=timezone.utc), + step_index=i % 2, + error_message=f"Error {i}", + failure_count_at_step=1, + ) + for i in range(6) + ] # 6/10 = 60% failure rate + + should_inv, reason = validator.should_invalidate(sample_cache_file) + assert should_inv is True + assert "60.0%" in reason # type: ignore[operator] + assert "50.0%" in reason # type: ignore[operator] + + +def test_total_failure_rate_validator_below_threshold( + sample_cache_file: CacheFile, +) -> None: + """Test validator does not invalidate when rate is acceptable.""" + validator = TotalFailureRateValidator(min_attempts=10, max_failure_rate=0.5) + + sample_cache_file.metadata.execution_attempts = 10 + sample_cache_file.metadata.failures = [ + CacheFailure( + timestamp=datetime.now(tz=timezone.utc), + step_index=1, + error_message=f"Error {i}", + failure_count_at_step=1, + ) + for i in range(4) + ] # 4/10 = 40% failure rate + + should_inv, reason = validator.should_invalidate(sample_cache_file) + assert should_inv is False + + +def test_total_failure_rate_validator_zero_attempts( + sample_cache_file: CacheFile, +) -> None: + """Test validator handles zero attempts correctly.""" + validator = TotalFailureRateValidator(min_attempts=10, max_failure_rate=0.5) + + sample_cache_file.metadata.execution_attempts = 0 + sample_cache_file.metadata.failures = [] + + should_inv, reason = validator.should_invalidate(sample_cache_file) + assert should_inv is False + + +def test_total_failure_rate_validator_name() -> None: + """Test validator returns correct name.""" + validator = TotalFailureRateValidator() + assert validator.get_name() == "TotalFailureRate" + + +# StaleCacheValidator Tests + + +def test_stale_cache_validator_not_stale(sample_cache_file: CacheFile) -> None: + """Test validator does not invalidate recent cache.""" + validator = StaleCacheValidator(max_age_days=30) + + sample_cache_file.metadata.last_executed_at = datetime.now( + tz=timezone.utc + ) - timedelta(days=10) + sample_cache_file.metadata.failures = [ + CacheFailure( + timestamp=datetime.now(tz=timezone.utc), + step_index=1, + error_message="Error", + failure_count_at_step=1, + ) + ] + + should_inv, reason = validator.should_invalidate(sample_cache_file) + assert should_inv is False + + +def test_stale_cache_validator_is_stale(sample_cache_file: CacheFile) -> None: + """Test validator invalidates old cache with failures.""" + validator = StaleCacheValidator(max_age_days=30) + + sample_cache_file.metadata.last_executed_at = datetime.now( + tz=timezone.utc + ) - timedelta(days=35) + sample_cache_file.metadata.failures = [ + CacheFailure( + timestamp=datetime.now(tz=timezone.utc), + step_index=1, + error_message="Error", + failure_count_at_step=1, + ) + ] + + should_inv, reason = validator.should_invalidate(sample_cache_file) + assert should_inv is True + assert "35 days" in reason # type: ignore[operator] + + +def test_stale_cache_validator_old_but_no_failures( + sample_cache_file: CacheFile, +) -> None: + """Test validator does not invalidate old cache without failures.""" + validator = StaleCacheValidator(max_age_days=30) + + sample_cache_file.metadata.last_executed_at = datetime.now( + tz=timezone.utc + ) - timedelta(days=100) + sample_cache_file.metadata.failures = [] + + should_inv, reason = validator.should_invalidate(sample_cache_file) + assert should_inv is False # Old but no failures = still valid + + +def test_stale_cache_validator_never_executed(sample_cache_file: CacheFile) -> None: + """Test validator handles cache that was never executed.""" + validator = StaleCacheValidator(max_age_days=30) + + sample_cache_file.metadata.last_executed_at = None + sample_cache_file.metadata.failures = [ + CacheFailure( + timestamp=datetime.now(tz=timezone.utc), + step_index=1, + error_message="Error", + failure_count_at_step=1, + ) + ] + + should_inv, reason = validator.should_invalidate(sample_cache_file) + assert should_inv is False # Never executed = can't be stale + + +def test_stale_cache_validator_name() -> None: + """Test validator returns correct name.""" + validator = StaleCacheValidator() + assert validator.get_name() == "StaleCache" + + +# CompositeCacheValidator Tests + + +def test_composite_validator_empty() -> None: + """Test composite validator with no validators.""" + validator = CompositeCacheValidator([]) + cache_file = CacheFile( + metadata=CacheMetadata( + version="0.1", + created_at=datetime.now(tz=timezone.utc), + execution_attempts=0, + is_valid=True, + ), + trajectory=[], + cache_parameters={}, + ) + + should_inv, reason = validator.should_invalidate(cache_file) + assert should_inv is False + assert reason is None + + +def test_composite_validator_single_validator_triggers( + sample_cache_file: CacheFile, +) -> None: + """Test composite validator with one validator that triggers.""" + step_validator = StepFailureCountValidator(max_failures_per_step=2) + composite = CompositeCacheValidator([step_validator]) + + # Add 2 failures + sample_cache_file.metadata.failures = [ + CacheFailure( + timestamp=datetime.now(tz=timezone.utc), + step_index=1, + error_message=f"Error {i}", + failure_count_at_step=i + 1, + ) + for i in range(2) + ] + + should_inv, reason = composite.should_invalidate(sample_cache_file, step_index=1) + assert should_inv is True + assert "StepFailureCount" in reason # type: ignore[operator] + + +def test_composite_validator_multiple_validators_all_pass( + sample_cache_file: CacheFile, +) -> None: + """Test composite validator when all validators pass.""" + composite = CompositeCacheValidator( + [ + StepFailureCountValidator(max_failures_per_step=3), + TotalFailureRateValidator(min_attempts=10, max_failure_rate=0.5), + ] + ) + + sample_cache_file.metadata.execution_attempts = 10 + sample_cache_file.metadata.failures = [ + CacheFailure( + timestamp=datetime.now(tz=timezone.utc), + step_index=1, + error_message="Error", + failure_count_at_step=1, + ) + ] # 1/10 = 10% rate, and only 1 failure at step 1 + + should_inv, reason = composite.should_invalidate(sample_cache_file, step_index=1) + assert should_inv is False + + +def test_composite_validator_multiple_validators_one_triggers( + sample_cache_file: CacheFile, +) -> None: + """Test composite validator when one validator triggers.""" + composite = CompositeCacheValidator( + [ + StepFailureCountValidator(max_failures_per_step=2), + TotalFailureRateValidator(min_attempts=100, max_failure_rate=0.5), + ] + ) + + sample_cache_file.metadata.execution_attempts = 10 # Below min_attempts + sample_cache_file.metadata.failures = [ + CacheFailure( + timestamp=datetime.now(tz=timezone.utc), + step_index=1, + error_message=f"Error {i}", + failure_count_at_step=i + 1, + ) + for i in range(3) + ] # 3 failures at step 1 (exceeds threshold of 2) + + should_inv, reason = composite.should_invalidate(sample_cache_file, step_index=1) + assert should_inv is True + assert "StepFailureCount" in reason # type: ignore[operator] + assert "Step 1 failed 3 times" in reason # type: ignore[operator] + + +def test_composite_validator_multiple_validators_multiple_trigger( + sample_cache_file: CacheFile, +) -> None: + """Test composite validator when multiple validators trigger.""" + composite = CompositeCacheValidator( + [ + StepFailureCountValidator(max_failures_per_step=2), + TotalFailureRateValidator(min_attempts=5, max_failure_rate=0.5), + ] + ) + + sample_cache_file.metadata.execution_attempts = 5 + sample_cache_file.metadata.failures = [ + CacheFailure( + timestamp=datetime.now(tz=timezone.utc), + step_index=1, + error_message=f"Error {i}", + failure_count_at_step=i + 1, + ) + for i in range(4) + ] # 4/5 = 80% rate (exceeds 50%), and 4 failures at step 1 (exceeds 2) + + should_inv, reason = composite.should_invalidate(sample_cache_file, step_index=1) + assert should_inv is True + assert "StepFailureCount" in reason # type: ignore[operator] + assert "TotalFailureRate" in reason # type: ignore[operator] + assert ";" in reason # type: ignore[operator] # Multiple reasons combined + + +def test_composite_validator_add_validator(sample_cache_file: CacheFile) -> None: + """Test adding validator to composite after initialization.""" + composite = CompositeCacheValidator([]) + assert len(composite.validators) == 0 + + composite.add_validator(StepFailureCountValidator(max_failures_per_step=1)) + assert len(composite.validators) == 1 + + # Add failure + sample_cache_file.metadata.failures = [ + CacheFailure( + timestamp=datetime.now(tz=timezone.utc), + step_index=1, + error_message="Error", + failure_count_at_step=1, + ) + ] + + should_inv, reason = composite.should_invalidate(sample_cache_file, step_index=1) + assert should_inv is True + + +def test_composite_validator_name() -> None: + """Test composite validator returns correct name.""" + composite = CompositeCacheValidator([]) + assert composite.get_name() == "CompositeValidator" + + +# Custom Validator Tests + + +class CustomTestValidator(CacheValidator): + """Custom test validator for testing extensibility.""" + + def __init__(self, should_trigger: bool = False): + self.should_trigger = should_trigger + + def should_invalidate( + self, _cache_file: CacheFile, _step_index: int | None = None + ) -> tuple[bool, str | None]: + if self.should_trigger: + return True, "Custom validation failed" + return False, None + + def get_name(self) -> str: + return "CustomTest" + + +def test_custom_validator_integration(sample_cache_file: CacheFile) -> None: + """Test that custom validators work with composite.""" + custom = CustomTestValidator(should_trigger=True) + composite = CompositeCacheValidator([custom]) + + should_inv, reason = composite.should_invalidate(sample_cache_file) + assert should_inv is True + assert "CustomTest" in reason # type: ignore[operator] + assert "Custom validation failed" in reason # type: ignore[operator] + + +def test_custom_validator_with_built_in(sample_cache_file: CacheFile) -> None: + """Test custom validator alongside built-in validators.""" + custom = CustomTestValidator(should_trigger=False) + step_validator = StepFailureCountValidator(max_failures_per_step=1) + + composite = CompositeCacheValidator([custom, step_validator]) + + # Add failure to trigger step validator + sample_cache_file.metadata.failures = [ + CacheFailure( + timestamp=datetime.now(tz=timezone.utc), + step_index=1, + error_message="Error", + failure_count_at_step=1, + ) + ] + + should_inv, reason = composite.should_invalidate(sample_cache_file, step_index=1) + assert should_inv is True + assert "StepFailureCount" in reason # type: ignore[operator] + assert "CustomTest" not in reason # type: ignore[operator] # Custom didn't trigger diff --git a/tests/unit/utils/test_cache_writer.py b/tests/unit/utils/test_cache_writer.py index 2c875ae4..6f6e84d7 100644 --- a/tests/unit/utils/test_cache_writer.py +++ b/tests/unit/utils/test_cache_writer.py @@ -7,13 +7,19 @@ from askui.models.shared.agent_message_param import MessageParam, ToolUseBlockParam from askui.models.shared.agent_on_message_cb import OnMessageCbParam -from askui.utils.cache_writer import CacheWriter +from askui.models.shared.settings import CacheFile, CacheWritingSettings + +# Note: CacheWritingSettings was renamed to CacheWritingSettings in v0.2 +from askui.utils.caching.cache_writer import CacheWriter def test_cache_writer_initialization() -> None: """Test CacheWriter initialization.""" with tempfile.TemporaryDirectory() as temp_dir: - cache_writer = CacheWriter(cache_dir=temp_dir, file_name="test.json") + cache_writer = CacheWriter( + cache_dir=temp_dir, + cache_writing_settings=CacheWritingSettings(filename="test.json"), + ) assert cache_writer.cache_dir == Path(temp_dir) assert cache_writer.file_name == "test.json" assert cache_writer.messages == [] @@ -34,17 +40,26 @@ def test_cache_writer_creates_cache_directory() -> None: def test_cache_writer_adds_json_extension() -> None: """Test that CacheWriter adds .json extension if not present.""" with tempfile.TemporaryDirectory() as temp_dir: - cache_writer = CacheWriter(cache_dir=temp_dir, file_name="test") + cache_writer = CacheWriter( + cache_dir=temp_dir, + cache_writing_settings=CacheWritingSettings(filename="test"), + ) assert cache_writer.file_name == "test.json" - cache_writer2 = CacheWriter(cache_dir=temp_dir, file_name="test.json") + cache_writer2 = CacheWriter( + cache_dir=temp_dir, + cache_writing_settings=CacheWritingSettings(filename="test.json"), + ) assert cache_writer2.file_name == "test.json" def test_cache_writer_add_message_cb_stores_tool_use_blocks() -> None: """Test that add_message_cb stores ToolUseBlockParam from assistant messages.""" with tempfile.TemporaryDirectory() as temp_dir: - cache_writer = CacheWriter(cache_dir=temp_dir, file_name="test.json") + cache_writer = CacheWriter( + cache_dir=temp_dir, + cache_writing_settings=CacheWritingSettings(filename="test.json"), + ) tool_use_block = ToolUseBlockParam( id="test_id", @@ -73,7 +88,10 @@ def test_cache_writer_add_message_cb_stores_tool_use_blocks() -> None: def test_cache_writer_add_message_cb_ignores_non_tool_use_content() -> None: """Test that add_message_cb ignores non-ToolUseBlockParam content.""" with tempfile.TemporaryDirectory() as temp_dir: - cache_writer = CacheWriter(cache_dir=temp_dir, file_name="test.json") + cache_writer = CacheWriter( + cache_dir=temp_dir, + cache_writing_settings=CacheWritingSettings(filename="test.json"), + ) message = MessageParam( role="assistant", @@ -93,7 +111,10 @@ def test_cache_writer_add_message_cb_ignores_non_tool_use_content() -> None: def test_cache_writer_add_message_cb_ignores_user_messages() -> None: """Test that add_message_cb ignores user messages.""" with tempfile.TemporaryDirectory() as temp_dir: - cache_writer = CacheWriter(cache_dir=temp_dir, file_name="test.json") + cache_writer = CacheWriter( + cache_dir=temp_dir, + cache_writing_settings=CacheWritingSettings(filename="test.json"), + ) message = MessageParam( role="user", @@ -113,7 +134,10 @@ def test_cache_writer_add_message_cb_ignores_user_messages() -> None: def test_cache_writer_detects_cached_execution() -> None: """Test that CacheWriter detects when execute_cached_executions_tool is used.""" with tempfile.TemporaryDirectory() as temp_dir: - cache_writer = CacheWriter(cache_dir=temp_dir, file_name="test.json") + cache_writer = CacheWriter( + cache_dir=temp_dir, + cache_writing_settings=CacheWritingSettings(filename="test.json"), + ) tool_use_block = ToolUseBlockParam( id="cached_exec_id", @@ -138,10 +162,15 @@ def test_cache_writer_detects_cached_execution() -> None: def test_cache_writer_generate_writes_file() -> None: - """Test that generate() writes messages to a JSON file.""" + """Test that generate() writes messages to a JSON file in v0.1 format.""" with tempfile.TemporaryDirectory() as temp_dir: cache_dir = Path(temp_dir) - cache_writer = CacheWriter(cache_dir=str(cache_dir), file_name="output.json") + cache_writer = CacheWriter( + cache_dir=str(cache_dir), + cache_writing_settings=CacheWritingSettings( + filename="output.json", parameter_identification_strategy="preset" + ), + ) # Add some tool use blocks tool_use1 = ToolUseBlockParam( @@ -164,22 +193,39 @@ def test_cache_writer_generate_writes_file() -> None: cache_file = cache_dir / "output.json" assert cache_file.exists() - # Verify file content + # Verify file content (v0.1 format) with cache_file.open("r", encoding="utf-8") as f: data = json.load(f) - assert len(data) == 2 - assert data[0]["id"] == "id1" - assert data[0]["name"] == "tool1" - assert data[1]["id"] == "id2" - assert data[1]["name"] == "tool2" + # Check v0.1 structure + assert "metadata" in data + assert "trajectory" in data + assert "cache_parameters" in data + + # Check metadata + assert data["metadata"]["version"] == "0.1" + assert "created_at" in data["metadata"] + assert data["metadata"]["execution_attempts"] == 0 + assert data["metadata"]["is_valid"] is True + + # Check trajectory + assert len(data["trajectory"]) == 2 + assert data["trajectory"][0]["id"] == "id1" + assert data["trajectory"][0]["name"] == "tool1" + assert data["trajectory"][1]["id"] == "id2" + assert data["trajectory"][1]["name"] == "tool2" def test_cache_writer_generate_auto_names_file() -> None: """Test that generate() auto-generates filename if not provided.""" with tempfile.TemporaryDirectory() as temp_dir: cache_dir = Path(temp_dir) - cache_writer = CacheWriter(cache_dir=str(cache_dir), file_name="") + cache_writer = CacheWriter( + cache_dir=str(cache_dir), + cache_writing_settings=CacheWritingSettings( + filename="", parameter_identification_strategy="preset" + ), + ) tool_use = ToolUseBlockParam( id="id1", @@ -200,7 +246,10 @@ def test_cache_writer_generate_skips_cached_execution() -> None: """Test that generate() doesn't write file for cached executions.""" with tempfile.TemporaryDirectory() as temp_dir: cache_dir = Path(temp_dir) - cache_writer = CacheWriter(cache_dir=str(cache_dir), file_name="test.json") + cache_writer = CacheWriter( + cache_dir=str(cache_dir), + cache_writing_settings=CacheWritingSettings(filename="test.json"), + ) cache_writer.was_cached_execution = True cache_writer.messages = [ @@ -222,7 +271,10 @@ def test_cache_writer_generate_skips_cached_execution() -> None: def test_cache_writer_reset() -> None: """Test that reset() clears messages and filename.""" with tempfile.TemporaryDirectory() as temp_dir: - cache_writer = CacheWriter(cache_dir=temp_dir, file_name="original.json") + cache_writer = CacheWriter( + cache_dir=temp_dir, + cache_writing_settings=CacheWritingSettings(filename="original.json"), + ) # Add some data cache_writer.messages = [ @@ -243,12 +295,12 @@ def test_cache_writer_reset() -> None: assert cache_writer.was_cached_execution is False -def test_cache_writer_read_cache_file() -> None: - """Test that read_cache_file() loads ToolUseBlockParam from JSON.""" +def test_cache_writer_read_cache_file_v1() -> None: + """Test backward compatibility: read_cache_file() loads v0.0 format.""" with tempfile.TemporaryDirectory() as temp_dir: - cache_file = Path(temp_dir) / "test_cache.json" + cache_file_path = Path(temp_dir) / "test_cache.json" - # Create a cache file + # Create a v0.0 cache file (just a list) trajectory: list[dict[str, Any]] = [ { "id": "id1", @@ -264,25 +316,80 @@ def test_cache_writer_read_cache_file() -> None: }, ] - with cache_file.open("w", encoding="utf-8") as f: + with cache_file_path.open("w", encoding="utf-8") as f: json.dump(trajectory, f) # Read cache file - result = CacheWriter.read_cache_file(cache_file) + result = CacheWriter.read_cache_file(cache_file_path) + + # Should return CacheFile with migrated v0.0 data (now v0.1) + assert isinstance(result, CacheFile) + assert result.metadata.version == "0.1" # Migrated from v0.0 to v0.1 + assert len(result.trajectory) == 2 + assert isinstance(result.trajectory[0], ToolUseBlockParam) + assert result.trajectory[0].id == "id1" + assert result.trajectory[0].name == "tool1" + assert isinstance(result.trajectory[1], ToolUseBlockParam) + assert result.trajectory[1].id == "id2" + assert result.trajectory[1].name == "tool2" + + +def test_cache_writer_read_cache_file_v2() -> None: + """Test that read_cache_file() loads v0.1 format correctly.""" + with tempfile.TemporaryDirectory() as temp_dir: + cache_file_path = Path(temp_dir) / "test_cache_v2.json" + + # Create a v0.1 cache file + cache_data = { + "metadata": { + "version": "0.1", + "created_at": "2025-12-11T10:00:00Z", + "last_executed_at": None, + "execution_attempts": 0, + "failures": [], + "is_valid": True, + "invalidation_reason": None, + }, + "trajectory": [ + { + "id": "id1", + "name": "tool1", + "input": {"param": "value1"}, + "type": "tool_use", + }, + { + "id": "id2", + "name": "tool2", + "input": {"param": "value2"}, + "type": "tool_use", + }, + ], + "cache_parameters": {"current_date": "Current date in YYYY-MM-DD format"}, + } + + with cache_file_path.open("w", encoding="utf-8") as f: + json.dump(cache_data, f) + + # Read cache file + result = CacheWriter.read_cache_file(cache_file_path) - assert len(result) == 2 - assert isinstance(result[0], ToolUseBlockParam) - assert result[0].id == "id1" - assert result[0].name == "tool1" - assert isinstance(result[1], ToolUseBlockParam) - assert result[1].id == "id2" - assert result[1].name == "tool2" + # Should return CacheFile + assert isinstance(result, CacheFile) + assert result.metadata.version == "0.1" + assert result.metadata.is_valid is True + assert len(result.trajectory) == 2 + assert result.trajectory[0].id == "id1" + assert result.trajectory[1].id == "id2" + assert "current_date" in result.cache_parameters def test_cache_writer_set_file_name() -> None: """Test that set_file_name() updates the filename.""" with tempfile.TemporaryDirectory() as temp_dir: - cache_writer = CacheWriter(cache_dir=temp_dir, file_name="original.json") + cache_writer = CacheWriter( + cache_dir=temp_dir, + cache_writing_settings=CacheWritingSettings(filename="original.json"), + ) cache_writer.set_file_name("new_name") assert cache_writer.file_name == "new_name.json" @@ -295,7 +402,12 @@ def test_cache_writer_generate_resets_after_writing() -> None: """Test that generate() calls reset() after writing the file.""" with tempfile.TemporaryDirectory() as temp_dir: cache_dir = Path(temp_dir) - cache_writer = CacheWriter(cache_dir=str(cache_dir), file_name="test.json") + cache_writer = CacheWriter( + cache_dir=str(cache_dir), + cache_writing_settings=CacheWritingSettings( + filename="test.json", parameter_identification_strategy="preset" + ), + ) cache_writer.messages = [ ToolUseBlockParam( @@ -310,3 +422,77 @@ def test_cache_writer_generate_resets_after_writing() -> None: # After generate, messages should be empty assert cache_writer.messages == [] + + +def test_cache_writer_detects_and_stores_parameters() -> None: + """Test that CacheWriter detects parameters and stores them in metadata.""" + with tempfile.TemporaryDirectory() as temp_dir: + cache_dir = Path(temp_dir) + cache_writer = CacheWriter( + cache_dir=str(cache_dir), + cache_writing_settings=CacheWritingSettings( + filename="test.json", parameter_identification_strategy="preset" + ), + ) + + # Add tool use blocks with parameters + cache_writer.messages = [ + ToolUseBlockParam( + id="id1", + name="computer", + input={"action": "type", "text": "Today is {{current_date}}"}, + type="tool_use", + ), + ToolUseBlockParam( + id="id2", + name="computer", + input={"action": "type", "text": "User: {{user_name}}"}, + type="tool_use", + ), + ] + + cache_writer.generate() + + # Read back the cache file + cache_file = cache_dir / "test.json" + with cache_file.open("r", encoding="utf-8") as f: + data = json.load(f) + + # Verify parameters were detected and stored + assert "cache_parameters" in data + assert "current_date" in data["cache_parameters"] + assert "user_name" in data["cache_parameters"] + assert len(data["cache_parameters"]) == 2 + + +def test_cache_writer_empty_parameters_when_none_found() -> None: + """Test that parameters dict is empty when no parameters exist.""" + with tempfile.TemporaryDirectory() as temp_dir: + cache_dir = Path(temp_dir) + cache_writer = CacheWriter( + cache_dir=str(cache_dir), + cache_writing_settings=CacheWritingSettings( + filename="test.json", parameter_identification_strategy="preset" + ), + ) + + # Add tool use blocks without parameters + cache_writer.messages = [ + ToolUseBlockParam( + id="id1", + name="computer", + input={"action": "click", "coordinate": [100, 200]}, + type="tool_use", + ) + ] + + cache_writer.generate() + + # Read back the cache file + cache_file = cache_dir / "test.json" + with cache_file.open("r", encoding="utf-8") as f: + data = json.load(f) + + # Verify parameters dict is empty + assert "cache_parameters" in data + assert data["cache_parameters"] == {} diff --git a/tests/unit/utils/test_trajectory_executor.py b/tests/unit/utils/test_trajectory_executor.py new file mode 100644 index 00000000..76e0f010 --- /dev/null +++ b/tests/unit/utils/test_trajectory_executor.py @@ -0,0 +1,845 @@ +"""Unit tests for TrajectoryExecutor.""" + +from unittest.mock import MagicMock + +from askui.models.shared.agent_message_param import ( + MessageParam, + TextBlockParam, + ToolResultBlockParam, + ToolUseBlockParam, +) +from askui.models.shared.tools import ToolCollection +from askui.utils.trajectory_executor import TrajectoryExecutor + + +def test_trajectory_executor_initialization() -> None: + """Test TrajectoryExecutor initialization.""" + trajectory = [ToolUseBlockParam(id="1", name="tool1", input={}, type="tool_use")] + toolbox = ToolCollection() + + executor = TrajectoryExecutor( + trajectory=trajectory, + toolbox=toolbox, + parameter_values={"var": "value"}, + delay_time=0.1, + ) + + assert executor.trajectory == trajectory + assert executor.toolbox == toolbox + assert executor.parameter_values == {"var": "value"} + assert executor.delay_time == 0.1 + assert executor.current_step_index == 0 + + +def test_trajectory_executor_execute_simple_step() -> None: + """Test executing a simple step.""" + mock_toolbox = MagicMock(spec=ToolCollection) + mock_toolbox.run.return_value = [ + ToolResultBlockParam( + tool_use_id="1", + content=[TextBlockParam(type="text", text="Tool result")], + ) + ] + mock_toolbox._tool_map = {} + + trajectory = [ + ToolUseBlockParam( + id="1", name="test_tool", input={"param": "value"}, type="tool_use" + ) + ] + + executor = TrajectoryExecutor( + trajectory=trajectory, toolbox=mock_toolbox, delay_time=0 + ) + + result = executor.execute_next_step() + + assert result.status == "SUCCESS" + assert result.step_index == 0 + assert result.error_message is None + assert executor.current_step_index == 1 + assert mock_toolbox.run.call_count == 1 + + +def test_trajectory_executor_execute_all_steps() -> None: + """Test executing all steps in a trajectory.""" + mock_toolbox = MagicMock(spec=ToolCollection) + mock_toolbox.run.return_value = [ + ToolResultBlockParam( + tool_use_id="1", + content=[TextBlockParam(type="text", text="Result")], + ) + ] + mock_toolbox._tool_map = {} + + trajectory = [ + ToolUseBlockParam(id="1", name="tool1", input={}, type="tool_use"), + ToolUseBlockParam(id="2", name="tool2", input={}, type="tool_use"), + ] + + executor = TrajectoryExecutor( + trajectory=trajectory, toolbox=mock_toolbox, delay_time=0 + ) + + results = executor.execute_all() + + # Should have 3 results: 2 successful steps + 1 completed + assert len(results) == 3 + assert results[0].status == "SUCCESS" + assert results[0].step_index == 0 + assert results[1].status == "SUCCESS" + assert results[1].step_index == 1 + assert results[2].status == "COMPLETED" + assert executor.current_step_index == 2 + assert mock_toolbox.run.call_count == 2 + + +def test_trajectory_executor_executes_screenshot_tools() -> None: + """Test that screenshot tools are executed (not skipped).""" + mock_toolbox = MagicMock(spec=ToolCollection) + mock_toolbox.run.return_value = [ + ToolResultBlockParam( + tool_use_id="1", + content=[TextBlockParam(type="text", text="Screenshot result")], + ) + ] + mock_toolbox._tool_map = {} + + trajectory = [ + ToolUseBlockParam(id="1", name="screenshot", input={}, type="tool_use"), + ToolUseBlockParam(id="2", name="tool1", input={}, type="tool_use"), + ] + + executor = TrajectoryExecutor( + trajectory=trajectory, toolbox=mock_toolbox, delay_time=0 + ) + + result = executor.execute_next_step() + + # Should execute screenshot tool + assert result.status == "SUCCESS" + assert result.step_index == 0 # First step executed + assert mock_toolbox.run.call_count == 1 + # Verify screenshot tool was called + assert mock_toolbox.run.call_args[0][0][0].name == "screenshot" + + +def test_trajectory_executor_executes_retrieve_trajectories_tool() -> None: + """Test that retrieve_available_trajectories_tool is executed (not skipped).""" + mock_toolbox = MagicMock(spec=ToolCollection) + mock_toolbox.run.return_value = [ + ToolResultBlockParam( + tool_use_id="1", + content=[TextBlockParam(type="text", text="Trajectory list")], + ) + ] + mock_toolbox._tool_map = {} + + trajectory = [ + ToolUseBlockParam( + id="1", + name="retrieve_available_trajectories_tool", + input={}, + type="tool_use", + ), + ToolUseBlockParam(id="2", name="tool1", input={}, type="tool_use"), + ] + + executor = TrajectoryExecutor( + trajectory=trajectory, toolbox=mock_toolbox, delay_time=0 + ) + + result = executor.execute_next_step() + + # Should execute retrieve tool + assert result.status == "SUCCESS" + assert result.step_index == 0 # First step executed + assert mock_toolbox.run.call_count == 1 + # Verify retrieve tool was called + assert ( + mock_toolbox.run.call_args[0][0][0].name + == "retrieve_available_trajectories_tool" + ) + + +def test_trajectory_executor_pauses_at_non_cacheable_tool() -> None: + """Test that execution pauses at non-cacheable tools.""" + mock_toolbox = MagicMock(spec=ToolCollection) + mock_toolbox.run.return_value = [ + ToolResultBlockParam( + tool_use_id="1", + content=[TextBlockParam(type="text", text="Result")], + ) + ] + + # Create mock tools in the toolbox + cacheable_tool = MagicMock() + cacheable_tool.is_cacheable = True + non_cacheable_tool = MagicMock() + non_cacheable_tool.is_cacheable = False + + mock_toolbox._tool_map = { + "cacheable": cacheable_tool, + "non_cacheable": non_cacheable_tool, + } + + trajectory = [ + ToolUseBlockParam(id="1", name="cacheable", input={}, type="tool_use"), + ToolUseBlockParam(id="2", name="non_cacheable", input={}, type="tool_use"), + ToolUseBlockParam(id="3", name="cacheable", input={}, type="tool_use"), + ] + + executor = TrajectoryExecutor( + trajectory=trajectory, toolbox=mock_toolbox, delay_time=0 + ) + + results = executor.execute_all() + + # Should execute first step, then pause at non-cacheable + assert len(results) == 2 + assert results[0].status == "SUCCESS" + assert results[0].step_index == 0 + assert results[1].status == "NEEDS_AGENT" + assert results[1].step_index == 1 + assert mock_toolbox.run.call_count == 1 # Only first step executed + assert executor.current_step_index == 1 # Paused at step 1 + + +def test_trajectory_executor_handles_tool_error() -> None: + """Test that executor handles tool execution errors gracefully.""" + mock_toolbox = MagicMock(spec=ToolCollection) + mock_toolbox._tool_map = {} + mock_toolbox.run.side_effect = Exception("Tool execution failed") + + trajectory = [ + ToolUseBlockParam(id="1", name="failing_tool", input={}, type="tool_use") + ] + + executor = TrajectoryExecutor( + trajectory=trajectory, toolbox=mock_toolbox, delay_time=0 + ) + + result = executor.execute_next_step() + + assert result.status == "FAILED" + assert result.step_index == 0 + assert "Tool execution failed" in (result.error_message or "") + + +def test_trajectory_executor_substitutes_cache_parameters() -> None: + """Test that executor substitutes cache_parameters before execution.""" + captured_steps = [] + + def capture_run(steps): # type: ignore + captured_steps.extend(steps) + return [ + ToolResultBlockParam( + tool_use_id="1", + content=[TextBlockParam(type="text", text="Result")], + ) + ] + + mock_toolbox = MagicMock(spec=ToolCollection) + mock_toolbox.run = capture_run + mock_toolbox._tool_map = {} + + trajectory = [ + ToolUseBlockParam( + id="1", + name="test_tool", + input={"text": "Hello {{name}}"}, + type="tool_use", + ) + ] + + executor = TrajectoryExecutor( + trajectory=trajectory, + toolbox=mock_toolbox, + parameter_values={"name": "Alice"}, + delay_time=0, + ) + + result = executor.execute_next_step() + + assert result.status == "SUCCESS" + assert len(captured_steps) == 1 + assert captured_steps[0].input["text"] == "Hello Alice" + + +def test_trajectory_executor_get_current_step_index() -> None: + """Test getting current step index.""" + toolbox = ToolCollection() + trajectory = [ToolUseBlockParam(id="1", name="tool1", input={}, type="tool_use")] + + executor = TrajectoryExecutor(trajectory=trajectory, toolbox=toolbox, delay_time=0) + + assert executor.get_current_step_index() == 0 + + executor.current_step_index = 5 + assert executor.get_current_step_index() == 5 + + +def test_trajectory_executor_get_remaining_trajectory() -> None: + """Test getting remaining trajectory steps.""" + toolbox = ToolCollection() + trajectory = [ + ToolUseBlockParam(id="1", name="tool1", input={}, type="tool_use"), + ToolUseBlockParam(id="2", name="tool2", input={}, type="tool_use"), + ToolUseBlockParam(id="3", name="tool3", input={}, type="tool_use"), + ] + + executor = TrajectoryExecutor(trajectory=trajectory, toolbox=toolbox, delay_time=0) + + # Initially, all steps remain + remaining = executor.get_remaining_trajectory() + assert len(remaining) == 3 + + # After advancing to step 1 + executor.current_step_index = 1 + remaining = executor.get_remaining_trajectory() + assert len(remaining) == 2 + assert remaining[0].id == "2" + assert remaining[1].id == "3" + + # At the end + executor.current_step_index = 3 + remaining = executor.get_remaining_trajectory() + assert len(remaining) == 0 + + +def test_trajectory_executor_skip_current_step() -> None: + """Test skipping the current step.""" + toolbox = ToolCollection() + trajectory = [ + ToolUseBlockParam(id="1", name="tool1", input={}, type="tool_use"), + ToolUseBlockParam(id="2", name="tool2", input={}, type="tool_use"), + ] + + executor = TrajectoryExecutor(trajectory=trajectory, toolbox=toolbox, delay_time=0) + + assert executor.current_step_index == 0 + executor.skip_current_step() + assert executor.current_step_index == 1 + executor.skip_current_step() + assert executor.current_step_index == 2 + + +def test_trajectory_executor_skip_at_end_does_nothing() -> None: + """Test that skipping at the end doesn't cause errors.""" + toolbox = ToolCollection() + trajectory = [ToolUseBlockParam(id="1", name="tool1", input={}, type="tool_use")] + + executor = TrajectoryExecutor(trajectory=trajectory, toolbox=toolbox, delay_time=0) + + executor.current_step_index = 1 # Already at end + executor.skip_current_step() + assert executor.current_step_index == 1 # Stays at end + + +def test_trajectory_executor_completed_status_when_done() -> None: + """Test that executor returns COMPLETED when all steps are done.""" + mock_toolbox = MagicMock(spec=ToolCollection) + mock_toolbox.run.return_value = [ + ToolResultBlockParam( + tool_use_id="1", + content=[TextBlockParam(type="text", text="Result")], + ) + ] + mock_toolbox._tool_map = {} + + trajectory = [ToolUseBlockParam(id="1", name="tool1", input={}, type="tool_use")] + + executor = TrajectoryExecutor( + trajectory=trajectory, toolbox=mock_toolbox, delay_time=0 + ) + + # Execute the step + result1 = executor.execute_next_step() + assert result1.status == "SUCCESS" + + # Try to execute again (no more steps) + result2 = executor.execute_next_step() + assert result2.status == "COMPLETED" + + +def test_trajectory_executor_execute_all_stops_on_failure() -> None: + """Test that execute_all stops when a step fails.""" + # Mock to fail on second call + call_count = [0] + + def mock_run(_steps): # type: ignore + call_count[0] += 1 + if call_count[0] == 2: + msg = "Second call fails" + raise Exception(msg) # noqa: TRY002 + return [ + ToolResultBlockParam( + tool_use_id="1", + content=[TextBlockParam(type="text", text="Result")], + ) + ] + + mock_toolbox = MagicMock(spec=ToolCollection) + mock_toolbox.run = mock_run + mock_toolbox._tool_map = {} + + trajectory = [ + ToolUseBlockParam(id="1", name="tool1", input={}, type="tool_use"), + ToolUseBlockParam(id="2", name="tool1", input={}, type="tool_use"), + ToolUseBlockParam(id="3", name="tool1", input={}, type="tool_use"), + ] + + executor = TrajectoryExecutor( + trajectory=trajectory, toolbox=mock_toolbox, delay_time=0 + ) + + results = executor.execute_all() + + # Should have 2 results: 1 success + 1 failure (stopped) + assert len(results) == 2 + assert results[0].status == "SUCCESS" + assert results[1].status == "FAILED" + assert executor.current_step_index == 1 # Stopped at failed step + + +def test_trajectory_executor_builds_message_history() -> None: + """Test that executor builds message history during execution.""" + mock_toolbox = MagicMock(spec=ToolCollection) + mock_toolbox.run.return_value = [ + ToolResultBlockParam( + tool_use_id="1", + content=[TextBlockParam(type="text", text="Result1")], + ) + ] + mock_toolbox._tool_map = {} + + trajectory = [ + ToolUseBlockParam( + id="tool1", name="test_tool", input={"x": 100}, type="tool_use" + ) + ] + + executor = TrajectoryExecutor( + trajectory=trajectory, toolbox=mock_toolbox, delay_time=0 + ) + + result = executor.execute_next_step() + + # Verify message history is built + assert result.status == "SUCCESS" + assert len(result.message_history) == 2 # Assistant + User message + + # Verify assistant message (tool use) + assert isinstance(result.message_history[0], MessageParam) + assert result.message_history[0].role == "assistant" + assert isinstance(result.message_history[0].content, list) + assert len(result.message_history[0].content) == 1 + assert isinstance(result.message_history[0].content[0], ToolUseBlockParam) + + # Verify user message (tool result) + assert isinstance(result.message_history[1], MessageParam) + assert result.message_history[1].role == "user" + assert isinstance(result.message_history[1].content, list) + assert len(result.message_history[1].content) == 1 + assert isinstance(result.message_history[1].content[0], ToolResultBlockParam) + + +def test_trajectory_executor_message_history_accumulates() -> None: + """Test that message history accumulates across multiple steps.""" + mock_toolbox = MagicMock(spec=ToolCollection) + mock_toolbox.run.return_value = [ + ToolResultBlockParam( + tool_use_id="1", + content=[TextBlockParam(type="text", text="Result")], + ) + ] + mock_toolbox._tool_map = {} + + trajectory = [ + ToolUseBlockParam(id="1", name="tool1", input={}, type="tool_use"), + ToolUseBlockParam(id="2", name="tool2", input={}, type="tool_use"), + ToolUseBlockParam(id="3", name="tool3", input={}, type="tool_use"), + ] + + executor = TrajectoryExecutor( + trajectory=trajectory, toolbox=mock_toolbox, delay_time=0 + ) + + results = executor.execute_all() + + # Should have 4 results: 3 successful steps + 1 completed + assert len(results) == 4 + + # Check message history grows + assert len(results[0].message_history) == 2 # Step 1: assistant + user + assert len(results[1].message_history) == 4 # Step 2: + assistant + user + assert len(results[2].message_history) == 6 # Step 3: + assistant + user + assert len(results[3].message_history) == 6 # Completed: same as last step + + +def test_trajectory_executor_message_history_contains_tool_use_id() -> None: + """Test that tool result has correct tool_use_id matching tool use.""" + mock_toolbox = MagicMock(spec=ToolCollection) + mock_toolbox.run.return_value = [ + ToolResultBlockParam( + tool_use_id="1", + content=[TextBlockParam(type="text", text="Success")], + ) + ] + mock_toolbox._tool_map = {} + + trajectory = [ToolUseBlockParam(id="1", name="tool", input={}, type="tool_use")] + + executor = TrajectoryExecutor( + trajectory=trajectory, toolbox=mock_toolbox, delay_time=0 + ) + + result = executor.execute_next_step() + + # Get tool use and tool result + tool_use = result.message_history[0].content[0] + tool_result = result.message_history[1].content[0] + + # Verify tool_use_id matches + assert isinstance(tool_use, ToolUseBlockParam) + assert isinstance(tool_result, ToolResultBlockParam) + assert tool_result.tool_use_id == tool_use.id + assert tool_result.tool_use_id == "1" + + +def test_trajectory_executor_message_history_includes_text_result() -> None: + """Test that tool results include text content.""" + mock_toolbox = MagicMock(spec=ToolCollection) + mock_toolbox.run.return_value = [ + ToolResultBlockParam( + tool_use_id="1", + content=[TextBlockParam(type="text", text="Tool executed successfully")], + ) + ] + mock_toolbox._tool_map = {} + + trajectory = [ToolUseBlockParam(id="1", name="tool", input={}, type="tool_use")] + + executor = TrajectoryExecutor( + trajectory=trajectory, toolbox=mock_toolbox, delay_time=0 + ) + + result = executor.execute_next_step() + + # Get tool result + tool_result_block = result.message_history[1].content[0] + assert isinstance(tool_result_block, ToolResultBlockParam) + + # Verify text content is included + assert isinstance(tool_result_block.content, list) + assert len(tool_result_block.content) == 1 + assert isinstance(tool_result_block.content[0], TextBlockParam) + assert tool_result_block.content[0].text == "Tool executed successfully" + + +def test_trajectory_executor_message_history_on_failure() -> None: + """Test that message history is included even when execution fails.""" + mock_toolbox = MagicMock(spec=ToolCollection) + mock_toolbox._tool_map = {} + mock_toolbox.run.side_effect = Exception("Tool failed") + + trajectory = [ + ToolUseBlockParam(id="1", name="failing_tool", input={}, type="tool_use") + ] + + executor = TrajectoryExecutor( + trajectory=trajectory, toolbox=mock_toolbox, delay_time=0 + ) + + result = executor.execute_next_step() + + assert result.status == "FAILED" + # Message history should include the assistant message (tool use) + # but not the user message (since execution failed) + assert len(result.message_history) == 1 + assert result.message_history[0].role == "assistant" + + +def test_trajectory_executor_message_history_on_pause() -> None: + """Test that message history is included when execution pauses.""" + mock_toolbox = MagicMock(spec=ToolCollection) + mock_toolbox.run.return_value = [ + ToolResultBlockParam( + tool_use_id="1", + content=[TextBlockParam(type="text", text="Result")], + ) + ] + + cacheable_tool = MagicMock() + cacheable_tool.is_cacheable = True + non_cacheable_tool = MagicMock() + non_cacheable_tool.is_cacheable = False + + mock_toolbox._tool_map = { + "cacheable": cacheable_tool, + "non_cacheable": non_cacheable_tool, + } + + trajectory = [ + ToolUseBlockParam(id="1", name="cacheable", input={}, type="tool_use"), + ToolUseBlockParam(id="2", name="non_cacheable", input={}, type="tool_use"), + ] + + executor = TrajectoryExecutor( + trajectory=trajectory, toolbox=mock_toolbox, delay_time=0 + ) + + results = executor.execute_all() + + # Should pause at non-cacheable step + assert len(results) == 2 + assert results[0].status == "SUCCESS" + assert results[1].status == "NEEDS_AGENT" + + # Message history should include only the successfully executed cacheable step: + # 1. First step: assistant message (tool use) + # 2. First step: user message (tool result) + # The non-cacheable tool is NOT in message history - instead it's in tool_result + assert len(results[1].message_history) == 2 + assert results[1].message_history[0].role == "assistant" # First cacheable tool use + assert results[1].message_history[1].role == "user" # First tool result + + # The non-cacheable tool should be in tool_result for reference + assert results[1].tool_result is not None + assert results[1].tool_result.name == "non_cacheable" + + +def test_trajectory_executor_message_history_order() -> None: + """Test that message history maintains correct order.""" + mock_toolbox = MagicMock(spec=ToolCollection) + mock_toolbox.run.return_value = [ + ToolResultBlockParam( + tool_use_id="1", + content=[TextBlockParam(type="text", text="Result")], + ) + ] + mock_toolbox._tool_map = {} + + trajectory = [ + ToolUseBlockParam(id="1", name="tool1", input={"step": 1}, type="tool_use"), + ToolUseBlockParam(id="2", name="tool2", input={"step": 2}, type="tool_use"), + ] + + executor = TrajectoryExecutor( + trajectory=trajectory, toolbox=mock_toolbox, delay_time=0 + ) + + results = executor.execute_all() + + # Get final message history + final_history = results[-1].message_history + + # Should have 4 messages: assistant, user, assistant, user + assert len(final_history) == 4 + assert final_history[0].role == "assistant" # Step 1 tool use + assert final_history[1].role == "user" # Step 1 result + assert final_history[2].role == "assistant" # Step 2 tool use + assert final_history[3].role == "user" # Step 2 result + + # Verify step order in tool use + tool_use_1 = final_history[0].content[0] + tool_use_2 = final_history[2].content[0] + assert isinstance(tool_use_1, ToolUseBlockParam) + assert isinstance(tool_use_2, ToolUseBlockParam) + assert tool_use_1.input == {"step": 1} + assert tool_use_2.input == {"step": 2} + + +# Visual Validation Extension Point Tests + + +def test_visual_validation_disabled_by_default() -> None: + """Test that visual validation is disabled by default.""" + mock_toolbox = MagicMock(spec=ToolCollection) + mock_toolbox.run.return_value = [ + ToolResultBlockParam( + tool_use_id="1", + content=[TextBlockParam(type="text", text="Result")], + ) + ] + mock_toolbox._tool_map = {} + + trajectory = [ + ToolUseBlockParam( + id="1", + name="click", + input={"x": 100}, + type="tool_use", + visual_representation="abc123", + ), + ] + + executor = TrajectoryExecutor( + trajectory=trajectory, toolbox=mock_toolbox, delay_time=0 + ) + + # visual_validation_enabled should be False by default + assert executor.visual_validation_enabled is False + + # Should execute successfully without validation + results = executor.execute_all() + assert results[0].status == "SUCCESS" + + +def test_visual_validation_enabled_flag() -> None: + """Test that visual validation flag can be enabled.""" + mock_toolbox = MagicMock(spec=ToolCollection) + mock_toolbox.run.return_value = [ + ToolResultBlockParam( + tool_use_id="1", + content=[TextBlockParam(type="text", text="Result")], + ) + ] + mock_toolbox._tool_map = {} + + trajectory = [ + ToolUseBlockParam(id="1", name="click", input={"x": 100}, type="tool_use"), + ] + + executor = TrajectoryExecutor( + trajectory=trajectory, + toolbox=mock_toolbox, + delay_time=0, + visual_validation_enabled=True, + ) + + # visual_validation_enabled should be True + assert executor.visual_validation_enabled is True + + +def test_validate_step_visually_hook_exists() -> None: + """Test that validate_step_visually hook exists and returns correct signature.""" + mock_toolbox = MagicMock(spec=ToolCollection) + mock_toolbox._tool_map = {} + + trajectory = [ + ToolUseBlockParam(id="1", name="click", input={"x": 100}, type="tool_use"), + ] + + executor = TrajectoryExecutor( + trajectory=trajectory, toolbox=mock_toolbox, delay_time=0 + ) + + # Hook should exist + assert hasattr(executor, "validate_step_visually") + assert callable(executor.validate_step_visually) + + # Hook should return correct signature + step = trajectory[0] + is_valid, error_msg = executor.validate_step_visually(step) + + assert isinstance(is_valid, bool) + assert is_valid is True # Currently always returns True + assert error_msg is None + + +def test_validate_step_visually_always_passes_when_disabled() -> None: + """Test that validation always passes when disabled (default behavior).""" + mock_toolbox = MagicMock(spec=ToolCollection) + mock_toolbox.run.return_value = [ + ToolResultBlockParam( + tool_use_id="1", + content=[TextBlockParam(type="text", text="Result")], + ) + ] + mock_toolbox._tool_map = {} + + # Create step with visual validation fields + trajectory = [ + ToolUseBlockParam( + id="1", + name="click", + input={"x": 100}, + type="tool_use", + visual_representation="abc123", + ), + ] + + # Validation disabled (default) + executor = TrajectoryExecutor( + trajectory=trajectory, + toolbox=mock_toolbox, + delay_time=0, + visual_validation_enabled=False, + ) + + # Should execute without calling validate_step_visually + results = executor.execute_all() + assert results[0].status == "SUCCESS" + + +def test_validate_step_visually_hook_called_when_enabled() -> None: + """Test that validate_step_visually is called when enabled.""" + mock_toolbox = MagicMock(spec=ToolCollection) + mock_toolbox.run.return_value = [ + ToolResultBlockParam( + tool_use_id="1", + content=[TextBlockParam(type="text", text="Result")], + ) + ] + mock_toolbox._tool_map = {} + + trajectory = [ + ToolUseBlockParam( + id="1", + name="click", + input={"x": 100}, + type="tool_use", + visual_representation="abc123", + ), + ] + + executor = TrajectoryExecutor( + trajectory=trajectory, + toolbox=mock_toolbox, + delay_time=0, + visual_validation_enabled=True, + ) + + # Mock the validation hook to track calls + original_validate = executor.validate_step_visually + validation_called = [] + + def mock_validate(step, screenshot=None) -> tuple[bool, str | None]: # type: ignore[no-untyped-def] + validation_called.append(step) + return original_validate(step, screenshot) + + executor.validate_step_visually = mock_validate # type: ignore[assignment] + + # Execute trajectory + results = executor.execute_all() + + # Validation should have been called + assert len(validation_called) == 1 + assert results[0].status == "SUCCESS" + + +def test_visual_validation_fields_on_tool_use_block() -> None: + """Test that ToolUseBlockParam supports visual validation fields. + + The visual_representation field stores perceptual hashes (pHash/aHash) for + visual validation during cache execution. + """ + # Create step with visual representation field + step = ToolUseBlockParam( + id="1", + name="click", + input={"x": 100, "y": 200}, + type="tool_use", + visual_representation="a8f3c9e14b7d2056", + ) + + # Field should be accessible + assert step.visual_representation == "a8f3c9e14b7d2056" + + # Default value should be None + step_default = ToolUseBlockParam( + id="2", name="type", input={"text": "hello"}, type="tool_use" + ) + + assert step_default.visual_representation is None diff --git a/tests/unit/utils/test_visual_validation.py b/tests/unit/utils/test_visual_validation.py new file mode 100644 index 00000000..acc50a6c --- /dev/null +++ b/tests/unit/utils/test_visual_validation.py @@ -0,0 +1,247 @@ +"""Tests for visual validation utilities.""" + +import pytest +from PIL import Image, ImageDraw + +from askui.utils.visual_validation import ( + compute_ahash, + compute_phash, + extract_region, + get_validation_coordinate, + hamming_distance, + should_validate_step, + validate_visual_hash, +) + + +class TestHashComputation: + """Test hash computation functions.""" + + def test_compute_phash_returns_hex_string(self) -> None: + """Test that compute_phash returns a hexadecimal string.""" + # Create a simple test image + img = Image.new("RGB", (100, 100), color="red") + hash_result = compute_phash(img) + + # Should be a hex string + assert isinstance(hash_result, str) + assert len(hash_result) > 0 + # Should be valid hex + int(hash_result, 16) # Will raise if not valid hex + + def test_compute_ahash_returns_hex_string(self) -> None: + """Test that compute_ahash returns a hexadecimal string.""" + # Create a simple test image + img = Image.new("RGB", (100, 100), color="blue") + hash_result = compute_ahash(img) + + # Should be a hex string + assert isinstance(hash_result, str) + assert len(hash_result) > 0 + # Should be valid hex + int(hash_result, 16) # Will raise if not valid hex + + def test_identical_images_produce_same_phash(self) -> None: + """Test that identical images produce identical phashes.""" + img1 = Image.new("RGB", (100, 100), color="green") + img2 = Image.new("RGB", (100, 100), color="green") + + hash1 = compute_phash(img1) + hash2 = compute_phash(img2) + + assert hash1 == hash2 + + def test_different_images_produce_different_phash(self) -> None: + """Test that different images produce different phashes.""" + # Create images with patterns, not solid colors + img1 = Image.new("RGB", (100, 100), color="white") + draw1 = ImageDraw.Draw(img1) + draw1.rectangle((10, 10, 50, 50), fill="red") + + img2 = Image.new("RGB", (100, 100), color="white") + draw2 = ImageDraw.Draw(img2) + draw2.rectangle((60, 60, 90, 90), fill="blue") + + hash1 = compute_phash(img1) + hash2 = compute_phash(img2) + + assert hash1 != hash2 + + +class TestHammingDistance: + """Test Hamming distance calculation.""" + + def test_identical_hashes_have_zero_distance(self) -> None: + """Test that identical hashes have Hamming distance of 0.""" + hash1 = "a1b2c3d4" + hash2 = "a1b2c3d4" + + distance = hamming_distance(hash1, hash2) + assert distance == 0 + + def test_different_hashes_have_nonzero_distance(self) -> None: + """Test that different hashes have non-zero Hamming distance.""" + hash1 = "ffffffff" # All 1s in binary + hash2 = "00000000" # All 0s in binary + + distance = hamming_distance(hash1, hash2) + assert distance > 0 + + def test_hamming_distance_raises_on_different_lengths(self) -> None: + """Test that hamming_distance raises ValueError for different lengths.""" + hash1 = "a1b2" + hash2 = "a1b2c3" + + with pytest.raises(ValueError, match="Hash lengths differ"): + hamming_distance(hash1, hash2) + + +class TestExtractRegion: + """Test region extraction from images.""" + + def test_extract_region_returns_image(self) -> None: + """Test that extract_region returns a PIL Image.""" + img = Image.new("RGB", (200, 200), color="red") + center = (100, 100) + + region = extract_region(img, center, size=50) + + assert isinstance(region, Image.Image) + + def test_extract_region_has_correct_size(self) -> None: + """Test that extracted region has correct size.""" + img = Image.new("RGB", (200, 200), color="red") + center = (100, 100) + size = 50 + + region = extract_region(img, center, size=size) + + # Region should be approximately size x size + assert region.width <= size + assert region.height <= size + + def test_extract_region_at_edge(self) -> None: + """Test that extract_region handles edge cases.""" + img = Image.new("RGB", (100, 100), color="red") + center = (10, 10) # Near edge + + # Should not raise an error + region = extract_region(img, center, size=50) + assert isinstance(region, Image.Image) + + +class TestValidateVisualHash: + """Test visual hash validation.""" + + def test_validate_visual_hash_passes_for_identical_images(self) -> None: + """Test validation passes for identical images.""" + img = Image.new("RGB", (100, 100), color="red") + stored_hash = compute_phash(img) + + is_valid, error_msg, distance = validate_visual_hash( + stored_hash, img, threshold=10, hash_method="phash" + ) + + assert is_valid is True + assert error_msg is None + assert distance == 0 + + def test_validate_visual_hash_fails_for_different_images(self) -> None: + """Test validation fails for very different images.""" + # Create images with different patterns + img1 = Image.new("RGB", (100, 100), color="white") + draw1 = ImageDraw.Draw(img1) + draw1.rectangle((10, 10, 50, 50), fill="red") + + img2 = Image.new("RGB", (100, 100), color="white") + draw2 = ImageDraw.Draw(img2) + draw2.rectangle((60, 60, 90, 90), fill="blue") + + stored_hash = compute_phash(img1) + + is_valid, error_msg, distance = validate_visual_hash( + stored_hash, img2, threshold=5, hash_method="phash" + ) + + # Should fail due to high distance + assert is_valid is False + assert error_msg is not None + assert "Visual validation failed" in error_msg + + def test_validate_visual_hash_with_ahash_method(self) -> None: + """Test validation works with ahash method.""" + img = Image.new("RGB", (100, 100), color="green") + stored_hash = compute_ahash(img) + + is_valid, error_msg, distance = validate_visual_hash( + stored_hash, img, threshold=10, hash_method="ahash" + ) + + assert is_valid is True + assert error_msg is None + assert distance == 0 + + def test_validate_visual_hash_unknown_method(self) -> None: + """Test validation fails gracefully with unknown hash method.""" + img = Image.new("RGB", (100, 100), color="red") + stored_hash = "abcdef" + + is_valid, error_msg, distance = validate_visual_hash( + stored_hash, img, threshold=10, hash_method="unknown_method" + ) + + assert is_valid is False + assert error_msg is not None + assert "Unknown hash method" in error_msg + + +class TestShouldValidateStep: + """Test step validation logic.""" + + def test_should_validate_left_click(self) -> None: + """Test that left_click actions should be validated.""" + assert should_validate_step("computer", "left_click") is True + + def test_should_validate_right_click(self) -> None: + """Test that right_click actions should be validated.""" + assert should_validate_step("computer", "right_click") is True + + def test_should_validate_type_action(self) -> None: + """Test that type actions should be validated.""" + assert should_validate_step("computer", "type") is True + + def test_should_not_validate_screenshot(self) -> None: + """Test that screenshot actions should not be validated.""" + assert should_validate_step("computer", "screenshot") is False + + def test_should_not_validate_unknown_tool(self) -> None: + """Test that unknown tools should not be validated.""" + assert should_validate_step("unknown_tool", None) is False + + +class TestGetValidationCoordinate: + """Test coordinate extraction for validation.""" + + def test_get_validation_coordinate_from_computer_tool(self) -> None: + """Test extracting coordinate from computer tool input.""" + tool_input = {"action": "left_click", "coordinate": [450, 300]} + + coord = get_validation_coordinate(tool_input) + + assert coord == (450, 300) + + def test_get_validation_coordinate_returns_none_without_coordinate(self) -> None: + """Test returns None when no coordinate in input.""" + tool_input = {"action": "screenshot"} + + coord = get_validation_coordinate(tool_input) + + assert coord is None + + def test_get_validation_coordinate_handles_invalid_format(self) -> None: + """Test handles invalid coordinate format gracefully.""" + tool_input = {"coordinate": "invalid"} + + coord = get_validation_coordinate(tool_input) + + assert coord is None