From 0165cee228207d7e210609fc7cce2176681949fa Mon Sep 17 00:00:00 2001 From: philipph-askui Date: Fri, 12 Dec 2025 14:55:48 +0100 Subject: [PATCH 01/30] feat(caching): major caching refactoring that adds caching v0.1 functionality --- docs/caching.md | 781 +++++++++++-- src/askui/agent_base.py | 61 +- src/askui/models/shared/agent.py | 253 ++++- src/askui/models/shared/settings.py | 32 + src/askui/models/shared/tools.py | 17 + src/askui/prompts/caching.py | 85 +- src/askui/tools/caching_tools.py | 831 +++++++++++++- src/askui/utils/cache_execution_manager.py | 339 ++++++ src/askui/utils/cache_manager.py | 145 +++ src/askui/utils/cache_migration.py | 303 +++++ src/askui/utils/cache_validator.py | 242 ++++ src/askui/utils/cache_writer.py | 219 +++- src/askui/utils/placeholder_handler.py | 298 +++++ src/askui/utils/placeholder_identifier.py | 134 +++ src/askui/utils/trajectory_executor.py | 333 ++++++ tests/unit/tools/test_caching_tools.py | 1061 +++++++++++++++--- tests/unit/utils/test_cache_manager.py | 378 +++++++ tests/unit/utils/test_cache_migration.py | 360 ++++++ tests/unit/utils/test_cache_validator.py | 486 ++++++++ tests/unit/utils/test_cache_writer.py | 169 ++- tests/unit/utils/test_placeholder_handler.py | 378 +++++++ tests/unit/utils/test_trajectory_executor.py | 754 +++++++++++++ 22 files changed, 7294 insertions(+), 365 deletions(-) create mode 100644 src/askui/utils/cache_execution_manager.py create mode 100644 src/askui/utils/cache_manager.py create mode 100644 src/askui/utils/cache_migration.py create mode 100644 src/askui/utils/cache_validator.py create mode 100644 src/askui/utils/placeholder_handler.py create mode 100644 src/askui/utils/placeholder_identifier.py create mode 100644 src/askui/utils/trajectory_executor.py create mode 100644 tests/unit/utils/test_cache_manager.py create mode 100644 tests/unit/utils/test_cache_migration.py create mode 100644 tests/unit/utils/test_cache_validator.py create mode 100644 tests/unit/utils/test_placeholder_handler.py create mode 100644 tests/unit/utils/test_trajectory_executor.py diff --git a/docs/caching.md b/docs/caching.md index d4da680c..d5bafe86 100644 --- a/docs/caching.md +++ b/docs/caching.md @@ -8,6 +8,8 @@ The caching mechanism allows you to record and replay agent action sequences (tr The caching system works by recording all tool use actions (mouse movements, clicks, typing, etc.) performed by the agent during an `act()` execution. These recorded sequences can then be replayed in subsequent executions, allowing the agent to skip the decision-making process and execute the actions directly. +**New in v0.1:** The caching system now includes advanced features like placeholder support for dynamic values, smart handling of non-cacheable tools that require agent intervention, comprehensive message history tracking, and automatic failure detection with recovery capabilities. + ## Caching Strategies The caching mechanism supports four strategies, configured via the `caching_settings` parameter in the `act()` method: @@ -28,6 +30,7 @@ caching_settings = CachingSettings( strategy="write", # One of: "read", "write", "both", "no" cache_dir=".cache", # Directory to store cache files filename="my_test.json", # Filename for the cache file (optional for write mode) + auto_identify_placeholders=True, # Auto-detect dynamic values (default: True) execute_cached_trajectory_tool_settings=CachedExecutionToolSettings( delay_time_between_action=0.5 # Delay in seconds between each cached action ) @@ -39,6 +42,7 @@ caching_settings = CachingSettings( - **`strategy`**: The caching strategy to use (`"read"`, `"write"`, `"both"`, or `"no"`). - **`cache_dir`**: Directory where cache files are stored. Defaults to `".cache"`. - **`filename`**: Name of the cache file to write to or read from. If not specified in write mode, a timestamped filename will be generated automatically (format: `cached_trajectory_YYYYMMDDHHMMSSffffff.json`). +- **`auto_identify_placeholders`**: **New in v0.1!** When `True` (default), uses AI to automatically identify and parameterize dynamic values like dates, usernames, and IDs during cache recording. When `False`, only manually specified placeholders (using `{{...}}` syntax) are detected. See [Automatic Placeholder Identification](#automatic-placeholder-identification). - **`execute_cached_trajectory_tool_settings`**: Configuration for the trajectory execution tool (optional). See [Execution Settings](#execution-settings) below. ### Execution Settings @@ -82,7 +86,7 @@ with VisionAgent() as agent: ) ``` -After execution, a cache file will be created at `.cache/login_test.json` containing all the tool use actions performed by the agent. +After execution, a cache file will be created at `.cache/login_test.json` containing all the tool use actions performed by the agent, along with metadata about the execution. ### Reading from Cache (Replaying) @@ -102,12 +106,91 @@ with VisionAgent() as agent: ) ``` -When using `strategy="read"`, the agent receives two additional tools: +When using `strategy="read"`, the agent receives two tools: + +1. **`RetrieveCachedTestExecutions`**: Lists all available cache files in the cache directory +2. **`ExecuteCachedTrajectory`**: Executes a cached trajectory. Can start from the beginning (default) or continue from a specific step index using the optional `start_from_step_index` parameter (useful after handling non-cacheable steps) + +The agent will automatically check if a relevant cached trajectory exists and use it if appropriate. During execution, the agent can see all screenshots and results in the message history. After executing a cached trajectory, the agent will verify the results and make corrections if needed. + +### Using Placeholders for Dynamic Values + +**New in v0.1:** Trajectories can contain placeholders for dynamic values that change between executions: + +```python +from askui import VisionAgent +from askui.models.shared.settings import CachingSettings + +# When recording, use dynamic values as normal +# The system automatically detects patterns like dates and user-specific data +with VisionAgent() as agent: + agent.act( + goal="Create a new task for today with the title 'Review PR'", + caching_settings=CachingSettings( + strategy="write", + cache_dir=".cache", + filename="create_task.json" + ) + ) + +# Later, when replaying, the agent can provide placeholder values +# If the cache file contains {{current_date}} or {{task_title}}, provide them: +with VisionAgent() as agent: + agent.act( + goal="Create a task using the cached flow", + caching_settings=CachingSettings( + strategy="read", + cache_dir=".cache" + ) + ) + # The agent will automatically detect required placeholders and can provide them + # via the placeholder_values parameter when calling ExecuteCachedTrajectory +``` + +Placeholders use the syntax `{{variable_name}}` and are automatically detected during cache file creation. When executing a trajectory with placeholders, the agent must provide values for all required placeholders. + +### Handling Non-Cacheable Steps + +**New in v0.1:** Some tools cannot be cached and require the agent to execute them live. Examples include debugging tools, contextual decisions, or tools that depend on runtime state. + +```python +from askui import VisionAgent +from askui.models.shared.settings import CachingSettings + +with VisionAgent() as agent: + agent.act( + goal="Debug the login form by checking element states", + caching_settings=CachingSettings( + strategy="read", + cache_dir=".cache" + ) + ) + # If the cached trajectory contains non-cacheable steps: + # 1. Execution pauses when reaching the non-cacheable step + # 2. Agent receives NEEDS_AGENT status with current step index + # 3. Agent executes the non-cacheable step manually + # 4. Agent uses ExecuteCachedTrajectory with start_from_step_index to resume +``` + +Tools can be marked as non-cacheable by setting `is_cacheable=False` in their definition. When trajectory execution reaches a non-cacheable tool, it pauses and returns control to the agent for manual execution. -1. **`retrieve_available_trajectories_tool`**: Lists all available cache files in the cache directory -2. **`execute_cached_executions_tool`**: Executes a specific cached trajectory +### Continuing from a Specific Step -The agent will automatically check if a relevant cached trajectory exists and use it if appropriate. After executing a cached trajectory, the agent will verify the results and make corrections if needed. +**New in v0.1:** After handling a non-cacheable step or recovering from a failure, the agent can continue execution from a specific step index using the `start_from_step_index` parameter: + +```python +# The agent uses ExecuteCachedTrajectory with start_from_step_index like this: +result = execute_cached_trajectory_tool( + trajectory_file=".cache/my_test.json", + start_from_step_index=5, # Continue from step 5 + placeholder_values={"date": "2025-12-11"} # Provide any required placeholders +) +``` + +This is particularly useful for: +- Resuming after manual execution of non-cacheable steps +- Recovering from partial failures +- Skipping steps that are no longer needed ### Referencing Cache Files in Goal Prompts @@ -124,8 +207,8 @@ from askui.models.shared.settings import CachingSettings with VisionAgent() as agent: agent.act( goal="""Open the website in Google Chrome. - - If the cache file "open_website_in_chrome.json" is available, please use it + + If the cache file "open_website_in_chrome.json" is available, please use it for this execution. It will open a new window in Chrome and navigate to the website.""", caching_settings=CachingSettings( strategy="read", @@ -147,8 +230,8 @@ test_id = "TEST_001" with VisionAgent() as agent: agent.act( goal=f"""Execute test {test_id} according to the test definition. - - Check if a cache file named "{test_id}.json" exists. If it does, use it to + + Check if a cache file named "{test_id}.json" exists. If it does, use it to replay the test actions, then verify the results.""", caching_settings=CachingSettings( strategy="read", @@ -168,9 +251,9 @@ from askui.models.shared.settings import CachingSettings with VisionAgent() as agent: agent.act( goal="""Fill out the user registration form. - - Look for cache files that match the pattern "user_registration_*.json". - Choose the most recent one if multiple are available, as it likely contains + + Look for cache files that match the pattern "user_registration_*.json". + Choose the most recent one if multiple are available, as it likely contains the most up-to-date interaction sequence.""", caching_settings=CachingSettings( strategy="read", @@ -190,11 +273,11 @@ from askui.models.shared.settings import CachingSettings with VisionAgent() as agent: agent.act( goal="""Complete the full checkout process: - + 1. If "login.json" exists, use it to log in 2. If "add_to_cart.json" exists, use it to add items to cart 3. If "checkout.json" exists, use it to complete the checkout - + After each cached execution, verify the step completed successfully before proceeding.""", caching_settings=CachingSettings( strategy="read", @@ -261,108 +344,680 @@ In this mode: ## Cache File Format -Cache files are JSON files containing an array of tool use blocks. Each block represents a single tool invocation with the following structure: +**New in v0.1:** Cache files now use an enhanced format with metadata tracking, placeholder support, and execution history. + +### v0.1 Format (Current) + +Cache files are JSON objects with the following structure: ```json -[ +{ + "metadata": { + "version": "0.1", + "created_at": "2025-12-11T10:30:00Z", + "goal": "Greet user {{user_name}} and log them in", + "last_executed_at": "2025-12-11T15:45:00Z", + "execution_attempts": 3, + "failures": [ + { + "timestamp": "2025-12-11T14:20:00Z", + "step_index": 5, + "error_message": "Element not found", + "failure_count_at_step": 1 + } + ], + "is_valid": true, + "invalidation_reason": null + }, + "trajectory": [ { - "type": "tool_use", - "id": "toolu_01AbCdEfGhIjKlMnOpQrStUv", - "name": "computer", - "input": { - "action": "mouse_move", - "coordinate": [150, 200] - } + "type": "tool_use", + "id": "toolu_01AbCdEfGhIjKlMnOpQrStUv", + "name": "computer", + "input": { + "action": "type", + "text": "Hello {{user_name}}!" + } }, { - "type": "tool_use", - "id": "toolu_02AbCdEfGhIjKlMnOpQrStUv", - "name": "computer", - "input": { - "action": "left_click" - } - }, - { - "type": "tool_use", - "id": "toolu_03AbCdEfGhIjKlMnOpQrStUv", - "name": "computer", - "input": { - "action": "type", - "text": "admin" - } + "type": "tool_use", + "id": "toolu_02XyZaBcDeFgHiJkLmNoPqRs", + "name": "print_debug_info", + "input": {} + } + ], + "placeholders": { + "user_name": "Name of the user to greet" + } +} +``` + +**Note:** In the example above, `print_debug_info` is marked as non-cacheable (`is_cacheable=False`), so its `input` field is blank (`{}`). This saves space and privacy since non-cacheable tools aren't executed from cache anyway. + +#### Metadata Fields + +- **`version`**: Cache file format version (currently "0.1") +- **`created_at`**: ISO 8601 timestamp when the cache was created +- **`goal`**: **New!** The original goal/instruction given to the agent when recording this trajectory. Placeholders are applied to the goal text just like in the trajectory, making it easy to understand what the cache was designed to accomplish. +- **`last_executed_at`**: ISO 8601 timestamp of the last execution (null if never executed) +- **`execution_attempts`**: Number of times this trajectory has been executed +- **`failures`**: List of failures encountered during execution (see [Failure Tracking](#failure-tracking)) +- **`is_valid`**: Boolean indicating if the cache is still considered valid +- **`invalidation_reason`**: Optional string explaining why the cache was invalidated + +#### Placeholders + +The `placeholders` object maps placeholder names to their descriptions. Placeholders in the trajectory use the syntax `{{placeholder_name}}` and must be substituted with actual values during execution. + +#### Failure Tracking + +Each failure record contains: +- **`timestamp`**: When the failure occurred +- **`step_index`**: Which step failed (0-indexed) +- **`error_message`**: The error that occurred +- **`failure_count_at_step`**: How many times this specific step has failed + +This information helps with cache invalidation decisions and debugging. + +### v0.0 Format (Legacy) + +The old format was a simple JSON array: + +```json +[ + { + "type": "tool_use", + "id": "toolu_01AbCdEfGhIjKlMnOpQrStUv", + "name": "computer", + "input": { + "action": "mouse_move", + "coordinate": [150, 200] } + } ] ``` -Note: Screenshot actions are excluded from cached trajectories as they don't modify the UI state. +**Backward Compatibility:** v0.0 cache files are automatically migrated to v0.1 format when read. The system adds default metadata and wraps the trajectory array in the new structure. This migration is transparent and requires no user intervention. ## How It Works +### Internal Architecture + +The caching system consists of several key components: + +- **`CacheWriter`**: Handles recording trajectories in write mode +- **`CacheExecutionManager`**: Manages cache execution state, flow control, and metadata updates during trajectory replay +- **`TrajectoryExecutor`**: Executes individual steps from cached trajectories +- **Agent**: Orchestrates the conversation flow and delegates cache execution to `CacheExecutionManager` + +When executing a cached trajectory, the `Agent` class delegates all cache-related logic to `CacheExecutionManager`, which handles: +- State management (execution mode, verification pending, etc.) +- Execution flow control (success, failure, needs agent, completed) +- Message history building and injection +- Metadata updates (execution attempts, failures, invalidation) + +This separation of concerns keeps the Agent focused on conversation orchestration while CacheExecutionManager handles all caching complexity. + ### Write Mode In write mode, the `CacheWriter` class: 1. Intercepts all assistant messages via a callback function 2. Extracts tool use blocks from the messages -3. Stores them in memory during execution -4. Writes them to a JSON file when the agent finishes (on `stop_reason="end_turn"`) +3. Stores tool blocks in memory during execution +4. When agent finishes (on `stop_reason="end_turn"`): + - **Automatically identifies placeholders** using AI (if `auto_identify_placeholders=True`) + - Analyzes trajectory to find dynamic values (dates, usernames, IDs, etc.) + - Generates descriptive placeholder definitions + - Replaces identified values with `{{placeholder_name}}` syntax in trajectory + - Applies same replacements to the goal text + - **Blanks non-cacheable tool inputs** by setting `input: {}` for tools with `is_cacheable=False` (saves space and privacy) + - **Writes to JSON file** with: + - v0.1 metadata (version, timestamps, goal with placeholders) + - Trajectory of tool use blocks (with placeholders and blanked inputs) + - Placeholder definitions with descriptions 5. Automatically skips writing if a cached execution was used (to avoid recording replays) ### Read Mode In read mode: -1. Two caching tools are added to the agent's toolbox -2. A special system prompt (`CACHE_USE_PROMPT`) is appended to instruct the agent on how to use trajectories -3. The agent can call `retrieve_available_trajectories_tool` to see available cache files -4. The agent can call `execute_cached_executions_tool` with a trajectory file path to replay it -5. During replay, each tool use block is executed sequentially with a configurable delay between actions (default: 0.5 seconds) -6. Screenshot and trajectory retrieval tools are skipped during replay -7. The agent is instructed to verify results after replay and make corrections if needed +1. Two caching tools are added to the agent's toolbox: + - `RetrieveCachedTestExecutions`: Lists available trajectories + - `ExecuteCachedTrajectory`: Executes from the beginning or continues from a specific step using `start_from_step_index` +2. A special system prompt (`CACHE_USE_PROMPT`) instructs the agent on: + - How to use trajectories + - Placeholder handling + - Non-cacheable step management + - Failure recovery strategies +3. The agent can list available cache files and choose appropriate ones +4. During execution via `TrajectoryExecutor`: + - Each step is executed sequentially with configurable delays + - All tools in the trajectory are executed, including screenshots and retrieval tools + - Non-cacheable tools trigger a pause with `NEEDS_AGENT` status + - Placeholders are validated and substituted before execution + - Message history is built with assistant (tool use) and user (tool result) messages + - Agent sees all screenshots and results in the message history +5. Execution can pause for agent intervention: + - When reaching non-cacheable tools + - When errors occur (with failure details) +6. Agent can resume execution: + - Using `ExecuteCachedTrajectory` with `start_from_step_index` from the pause point + - Skipping failed or irrelevant steps +7. Results are verified by the agent, with corrections made as needed + +### Message History + +**New in v0.1:** During cached trajectory execution, a complete message history is built and returned to the agent. This includes: + +- **Assistant messages**: Containing `ToolUseBlockParam` for each action +- **User messages**: Containing `ToolResultBlockParam` with: + - Text results from tool execution + - Screenshots (when available) + - Error messages (on failure) + +This visibility allows the agent to: +- See the current UI state via screenshots +- Understand what actions were taken +- Detect when execution has diverged from expectations +- Make informed decisions about corrections or retries + +### Non-Cacheable Tools + +Tools can be marked as non-cacheable by setting `is_cacheable=False` in their definition: + +```python +from askui.models.shared.tools import Tool + +class DebugPrintTool(Tool): + name = "print_debug" + description = "Print debug information about current state" + is_cacheable = False # This tool requires agent context + + def __call__(self, message: str) -> str: + # Tool implementation... + pass +``` + +During trajectory execution, when a non-cacheable tool is encountered: + +1. `TrajectoryExecutor` pauses execution +2. Returns `ExecutionResult` with status `NEEDS_AGENT` +3. Includes current step index and message history +4. Agent receives control to execute the step manually +5. Agent uses `ExecuteCachedTrajectory` with `start_from_step_index` to resume from next step + +This mechanism is essential for tools that: +- Require runtime context (debugging, inspection) +- Make decisions based on current state +- Have side effects that shouldn't be blindly replayed +- Depend on external systems that may have changed + +## Failure Handling + +**New in v0.1:** Enhanced failure handling provides the agent with detailed information about what went wrong and where. + +### When Execution Fails + +If a step fails during trajectory execution: + +1. Execution stops at the failed step +2. `ExecutionResult` includes: + - Status: `FAILED` + - `step_index`: Which step failed + - `error_message`: The specific error + - `message_history`: All actions and results up to the failure +3. Failure is recorded in cache metadata for tracking +4. Agent receives the failure information and can decide: + - **Retry**: Execute remaining steps manually + - **Resume**: Fix the issue and use `ExecuteCachedTrajectory` with `start_from_step_index` from next step + - **Abort**: Report that cache needs re-recording + +### Failure Tracking + +Cache metadata tracks all failures: +```json +"failures": [ + { + "timestamp": "2025-12-11T14:20:00Z", + "step_index": 5, + "error_message": "Element not found: login button", + "failure_count_at_step": 2 + } +] +``` + +This information enables: +- Smart cache invalidation (too many failures → invalid cache) +- Debugging (which steps are problematic) +- Metrics (cache reliability over time) +- Auto-recovery strategies (skip commonly failing steps) + +### Agent Recovery Options + +The agent has several recovery strategies: + +1. **Manual Execution**: Execute remaining steps without cache +2. **Partial Resume**: Fix the issue (e.g., wait for element) then continue from next step +3. **Skip and Continue**: Skip the failed step and continue from a later step +4. **Report Invalid**: Mark the cache as outdated and request re-recording + +Example agent decision flow: +``` +Trajectory fails at step 5: "Element not found: submit button" +↓ +Agent takes screenshot to assess current state +↓ +Agent sees submit button is present but has different text +↓ +Agent clicks the button manually +↓ +Agent calls ExecuteCachedTrajectory(start_from_step_index=6) +↓ +Execution continues successfully +``` + +## Placeholders -The delay between actions can be customized using `CachedExecutionToolSettings` to accommodate different application response times. +**New in v0.1:** Placeholders enable dynamic value substitution in cached trajectories. -## Limitations +### Placeholder Syntax -- **UI State Sensitivity**: Cached trajectories assume the UI is in the same state as when they were recorded. If the UI has changed, the replay may fail or produce incorrect results. +Placeholders use double curly braces: `{{placeholder_name}}` + +Valid placeholder names: +- Must start with a letter or underscore +- Can contain letters, numbers, and underscores +- Examples: `{{date}}`, `{{user_name}}`, `{{order_id_123}}` + +### Automatic Placeholder Identification + +**New in v0.1!** The caching system uses AI to automatically identify and parameterize dynamic values when recording trajectories. + +#### How It Works + +When `auto_identify_placeholders=True` (the default), the system: + +1. **Records the trajectory** as normal during agent execution +2. **Analyzes the trajectory** using an LLM to identify dynamic values such as: + - Dates and timestamps (e.g., "2025-12-11", "10:30 AM") + - Usernames, emails, names (e.g., "john.doe", "test@example.com") + - Session IDs, tokens, UUIDs, API keys + - Dynamic text referencing current state or time + - File paths with user-specific or time-specific components + - Temporary or generated identifiers +3. **Generates placeholder definitions** with descriptive names and documentation: + ```json + { + "name": "current_date", + "value": "2025-12-11", + "description": "Current date in YYYY-MM-DD format" + } + ``` +4. **Replaces values with placeholders** in both the trajectory AND the goal: + - Original: `"text": "Login as john.doe"` + - Result: `"text": "Login as {{username}}"` +5. **Saves the templated trajectory** to the cache file + +#### Benefits + +✅ **No manual work** - Automatically identifies dynamic values +✅ **Smart detection** - LLM understands semantic meaning (dates vs coordinates) +✅ **Descriptive** - Generates helpful descriptions for each placeholder +✅ **Applies to goal** - Goal text also gets placeholder replacement + +#### What Gets Detected + +The AI identifies values that are likely to change between executions: + +**Will be detected as placeholders:** +- Dates: "2025-12-11", "Dec 11, 2025", "12/11/2025" +- Times: "10:30 AM", "14:45:00", "2025-12-11T10:30:00Z" +- Usernames: "john.doe", "admin_user", "test_account" +- Emails: "user@example.com", "test@domain.org" +- IDs: "uuid-1234-5678", "session_abc123", "order_9876" +- Names: "John Smith", "Jane Doe" +- Dynamic text: "Today is 2025-12-11", "Logged in as john.doe" + +**Will NOT be detected as placeholders:** +- UI coordinates: `{"x": 100, "y": 200}` +- Fixed button labels: "Submit", "Cancel", "OK" +- Configuration values: `{"timeout": 30, "retries": 3}` +- Generic actions: "click", "type", "scroll" +- Boolean values: `true`, `false` + +#### Disabling Auto-Identification + +If you prefer manual placeholder control: + +```python +caching_settings = CachingSettings( + strategy="write", + auto_identify_placeholders=False # Only detect {{...}} syntax +) +``` + +With `auto_identify_placeholders=False`, only manually specified placeholders using the `{{...}}` syntax will be detected. + +#### Logging + +To see what placeholders are being identified, enable INFO-level logging: + +```python +import logging +logging.basicConfig(level=logging.INFO) +``` + +You'll see output like: +``` +INFO: Using LLM to identify placeholders in trajectory +INFO: Identified 3 placeholders in trajectory +DEBUG: - current_date: 2025-12-11 (Current date in YYYY-MM-DD format) +DEBUG: - username: john.doe (Username for login) +DEBUG: - session_id: abc123 (Session identifier) +INFO: Replaced 3 placeholder values in trajectory +INFO: Applied placeholder replacement to goal: Login as john.doe -> Login as {{username}} +``` + +### Manual Placeholders + +You can also manually create placeholders when recording by using the syntax in your goal description. The system will preserve `{{...}}` patterns in tool inputs. + +### Providing Placeholder Values + +When executing a trajectory with placeholders, the agent must provide values: + +```python +# Via ExecuteCachedTrajectory +result = execute_cached_trajectory_tool( + trajectory_file=".cache/my_test.json", + placeholder_values={ + "current_date": "2025-12-11", + "user_email": "test@example.com" + } +) + +# Via ExecuteCachedTrajectory with start_from_step_index +result = execute_cached_trajectory_tool( + trajectory_file=".cache/my_test.json", + start_from_step_index=3, # Continue from step 3 + placeholder_values={ + "current_date": "2025-12-11", + "user_email": "test@example.com" + } +) +``` + +### Placeholder Validation + +Before execution, the system validates that: +- All required placeholders have values provided +- No required placeholders are missing + +If validation fails, execution is aborted with a clear error message listing missing placeholders. + +### Use Cases + +Placeholders are particularly useful for: +- **Date-dependent workflows**: Testing with current/future dates +- **User-specific actions**: Different users, emails, names +- **Order/transaction IDs**: Testing with different identifiers +- **Environment-specific values**: API endpoints, credentials +- **Parameterized testing**: Running same flow with different data + +Example: +```json +{ + "name": "computer", + "input": { + "action": "type", + "text": "Schedule meeting for {{meeting_date}} with {{attendee_email}}" + } +} +``` + +## Limitations and Considerations + +### Current Limitations + +- **UI State Sensitivity**: Cached trajectories assume the UI is in the same state as when they were recorded. If the UI has changed significantly, replay may fail. - **No on_message Callback**: When using `strategy="write"` or `strategy="both"`, you cannot provide a custom `on_message` callback, as the caching system uses this callback to record actions. - **Verification Required**: After executing a cached trajectory, the agent should verify that the results are correct, as UI changes may cause partial failures. -## Example: Complete Test Workflow +### Best Practices + +1. **Always Verify Results**: After cached execution, verify the outcome matches expectations +2. **Handle Failures Gracefully**: Provide clear recovery paths when trajectories fail +3. **Use Placeholders Wisely**: Identify dynamic values that should be parameterized +4. **Mark Non-Cacheable Tools**: Properly mark tools that require agent intervention +5. **Monitor Cache Validity**: Track execution attempts and failures to identify stale caches +6. **Test Cache Replay**: Periodically test that cached trajectories still work +7. **Version Your Caches**: Use descriptive filenames or directories for different app versions +8. **Adjust Delays**: Tune `delay_time_between_action` based on your app's responsiveness -Here's a complete example showing how to record and replay a test: +### When to Re-Record + +Consider re-recording a cached trajectory when: +- UI layout or element positions have changed significantly +- Workflow steps have been added, removed, or reordered +- Failures occur consistently at the same steps +- Execution takes significantly longer than expected +- The cache has been marked invalid due to failure patterns + +## Migration from v0.0 to v0.1 + +**Automatic Migration:** All v0.0 cache files are automatically migrated when read by the v0.1 system. No manual intervention is required. + +### What Happens During Migration + +When a v0.0 cache file (simple JSON array) is read: + +1. System detects v0.0 format (array instead of object with metadata) +2. Wraps trajectory in v0.1 structure +3. Adds default metadata: + ```json + { + "version": "0.1", + "created_at": "", + "last_executed_at": null, + "execution_attempts": 0, + "failures": [], + "is_valid": true, + "invalidation_reason": null + } + ``` +4. Extracts any placeholders found in trajectory +5. Returns fully-formed `CacheFile` object + +### Compatibility Guarantees + +- All v0.0 cache files continue to work without modification +- Migration is performed on-the-fly during read +- Original files are not modified on disk (unless re-written) +- v0.1 system can read both formats seamlessly + +### Batch Migration with CLI Tool + +For batch migration of existing v0.0 caches to v0.1 format, use the migration CLI utility: + +```bash +# Migrate all caches in a directory +python -m askui.utils.cache_migration --cache-dir .cache + +# Dry run (preview what would be migrated) +python -m askui.utils.cache_migration --cache-dir .cache --dry-run + +# Create backups before migration +python -m askui.utils.cache_migration --cache-dir .cache --backup + +# Migrate specific file pattern +python -m askui.utils.cache_migration --cache-dir .cache --pattern "test_*.json" + +# Verbose output +python -m askui.utils.cache_migration --cache-dir .cache --verbose +``` + +The migration tool will: +- Find all cache files matching the pattern +- Check if each file is v0.0 or already v0.1 +- Migrate v0.0 files to v0.1 format +- Optionally create backups with `.v1.backup` suffix +- Report detailed statistics about the migration + +Example output: +``` +INFO: Found 5 cache files in .cache +INFO: ✓ Migrated: login_test.json +INFO: ✓ Migrated: checkout_flow.json +INFO: ⊘ Already v0.1: user_registration.json +INFO: ✓ Migrated: search_test.json +INFO: ✗ Error: invalid.json (invalid JSON) + +============================================================ +Migration Summary: + Total files: 5 + Migrated: 3 + Already v0.1: 1 + Errors: 1 +============================================================ + +INFO: ✓ Migration completed successfully! +``` + +### Manual Migration (Programmatic) + +To upgrade individual v0.0 caches to v0.1 format programmatically: + +```python +from pathlib import Path +from askui.utils.cache_writer import CacheWriter +import json + +# Read v0.0 file (auto-migrates to v0.1 in memory) +cache_file = CacheWriter.read_cache_file(Path(".cache/old_cache.json")) + +# Write back to disk in v0.1 format +with open(".cache/old_cache.json", "w") as f: + json.dump(cache_file.model_dump(mode="json"), f, indent=2, default=str) +``` + +**Note:** Batch migration is optional - all v0.0 caches are automatically migrated during read operations. Use the migration tool if you prefer to: +- Pre-migrate all caches at once +- Create backups before migration +- Verify migration success across all files +- Audit which caches need migration + +## Example: Complete Test Workflow with v0.1 Features + +Here's a complete example showing advanced v0.1 features: ```python from askui import VisionAgent from askui.models.shared.settings import CachingSettings, CachedExecutionToolSettings -# Step 1: Record a successful login flow -print("Recording login flow...") +# Step 1: Record a workflow with dynamic values +print("Recording user registration flow...") with VisionAgent() as agent: agent.act( - goal="Navigate to the login page and log in with username 'testuser' and password 'testpass123'", + goal="Register a new user with email 'john@example.com' and today's date", caching_settings=CachingSettings( strategy="write", cache_dir="test_cache", - filename="user_login.json" + filename="user_registration.json" ) ) +# Cache file now contains placeholders for email and date -# Step 2: Later, replay the login flow for regression testing -print("\nReplaying login flow for regression test...") +# Step 2: Replay with different values +print("\nReplaying registration with new user...") with VisionAgent() as agent: agent.act( goal="""Log in to the application. - - If the cache file "user_login.json" is available, please use it to replay - the login sequence. It contains the steps to navigate to the login page and + + If the cache file "user_login.json" is available, please use it to replay + the login sequence. It contains the steps to navigate to the login page and authenticate with the test credentials.""", caching_settings=CachingSettings( strategy="read", cache_dir="test_cache", execute_cached_trajectory_tool_settings=CachedExecutionToolSettings( - delay_time_between_action=1.0 + delay_time_between_action=0.75 ) ) ) +# Agent will detect placeholders and provide new values: +# - email: "jane@example.com" +# - date: "2025-12-11" + +# Step 3: Handle partial failure and resume +print("\nTesting with non-cacheable debug step...") +with VisionAgent() as agent: + agent.act( + goal="Register user and debug if issues occur", + caching_settings=CachingSettings( + strategy="read", + cache_dir="test_cache" + ) + ) +# If trajectory includes a non-cacheable debug tool: +# 1. Execution pauses with NEEDS_AGENT status +# 2. Agent manually executes debug tool +# 3. Agent uses ExecuteCachedTrajectory with start_from_step_index to resume +# 4. Remaining steps execute successfully + +# Step 4: Monitor cache health +print("\nChecking cache metadata...") +cache_file = CacheWriter.read_cache_file(Path("test_cache/user_registration.json")) +print(f"Execution attempts: {cache_file.metadata.execution_attempts}") +print(f"Failures: {len(cache_file.metadata.failures)}") +print(f"Valid: {cache_file.metadata.is_valid}") +if cache_file.metadata.failures: + print("Recent failures:") + for failure in cache_file.metadata.failures[-3:]: + print(f" - Step {failure.step_index}: {failure.error_message}") ``` + +## Future Enhancements + +Planned features for future versions: + +- **Visual Validation**: Screenshot comparison using perceptual hashing (aHash) to detect UI changes +- **Cache Invalidation Strategies**: Configurable validators for automatic cache invalidation +- **Cache Management Tools**: Tools for listing, validating, and invalidating caches +- **Smart Retry**: Automatic retry with adjustments when specific failure patterns are detected +- **Cache Analytics**: Metrics dashboard showing cache performance and reliability +- **Differential Caching**: Record only changed steps when updating existing caches + +## Troubleshooting + +### Common Issues + +**Issue**: Cached trajectory fails to execute +- **Cause**: UI has changed since recording +- **Solution**: Take a screenshot to compare, re-record the trajectory, or manually execute failing steps + +**Issue**: "Missing required placeholders" error +- **Cause**: Trajectory contains placeholders but values weren't provided +- **Solution**: Check cache metadata for required placeholders and provide values via `placeholder_values` parameter + +**Issue**: Execution pauses unexpectedly +- **Cause**: Trajectory contains non-cacheable tool +- **Solution**: Execute the non-cacheable step manually, then use `ExecuteCachedTrajectory` with `start_from_step_index` to resume + +**Issue**: Actions execute too quickly, causing failures +- **Cause**: `delay_time_between_action` is too short for your application +- **Solution**: Increase delay in `CachedExecutionToolSettings` (e.g., from 0.5 to 1.0 seconds) + +**Issue**: "Tool not found in toolbox" error +- **Cause**: Cached trajectory uses a tool that's no longer available +- **Solution**: Re-record the trajectory with current tools, or add the missing tool back + +### Debug Tips + +1. **Check message history**: After execution, review `message_history` in the result to see exactly what happened +2. **Monitor failure metadata**: Track `execution_attempts` and `failures` in cache metadata +3. **Test incrementally**: Use `ExecuteCachedTrajectory` with `start_from_step_index` to test specific sections of a trajectory +4. **Verify placeholders**: Print cache metadata to see what placeholders are expected +5. **Adjust delays**: If timing issues occur, increase `delay_time_between_action` incrementally + +For more help, see the [GitHub Issues](https://github.com/askui/vision-agent/issues) or contact support. diff --git a/src/askui/agent_base.py b/src/askui/agent_base.py index 6512bc67..4ba50170 100644 --- a/src/askui/agent_base.py +++ b/src/askui/agent_base.py @@ -306,13 +306,15 @@ def act( caching_settings or self._get_default_caching_settings_for_act(_model) ) - tools, on_message, cached_execution_tool = self._patch_act_with_cache( - _caching_settings, _settings, tools, on_message - ) _tools = self._build_tools(tools, _model) - if cached_execution_tool: - cached_execution_tool.set_toolbox(_tools) + if _caching_settings.strategy != "no": + on_message = self._patch_act_with_cache( + _caching_settings, _settings, _tools, on_message, goal_str + ) + logger.info( + f"Starting agent act with caching enabled (strategy={_caching_settings.strategy})" + ) self._model_router.act( messages=messages, @@ -336,36 +338,40 @@ def _patch_act_with_cache( self, caching_settings: CachingSettings, settings: ActSettings, - tools: list[Tool] | ToolCollection | None, + toolbox: ToolCollection, on_message: OnMessageCb | None, - ) -> tuple[ - list[Tool] | ToolCollection, OnMessageCb | None, ExecuteCachedTrajectory | None - ]: - """Patch act settings and tools with caching functionality. + goal: str | None = None, + ) -> OnMessageCb | None: + """Patch act settings and toolbox with caching functionality. Args: caching_settings: The caching settings to apply settings: The act settings to modify - tools: The tools list to extend with caching tools + toolbox: The toolbox to extend with caching tools on_message: The message callback (may be replaced for write mode) + goal: The goal string (used for cache metadata) Returns: - A tuple of (modified_tools, modified_on_message, cached_execution_tool) + modified_on_message """ + logger.debug("Setting up caching") caching_tools: list[Tool] = [] - cached_execution_tool: ExecuteCachedTrajectory | None = None # Setup read mode: add caching tools and modify system prompt if caching_settings.strategy in ["read", "both"]: - cached_execution_tool = ExecuteCachedTrajectory( - caching_settings.execute_cached_trajectory_tool_settings - ) + from askui.tools.caching_tools import VerifyCacheExecution + caching_tools.extend( [ RetrieveCachedTestExecutions(caching_settings.cache_dir), - cached_execution_tool, + ExecuteCachedTrajectory( + toolbox=toolbox, + settings=caching_settings.execute_cached_trajectory_tool_settings, + ), + VerifyCacheExecution(), ] ) + if isinstance(settings.messages.system, str): settings.messages.system = ( settings.messages.system + "\n" + CACHE_USE_PROMPT @@ -377,27 +383,30 @@ def _patch_act_with_cache( ] else: # Omit or None settings.messages.system = CACHE_USE_PROMPT + logger.debug("Added cache usage instructions to system prompt") - # Add caching tools to the tools list - if isinstance(tools, list): - tools = caching_tools + tools - elif isinstance(tools, ToolCollection): - tools.append_tool(*caching_tools) - else: - tools = caching_tools + # Add caching tools to the toolbox + if caching_tools: + toolbox.append_tool(*caching_tools) # Setup write mode: create cache writer and set message callback + cache_writer = None if caching_settings.strategy in ["write", "both"]: cache_writer = CacheWriter( - caching_settings.cache_dir, caching_settings.filename + cache_dir=caching_settings.cache_dir, + file_name=caching_settings.filename, + caching_settings=caching_settings, + toolbox=toolbox, + goal=goal, ) if on_message is None: on_message = cache_writer.add_message_cb else: error_message = "Cannot use on_message callback when writing Cache" + logger.error(error_message) raise ValueError(error_message) - return tools, on_message, cached_execution_tool + return on_message def _get_default_settings_for_act(self, model: str) -> ActSettings: # noqa: ARG002 return ActSettings() diff --git a/src/askui/models/shared/agent.py b/src/askui/models/shared/agent.py index fdde6c32..4a0bebd1 100644 --- a/src/askui/models/shared/agent.py +++ b/src/askui/models/shared/agent.py @@ -1,4 +1,5 @@ import logging +from typing import TYPE_CHECKING from typing_extensions import override @@ -19,6 +20,11 @@ TruncationStrategyFactory, ) from askui.reporting import NULL_REPORTER, Reporter +from askui.utils.cache_execution_manager import CacheExecutionManager + +if TYPE_CHECKING: + from askui.models.shared.settings import CacheFile + from askui.utils.trajectory_executor import TrajectoryExecutor logger = logging.getLogger(__name__) @@ -50,6 +56,98 @@ def __init__( self._truncation_strategy_factory = ( truncation_strategy_factory or SimpleTruncationStrategyFactory() ) + # Cache execution manager handles all cache-related logic + self._cache_manager = CacheExecutionManager(reporter) + # Store current tool collection for cache executor access + self._tool_collection: ToolCollection | None = None + + + def _get_agent_response( + self, + model: str, + truncation_strategy: TruncationStrategy, + tool_collection: ToolCollection, + settings: ActSettings, + on_message: OnMessageCb, + ) -> MessageParam | None: + """Get response from agent API. + + Args: + model: Model to use + truncation_strategy: Message truncation strategy + tool_collection: Available tools + settings: Agent settings + on_message: Callback for messages + + Returns: + Assistant message or None if cancelled by callback + """ + response_message = self._messages_api.create_message( + messages=truncation_strategy.messages, + model=model, + tools=tool_collection, + max_tokens=settings.messages.max_tokens, + betas=settings.messages.betas, + system=settings.messages.system, + thinking=settings.messages.thinking, + tool_choice=settings.messages.tool_choice, + temperature=settings.messages.temperature, + ) + + message_by_assistant = self._call_on_message( + on_message, response_message, truncation_strategy.messages + ) + if message_by_assistant is None: + return None + + message_by_assistant_dict = message_by_assistant.model_dump(mode="json") + logger.debug(message_by_assistant_dict) + truncation_strategy.append_message(message_by_assistant) + self._reporter.add_message(self.__class__.__name__, message_by_assistant_dict) + + return message_by_assistant + + def _process_tool_execution( + self, + message_by_assistant: MessageParam, + tool_collection: ToolCollection, + on_message: OnMessageCb, + truncation_strategy: TruncationStrategy, + model: str, + settings: ActSettings, + ) -> None: + """Process tool execution and continue if needed. + + Args: + message_by_assistant: Assistant message with potential tool uses + tool_collection: Available tools + on_message: Callback for messages + truncation_strategy: Message truncation strategy + model: Model to use + settings: Agent settings + """ + tool_result_message = self._use_tools(message_by_assistant, tool_collection) + if not tool_result_message: + return + + tool_result_message = self._call_on_message( + on_message, tool_result_message, truncation_strategy.messages + ) + if not tool_result_message: + return + + tool_result_message_dict = tool_result_message.model_dump(mode="json") + logger.debug(tool_result_message_dict) + truncation_strategy.append_message(tool_result_message) + + # Continue with next step recursively + self._step( + model=model, + tool_collection=tool_collection, + on_message=on_message, + settings=settings, + truncation_strategy=truncation_strategy, + ) def _step( self, @@ -65,59 +163,55 @@ def _step( blocks, this method is going to return immediately, as there is nothing to act upon. + When executing from cache (cache execution mode), messages from the cache + executor are added to the truncation strategy, which automatically manages + message history size by removing old messages when needed. + Args: model (str): The model to use for message creation. on_message (OnMessageCb): Callback on new messages settings (AgentSettings): The settings for the step. tool_collection (ToolCollection): The tools to use for the step. truncation_strategy (TruncationStrategy): The truncation strategy to use - for the step. + for the step. Manages message history size automatically. Returns: None """ + # Get or generate assistant message if truncation_strategy.messages[-1].role == "user": - response_message = self._messages_api.create_message( - messages=truncation_strategy.messages, - model=model, - tools=tool_collection, - max_tokens=settings.messages.max_tokens, - betas=settings.messages.betas, - system=settings.messages.system, - thinking=settings.messages.thinking, - tool_choice=settings.messages.tool_choice, - temperature=settings.messages.temperature, - ) - message_by_assistant = self._call_on_message( - on_message, response_message, truncation_strategy.messages + # Try to execute from cache first + if self._cache_manager.handle_execution_step( + on_message, + truncation_strategy, + model, + tool_collection, + settings, + self.__class__.__name__, + self._step, + ): + return # Cache step handled and recursion occurred + + # Normal flow: get agent response + message_by_assistant = self._get_agent_response( + model, truncation_strategy, tool_collection, settings, on_message ) if message_by_assistant is None: return - message_by_assistant_dict = message_by_assistant.model_dump(mode="json") - logger.debug(message_by_assistant_dict) - truncation_strategy.append_message(message_by_assistant) - self._reporter.add_message( - self.__class__.__name__, message_by_assistant_dict - ) else: + # Last message is already from assistant message_by_assistant = truncation_strategy.messages[-1] + + # Check stop reason and process tools self._handle_stop_reason(message_by_assistant, settings.messages.max_tokens) - if tool_result_message := self._use_tools( - message_by_assistant, tool_collection - ): - if tool_result_message := self._call_on_message( - on_message, tool_result_message, truncation_strategy.messages - ): - tool_result_message_dict = tool_result_message.model_dump(mode="json") - logger.debug(tool_result_message_dict) - truncation_strategy.append_message(tool_result_message) - self._step( - model=model, - tool_collection=tool_collection, - on_message=on_message, - settings=settings, - truncation_strategy=truncation_strategy, - ) + self._process_tool_execution( + message_by_assistant, + tool_collection, + on_message, + truncation_strategy, + model, + settings, + ) def _call_on_message( self, @@ -129,6 +223,27 @@ def _call_on_message( return message return on_message(OnMessageCbParam(message=message, messages=messages)) + def _setup_cache_tools(self, tool_collection: ToolCollection) -> None: + """Set agent reference on caching tools. + + This allows caching tools to access the agent state for + cache execution and verification. + + Args: + tool_collection: The tool collection to search for cache tools + """ + # Import here to avoid circular dependency + from askui.tools.caching_tools import ( + ExecuteCachedTrajectory, + VerifyCacheExecution, + ) + + # Iterate through tools and set agent on caching tools + for tool_name, tool in tool_collection.get_tools().items(): + if isinstance(tool, (ExecuteCachedTrajectory, VerifyCacheExecution)): + tool.set_agent(self) + logger.debug("Set agent reference on %s", tool_name) + @override def act( self, @@ -138,8 +253,17 @@ def act( tools: ToolCollection | None = None, settings: ActSettings | None = None, ) -> None: + # Reset cache execution state at the start of each act() call + self._cache_manager.reset_state() + _settings = settings or ActSettings() _tool_collection = tools or ToolCollection() + # Store tool collection so it can be accessed by caching tools + self._tool_collection = _tool_collection + + # Set agent reference on ExecuteCachedTrajectory tools + self._setup_cache_tools(_tool_collection) + truncation_strategy = ( self._truncation_strategy_factory.create_truncation_strategy( tools=_tool_collection.to_params(), @@ -192,3 +316,60 @@ def _handle_stop_reason(self, message: MessageParam, max_tokens: int) -> None: raise MaxTokensExceededError(max_tokens) if message.stop_reason == "refusal": raise ModelRefusalError + + # Public methods for cache management (used by caching tools) + # These delegate to the CacheExecutionManager + def activate_cache_execution( + self, + executor: "TrajectoryExecutor", + cache_file: "CacheFile", + cache_file_path: str, + ) -> None: + """Activate cache execution mode. + + Args: + executor: The trajectory executor to use + cache_file: The cache file being executed + cache_file_path: Path to the cache file + """ + self._cache_manager.activate_execution(executor, cache_file, cache_file_path) + + def get_cache_info(self) -> tuple["CacheFile | None", str | None]: + """Get current cache file and path. + + Returns: + Tuple of (cache_file, cache_file_path) + """ + return self._cache_manager.get_cache_info() + + def is_cache_verification_pending(self) -> bool: + """Check if cache verification is pending. + + Returns: + True if verification is pending + """ + return self._cache_manager.is_cache_verification_pending() + + def update_cache_metadata_on_completion(self, success: bool) -> None: + """Update cache metadata after execution completion (public API). + + Args: + success: Whether the execution was successful + """ + self._cache_manager.update_metadata_on_completion(success) + + def update_cache_metadata_on_failure( + self, step_index: int, error_message: str + ) -> None: + """Update cache metadata after execution failure (public API). + + Args: + step_index: The step index where failure occurred + error_message: The error message + """ + self._cache_manager.update_metadata_on_failure(step_index, error_message) + + def clear_cache_state(self) -> None: + """Clear cache execution state.""" + self._cache_manager.clear_cache_state() + diff --git a/src/askui/models/shared/settings.py b/src/askui/models/shared/settings.py index 547d97b6..6265fa9a 100644 --- a/src/askui/models/shared/settings.py +++ b/src/askui/models/shared/settings.py @@ -1,3 +1,6 @@ +from datetime import datetime +from typing import Optional + from anthropic import Omit, omit from anthropic.types import AnthropicBetaParam from anthropic.types.beta import ( @@ -8,6 +11,8 @@ from pydantic import BaseModel, ConfigDict, Field from typing_extensions import Literal +from askui.models.shared.agent_message_param import ToolUseBlockParam + COMPUTER_USE_20250124_BETA_FLAG = "computer-use-2025-01-24" COMPUTER_USE_20251124_BETA_FLAG = "computer-use-2025-11-24" @@ -39,6 +44,33 @@ class CachingSettings(BaseModel): strategy: CACHING_STRATEGY = "no" cache_dir: str = ".cache" filename: str = "" + auto_identify_placeholders: bool = True execute_cached_trajectory_tool_settings: CachedExecutionToolSettings = ( CachedExecutionToolSettings() ) + + +class CacheFailure(BaseModel): + timestamp: datetime + step_index: int + error_message: str + failure_count_at_step: int + + +class CacheMetadata(BaseModel): + version: str = "0.1" + created_at: datetime + goal: Optional[str] = None + last_executed_at: Optional[datetime] = None + execution_attempts: int = 0 + failures: list[CacheFailure] = Field(default_factory=list) + is_valid: bool = True + invalidation_reason: Optional[str] = None + + +class CacheFile(BaseModel): + """Cache file structure (v0.1) wrapping trajectory with metadata.""" + + metadata: CacheMetadata + trajectory: list[ToolUseBlockParam] + placeholders: dict[str, str] = Field(default_factory=dict) diff --git a/src/askui/models/shared/tools.py b/src/askui/models/shared/tools.py index da39ee2d..016b95f7 100644 --- a/src/askui/models/shared/tools.py +++ b/src/askui/models/shared/tools.py @@ -165,6 +165,15 @@ class Tool(BaseModel, ABC): default_factory=_default_input_schema, description="JSON schema for tool parameters", ) + is_cacheable: bool = Field( + default=True, + description=( + "Whether this tool's execution can be cached. " + "Set to False for tools with side effects that shouldn't be repeated " + "(e.g., print/output/notification/external API tools with state changes). " + "Default: True." + ), + ) @abstractmethod def __call__(self, *args: Any, **kwargs: Any) -> ToolCallResult: @@ -341,6 +350,14 @@ def append_tool(self, *tools: Tool) -> "Self": self._tool_map[tool.to_params()["name"]] = tool return self + def get_tools(self) -> dict[str, Tool]: + """Get all tools in the collection. + + Returns: + Dictionary mapping tool names to Tool instances + """ + return dict(self._tool_map) + def reset_tools(self, tools: list[Tool] | None = None) -> "Self": """Reset the tools in the collection with new tools.""" _tools = tools or [] diff --git a/src/askui/prompts/caching.py b/src/askui/prompts/caching.py index a89cf224..bf643918 100644 --- a/src/askui/prompts/caching.py +++ b/src/askui/prompts/caching.py @@ -10,16 +10,89 @@ "typing actions from a previously successful execution.\n" " If there is a trajectory available for a step you need to take, " "always use it!\n" - " You can execute a trajectory with the ExecuteCachedExecution tool.\n" - " After a trajectory was executed, make sure to verify the results! " - "While it works most of the time, occasionally, the execution can be " - "(partly) incorrect. So make sure to verify if everything is filled out " - "as expected, and make corrections where necessary!\n" + "\n" + " EXECUTING TRAJECTORIES:\n" + " - Use ExecuteCachedTrajectory to execute a cached trajectory\n" + " - You will see all screenshots and results from the execution in the message history\n" + " - After execution completes, verify the results are correct\n" + " - If execution fails partway, you'll see exactly where it failed and can decide how to proceed\n" + "\n" + " PLACEHOLDERS:\n" + " - Trajectories may contain dynamic placeholders like {{current_date}} or {{user_name}}\n" + " - When executing a trajectory, check if it requires placeholder values\n" + " - Provide placeholder values using the placeholder_values parameter as a dictionary\n" + " - Example: ExecuteCachedTrajectory(trajectory_file='test.json', placeholder_values={'current_date': '2025-12-11'})\n" + " - If required placeholders are missing, execution will fail with a clear error message\n" + "\n" + " NON-CACHEABLE STEPS:\n" + " - Some tools cannot be cached and require your direct execution (e.g., print_debug, contextual decisions)\n" + " - When trajectory execution reaches a non-cacheable step, it will pause and return control to you\n" + " - You'll receive a NEEDS_AGENT status with the current step index\n" + " - Execute the non-cacheable step manually using your regular tools\n" + " - After completing the non-cacheable step, continue the trajectory using ExecuteCachedTrajectory with start_from_step_index\n" + "\n" + " CONTINUING TRAJECTORIES:\n" + " - Use ExecuteCachedTrajectory with start_from_step_index to resume execution after handling a non-cacheable step\n" + " - Provide the same trajectory file and the step index where execution should continue\n" + " - Example: ExecuteCachedTrajectory(trajectory_file='test.json', start_from_step_index=5, placeholder_values={...})\n" + " - The tool will execute remaining steps from that index onwards\n" + "\n" + " FAILURE HANDLING:\n" + " - If a trajectory fails during execution, you'll see the error message and the step where it failed\n" + " - Analyze the failure: Was it due to UI changes, timing issues, or incorrect state?\n" + " - Options for handling failures:\n" + " 1. Execute the remaining steps manually\n" + " 2. Fix the issue and retry from a specific step using ExecuteCachedTrajectory with start_from_step_index\n" + " 3. Report that the cached trajectory is outdated and needs re-recording\n" + "\n" + " BEST PRACTICES:\n" + " - Always verify results after trajectory execution completes\n" + " - While trajectories work most of the time, occasionally execution can be partly incorrect\n" + " - Make corrections where necessary after cached execution\n" + " - if you need to make any corrections after a trajectory execution, please mark the cached execution as failed\n" + " - If a trajectory consistently fails, it may be invalid and should be re-recorded\n" " \n" " \n" " There are several trajectories available to you.\n" " Their filename is a unique testID.\n" - " If executed using the ExecuteCachedExecution tool, a trajectory will " + " If executed using the ExecuteCachedTrajectory tool, a trajectory will " "automatically execute all necessary steps for the test with that id.\n" " \n" ) + +PLACEHOLDER_IDENTIFIER_SYSTEM_PROMPT = """You are analyzing UI automation trajectories \ +to identify values that should be parameterized as placeholders. + +Identify values that are likely to change between executions, such as: +- Dates and timestamps (e.g., "2025-12-11", "10:30 AM", "2025-12-11T14:30:00Z") +- Usernames, emails, names (e.g., "john.doe", "test@example.com", "John Smith") +- Session IDs, tokens, UUIDs, API keys +- Dynamic text that references current state or time-sensitive information +- File paths with user-specific or time-specific components +- Temporary or generated identifiers + +DO NOT mark as placeholders: +- UI element coordinates (x, y positions) +- Fixed button labels or static UI text +- Configuration values that don't change (e.g., timeouts, retry counts) +- Generic action names like "click", "type", "scroll" +- Tool names +- Boolean values or common constants + +For each placeholder, provide: +1. A descriptive name in snake_case (e.g., "current_date", "user_email") +2. The actual value found in the trajectory +3. A brief description of what it represents + +Return your analysis as a JSON object with this structure: +{ + "placeholders": [ + { + "name": "current_date", + "value": "2025-12-11", + "description": "Current date in YYYY-MM-DD format" + } + ] +} + +If no placeholders are found, return an empty placeholders array.""" diff --git a/src/askui/tools/caching_tools.py b/src/askui/tools/caching_tools.py index 4abb3a55..764a09b2 100644 --- a/src/askui/tools/caching_tools.py +++ b/src/askui/tools/caching_tools.py @@ -1,13 +1,22 @@ +import json import logging -import time from pathlib import Path +from typing import TYPE_CHECKING from pydantic import validate_call from typing_extensions import override from ..models.shared.settings import CachedExecutionToolSettings from ..models.shared.tools import Tool, ToolCollection +from ..utils.cache_manager import CacheManager from ..utils.cache_writer import CacheWriter +from ..utils.placeholder_handler import PlaceholderHandler + +if TYPE_CHECKING: + from ..models.shared.agent import Agent + from ..models.shared.agent_message_param import ToolUseBlockParam + from ..models.shared.settings import CacheFile + from ..utils.trajectory_executor import TrajectoryExecutor logger = logging.getLogger() @@ -27,28 +36,92 @@ def __init__(self, cache_dir: str, trajectories_format: str = ".json") -> None: "replayed using the execute_trajectory_tool. Call this tool " "first to see which trajectories are available before " "executing one. The tool returns a list of file paths to " - "available trajectory files." + "available trajectory files.\n\n" + "By default, only valid (non-invalidated) caches are returned. " + "Set include_invalid=True to see all caches including those " + "marked as invalid due to repeated failures." ), + input_schema={ + "type": "object", + "properties": { + "include_invalid": { + "type": "boolean", + "description": ( + "Whether to include invalid/invalidated caches in the results. " + "Default is False (only show valid caches)." + ), + "default": False, + }, + }, + "required": [], + }, ) self._cache_dir = Path(cache_dir) self._trajectories_format = trajectories_format @override @validate_call - def __call__(self) -> list[str]: # type: ignore + def __call__(self, include_invalid: bool = False) -> list[str]: # type: ignore + logger.info( + "Retrieving cached trajectories from %s (include_invalid=%s)", + self._cache_dir, + include_invalid, + ) + if not Path.is_dir(self._cache_dir): error_msg = f"Trajectories directory not found: {self._cache_dir}" logger.error(error_msg) raise FileNotFoundError(error_msg) - available = [ - str(f) + all_files = [ + f for f in self._cache_dir.iterdir() if str(f).endswith(self._trajectories_format) ] + logger.debug("Found %d total cache files", len(all_files)) + + if not include_invalid: + # Filter out invalid caches + valid_files = [] + invalid_count = 0 + unreadable_count = 0 + for f in all_files: + try: + cache_file = CacheWriter.read_cache_file(f) + if cache_file.metadata.is_valid: + valid_files.append(str(f)) + else: + invalid_count += 1 + logger.debug( + "Excluding invalid cache: %s (reason: %s)", + f.name, + cache_file.metadata.invalidation_reason, + ) + except Exception: + unreadable_count += 1 + logger.exception("Failed to read cache file %s", f.name) + # If we can't read it, exclude it + continue + available = valid_files + logger.info( + "Found %d valid cache(s), excluded %d invalid, %d unreadable", + len(valid_files), + invalid_count, + unreadable_count, + ) + else: + available = [str(f) for f in all_files] + logger.info("Retrieved %d cache file(s) (all included)", len(available)) if not available: - warning_msg = f"Warning: No trajectory files found in {self._cache_dir}" + if include_invalid: + warning_msg = f"Warning: No trajectory files found in {self._cache_dir}" + else: + warning_msg = ( + f"Warning: No valid trajectory files found in " + f"{self._cache_dir}. " + "Try include_invalid=True to see all caches." + ) logger.warning(warning_msg) return available @@ -56,25 +129,36 @@ def __call__(self) -> list[str]: # type: ignore class ExecuteCachedTrajectory(Tool): """ - Execute a predefined trajectory to fast-forward through UI interactions + Execute or continue a predefined trajectory to fast-forward through UI interactions """ - def __init__(self, settings: CachedExecutionToolSettings | None = None) -> None: + def __init__( + self, + toolbox: ToolCollection, + settings: CachedExecutionToolSettings | None = None, + ) -> None: super().__init__( name="execute_cached_executions_tool", description=( - "Execute a pre-recorded trajectory to automatically perform a " - "sequence of UI interactions. This tool replays mouse movements, " - "clicks, and typing actions from a previously successful execution.\n\n" + "Activate cache execution mode to replay a pre-recorded trajectory. " + "This tool sets up the agent to execute cached UI interactions step-by-step.\n\n" "Before using this tool:\n" "1. Use retrieve_available_trajectories_tool to see which " "trajectory files are available\n" "2. Select the appropriate trajectory file path from the " "returned list\n" - "3. Pass the full file path to this tool\n\n" - "The trajectory will be executed step-by-step, and you should " - "verify the results afterward. Note: Trajectories may fail if " - "the UI state has changed since they were recorded." + "3. If the trajectory contains placeholders (e.g., {{current_date}}), " + "provide values for them in the placeholder_values parameter\n" + "4. Pass the full file path to this tool\n\n" + "Placeholders allow dynamic values to be injected during execution. " + "For example, if a trajectory types '{{current_date}}', you must " + "provide placeholder_values={'current_date': '2025-12-11'}.\n\n" + "To continue from a specific step (e.g., after manually handling a " + "non-cacheable step), use the start_from_step_index parameter. " + "By default, execution starts from the beginning (step 0).\n\n" + "Once activated, the agent will execute cached steps automatically. " + "If a non-cacheable step is encountered, the agent will be asked to " + "handle it manually before resuming cache execution." ), input_schema={ "type": "object", @@ -87,6 +171,26 @@ def __init__(self, settings: CachedExecutionToolSettings | None = None) -> None: "available files)" ), }, + "start_from_step_index": { + "type": "integer", + "description": ( + "Optional: The step index to start or resume execution from (0-based). " + "Use 0 (default) to start from the beginning. Use a higher index " + "to continue from a specific step, e.g., after manually handling " + "a non-cacheable step." + ), + "default": 0, + }, + "placeholder_values": { + "type": "object", + "description": ( + "Optional dictionary mapping placeholder names to their values. " + "Required if the trajectory contains placeholders like {{variable}}. " + "Example: {'current_date': '2025-12-11', 'user_name': 'Alice'}" + ), + "additionalProperties": {"type": "string"}, + "default": {}, + }, }, "required": ["trajectory_file"], }, @@ -94,51 +198,688 @@ def __init__(self, settings: CachedExecutionToolSettings | None = None) -> None: if not settings: settings = CachedExecutionToolSettings() self._settings = settings - - def set_toolbox(self, toolbox: ToolCollection) -> None: - """Set the AgentOS/AskUiControllerClient reference for executing actions.""" + self._agent: "Agent | None" = None # Will be set by set_agent() self._toolbox = toolbox + def set_agent(self, agent: "Agent") -> None: + """Set the agent reference for cache execution mode activation. + + Args: + agent: The Agent instance that will execute the cached trajectory + """ + self._agent = agent + + def _validate_trajectory_file(self, trajectory_file: str) -> str | None: + """Validate that trajectory file exists. + + Args: + trajectory_file: Path to the trajectory file + + Returns: + Error message if validation fails, None otherwise + """ + if not Path(trajectory_file).is_file(): + error_msg = ( + f"Trajectory file not found: {trajectory_file}\n" + "Use retrieve_available_trajectories_tool to see " + "available files." + ) + logger.error(error_msg) + return error_msg + return None + + def _validate_step_index( + self, start_from_step_index: int, trajectory_length: int + ) -> str | None: + """Validate step index is within bounds. + + Args: + start_from_step_index: Index to start from + trajectory_length: Total number of steps in trajectory + + Returns: + Error message if validation fails, None otherwise + """ + logger.debug( + "Validating start_from_step_index=%d (trajectory has %d steps)", + start_from_step_index, + trajectory_length, + ) + if start_from_step_index < 0 or start_from_step_index >= trajectory_length: + error_msg = ( + f"Invalid start_from_step_index: {start_from_step_index}. " + f"Trajectory has {trajectory_length} steps " + f"(valid indices: 0-{trajectory_length - 1})." + ) + logger.error(error_msg) + return error_msg + return None + + def _validate_placeholders( + self, + trajectory: list["ToolUseBlockParam"], + placeholder_values: dict[str, str], + cache_placeholders: dict[str, str], + ) -> str | None: + """Validate placeholder values. + + Args: + trajectory: The cached trajectory + placeholder_values: User-provided placeholder values + cache_placeholders: Placeholders defined in cache file + + Returns: + Error message if validation fails, None otherwise + """ + logger.debug("Validating placeholder values") + is_valid, missing = PlaceholderHandler.validate_placeholders( + trajectory, placeholder_values + ) + if not is_valid: + error_msg = ( + f"Missing required placeholder values: {', '.join(missing)}\n" + f"The trajectory contains the following placeholders: " + f"{', '.join(cache_placeholders.keys())}\n" + f"Please provide values for all placeholders in the " + f"placeholder_values parameter." + ) + logger.error(error_msg) + return error_msg + return None + + def _create_executor( + self, + cache_file: "CacheFile", + placeholder_values: dict[str, str], + start_from_step_index: int, + ) -> "TrajectoryExecutor": + """Create and configure trajectory executor. + + Args: + cache_file: The cache file to execute + placeholder_values: Placeholder values to use + start_from_step_index: Index to start execution from + + Returns: + Configured TrajectoryExecutor instance + """ + logger.debug( + "Creating TrajectoryExecutor with delay=%ss", + self._settings.delay_time_between_action, + ) + + # Import here to avoid circular dependency + from askui.utils.trajectory_executor import TrajectoryExecutor + + executor = TrajectoryExecutor( + trajectory=cache_file.trajectory, + toolbox=self._toolbox, + placeholder_values=placeholder_values, + delay_time=self._settings.delay_time_between_action, + ) + + # Set the starting position if continuing + if start_from_step_index > 0: + executor.current_step_index = start_from_step_index + logger.debug( + "Set executor start position to step %d", start_from_step_index + ) + + return executor + + def _format_success_message( + self, + trajectory_file: str, + trajectory_length: int, + start_from_step_index: int, + placeholder_count: int, + ) -> str: + """Format success message. + + Args: + trajectory_file: Path to trajectory file + trajectory_length: Total steps in trajectory + start_from_step_index: Starting step index + placeholder_count: Number of placeholders used + + Returns: + Formatted success message + """ + if start_from_step_index == 0: + success_msg = ( + f"✓ Cache execution mode activated for " + f"{Path(trajectory_file).name}. " + f"Will execute {trajectory_length} cached steps." + ) + else: + remaining_steps = trajectory_length - start_from_step_index + success_msg = ( + f"✓ Cache execution mode activated, resuming from step " + f"{start_from_step_index}. " + f"Will execute {remaining_steps} remaining cached steps." + ) + + if placeholder_count > 0: + success_msg += f" Using {placeholder_count} placeholder value(s)." + + return success_msg + @override @validate_call - def __call__(self, trajectory_file: str) -> str: - if not hasattr(self, "_toolbox"): - error_msg = "Toolbox not set. Call set_toolbox() first." + def __call__( + self, + trajectory_file: str, + start_from_step_index: int = 0, + placeholder_values: dict[str, str] | None = None, + ) -> str: + """Activate cache execution mode for the agent. + + This method validates the cache file and sets up the agent to execute + cached steps. The actual execution happens in the agent's step loop. + + Returns: + Success message indicating cache mode has been activated + """ + if placeholder_values is None: + placeholder_values = {} + + logger.info( + "Activating cache execution mode: %s (start_from_step=%d)", + Path(trajectory_file).name, + start_from_step_index, + ) + + # Validate agent is set + if not self._agent: + error_msg = "Agent not set. Call set_agent() first." logger.error(error_msg) raise RuntimeError(error_msg) + # Validate trajectory file exists + if error := self._validate_trajectory_file(trajectory_file): + return error + + # Load cache file + logger.debug("Loading cache file: %s", trajectory_file) + cache_file = CacheWriter.read_cache_file(Path(trajectory_file)) + + logger.debug( + "Cache loaded: %d steps, %d placeholders, valid=%s", + len(cache_file.trajectory), + len(cache_file.placeholders), + cache_file.metadata.is_valid, + ) + + # Warn if cache is invalid + if not cache_file.metadata.is_valid: + warning_msg = ( + f"WARNING: Using invalid cache from {Path(trajectory_file).name}. " + f"Reason: {cache_file.metadata.invalidation_reason}. " + "This cache may not work correctly." + ) + logger.warning(warning_msg) + + # Validate step index + if error := self._validate_step_index( + start_from_step_index, len(cache_file.trajectory) + ): + return error + + # Validate placeholders + if error := self._validate_placeholders( + cache_file.trajectory, placeholder_values, cache_file.placeholders + ): + return error + + # Create and configure executor + executor = self._create_executor( + cache_file, placeholder_values, start_from_step_index + ) + + # Store executor and cache info in agent state + self._agent.activate_cache_execution( + executor=executor, + cache_file=cache_file, + cache_file_path=trajectory_file, + ) + + # Format and return success message + success_msg = self._format_success_message( + trajectory_file, + len(cache_file.trajectory), + start_from_step_index, + len(placeholder_values), + ) + logger.info(success_msg) + return success_msg + + +class InspectCacheMetadata(Tool): + """ + Inspect detailed metadata for a cached trajectory file + """ + + def __init__(self) -> None: + super().__init__( + name="inspect_cache_metadata_tool", + description=( + "Inspect and display detailed metadata for a cached trajectory file. " + "This tool shows information about:\n" + "- Cache version and creation timestamp\n" + "- Execution statistics (attempts, last execution time)\n" + "- Validity status and invalidation reason (if invalid)\n" + "- Failure history with timestamps and error messages\n" + "- Placeholders and trajectory step count\n\n" + "Use this tool to debug cache issues or understand why a cache " + "might be failing or invalidated." + ), + input_schema={ + "type": "object", + "properties": { + "trajectory_file": { + "type": "string", + "description": ( + "Full path to the trajectory file to inspect. " + "Use retrieve_available_trajectories_tool to " + "find available files." + ), + }, + }, + "required": ["trajectory_file"], + }, + ) + + @override + @validate_call + def __call__(self, trajectory_file: str) -> str: + logger.info("Inspecting cache metadata: %s", Path(trajectory_file).name) + if not Path(trajectory_file).is_file(): error_msg = ( f"Trajectory file not found: {trajectory_file}\n" "Use retrieve_available_trajectories_tool to see available files." ) logger.error(error_msg) - raise FileNotFoundError(error_msg) + return error_msg + + try: + cache_file = CacheWriter.read_cache_file(Path(trajectory_file)) + except Exception: + error_msg = f"Failed to read cache file {Path(trajectory_file).name}" + logger.exception(error_msg) + return error_msg + + metadata = cache_file.metadata + logger.debug( + "Metadata loaded: version=%s, valid=%s, attempts=%d, failures=%d", + metadata.version, + metadata.is_valid, + metadata.execution_attempts, + len(metadata.failures), + ) + + # Format the metadata into a readable string + lines = [ + "=== Cache Metadata ===", + f"File: {trajectory_file}", + "", + "--- Basic Info ---", + f"Version: {metadata.version}", + f"Created: {metadata.created_at}", + f"Last Executed: {metadata.last_executed_at or 'Never'}", + "", + "--- Execution Statistics ---", + f"Total Execution Attempts: {metadata.execution_attempts}", + f"Total Failures: {len(metadata.failures)}", + "", + "--- Validity Status ---", + f"Is Valid: {metadata.is_valid}", + ] + + if not metadata.is_valid: + lines.append(f"Invalidation Reason: {metadata.invalidation_reason}") + + lines.append("") + lines.append("--- Trajectory Info ---") + lines.append(f"Total Steps: {len(cache_file.trajectory)}") + lines.append(f"Placeholders: {len(cache_file.placeholders)}") + if cache_file.placeholders: + lines.append( + f"Placeholder Names: {', '.join(cache_file.placeholders.keys())}" + ) + + if metadata.failures: + lines.append("") + lines.append("--- Failure History ---") + for i, failure in enumerate(metadata.failures, 1): + lines.append(f"Failure {i}:") + lines.append(f" Timestamp: {failure.timestamp}") + lines.append(f" Step Index: {failure.step_index}") + lines.append( + f" Failure Count at Step: {failure.failure_count_at_step}" + ) + lines.append(f" Error: {failure.error_message}") + + return "\n".join(lines) + + +class RevalidateCache(Tool): + """ + Manually mark a cache as valid (reset invalidation) + """ + + def __init__(self) -> None: + super().__init__( + name="revalidate_cache_tool", + description=( + "Manually mark a cache as valid, resetting any previous invalidation. " + "Use this tool when:\n" + "- A cache was invalidated but the underlying issue has been fixed\n" + "- You want to give a previously failing cache another chance\n" + "- You've manually verified the cache should work now\n\n" + "This will:\n" + "- Set is_valid=True\n" + "- Clear the invalidation_reason\n" + "- Keep existing failure history (for debugging)\n" + "- Keep execution attempt counters\n\n" + "Note: The cache can still be auto-invalidated again if it " + "continues to fail." + ), + input_schema={ + "type": "object", + "properties": { + "trajectory_file": { + "type": "string", + "description": ( + "Full path to the trajectory file to revalidate. " + "Use retrieve_available_trajectories_tool with " + "include_invalid=True to find invalidated caches." + ), + }, + }, + "required": ["trajectory_file"], + }, + ) - # Load and execute trajectory - trajectory = CacheWriter.read_cache_file(Path(trajectory_file)) - info_msg = f"Executing cached trajectory from {trajectory_file}" - logger.info(info_msg) - for step in trajectory: - if ( - "screenshot" in step.name - or step.name == "retrieve_available_trajectories_tool" - ): - continue - try: - self._toolbox.run([step]) - except Exception as e: - error_msg = f"An error occured during the cached execution: {e}" - logger.exception(error_msg) - return ( - f"An error occured while executing the trajectory from " - f"{trajectory_file}. Please verify the UI state and " - "continue without cache." + @override + @validate_call + def __call__(self, trajectory_file: str) -> str: + if not Path(trajectory_file).is_file(): + error_msg = ( + f"Trajectory file not found: {trajectory_file}\n" + "Use retrieve_available_trajectories_tool to see available files." + ) + logger.error(error_msg) + return error_msg + + try: + cache_file = CacheWriter.read_cache_file(Path(trajectory_file)) + except Exception: + error_msg = f"Failed to read cache file {trajectory_file}" + logger.exception(error_msg) + return error_msg + + # Mark cache as valid + cache_manager = CacheManager() + was_invalid = not cache_file.metadata.is_valid + previous_reason = cache_file.metadata.invalidation_reason + + cache_manager.mark_cache_valid(cache_file) + + # Write back to disk + try: + cache_path = Path(trajectory_file) + with cache_path.open("w") as f: + json.dump( + cache_file.model_dump(mode="json"), + f, + indent=2, + default=str, ) - time.sleep(self._settings.delay_time_between_action) + except Exception: + error_msg = f"Failed to write cache file {trajectory_file}" + logger.exception(error_msg) + return error_msg + + if was_invalid: + logger.info("Cache revalidated: %s", trajectory_file) + return ( + f"Successfully revalidated cache: {trajectory_file}\n" + f"Previous invalidation reason was: {previous_reason}\n" + "The cache is now marked as valid and can be used again." + ) + logger.info("Cache was already valid: %s", trajectory_file) + return f"Cache {trajectory_file} was already valid. No changes made." + - logger.info("Finished executing cached trajectory") +class InvalidateCache(Tool): + """ + Manually mark a cache as invalid + """ + + def __init__(self) -> None: + super().__init__( + name="invalidate_cache_tool", + description=( + "Manually mark a cache as invalid with a custom reason. " + "Use this tool when:\n" + "- You've determined a cache is no longer reliable\n" + "- The UI has changed and the cached actions won't work\n" + "- You want to prevent automatic execution of a problematic cache\n\n" + "This will:\n" + "- Set is_valid=False\n" + "- Record your custom invalidation reason\n" + "- Keep all existing metadata (failures, execution attempts)\n" + "- Hide the cache from default trajectory listings\n\n" + "The cache can later be revalidated using revalidate_cache_tool " + "if the issue is resolved." + ), + input_schema={ + "type": "object", + "properties": { + "trajectory_file": { + "type": "string", + "description": ( + "Full path to the trajectory file to " + "invalidate. " + "Use retrieve_available_trajectories_tool to " + "find available files." + ), + }, + "reason": { + "type": "string", + "description": ( + "Reason for invalidating this cache. " + "Be specific about why " + "this cache should not be used " + "(e.g., 'UI changed - button moved', " + "'Workflow outdated', 'Replaced by new cache')." + ), + }, + }, + "required": ["trajectory_file", "reason"], + }, + ) + + @override + @validate_call + def __call__(self, trajectory_file: str, reason: str) -> str: + if not Path(trajectory_file).is_file(): + error_msg = ( + f"Trajectory file not found: {trajectory_file}\n" + "Use retrieve_available_trajectories_tool to see available files." + ) + logger.error(error_msg) + return error_msg + + try: + cache_file = CacheWriter.read_cache_file(Path(trajectory_file)) + except Exception: + error_msg = f"Failed to read cache file {trajectory_file}" + logger.exception(error_msg) + return error_msg + + # Mark cache as invalid + cache_manager = CacheManager() + was_valid = cache_file.metadata.is_valid + + cache_manager.invalidate_cache(cache_file, reason=reason) + + # Write back to disk + try: + cache_path = Path(trajectory_file) + with cache_path.open("w") as f: + json.dump( + cache_file.model_dump(mode="json"), + f, + indent=2, + default=str, + ) + except Exception: + error_msg = f"Failed to write cache file {trajectory_file}" + logger.exception(error_msg) + return error_msg + + logger.info("Cache manually invalidated: %s", trajectory_file) + + if was_valid: + return ( + f"Successfully invalidated cache: {trajectory_file}\n" + f"Reason: {reason}\n" + "The cache will not appear in default trajectory listings. " + "Use revalidate_cache_tool to restore it if needed." + ) return ( - f"Successfully executed trajectory from {trajectory_file}. " - "Please verify the UI state." + f"Cache {trajectory_file} was already invalid.\n" + f"Updated invalidation reason to: {reason}" + ) + + +class VerifyCacheExecution(Tool): + """Tool for agent to explicitly report cache execution verification results.""" + + def __init__(self) -> None: + super().__init__( + name="verify_cache_execution", + description=( + "IMPORTANT: Call this tool immediately after reviewing a cached trajectory execution.\n\n" + "Report whether the cached execution successfully achieved the target system state.\n" + "You MUST call this tool to complete the cache verification process.\n\n" + "Set success=True if:\n" + "- The cached execution correctly achieved the intended goal\n" + "- The final state matches what was expected\n" + "- No corrections or additional actions were needed\n\n" + "Set success=False if:\n" + "- The execution did not achieve the target state\n" + "- You had to make corrections or perform additional actions\n" + "- The final state is incorrect or incomplete" + ), + input_schema={ + "type": "object", + "properties": { + "success": { + "type": "boolean", + "description": ( + "True if cached execution correctly " + "achieved target state, " + "False if execution was incorrect or " + "corrections were needed" + ), + }, + "verification_notes": { + "type": "string", + "description": ( + "Brief explanation of what you verified. " + "If success=False, describe what was " + "wrong and what corrections you made." + ), + }, + }, + "required": ["success", "verification_notes"], + }, ) + self.is_cacheable = False # Verification is not cacheable + self._agent: "Agent | None" = None + + def set_agent(self, agent: "Agent") -> None: + """Set agent reference for metadata updates.""" + self._agent = agent + + @override + @validate_call + def __call__(self, success: bool, verification_notes: str) -> str: + """Record cache verification result. + + Args: + success: Whether cache execution achieved target state + verification_notes: Explanation of verification result + + Returns: + Confirmation message + """ + logger.info( + "Cache verification reported: success=%s, notes=%s", + success, + verification_notes, + ) + + if not self._agent: + error_msg = "Agent not set. Cannot record verification result." + logger.error(error_msg) + return error_msg + + # Check if there's a cache file to update (more reliable than checking flag) + cache_file, cache_file_path = self._agent.get_cache_info() + if not (cache_file and cache_file_path): + warning_msg = ( + "No cache file to update. " + "Verification tool called without recent cache execution." + ) + logger.warning(warning_msg) + return warning_msg + + # Debug log if verification flag wasn't explicitly set + # (This can happen if verification is called directly without the flag, + # but we still proceed since we have the cache file) + if not self._agent.is_cache_verification_pending(): + logger.debug( + "Verification flag not set, but cache file exists. " + "This is normal for direct verification calls." + ) + + # Update cache metadata based on verification result + if success: + self._agent.update_cache_metadata_on_completion(success=True) + result_msg = ( + f"✓ Cache verification successful: {verification_notes}\n\n" + "The cached trajectory execution achieved the target " + "system state correctly. " + "You may now proceed with any additional tasks or " + "conclude the execution." + ) + logger.info(result_msg) + else: + error_msg = ( + f"Cache execution did not lead to target system state: " + f"{verification_notes}" + ) + self._agent.update_cache_metadata_on_failure( + step_index=-1, # -1 indicates verification failure + error_message=error_msg, + ) + result_msg = ( + f"✗ Cache verification failed: {verification_notes}\n\n" + "The cached trajectory did not achieve the target " + "system state correctly. " + "You should now continue to complete the task manually " + "from the current state. " + "Use your tools to finish achieving the goal, taking into " + "account what the cache attempted to do and what " + "corrections are needed." + ) + logger.warning(result_msg) + + # Clear verification flag and cache references after verification + self._agent.clear_cache_state() + + return result_msg diff --git a/src/askui/utils/cache_execution_manager.py b/src/askui/utils/cache_execution_manager.py new file mode 100644 index 00000000..869b3b37 --- /dev/null +++ b/src/askui/utils/cache_execution_manager.py @@ -0,0 +1,339 @@ +"""Manager for cache execution flow and state.""" + +import json +import logging +from pathlib import Path +from typing import TYPE_CHECKING, Callable + +from askui.models.shared.agent_message_param import MessageParam, TextBlockParam +from askui.models.shared.agent_on_message_cb import OnMessageCb +from askui.models.shared.settings import ActSettings +from askui.models.shared.tools import ToolCollection +from askui.models.shared.truncation_strategies import TruncationStrategy +from askui.reporting import Reporter +from askui.utils.trajectory_executor import ExecutionResult + +if TYPE_CHECKING: + from askui.models.shared.settings import CacheFile + from askui.utils.trajectory_executor import TrajectoryExecutor + +# Type for the step callback function (matches Agent._step signature) +StepCallback = Callable[ + [str, OnMessageCb, ActSettings, ToolCollection, TruncationStrategy], None +] + +logger = logging.getLogger(__name__) + + +class CacheExecutionManager: + """Manages cache execution flow, state, and metadata updates. + + This class encapsulates all cache-related logic, keeping the Agent class + focused on conversation orchestration. + """ + + def __init__(self, reporter: Reporter) -> None: + """Initialize cache execution manager. + + Args: + reporter: Reporter for logging messages and actions + """ + self._reporter = reporter + # Cache execution state + self._executing_from_cache: bool = False + self._cache_executor: "TrajectoryExecutor | None" = None + self._cache_file_path: str | None = None + self._cache_file: "CacheFile | None" = None + # Track cache verification after execution completes + self._cache_verification_pending: bool = False + + def reset_state(self) -> None: + """Reset cache execution state.""" + self._executing_from_cache = False + self._cache_executor = None + self._cache_file_path = None + self._cache_file = None + self._cache_verification_pending = False + logger.debug("Reset cache execution state") + + def activate_execution( + self, + executor: "TrajectoryExecutor", + cache_file: "CacheFile", + cache_file_path: str, + ) -> None: + """Activate cache execution mode. + + Args: + executor: The trajectory executor to use + cache_file: The cache file being executed + cache_file_path: Path to the cache file + """ + self._cache_executor = executor + self._cache_file = cache_file + self._cache_file_path = cache_file_path + self._executing_from_cache = True + + def get_cache_info(self) -> tuple["CacheFile | None", str | None]: + """Get current cache file and path. + + Returns: + Tuple of (cache_file, cache_file_path) + """ + return (self._cache_file, self._cache_file_path) + + def is_cache_verification_pending(self) -> bool: + """Check if cache verification is pending. + + Returns: + True if verification is pending + """ + return self._cache_verification_pending + + def clear_cache_state(self) -> None: + """Clear cache execution state after verification.""" + self._cache_verification_pending = False + self._cache_file = None + self._cache_file_path = None + + def handle_execution_step( + self, + on_message: OnMessageCb, + truncation_strategy: TruncationStrategy, + model: str, + tool_collection: ToolCollection, + settings: ActSettings, + agent_class_name: str, + step_callback: StepCallback, + ) -> bool: + """Handle cache execution step. + + Args: + on_message: Callback for messages + truncation_strategy: Message truncation strategy + model: Model to use + tool_collection: Available tools + settings: Agent settings + agent_class_name: Name of agent class for reporting + step_callback: Callback to continue agent step + + Returns: + True if cache step was handled and recursion occurred, + False if should continue with normal flow + """ + if not (self._executing_from_cache and self._cache_executor): + return False + + logger.debug("Executing next step from cache") + result: ExecutionResult = self._cache_executor.execute_next_step() + + if result.status == "SUCCESS": + return self._handle_cache_success( + result, + on_message, + truncation_strategy, + model, + tool_collection, + settings, + agent_class_name, + step_callback, + ) + if result.status == "NEEDS_AGENT": + return self._handle_cache_needs_agent(result) + if result.status == "COMPLETED": + return self._handle_cache_completed(truncation_strategy) + # result.status == "FAILED" + return self._handle_cache_failed(result) + + def _handle_cache_success( + self, + result: ExecutionResult, + on_message: OnMessageCb, + truncation_strategy: TruncationStrategy, + model: str, + tool_collection: ToolCollection, + settings: ActSettings, + agent_class_name: str, + step_callback: StepCallback, + ) -> bool: + """Handle successful cache step execution.""" + if len(result.message_history) < 2: + return False + + assistant_msg = result.message_history[-2] + user_msg = result.message_history[-1] + + # Add assistant message (tool use) + message_by_assistant = self._call_on_message( + on_message, assistant_msg, truncation_strategy.messages + ) + if message_by_assistant is None: + return True + + truncation_strategy.append_message(message_by_assistant) + self._reporter.add_message( + agent_class_name, message_by_assistant.model_dump(mode="json") + ) + + # Add user message (tool result) + user_msg_processed = self._call_on_message( + on_message, user_msg, truncation_strategy.messages + ) + if user_msg_processed is None: + return True + + truncation_strategy.append_message(user_msg_processed) + + # Continue with next step recursively + step_callback( + model, + on_message, + settings, + tool_collection, + truncation_strategy, + ) + return True + + def _handle_cache_needs_agent(self, result: ExecutionResult) -> bool: + """Handle cache execution pausing for non-cacheable tool.""" + logger.info( + "Paused cache execution at step %d " + "(non-cacheable tool - agent will handle this step)", + result.step_index, + ) + self._executing_from_cache = False + return False # Fall through to normal agent API call + + def _handle_cache_completed( + self, truncation_strategy: TruncationStrategy + ) -> bool: + """Handle cache execution completion.""" + logger.info( + "✓ Cache trajectory execution completed - " + "requesting agent verification" + ) + self._executing_from_cache = False + self._cache_verification_pending = True + + # Inject verification request message + verification_request = MessageParam( + role="user", + content=[ + TextBlockParam( + type="text", + text=( + "The cached trajectory execution has completed. " + "Please verify if the execution correctly achieved " + "the target system state. " + "Use the verify_cache_execution tool to report " + "your verification result." + ), + ) + ], + ) + truncation_strategy.append_message(verification_request) + logger.debug("Injected cache verification request message") + return False # Fall through to let agent verify execution + + def _handle_cache_failed(self, result: ExecutionResult) -> bool: + """Handle cache execution failure.""" + logger.error( + "✗ Cache execution failed at step %d: %s", + result.step_index, + result.error_message, + ) + self._executing_from_cache = False + + # Update cache metadata + if self._cache_file and self._cache_file_path: + self.update_metadata_on_failure( + step_index=result.step_index, + error_message=result.error_message or "Unknown error", + ) + + return False # Fall through to let agent continue + + def _call_on_message( + self, + on_message: OnMessageCb | None, + message: MessageParam, + messages: list[MessageParam], + ) -> MessageParam | None: + """Call on_message callback if provided.""" + if on_message is None: + return message + from askui.models.shared.agent_on_message_cb import OnMessageCbParam + + return on_message(OnMessageCbParam(message=message, messages=messages)) + + def update_metadata_on_completion(self, success: bool) -> None: + """Update cache metadata after execution completion. + + Args: + success: Whether the execution was successful + """ + if not self._cache_file or not self._cache_file_path: + return + + try: + from askui.utils.cache_manager import CacheManager + + cache_manager = CacheManager() + cache_manager.record_execution_attempt(self._cache_file, success=success) + + # Write updated metadata back to disk + cache_path = Path(self._cache_file_path) + with cache_path.open("w") as f: + json.dump( + self._cache_file.model_dump(mode="json"), + f, + indent=2, + default=str, + ) + logger.debug("Updated cache metadata: %s", cache_path.name) + except Exception: + logger.exception("Failed to update cache metadata") + + def update_metadata_on_failure( + self, step_index: int, error_message: str + ) -> None: + """Update cache metadata after execution failure. + + Args: + step_index: The step index where failure occurred + error_message: The error message + """ + if not self._cache_file or not self._cache_file_path: + return + + try: + from askui.utils.cache_manager import CacheManager + + cache_manager = CacheManager() + cache_manager.record_execution_attempt(self._cache_file, success=False) + cache_manager.record_step_failure( + self._cache_file, + step_index=step_index, + error_message=error_message, + ) + + # Check if cache should be invalidated + should_inv, reason = cache_manager.should_invalidate( + self._cache_file, step_index=step_index + ) + if should_inv and reason: + logger.warning("Cache invalidated: %s", reason) + cache_manager.invalidate_cache(self._cache_file, reason=reason) + + # Write updated metadata back to disk + cache_path = Path(self._cache_file_path) + with cache_path.open("w") as f: + json.dump( + self._cache_file.model_dump(mode="json"), + f, + indent=2, + default=str, + ) + logger.debug("Updated cache metadata after failure: %s", cache_path.name) + except Exception: + logger.exception("Failed to update cache metadata") diff --git a/src/askui/utils/cache_manager.py b/src/askui/utils/cache_manager.py new file mode 100644 index 00000000..ed94a0cb --- /dev/null +++ b/src/askui/utils/cache_manager.py @@ -0,0 +1,145 @@ +"""Cache management utilities for tracking execution and invalidation. + +This module provides the CacheManager class that handles cache metadata updates, +failure tracking, and cache invalidation using configurable validation strategies. +""" + +from datetime import datetime, timezone +from typing import Optional + +from askui.models.shared.settings import CacheFailure, CacheFile +from askui.utils.cache_validator import ( + CacheValidator, + CompositeCacheValidator, + StaleCacheValidator, + StepFailureCountValidator, + TotalFailureRateValidator, +) + + +class CacheManager: + """Manages cache metadata updates and validation. + + Uses a CompositeCacheValidator for extensible invalidation logic. + Users can provide custom validators via the validator parameter. + + Example: + # Using default validators + manager = CacheManager() + + # Using custom validator + custom_validator = CompositeCacheValidator([ + StepFailureCountValidator(max_failures_per_step=5), + MyCustomValidator() + ]) + manager = CacheManager(validator=custom_validator) + """ + + def __init__(self, validator: Optional[CacheValidator] = None): + """Initialize cache manager. + + Args: + validator: Custom validator or None to use default composite validator + """ + if validator is None: + # Default validator with built-in strategies + self.validator = CompositeCacheValidator( + [ + StepFailureCountValidator(max_failures_per_step=3), + TotalFailureRateValidator(min_attempts=10, max_failure_rate=0.5), + StaleCacheValidator(max_age_days=30), + ] + ) + else: + self.validator = validator + + def record_execution_attempt( + self, + cache_file: CacheFile, + success: bool, + failure_info: Optional[CacheFailure] = None, + ) -> None: + """Record an execution attempt and update metadata. + + Args: + cache_file: The cache file to update + success: Whether the execution was successful + failure_info: Optional failure information if execution failed + """ + cache_file.metadata.execution_attempts += 1 + + if success: + cache_file.metadata.last_executed_at = datetime.now(tz=timezone.utc) + # Successful execution - metadata updated + elif failure_info: + cache_file.metadata.failures.append(failure_info) + + def record_step_failure( + self, cache_file: CacheFile, step_index: int, error_message: str + ) -> None: + """Record a failure at specific step. + + Args: + cache_file: The cache file to update + step_index: Index of the step that failed + error_message: Description of the error + """ + failure = CacheFailure( + timestamp=datetime.now(tz=timezone.utc), + step_index=step_index, + error_message=error_message, + failure_count_at_step=self.get_failure_count_for_step( + cache_file, step_index + ) + + 1, + ) + cache_file.metadata.failures.append(failure) + + def should_invalidate( + self, cache_file: CacheFile, step_index: Optional[int] = None + ) -> tuple[bool, Optional[str]]: + """Check if cache should be invalidated using the validator. + + Args: + cache_file: The cache file to check + step_index: Optional step index where failure occurred + + Returns: + Tuple of (should_invalidate: bool, reason: Optional[str]) + """ + return self.validator.should_invalidate(cache_file, step_index) + + def invalidate_cache(self, cache_file: CacheFile, reason: str) -> None: + """Mark cache as invalid. + + Args: + cache_file: The cache file to invalidate + reason: Reason for invalidation + """ + cache_file.metadata.is_valid = False + cache_file.metadata.invalidation_reason = reason + + def mark_cache_valid(self, cache_file: CacheFile) -> None: + """Mark cache as valid. + + Args: + cache_file: The cache file to mark as valid + """ + cache_file.metadata.is_valid = True + cache_file.metadata.invalidation_reason = None + + def get_failure_count_for_step( + self, cache_file: CacheFile, step_index: int + ) -> int: + """Get number of failures for a specific step. + + Args: + cache_file: The cache file to check + step_index: Index of the step to count failures for + + Returns: + Number of failures at this step + """ + return sum( + 1 for f in cache_file.metadata.failures if f.step_index == step_index + ) diff --git a/src/askui/utils/cache_migration.py b/src/askui/utils/cache_migration.py new file mode 100644 index 00000000..5be27edf --- /dev/null +++ b/src/askui/utils/cache_migration.py @@ -0,0 +1,303 @@ +"""Cache migration utilities for converting v0.0 caches to v0.1 format. + +This module provides tools to batch migrate existing v0.0 cache files to the new +v0.1 format with metadata support. Individual files are automatically migrated on +first read by CacheWriter, but this utility is useful for: + +1. Batch migration of all caches in a directory +2. Pre-migration without executing caches +3. Verification of migration success +4. Backup creation before migration + +Usage: + # Migrate all caches in a directory + python -m askui.utils.cache_migration --cache-dir .cache + + # Dry run (don't modify files) + python -m askui.utils.cache_migration --cache-dir .cache --dry-run + + # Create backups before migration + python -m askui.utils.cache_migration --cache-dir .cache --backup +""" + +import argparse +import json +import logging +import shutil +from pathlib import Path +from typing import Any + +from askui.utils.cache_writer import CacheWriter + +logger = logging.getLogger(__name__) + + +class CacheMigrationError(Exception): + """Raised when cache migration fails.""" + + +class CacheMigration: + """Handles migration of cache files from v0.0 to v0.1 format.""" + + def __init__( + self, + backup: bool = False, + backup_suffix: str = ".v1.backup", + ): + """Initialize cache migration utility. + + Args: + backup: Whether to create backup files before migration + backup_suffix: Suffix to add to backup files + """ + self.backup = backup + self.backup_suffix = backup_suffix + self.migrated_count = 0 + self.skipped_count = 0 + self.error_count = 0 + + def migrate_file(self, file_path: Path, dry_run: bool = False) -> tuple[bool, str]: + """Migrate a single cache file from v0.0 to v0.1. + + Args: + file_path: Path to the cache file + dry_run: If True, don't modify the file + + Returns: + Tuple of (success: bool, message: str) + """ + if not file_path.is_file(): + return False, f"File not found: {file_path}" + + try: + # Read the file + with open(file_path, "r") as f: + data = json.load(f) + + # Check if already v0.1 + if isinstance(data, dict) and "metadata" in data: + version = data.get("metadata", {}).get("version") + if version == "0.1": + return False, f"Already v0.1: {file_path.name}" + + # Use CacheWriter to read (automatically migrates) + try: + cache_file = CacheWriter.read_cache_file(file_path) + except Exception as e: + return False, f"Failed to read cache: {str(e)}" + + # Verify it's now v0.1 + if cache_file.metadata.version != "0.1": + return ( + False, + f"Migration failed: Version is {cache_file.metadata.version}", + ) + + if dry_run: + return True, f"Would migrate: {file_path.name}" + + # Create backup if requested + if self.backup: + backup_path = file_path.with_suffix( + file_path.suffix + self.backup_suffix + ) + shutil.copy2(file_path, backup_path) + logger.debug(f"Created backup: {backup_path}") + + # Write migrated version back + with open(file_path, "w") as f: + json.dump(cache_file.model_dump(mode="json"), f, indent=2, default=str) + + return True, f"Migrated: {file_path.name}" + + except Exception as e: + logger.error(f"Error migrating {file_path}: {e}", exc_info=True) + return False, f"Error: {str(e)}" + + def migrate_directory( + self, + cache_dir: Path, + file_pattern: str = "*.json", + dry_run: bool = False, + ) -> dict[str, Any]: + """Migrate all cache files in a directory. + + Args: + cache_dir: Directory containing cache files + file_pattern: Glob pattern for cache files + dry_run: If True, don't modify files + + Returns: + Dictionary with migration statistics + """ + if not cache_dir.is_dir(): + raise CacheMigrationError(f"Directory not found: {cache_dir}") + + # Find all cache files + cache_files = list(cache_dir.glob(file_pattern)) + + if not cache_files: + logger.warning(f"No cache files found in {cache_dir}") + return { + "migrated": 0, + "skipped": 0, + "errors": 0, + "total": 0, + } + + logger.info(f"Found {len(cache_files)} cache files in {cache_dir}") + + # Reset counters + self.migrated_count = 0 + self.skipped_count = 0 + self.error_count = 0 + + # Migrate each file + results = [] + for file_path in cache_files: + success, message = self.migrate_file(file_path, dry_run=dry_run) + + if success: + self.migrated_count += 1 + logger.info(f"✓ {message}") + elif "Already v0.1" in message: + self.skipped_count += 1 + logger.debug(f"⊘ {message}") + else: + self.error_count += 1 + logger.error(f"✗ {message}") + + results.append( + {"file": file_path.name, "success": success, "message": message} + ) + + # Log summary + logger.info(f"\n{'=' * 60}") + logger.info("Migration Summary:") + logger.info(f" Total files: {len(cache_files)}") + logger.info(f" Migrated: {self.migrated_count}") + logger.info(f" Already v0.1: {self.skipped_count}") + logger.info(f" Errors: {self.error_count}") + logger.info(f"{'=' * 60}\n") + + return { + "migrated": self.migrated_count, + "skipped": self.skipped_count, + "errors": self.error_count, + "total": len(cache_files), + "results": results, + } + + +def main() -> int: + """CLI entry point for cache migration. + + Returns: + Exit code (0 for success, 1 for failure) + """ + parser = argparse.ArgumentParser( + description="Migrate cache files from v0.0 to v0.1 format", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Migrate all caches in .cache directory + python -m askui.utils.cache_migration --cache-dir .cache + + # Dry run (show what would be migrated) + python -m askui.utils.cache_migration --cache-dir .cache --dry-run + + # Create backups before migration + python -m askui.utils.cache_migration --cache-dir .cache --backup + + # Custom file pattern + python -m askui.utils.cache_migration --cache-dir .cache --pattern "test_*.json" + """, + ) + + parser.add_argument( + "--cache-dir", + type=str, + required=True, + help="Directory containing cache files to migrate", + ) + + parser.add_argument( + "--pattern", + type=str, + default="*.json", + help="Glob pattern for cache files (default: *.json)", + ) + + parser.add_argument( + "--dry-run", + action="store_true", + help="Show what would be migrated without modifying files", + ) + + parser.add_argument( + "--backup", + action="store_true", + help="Create backup files before migration (adds .v1.backup suffix)", + ) + + parser.add_argument( + "--backup-suffix", + type=str, + default=".v1.backup", + help="Suffix for backup files (default: .v1.backup)", + ) + + parser.add_argument( + "-v", + "--verbose", + action="store_true", + help="Enable verbose logging", + ) + + args = parser.parse_args() + + # Configure logging + logging.basicConfig( + level=logging.DEBUG if args.verbose else logging.INFO, + format="%(levelname)s: %(message)s", + ) + + try: + cache_dir = Path(args.cache_dir) + + if args.dry_run: + logger.info("DRY RUN MODE - No files will be modified\n") + + # Create migration utility + migration = CacheMigration( + backup=args.backup, + backup_suffix=args.backup_suffix, + ) + + # Perform migration + stats = migration.migrate_directory( + cache_dir=cache_dir, + file_pattern=args.pattern, + dry_run=args.dry_run, + ) + + # Return success if no errors + if stats["errors"] == 0: + logger.info("✓ Migration completed successfully!") + return 0 + logger.error(f"✗ Migration completed with {stats['errors']} errors") + return 1 + + except CacheMigrationError as e: + logger.error(f"Migration failed: {e}") + return 1 + except KeyboardInterrupt: + logger.info("\nMigration cancelled by user") + return 1 + except Exception as e: + logger.error(f"Unexpected error: {e}", exc_info=True) + return 1 + + +if __name__ == "__main__": + exit(main()) diff --git a/src/askui/utils/cache_validator.py b/src/askui/utils/cache_validator.py new file mode 100644 index 00000000..7cc811ac --- /dev/null +++ b/src/askui/utils/cache_validator.py @@ -0,0 +1,242 @@ +"""Cache validation strategies for automatic cache invalidation. + +This module provides an extensible validator pattern that allows users to +define custom cache invalidation logic. The system includes built-in validators +for common scenarios like step failure counts, overall failure rates, and stale caches. +""" + +from abc import ABC, abstractmethod +from datetime import datetime, timedelta, timezone +from typing import Optional + +from askui.models.shared.settings import CacheFile + + +class CacheValidator(ABC): + """Abstract base class for cache validation strategies. + + Users can implement custom validators by subclassing this and implementing + the should_invalidate method. + """ + + @abstractmethod + def should_invalidate( + self, cache_file: CacheFile, step_index: Optional[int] = None + ) -> tuple[bool, Optional[str]]: + """Check if cache should be invalidated. + + Args: + cache_file: The cache file with metadata and trajectory + step_index: Optional step index where failure occurred + + Returns: + Tuple of (should_invalidate: bool, reason: Optional[str]) + """ + pass + + @abstractmethod + def get_name(self) -> str: + """Return validator name for logging/debugging.""" + pass + + +class CompositeCacheValidator(CacheValidator): + """Composite validator that combines multiple validation strategies. + + Invalidates cache if ANY of the validators returns True. + Users can add custom validators via add_validator(). + """ + + def __init__(self, validators: Optional[list[CacheValidator]] = None): + """Initialize composite validator. + + Args: + validators: Optional list of validators to include + """ + self.validators: list[CacheValidator] = validators or [] + + def add_validator(self, validator: CacheValidator) -> None: + """Add a validator to the composite. + + Args: + validator: The validator to add + """ + self.validators.append(validator) + + def should_invalidate( + self, cache_file: CacheFile, step_index: Optional[int] = None + ) -> tuple[bool, Optional[str]]: + """Check all validators, invalidate if any returns True. + + Args: + cache_file: The cache file with metadata and trajectory + step_index: Optional step index where failure occurred + + Returns: + Tuple of (should_invalidate: bool, reason: Optional[str]) + If multiple validators trigger, reasons are combined with "; " + """ + reasons = [] + for validator in self.validators: + should_inv, reason = validator.should_invalidate(cache_file, step_index) + if should_inv and reason: + reasons.append(f"{validator.get_name()}: {reason}") + + if reasons: + return True, "; ".join(reasons) + return False, None + + def get_name(self) -> str: + """Return validator name.""" + return "CompositeValidator" + + +# Built-in validators + + +class StepFailureCountValidator(CacheValidator): + """Invalidate if same step fails too many times. + + This validator counts how many times a specific step has failed + and invalidates the cache if it exceeds the threshold. + """ + + def __init__(self, max_failures_per_step: int = 3): + """Initialize validator. + + Args: + max_failures_per_step: Maximum number of failures allowed per step + """ + self.max_failures_per_step = max_failures_per_step + + def should_invalidate( + self, cache_file: CacheFile, step_index: Optional[int] = None + ) -> tuple[bool, Optional[str]]: + """Check if step has failed too many times. + + Args: + cache_file: The cache file with metadata and trajectory + step_index: The step index to check (required for this validator) + + Returns: + Tuple of (should_invalidate: bool, reason: Optional[str]) + """ + if step_index is None: + return False, None + + # Count failures at this specific step + failures_at_step = sum( + 1 for f in cache_file.metadata.failures if f.step_index == step_index + ) + + if failures_at_step >= self.max_failures_per_step: + return ( + True, + f"Step {step_index} failed {failures_at_step} times (max: {self.max_failures_per_step})", + ) + return False, None + + def get_name(self) -> str: + """Return validator name.""" + return "StepFailureCount" + + +class TotalFailureRateValidator(CacheValidator): + """Invalidate if overall failure rate is too high. + + This validator calculates the ratio of failures to execution attempts + and invalidates if the rate exceeds the threshold after a minimum + number of attempts. + """ + + def __init__(self, min_attempts: int = 10, max_failure_rate: float = 0.5): + """Initialize validator. + + Args: + min_attempts: Minimum execution attempts before checking rate + max_failure_rate: Maximum acceptable failure rate (0.0 to 1.0) + """ + self.min_attempts = min_attempts + self.max_failure_rate = max_failure_rate + + def should_invalidate( + self, cache_file: CacheFile, step_index: Optional[int] = None + ) -> tuple[bool, Optional[str]]: + """Check if overall failure rate is too high. + + Args: + cache_file: The cache file with metadata and trajectory + step_index: Unused for this validator + + Returns: + Tuple of (should_invalidate: bool, reason: Optional[str]) + """ + attempts = cache_file.metadata.execution_attempts + if attempts < self.min_attempts: + return False, None + + failures = len(cache_file.metadata.failures) + failure_rate = failures / attempts if attempts > 0 else 0.0 + + if failure_rate > self.max_failure_rate: + return ( + True, + f"Failure rate {failure_rate:.1%} exceeds {self.max_failure_rate:.1%} after {attempts} attempts", + ) + return False, None + + def get_name(self) -> str: + """Return validator name.""" + return "TotalFailureRate" + + +class StaleCacheValidator(CacheValidator): + """Invalidate if cache is old and has failures. + + This validator checks if a cache hasn't been successfully executed + in a long time AND has failures. Caches without failures are not + considered stale regardless of age. + """ + + def __init__(self, max_age_days: int = 30): + """Initialize validator. + + Args: + max_age_days: Maximum age in days for cache with failures + """ + self.max_age_days = max_age_days + + def should_invalidate( + self, cache_file: CacheFile, step_index: Optional[int] = None + ) -> tuple[bool, Optional[str]]: + """Check if cache is stale (old + has failures). + + Args: + cache_file: The cache file with metadata and trajectory + step_index: Unused for this validator + + Returns: + Tuple of (should_invalidate: bool, reason: Optional[str]) + """ + if not cache_file.metadata.last_executed_at: + return False, None + + if not cache_file.metadata.failures: + return False, None # No failures, age doesn't matter + + # Ensure last_executed_at is timezone-aware + last_executed = cache_file.metadata.last_executed_at + if last_executed.tzinfo is None: + last_executed = last_executed.replace(tzinfo=timezone.utc) + + age = datetime.now(tz=timezone.utc) - last_executed + if age > timedelta(days=self.max_age_days): + return ( + True, + f"Cache not successfully executed in {age.days} days and has failures", + ) + return False, None + + def get_name(self) -> str: + """Return validator name.""" + return "StaleCache" diff --git a/src/askui/utils/cache_writer.py b/src/askui/utils/cache_writer.py index 36508c73..835f5878 100644 --- a/src/askui/utils/cache_writer.py +++ b/src/askui/utils/cache_writer.py @@ -3,14 +3,28 @@ from datetime import datetime, timezone from pathlib import Path +from askui.locators.serializers import VlmLocatorSerializer +from askui.models.anthropic.messages_api import AnthropicMessagesApi +from askui.models.model_router import create_api_client from askui.models.shared.agent_message_param import MessageParam, ToolUseBlockParam from askui.models.shared.agent_on_message_cb import OnMessageCbParam +from askui.models.shared.settings import CacheFile, CacheMetadata, CachingSettings +from askui.models.shared.tools import ToolCollection +from askui.utils.placeholder_handler import PlaceholderHandler +from askui.utils.placeholder_identifier import identify_placeholders logger = logging.getLogger(__name__) class CacheWriter: - def __init__(self, cache_dir: str = ".cache", file_name: str = "") -> None: + def __init__( + self, + cache_dir: str = ".cache", + file_name: str = "", + caching_settings: CachingSettings | None = None, + toolbox: ToolCollection | None = None, + goal: str | None = None, + ) -> None: self.cache_dir = Path(cache_dir) self.cache_dir.mkdir(exist_ok=True) self.messages: list[ToolUseBlockParam] = [] @@ -18,6 +32,16 @@ def __init__(self, cache_dir: str = ".cache", file_name: str = "") -> None: file_name += ".json" self.file_name = file_name self.was_cached_execution = False + self._caching_settings = caching_settings or CachingSettings() + self._goal = goal + self._toolbox: ToolCollection | None = None + # Get messages_api for placeholder identification + self._messages_api = AnthropicMessagesApi( + client=create_api_client(api_provider="askui"), + locator_serializer=VlmLocatorSerializer(), + ) + # Set toolbox for cache writer so it can check which tools are cacheable + self._toolbox = toolbox def add_message_cb(self, param: OnMessageCbParam) -> MessageParam: """Add a message to cache.""" @@ -50,21 +74,198 @@ def generate(self) -> None: if self.was_cached_execution: logger.info("Will not write cache file as this was a cached execution") return + if not self.file_name: self.file_name = ( f"cached_trajectory_{datetime.now(tz=timezone.utc):%Y%m%d%H%M%S%f}.json" ) + cache_file_path = self.cache_dir / self.file_name - messages_json = [m.model_dump() for m in self.messages] - with cache_file_path.open("w", encoding="utf-8") as f: - json.dump(messages_json, f, indent=4) - info_msg = f"Cache File written at {str(cache_file_path)}" - logger.info(info_msg) + goal_to_save, trajectory_to_save, placeholders_dict = ( + self._replace_placeholders() + ) + + if self._toolbox is not None: + trajectory_to_save = self._blank_non_cacheable_tool_inputs( + trajectory_to_save + ) + else: + logger.info("No toolbox set, skipping non-cacheable tool input blanking") + + self._generate_cache_file( + goal_to_save, trajectory_to_save, placeholders_dict, cache_file_path + ) self.reset() + def _replace_placeholders( + self, + ) -> tuple[str | None, list[ToolUseBlockParam], dict[str, str]]: + # Determine which trajectory and placeholders to use + trajectory_to_save = self.messages + goal_to_save = self._goal + placeholders_dict: dict[str, str] = {} + + if self._caching_settings.auto_identify_placeholders and self.messages: + placeholders_dict, placeholder_definitions = identify_placeholders( + trajectory=self.messages, + messages_api=self._messages_api, + ) + n_placeholders = len(placeholder_definitions) + # Replace actual values with {{placeholder_name}} syntax in trajectory + if placeholder_definitions: + trajectory_to_save = ( + PlaceholderHandler.replace_values_with_placeholders( + trajectory=self.messages, + placeholder_definitions=placeholder_definitions, + ) + ) + + # Also apply placeholder replacement to the goal + if self._goal: + goal_to_save = self._goal + # Build replacement map: value -> placeholder syntax + replacements = { + str(p.value): f"{{{{{p.name}}}}}" + for p in placeholder_definitions + } + # Sort by length descending to replace longer matches first + for actual_value in sorted( + replacements.keys(), key=len, reverse=True + ): + if actual_value in goal_to_save: + goal_to_save = goal_to_save.replace( + actual_value, replacements[actual_value] + ) + else: + # Manual placeholder extraction + placeholder_names = PlaceholderHandler.extract_placeholders(self.messages) + placeholders_dict = { + name: f"Placeholder for {name}" # Generic description + for name in placeholder_names + } + n_placeholders = len(placeholder_names) + logger.info(f"Replaced {n_placeholders} placeholder values in trajectory") + return goal_to_save, trajectory_to_save, placeholders_dict + + def _blank_non_cacheable_tool_inputs( + self, trajectory: list[ToolUseBlockParam] + ) -> list[ToolUseBlockParam]: + """Blank out input fields for non-cacheable tools to save space. + + For tools marked as is_cacheable=False, we replace their input with an + empty dict since we won't be executing them from cache anyway. + + Args: + trajectory: The trajectory to process + + Returns: + New trajectory with non-cacheable tool inputs blanked out + """ + if self._toolbox is None: + return trajectory + + blanked_count = 0 + result: list[ToolUseBlockParam] = [] + for tool_block in trajectory: + # Check if this tool is cacheable + tool = self._toolbox._tool_map.get(tool_block.name) + + # If tool is not cacheable, blank out its input + if tool is not None and not tool.is_cacheable: + logger.debug( + f"Blanking input for non-cacheable tool: {tool_block.name}" + ) + blanked_count += 1 + result.append( + ToolUseBlockParam( + id=tool_block.id, + name=tool_block.name, + input={}, # Blank out the input + type=tool_block.type, + cache_control=tool_block.cache_control, + ) + ) + else: + # Keep the tool block as-is + result.append(tool_block) + + if blanked_count > 0: + logger.info( + f"Blanked inputs for {blanked_count} non-cacheable tool(s) to save space" + ) + + return result + + def _generate_cache_file( + self, + goal_to_save: str | None, + trajectory_to_save: list[ToolUseBlockParam], + placeholders_dict: dict[str, str], + cache_file_path: Path, + ) -> None: + cache_file = CacheFile( + metadata=CacheMetadata( + version="0.1", + created_at=datetime.now(tz=timezone.utc), + goal=goal_to_save, + ), + trajectory=trajectory_to_save, + placeholders=placeholders_dict, + ) + + with cache_file_path.open("w", encoding="utf-8") as f: + json.dump(cache_file.model_dump(mode="json"), f, indent=4) + logger.info(f"Cache file successfully written: {cache_file_path} ") + @staticmethod - def read_cache_file(cache_file_path: Path) -> list[ToolUseBlockParam]: + def read_cache_file(cache_file_path: Path) -> CacheFile: + """Read cache file with backward compatibility for v0.0 format. + + Returns: + CacheFile object with metadata and trajectory + """ + logger.debug(f"Reading cache file: {cache_file_path}") with cache_file_path.open("r", encoding="utf-8") as f: - raw_trajectory = json.load(f) - return [ToolUseBlockParam(**step) for step in raw_trajectory] + raw_data = json.load(f) + + # Detect format version + if isinstance(raw_data, list): + # v0.0 format: just a list of tool use blocks + logger.info( + f"Detected v0.0 cache format in {cache_file_path.name}, migrating to v0.1" + ) + trajectory = [ToolUseBlockParam(**step) for step in raw_data] + # Create default metadata for v0.0 files (migrated to v0.1 format) + cache_file = CacheFile( + metadata=CacheMetadata( + version="0.1", # Migrated from v0.0 to v0.1 format + created_at=datetime.fromtimestamp( + cache_file_path.stat().st_ctime, tz=timezone.utc + ), + ), + trajectory=trajectory, + placeholders={}, + ) + logger.info( + f"Successfully loaded and migrated v0.0 cache: {len(trajectory)} steps, 0 placeholders" + ) + return cache_file + if isinstance(raw_data, dict) and "metadata" in raw_data: + # v0.1 format: structured with metadata + cache_file = CacheFile(**raw_data) + logger.info( + f"Successfully loaded v0.1 cache: {len(cache_file.trajectory)} steps, " + f"{len(cache_file.placeholders)} placeholders" + ) + if cache_file.metadata.goal: + logger.debug(f"Cache goal: {cache_file.metadata.goal}") + return cache_file + logger.error( + f"Unknown cache file format in {cache_file_path.name}. " + "Expected either a list (v0.0) or dict with 'metadata' key (v0.1)." + ) + raise ValueError( + f"Unknown cache file format in {cache_file_path}. " + "Expected either a list (v0.0) or dict with 'metadata' key (v0.1)." + ) diff --git a/src/askui/utils/placeholder_handler.py b/src/askui/utils/placeholder_handler.py new file mode 100644 index 00000000..62915db0 --- /dev/null +++ b/src/askui/utils/placeholder_handler.py @@ -0,0 +1,298 @@ +"""Placeholder handling for cache trajectories. + +This module provides utilities for detecting, validating, and substituting +placeholders in cached trajectories. Placeholders use the {{variable_name}} +syntax and allow dynamic values to be injected during cache execution. +""" + +import re +from typing import Any + +from askui.models.shared.agent_message_param import ToolUseBlockParam + +# Regex pattern for matching placeholders: {{variable_name}} +# Allows alphanumeric characters and underscores, must start with letter/underscore +PLACEHOLDER_PATTERN = r"\{\{([a-zA-Z_][a-zA-Z0-9_]*)\}\}" + + +class PlaceholderHandler: + """Handler for placeholder detection, validation, and substitution.""" + + @staticmethod + def extract_placeholders(trajectory: list[ToolUseBlockParam]) -> set[str]: + """Extract all placeholder names from a trajectory. + + Scans all tool inputs for {{placeholder_name}} patterns and returns + a set of unique placeholder names. + + Args: + trajectory: List of tool use blocks to scan + + Returns: + Set of unique placeholder names found in the trajectory + + Example: + >>> trajectory = [ + ... ToolUseBlockParam( + ... id="1", + ... name="computer", + ... input={"action": "type", "text": "Today is {{current_date}}"}, + ... type="tool_use" + ... ) + ... ] + >>> PlaceholderHandler.extract_placeholders(trajectory) + {'current_date'} + """ + placeholders: set[str] = set() + + for step in trajectory: + # Recursively find placeholders in the input object + placeholders.update( + PlaceholderHandler._extract_from_value(step.input) + ) + + return placeholders + + @staticmethod + def _extract_from_value(value: Any) -> set[str]: + """Recursively extract placeholders from a value. + + Args: + value: Any value (str, dict, list, etc.) to search for placeholders + + Returns: + Set of placeholder names found + """ + placeholders: set[str] = set() + + if isinstance(value, str): + # Find all matches in the string + matches = re.finditer(PLACEHOLDER_PATTERN, value) + placeholders.update(match.group(1) for match in matches) + elif isinstance(value, dict): + # Recursively search dict values + for v in value.values(): + placeholders.update(PlaceholderHandler._extract_from_value(v)) + elif isinstance(value, list): + # Recursively search list items + for item in value: + placeholders.update(PlaceholderHandler._extract_from_value(item)) + + return placeholders + + @staticmethod + def validate_placeholders( + trajectory: list[ToolUseBlockParam], provided_values: dict[str, str] + ) -> tuple[bool, list[str]]: + """Validate that all required placeholders have values. + + Args: + trajectory: List of tool use blocks containing placeholders + provided_values: Dict of placeholder names to their values + + Returns: + Tuple of (is_valid, missing_placeholders) + - is_valid: True if all placeholders have values, False otherwise + - missing_placeholders: List of placeholder names that are missing values + + Example: + >>> trajectory = [...] # Contains {{current_date}} and {{user_name}} + >>> is_valid, missing = PlaceholderHandler.validate_placeholders( + ... trajectory, + ... {"current_date": "2025-12-11"} + ... ) + >>> is_valid + False + >>> missing + ['user_name'] + """ + required_placeholders = PlaceholderHandler.extract_placeholders(trajectory) + missing = [ + name for name in required_placeholders if name not in provided_values + ] + + return len(missing) == 0, missing + + @staticmethod + def replace_values_with_placeholders( + trajectory: list[ToolUseBlockParam], + placeholder_definitions: list[Any], # list[PlaceholderDefinition] + ) -> list[ToolUseBlockParam]: + """Replace actual values in trajectory with {{placeholder_name}} syntax. + + This is the reverse of substitute_placeholders - it takes identified values + and replaces them with placeholder syntax for saving to cache. + + Args: + trajectory: The trajectory to templatize + placeholder_definitions: List of PlaceholderDefinition objects with + name and value attributes + + Returns: + New trajectory with values replaced by placeholders + + Example: + >>> trajectory = [ + ... ToolUseBlockParam( + ... id="1", + ... name="computer", + ... input={"action": "type", "text": "Date: 2025-12-11"}, + ... type="tool_use" + ... ) + ... ] + >>> placeholders = [ + ... PlaceholderDefinition( + ... name="current_date", + ... value="2025-12-11", + ... description="Current date" + ... ) + ... ] + >>> result = PlaceholderHandler.replace_values_with_placeholders( + ... trajectory, placeholders + ... ) + >>> result[0].input["text"] + 'Date: {{current_date}}' + """ + # Build replacement map: value -> placeholder name + replacements = { + str(p.value): f"{{{{{p.name}}}}}" for p in placeholder_definitions + } + + # Apply replacements to each tool block + templated_trajectory = [] + for tool_block in trajectory: + templated_input = PlaceholderHandler._replace_values_in_value( + tool_block.input, replacements + ) + + templated_trajectory.append( + ToolUseBlockParam( + id=tool_block.id, + name=tool_block.name, + input=templated_input, + type=tool_block.type, + cache_control=tool_block.cache_control, + ) + ) + + return templated_trajectory + + @staticmethod + def _replace_values_in_value( + value: Any, replacements: dict[str, str] + ) -> Any: + """Recursively replace actual values with placeholder syntax. + + Args: + value: Any value (str, dict, list, etc.) to process + replacements: Dict mapping actual values to placeholder syntax + + Returns: + New value with replacements applied + """ + if isinstance(value, str): + # Replace exact matches and substring matches + result = value + # Sort by length descending to replace longer matches first + # This prevents partial replacements + for actual_value in sorted(replacements.keys(), key=len, reverse=True): + if actual_value in result: + result = result.replace(actual_value, replacements[actual_value]) + return result + elif isinstance(value, dict): + # Recursively replace in dict values + return { + k: PlaceholderHandler._replace_values_in_value(v, replacements) + for k, v in value.items() + } + elif isinstance(value, list): + # Recursively replace in list items + return [ + PlaceholderHandler._replace_values_in_value(item, replacements) + for item in value + ] + else: + # For non-string types, check if the value matches exactly + str_value = str(value) + if str_value in replacements: + # Return the placeholder as a string + return replacements[str_value] + return value + + @staticmethod + def substitute_placeholders( + tool_block: ToolUseBlockParam, placeholder_values: dict[str, str] + ) -> ToolUseBlockParam: + """Replace placeholders in a tool block with actual values. + + Creates a new ToolUseBlockParam with all {{placeholder}} occurrences + replaced with their corresponding values from placeholder_values. + + Args: + tool_block: The tool use block containing placeholders + placeholder_values: Dict mapping placeholder names to replacement values + + Returns: + New ToolUseBlockParam with placeholders substituted + + Example: + >>> tool_block = ToolUseBlockParam( + ... id="1", + ... name="computer", + ... input={"action": "type", "text": "Date: {{current_date}}"}, + ... type="tool_use" + ... ) + >>> result = PlaceholderHandler.substitute_placeholders( + ... tool_block, + ... {"current_date": "2025-12-11"} + ... ) + >>> result.input["text"] + 'Date: 2025-12-11' + """ + # Deep copy the input and substitute placeholders + substituted_input = PlaceholderHandler._substitute_in_value( + tool_block.input, placeholder_values + ) + + # Create new ToolUseBlockParam with substituted values + return ToolUseBlockParam( + id=tool_block.id, + name=tool_block.name, + input=substituted_input, + type=tool_block.type, + cache_control=tool_block.cache_control, + ) + + @staticmethod + def _substitute_in_value(value: Any, placeholder_values: dict[str, str]) -> Any: + """Recursively substitute placeholders in a value. + + Args: + value: Any value (str, dict, list, etc.) containing placeholders + placeholder_values: Dict of placeholder names to replacement values + + Returns: + New value with placeholders substituted + """ + if isinstance(value, str): + # Replace all placeholders in the string + result = value + for name, replacement in placeholder_values.items(): + pattern = r"\{\{" + re.escape(name) + r"\}\}" + result = re.sub(pattern, replacement, result) + return result + elif isinstance(value, dict): + # Recursively substitute in dict values + return { + k: PlaceholderHandler._substitute_in_value(v, placeholder_values) + for k, v in value.items() + } + elif isinstance(value, list): + # Recursively substitute in list items + return [ + PlaceholderHandler._substitute_in_value(item, placeholder_values) + for item in value + ] + else: + # Return other types as-is + return value diff --git a/src/askui/utils/placeholder_identifier.py b/src/askui/utils/placeholder_identifier.py new file mode 100644 index 00000000..49c9569d --- /dev/null +++ b/src/askui/utils/placeholder_identifier.py @@ -0,0 +1,134 @@ +"""Module for identifying placeholders in trajectories using LLM analysis.""" + +import json +import logging +from typing import Any + +from askui.models.shared.agent_message_param import MessageParam, ToolUseBlockParam +from askui.models.shared.messages_api import MessagesApi +from askui.prompts.caching import PLACEHOLDER_IDENTIFIER_SYSTEM_PROMPT + +logger = logging.getLogger(__name__) + + +class PlaceholderDefinition: + """Represents a placeholder identified in a trajectory.""" + + def __init__(self, name: str, value: Any, description: str) -> None: + self.name = name + self.value = value + self.description = description + + def __repr__(self) -> str: + return f"PlaceholderDefinition(name={self.name}, value={self.value})" + + +def identify_placeholders( + trajectory: list[ToolUseBlockParam], + messages_api: MessagesApi, + model: str = "claude-sonnet-4-5-20250929", +) -> tuple[dict[str, str], list[PlaceholderDefinition]]: + """Identify placeholders in a trajectory using LLM analysis. + + Args: + trajectory: The trajectory to analyze (list of tool use blocks) + messages_api: Messages API instance for LLM calls + model: Model to use for analysis + + Returns: + Tuple of: + - Dict mapping placeholder names to descriptions + - List of PlaceholderDefinition objects with name, value, and description + """ + if not trajectory: + logger.debug("Empty trajectory provided, skipping placeholder identification") + return {}, [] + + logger.info( + f"Starting placeholder identification for trajectory with {len(trajectory)} steps" + ) + + # Convert trajectory to serializable format for analysis + trajectory_data = [tool.model_dump(mode="json") for tool in trajectory] + logger.debug(f"Converted {len(trajectory_data)} tool blocks to JSON format") + + user_message = f"""Analyze this UI automation trajectory and identify all values that should be placeholders: + +```json +{json.dumps(trajectory_data, indent=2)} +``` + +Return only the JSON object with identified placeholders. Be thorough but conservative - only mark values that are clearly dynamic or time-sensitive.""" + + response_text = "" # Initialize for error logging + try: + # Make single API call + logger.debug(f"Calling LLM ({model}) to analyze trajectory for placeholders") + response = messages_api.create_message( + messages=[MessageParam(role="user", content=user_message)], + model=model, + system=PLACEHOLDER_IDENTIFIER_SYSTEM_PROMPT, + max_tokens=4096, + temperature=0.0, # Deterministic for analysis + ) + logger.debug("Received response from LLM") + + # Extract text from response + if isinstance(response.content, list): + response_text = next( + (block.text for block in response.content if hasattr(block, "text")), + "", + ) + else: + response_text = str(response.content) + + # Parse the JSON response + logger.debug("Parsing LLM response to extract placeholder definitions") + # Handle markdown code blocks if present + if "```json" in response_text: + logger.debug("Removing JSON markdown code block wrapper from response") + response_text = response_text.split("```json")[1].split("```")[0].strip() + elif "```" in response_text: + logger.debug("Removing code block wrapper from response") + response_text = response_text.split("```")[1].split("```")[0].strip() + + placeholder_data = json.loads(response_text) + logger.debug(f"Successfully parsed JSON response with {len(placeholder_data.get('placeholders', []))} placeholders") + + # Convert to our data structures + placeholder_definitions = [ + PlaceholderDefinition( + name=p["name"], value=p["value"], description=p["description"] + ) + for p in placeholder_data.get("placeholders", []) + ] + + placeholder_dict = { + p.name: p.description for p in placeholder_definitions + } + + if placeholder_definitions: + logger.info( + f"Successfully identified {len(placeholder_definitions)} placeholders in trajectory" + ) + for p in placeholder_definitions: + logger.debug(f" - {p.name}: {p.value} ({p.description})") + else: + logger.info( + "No placeholders identified in trajectory (this is normal for trajectories with only static values)" + ) + + return placeholder_dict, placeholder_definitions + + except json.JSONDecodeError as e: + logger.warning( + f"Failed to parse LLM response as JSON: {e}. Falling back to empty placeholder list.", + extra={"response_text": response_text[:500]}, # Log first 500 chars + ) + return {}, [] + except Exception as e: # noqa: BLE001 + logger.warning( + f"Failed to identify placeholders with LLM: {e}. Falling back to empty placeholder list.", + exc_info=True, + ) + return {}, [] diff --git a/src/askui/utils/trajectory_executor.py b/src/askui/utils/trajectory_executor.py new file mode 100644 index 00000000..61917f78 --- /dev/null +++ b/src/askui/utils/trajectory_executor.py @@ -0,0 +1,333 @@ +"""Trajectory executor for step-by-step cache execution. + +This module provides the TrajectoryExecutor class that enables controlled +execution of cached trajectories with support for pausing at non-cacheable +steps, error handling, and agent intervention. +""" + +import logging +import time +from typing import Any, Optional + +from pydantic import BaseModel, Field +from typing_extensions import Literal + +from askui.models.shared.agent_message_param import ( + ImageBlockParam, + MessageParam, + TextBlockParam, + ToolResultBlockParam, + ToolUseBlockParam, +) +from askui.models.shared.tools import ToolCollection +from askui.utils.placeholder_handler import PlaceholderHandler + +logger = logging.getLogger(__name__) + + +class ExecutionResult(BaseModel): + """Result of executing a single step in a trajectory. + + Attributes: + status: Execution status (SUCCESS, FAILED, NEEDS_AGENT, COMPLETED) + step_index: Index of the step that was executed + tool_result: The ToolResultBlockParam returned by the tool (if any), + preserving proper data types like ImageBlockParam for screenshots + error_message: Error message if execution failed + screenshots_taken: List of screenshots captured during this step + message_history: List of MessageParam representing the conversation history, + with proper content types (ImageBlockParam, TextBlockParam, etc.) + """ + + status: Literal["SUCCESS", "FAILED", "NEEDS_AGENT", "COMPLETED"] + step_index: int + tool_result: Optional[Any] = None + error_message: Optional[str] = None + screenshots_taken: list[Any] = Field(default_factory=list) + message_history: list[MessageParam] = Field(default_factory=list) + + +class TrajectoryExecutor: + """Executes cached trajectories step-by-step with control flow. + + Supports pausing at non-cacheable steps, placeholder substitution, + and collecting execution results for the agent to review. + """ + + def __init__( + self, + trajectory: list[ToolUseBlockParam], + toolbox: ToolCollection, + placeholder_values: dict[str, str] | None = None, + delay_time: float = 0.5, + visual_validation_enabled: bool = False, + ): + """Initialize the trajectory executor. + + Args: + trajectory: List of tool use blocks to execute + toolbox: ToolCollection for executing tools + placeholder_values: Dict of placeholder names to values + delay_time: Seconds to wait between step executions + visual_validation_enabled: Enable visual validation (future feature) + """ + self.trajectory = trajectory + self.toolbox = toolbox + self.placeholder_values = placeholder_values or {} + self.delay_time = delay_time + self.visual_validation_enabled = visual_validation_enabled + self.current_step_index = 0 + self.message_history: list[MessageParam] = [] + + def execute_next_step(self) -> ExecutionResult: + """Execute the next step in the trajectory. + + Returns: + ExecutionResult with status and details of the execution + + The method will: + 1. Check if there are more steps to execute + 2. Check if the step should be skipped (screenshots, retrieval tools) + 3. Check if the step is non-cacheable (needs agent) + 4. Substitute placeholders + 5. Execute the tool and build messages with proper data types + 6. Return result with updated message history + + Note: Tool results are preserved with their proper data types (e.g., + ImageBlockParam for screenshots) and added to message history. The + agent's truncation strategy will manage message history size. + """ + # Check if we've completed all steps + if self.current_step_index >= len(self.trajectory): + return ExecutionResult( + status="COMPLETED", + step_index=self.current_step_index - 1, + message_history=self.message_history, + ) + + step = self.trajectory[self.current_step_index] + step_index = self.current_step_index + + # Check if step should be skipped + if self._should_skip_step(step): + logger.debug(f"Skipping step {step_index}: {step.name}") + self.current_step_index += 1 + # Recursively execute next step + return self.execute_next_step() + + # Check if step needs agent intervention (non-cacheable) + if self.should_pause_for_agent(step): + logger.info( + f"Pausing at step {step_index}: {step.name} (non-cacheable tool)" + ) + return ExecutionResult( + status="NEEDS_AGENT", + step_index=step_index, + message_history=self.message_history, + ) + + # Visual validation (future feature - currently always passes) + # Extension point for aHash-based UI validation + if self.visual_validation_enabled: + is_valid, error_msg = self.validate_step_visually(step) + if not is_valid: + logger.warning( + f"Visual validation failed at step {step_index}: {error_msg}" + ) + return ExecutionResult( + status="FAILED", + step_index=step_index, + error_message=error_msg, + message_history=self.message_history.copy(), + ) + + # Substitute placeholders + substituted_step = PlaceholderHandler.substitute_placeholders( + step, self.placeholder_values + ) + + # Execute the tool + try: + logger.debug(f"Executing step {step_index}: {step.name}") + + # Add assistant message (tool use) to history + assistant_message = MessageParam( + role="assistant", + content=[substituted_step], + ) + self.message_history.append(assistant_message) + + # Execute the tool + tool_results = self.toolbox.run([substituted_step]) + + # toolbox.run() returns a list of content blocks (ToolResultBlockParam, etc.) + # We use these directly without converting to strings - this preserves + # proper data types like ImageBlockParam + + # Add user message (tool result) to history + user_message = MessageParam( + role="user", + content=tool_results if tool_results else [], + ) + self.message_history.append(user_message) + + # Move to next step + self.current_step_index += 1 + + # Add delay between actions + if self.current_step_index < len(self.trajectory): + time.sleep(self.delay_time) + + return ExecutionResult( + status="SUCCESS", + step_index=step_index, + tool_result=tool_results[0] if tool_results else None, + message_history=self.message_history.copy(), + ) + + except Exception as e: + logger.error( + f"Error executing step {step_index}: {step.name}", + exc_info=True, + ) + return ExecutionResult( + status="FAILED", + step_index=step_index, + error_message=str(e), + message_history=self.message_history.copy(), + ) + + def execute_all(self) -> list[ExecutionResult]: + """Execute all steps in the trajectory until completion or pause. + + Returns: + List of ExecutionResult for all executed steps + + Execution stops when: + - All steps are completed + - A step fails + - A non-cacheable step is encountered + """ + results: list[ExecutionResult] = [] + + while True: + result = self.execute_next_step() + results.append(result) + + # Stop if we've completed, failed, or need agent + if result.status in ["COMPLETED", "FAILED", "NEEDS_AGENT"]: + break + + return results + + def should_pause_for_agent(self, step: ToolUseBlockParam) -> bool: + """Check if execution should pause for agent intervention. + + Args: + step: The tool use block to check + + Returns: + True if agent should execute this step, False if it can be cached + + Currently checks if the tool is marked as non-cacheable. + """ + # Get the tool from toolbox + tool = self.toolbox._tool_map.get(step.name) + + if tool is None: + # Tool not found in regular tools, might be MCP tool + # For now, assume MCP tools are cacheable + return False + + # Check if tool is marked as non-cacheable + return not tool.is_cacheable + + def get_current_step_index(self) -> int: + """Get the index of the current step. + + Returns: + Current step index + """ + return self.current_step_index + + def get_remaining_trajectory(self) -> list[ToolUseBlockParam]: + """Get the remaining steps in the trajectory. + + Returns: + List of tool use blocks that haven't been executed yet + """ + return self.trajectory[self.current_step_index :] + + def skip_current_step(self) -> None: + """Skip the current step and move to the next one. + + Useful when the agent manually executes a non-cacheable step. + """ + if self.current_step_index < len(self.trajectory): + self.current_step_index += 1 + + def _should_skip_step(self, step: ToolUseBlockParam) -> bool: + """Check if a step should be skipped during execution. + + Args: + step: The tool use block to check + + Returns: + True if step should be skipped, False otherwise + + Note: As of v0.1, no steps are skipped. All tools in the trajectory + are executed, including screenshots and trajectory retrieval tools. + """ + return False + + def validate_step_visually( + self, step: ToolUseBlockParam, current_screenshot: Any = None + ) -> tuple[bool, str | None]: + """Hook for visual validation of cached steps using aHash comparison. + + This is an extension point for future visual validation implementation. + Currently returns (True, None) - no validation performed. + + Future implementation will: + 1. Check if step has visual_validation_required=True + 2. Compute aHash of current screen region + 3. Compare with stored visual_hash + 4. Return validation result based on Hamming distance threshold + + Args: + step: The trajectory step to validate + current_screenshot: Optional current screen capture (future use) + + Returns: + Tuple of (is_valid: bool, error_message: str | None) + - (True, None) if validation passes or is disabled + - (False, error_msg) if validation fails + + Example future implementation: + if not self.visual_validation_enabled: + return True, None + + if not step.visual_validation_required: + return True, None + + if step.visual_hash is None: + return True, None # No hash stored, skip validation + + # Capture current screen region + current_hash = compute_ahash(current_screenshot) + + # Compare hashes + distance = hamming_distance(step.visual_hash, current_hash) + threshold = 10 # Configurable + + if distance > threshold: + return False, ( + f"Visual validation failed: UI changed significantly " + f"(distance: {distance} > threshold: {threshold})" + ) + + return True, None + """ + # Future: Implement aHash comparison + # For now, always return True (no validation) + return True, None diff --git a/tests/unit/tools/test_caching_tools.py b/tests/unit/tools/test_caching_tools.py index a4404114..11fbcfe5 100644 --- a/tests/unit/tools/test_caching_tools.py +++ b/tests/unit/tools/test_caching_tools.py @@ -2,12 +2,15 @@ import json import tempfile +from datetime import datetime, timezone from pathlib import Path from typing import Any -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock import pytest +from askui.models.shared.agent import Agent +from askui.models.shared.messages_api import MessagesApi from askui.models.shared.settings import CachedExecutionToolSettings from askui.models.shared.tools import ToolCollection from askui.tools.caching_tools import ( @@ -16,14 +19,32 @@ ) +# ============================================================================ +# RetrieveCachedTestExecutions Tests (unchanged from before) +# ============================================================================ + + def test_retrieve_cached_test_executions_lists_json_files() -> None: """Test that RetrieveCachedTestExecutions lists all JSON files in cache dir.""" with tempfile.TemporaryDirectory() as temp_dir: cache_dir = Path(temp_dir) - # Create some cache files - (cache_dir / "cache1.json").write_text("{}", encoding="utf-8") - (cache_dir / "cache2.json").write_text("{}", encoding="utf-8") + # Create valid cache files with v0.1 format + cache_data = { + "metadata": { + "version": "0.1", + "created_at": datetime.now(tz=timezone.utc).isoformat(), + "execution_attempts": 0, + "last_executed_at": None, + "failures": [], + "is_valid": True, + "invalidation_reason": None, + }, + "trajectory": [], + "placeholders": {}, + } + (cache_dir / "cache1.json").write_text(json.dumps(cache_data), encoding="utf-8") + (cache_dir / "cache2.json").write_text(json.dumps(cache_data), encoding="utf-8") (cache_dir / "not_cache.txt").write_text("text", encoding="utf-8") tool = RetrieveCachedTestExecutions(cache_dir=str(cache_dir)) @@ -59,9 +80,22 @@ def test_retrieve_cached_test_executions_respects_custom_format() -> None: with tempfile.TemporaryDirectory() as temp_dir: cache_dir = Path(temp_dir) - # Create files with different extensions - (cache_dir / "cache1.json").write_text("{}", encoding="utf-8") - (cache_dir / "cache2.traj").write_text("{}", encoding="utf-8") + # Create valid cache files with different extensions + cache_data = { + "metadata": { + "version": "0.1", + "created_at": datetime.now(tz=timezone.utc).isoformat(), + "execution_attempts": 0, + "last_executed_at": None, + "failures": [], + "is_valid": True, + "invalidation_reason": None, + }, + "trajectory": [], + "placeholders": {}, + } + (cache_dir / "cache1.json").write_text(json.dumps(cache_data), encoding="utf-8") + (cache_dir / "cache2.traj").write_text(json.dumps(cache_data), encoding="utf-8") # Default format (.json) tool_json = RetrieveCachedTestExecutions( @@ -80,150 +114,272 @@ def test_retrieve_cached_test_executions_respects_custom_format() -> None: assert "cache2.traj" in result_traj[0] +def test_retrieve_caches_filters_invalid_by_default(tmp_path): + """Test that RetrieveCachedTestExecutions filters out invalid caches by default.""" + cache_dir = tmp_path / "cache" + cache_dir.mkdir() + + # Create a valid cache + valid_cache = cache_dir / "valid.json" + valid_data = { + "metadata": { + "version": "0.1", + "created_at": datetime.now(tz=timezone.utc).isoformat(), + "execution_attempts": 0, + "failures": [], + "is_valid": True, + "invalidation_reason": None, + }, + "trajectory": [], + "placeholders": {}, + } + with valid_cache.open("w") as f: + json.dump(valid_data, f) + + # Create an invalid cache + invalid_cache = cache_dir / "invalid.json" + invalid_data = { + "metadata": { + "version": "0.1", + "created_at": datetime.now(tz=timezone.utc).isoformat(), + "execution_attempts": 10, + "failures": [], + "is_valid": False, + "invalidation_reason": "Too many failures", + }, + "trajectory": [], + "placeholders": {}, + } + with invalid_cache.open("w") as f: + json.dump(invalid_data, f) + + tool = RetrieveCachedTestExecutions(cache_dir=str(cache_dir)) + + # Should only return valid cache + results = tool() + assert len(results) == 1 + assert str(valid_cache) in results + assert str(invalid_cache) not in results + + +def test_retrieve_caches_includes_invalid_when_requested(tmp_path): + """Test that RetrieveCachedTestExecutions includes invalid caches when requested.""" + cache_dir = tmp_path / "cache" + cache_dir.mkdir() + + # Create a valid cache + valid_cache = cache_dir / "valid.json" + valid_data = { + "metadata": { + "version": "0.1", + "created_at": datetime.now(tz=timezone.utc).isoformat(), + "execution_attempts": 0, + "failures": [], + "is_valid": True, + "invalidation_reason": None, + }, + "trajectory": [], + "placeholders": {}, + } + with valid_cache.open("w") as f: + json.dump(valid_data, f) + + # Create an invalid cache + invalid_cache = cache_dir / "invalid.json" + invalid_data = { + "metadata": { + "version": "0.1", + "created_at": datetime.now(tz=timezone.utc).isoformat(), + "execution_attempts": 10, + "failures": [], + "is_valid": False, + "invalidation_reason": "Too many failures", + }, + "trajectory": [], + "placeholders": {}, + } + with invalid_cache.open("w") as f: + json.dump(invalid_data, f) + + tool = RetrieveCachedTestExecutions(cache_dir=str(cache_dir)) + + # Should return both caches when include_invalid=True + results = tool(include_invalid=True) + assert len(results) == 2 + + +# ============================================================================ +# ExecuteCachedTrajectory Tests (refactored for new behavior) +# ============================================================================ + + def test_execute_cached_execution_initializes_without_toolbox() -> None: - """Test that ExecuteCachedExecution can be initialized without toolbox.""" + """Test that ExecuteCachedTrajectory can be initialized without toolbox.""" tool = ExecuteCachedTrajectory() assert tool.name == "execute_cached_executions_tool" + assert tool._toolbox is None # noqa: SLF001 + assert tool._agent is None # noqa: SLF001 -def test_execute_cached_execution_raises_error_without_toolbox() -> None: - """Test that ExecuteCachedExecution raises error when toolbox not set.""" - tool = ExecuteCachedTrajectory() +def test_execute_cached_execution_raises_error_without_toolbox_or_agent() -> None: + """Test that ExecuteCachedTrajectory raises error when neither toolbox nor agent set.""" + with tempfile.TemporaryDirectory() as temp_dir: + cache_file = Path(temp_dir) / "test.json" + cache_data = { + "metadata": { + "version": "0.1", + "created_at": datetime.now(tz=timezone.utc).isoformat(), + "execution_attempts": 0, + "last_executed_at": None, + "failures": [], + "is_valid": True, + "invalidation_reason": None, + }, + "trajectory": [], + "placeholders": {}, + } + cache_file.write_text(json.dumps(cache_data), encoding="utf-8") + + tool = ExecuteCachedTrajectory() - with pytest.raises(RuntimeError, match="Toolbox not set"): - tool(trajectory_file="some_file.json") + with pytest.raises(RuntimeError, match="Agent not set"): + tool(trajectory_file=str(cache_file)) -def test_execute_cached_execution_raises_error_when_file_not_found() -> None: - """Test that ExecuteCachedExecution raises error if trajectory file doesn't exist""" +def test_execute_cached_execution_returns_error_when_file_not_found() -> None: + """Test that ExecuteCachedTrajectory returns error message if file doesn't exist.""" tool = ExecuteCachedTrajectory() - mock_toolbox = MagicMock(spec=ToolCollection) - tool.set_toolbox(mock_toolbox) + mock_agent = MagicMock(spec=Agent) + mock_agent._tool_collection = MagicMock(spec=ToolCollection) + tool.set_agent(mock_agent) + + result = tool(trajectory_file="/non/existent/file.json") - with pytest.raises(FileNotFoundError, match="Trajectory file not found"): - tool(trajectory_file="/non/existent/file.json") + # New behavior: returns error message string instead of raising + assert isinstance(result, str) + assert "Trajectory file not found" in result -def test_execute_cached_execution_executes_trajectory() -> None: - """Test that ExecuteCachedExecution executes tools from trajectory file.""" +def test_execute_cached_execution_activates_cache_mode() -> None: + """Test that ExecuteCachedTrajectory activates cache mode in the agent.""" with tempfile.TemporaryDirectory() as temp_dir: cache_file = Path(temp_dir) / "test_trajectory.json" - # Create a trajectory file - trajectory: list[dict[str, Any]] = [ - { - "id": "tool1", - "name": "click_tool", - "input": {"x": 100, "y": 200}, - "type": "tool_use", - }, - { - "id": "tool2", - "name": "type_tool", - "input": {"text": "hello"}, - "type": "tool_use", + # Create a trajectory file with v0.1 format + cache_data = { + "metadata": { + "version": "0.1", + "created_at": datetime.now(tz=timezone.utc).isoformat(), + "execution_attempts": 0, + "last_executed_at": None, + "failures": [], + "is_valid": True, + "invalidation_reason": None, }, - ] + "trajectory": [ + { + "id": "tool1", + "name": "click_tool", + "input": {"x": 100, "y": 200}, + "type": "tool_use", + }, + { + "id": "tool2", + "name": "type_tool", + "input": {"text": "hello"}, + "type": "tool_use", + }, + ], + "placeholders": {}, + } with cache_file.open("w", encoding="utf-8") as f: - json.dump(trajectory, f) + json.dump(cache_data, f) - # Execute the trajectory - tool = ExecuteCachedTrajectory() + # Create mock agent with toolbox + mock_messages_api = MagicMock(spec=MessagesApi) + mock_agent = Agent(messages_api=mock_messages_api) mock_toolbox = MagicMock(spec=ToolCollection) - tool.set_toolbox(mock_toolbox) - - result = tool(trajectory_file=str(cache_file)) - - # Verify success message - assert "Successfully executed trajectory" in result - # Verify toolbox.run was called for each tool (2 calls) - assert mock_toolbox.run.call_count == 2 - - -def test_execute_cached_execution_skips_screenshot_tools() -> None: - """Test that ExecuteCachedExecution skips screenshot-related tools.""" - with tempfile.TemporaryDirectory() as temp_dir: - cache_file = Path(temp_dir) / "test_trajectory.json" - - # Create a trajectory with screenshot tools - trajectory: list[dict[str, Any]] = [ - { - "id": "tool1", - "name": "screenshot", - "input": {}, - "type": "tool_use", - }, - { - "id": "tool2", - "name": "click_tool", - "input": {"x": 100, "y": 200}, - "type": "tool_use", - }, - { - "id": "tool3", - "name": "retrieve_available_trajectories_tool", - "input": {}, - "type": "tool_use", - }, - ] - - with cache_file.open("w", encoding="utf-8") as f: - json.dump(trajectory, f) + mock_toolbox._tool_map = {} + mock_agent._tool_collection = mock_toolbox # noqa: SLF001 - # Execute the trajectory + # Create and configure tool tool = ExecuteCachedTrajectory() - mock_toolbox = MagicMock(spec=ToolCollection) - tool.set_toolbox(mock_toolbox) + tool.set_agent(mock_agent) + # Call the tool result = tool(trajectory_file=str(cache_file)) - # Verify only click_tool was executed (screenshot and retrieve tools skipped) - assert mock_toolbox.run.call_count == 1 - assert "Successfully executed trajectory" in result + # Verify return type is string + assert isinstance(result, str) + assert "✓ Cache execution mode activated" in result + assert "2 cached steps" in result + + # Verify agent state was set + assert mock_agent._executing_from_cache is True # noqa: SLF001 + assert mock_agent._cache_executor is not None # noqa: SLF001 + assert mock_agent._cache_file is not None # noqa: SLF001 + assert mock_agent._cache_file_path == str(cache_file) # noqa: SLF001 -def test_execute_cached_execution_handles_errors_gracefully() -> None: - """Test that ExecuteCachedExecution handles errors during execution.""" +def test_execute_cached_execution_works_with_set_toolbox() -> None: + """Test that ExecuteCachedTrajectory works with set_toolbox (legacy approach).""" with tempfile.TemporaryDirectory() as temp_dir: cache_file = Path(temp_dir) / "test_trajectory.json" - # Create a trajectory - trajectory: list[dict[str, Any]] = [ - { - "id": "tool1", - "name": "failing_tool", - "input": {}, - "type": "tool_use", + cache_data = { + "metadata": { + "version": "0.1", + "created_at": datetime.now(tz=timezone.utc).isoformat(), + "execution_attempts": 0, + "last_executed_at": None, + "failures": [], + "is_valid": True, + "invalidation_reason": None, }, - ] + "trajectory": [ + { + "id": "tool1", + "name": "test_tool", + "input": {}, + "type": "tool_use", + } + ], + "placeholders": {}, + } with cache_file.open("w", encoding="utf-8") as f: - json.dump(trajectory, f) + json.dump(cache_data, f) + + # Create mock agent without toolbox + mock_messages_api = MagicMock(spec=MessagesApi) + mock_agent = Agent(messages_api=mock_messages_api) - # Execute the trajectory with a failing tool + # Create tool and set toolbox directly tool = ExecuteCachedTrajectory() mock_toolbox = MagicMock(spec=ToolCollection) - mock_toolbox.run.side_effect = Exception("Tool execution failed") + mock_toolbox._tool_map = {} tool.set_toolbox(mock_toolbox) + tool.set_agent(mock_agent) result = tool(trajectory_file=str(cache_file)) - # Verify error message - assert "error occured" in result.lower() - assert "verify the UI state" in result + # Should succeed using the toolbox + assert isinstance(result, str) + assert "✓ Cache execution mode activated" in result -def test_execute_cached_execution_set_toolbox() -> None: - """Test that set_toolbox properly sets the toolbox reference.""" +def test_execute_cached_execution_set_agent_and_toolbox() -> None: + """Test that set_agent and set_toolbox properly set references.""" tool = ExecuteCachedTrajectory() + mock_agent = MagicMock(spec=Agent) mock_toolbox = MagicMock(spec=ToolCollection) + tool.set_agent(mock_agent) tool.set_toolbox(mock_toolbox) - # After setting toolbox, should be able to access it - assert hasattr(tool, "_toolbox") - assert tool._toolbox == mock_toolbox + assert tool._agent == mock_agent # noqa: SLF001 + assert tool._toolbox == mock_toolbox # noqa: SLF001 def test_execute_cached_execution_initializes_with_default_settings() -> None: @@ -232,6 +388,7 @@ def test_execute_cached_execution_initializes_with_default_settings() -> None: # Should have default settings initialized assert hasattr(tool, "_settings") + assert tool._settings.delay_time_between_action == 0.5 # noqa: SLF001 def test_execute_cached_execution_initializes_with_custom_settings() -> None: @@ -241,93 +398,677 @@ def test_execute_cached_execution_initializes_with_custom_settings() -> None: # Should have custom settings initialized assert hasattr(tool, "_settings") + assert tool._settings.delay_time_between_action == 1.0 # noqa: SLF001 -def test_execute_cached_execution_uses_delay_time_between_actions() -> None: - """Test that ExecuteCachedTrajectory uses the configured delay time.""" +def test_execute_cached_execution_with_placeholders() -> None: + """Test that ExecuteCachedTrajectory validates placeholders.""" with tempfile.TemporaryDirectory() as temp_dir: cache_file = Path(temp_dir) / "test_trajectory.json" - # Create a trajectory with 3 actions - trajectory: list[dict[str, Any]] = [ - { - "id": "tool1", - "name": "click_tool", - "input": {"x": 100, "y": 200}, - "type": "tool_use", + # Create a v0.1 cache file with placeholders + cache_data = { + "metadata": { + "version": "0.1", + "created_at": "2025-12-11T10:00:00Z", + "last_executed_at": None, + "execution_attempts": 0, + "failures": [], + "is_valid": True, + "invalidation_reason": None, }, - { - "id": "tool2", - "name": "type_tool", - "input": {"text": "hello"}, - "type": "tool_use", + "trajectory": [ + { + "id": "tool1", + "name": "type_tool", + "input": {"text": "Today is {{current_date}}"}, + "type": "tool_use", + }, + ], + "placeholders": { + "current_date": "Current date", }, - { - "id": "tool3", - "name": "move_tool", - "input": {"x": 300, "y": 400}, - "type": "tool_use", + } + + with cache_file.open("w", encoding="utf-8") as f: + json.dump(cache_data, f) + + # Create mock agent + mock_messages_api = MagicMock(spec=MessagesApi) + mock_agent = Agent(messages_api=mock_messages_api) + mock_toolbox = MagicMock(spec=ToolCollection) + mock_toolbox._tool_map = {} + mock_agent._tool_collection = mock_toolbox # noqa: SLF001 + + tool = ExecuteCachedTrajectory() + tool.set_agent(mock_agent) + + result = tool( + trajectory_file=str(cache_file), + placeholder_values={"current_date": "2025-12-11"}, + ) + + # Verify success + assert isinstance(result, str) + assert "✓ Cache execution mode activated" in result + assert "1 placeholder value" in result + + +def test_execute_cached_execution_missing_placeholders() -> None: + """Test that ExecuteCachedTrajectory returns error for missing placeholders.""" + with tempfile.TemporaryDirectory() as temp_dir: + cache_file = Path(temp_dir) / "test_trajectory.json" + + # Create a v0.1 cache file with placeholders + cache_data = { + "metadata": { + "version": "0.1", + "created_at": "2025-12-11T10:00:00Z", + "last_executed_at": None, + "execution_attempts": 0, + "failures": [], + "is_valid": True, + "invalidation_reason": None, }, - ] + "trajectory": [ + { + "id": "tool1", + "name": "type_tool", + "input": {"text": "Date: {{current_date}}, User: {{user_name}}"}, + "type": "tool_use", + } + ], + "placeholders": { + "current_date": "Current date", + "user_name": "User name", + }, + } with cache_file.open("w", encoding="utf-8") as f: - json.dump(trajectory, f) + json.dump(cache_data, f) - # Execute with custom delay time - custom_settings = CachedExecutionToolSettings(delay_time_between_action=0.1) - tool = ExecuteCachedTrajectory(settings=custom_settings) + # Create mock agent + mock_messages_api = MagicMock(spec=MessagesApi) + mock_agent = Agent(messages_api=mock_messages_api) mock_toolbox = MagicMock(spec=ToolCollection) - tool.set_toolbox(mock_toolbox) + mock_agent._tool_collection = mock_toolbox # noqa: SLF001 - # Mock time.sleep to verify it's called with correct delay - with patch("time.sleep") as mock_sleep: - result = tool(trajectory_file=str(cache_file)) + tool = ExecuteCachedTrajectory() + tool.set_agent(mock_agent) - # Verify success - assert "Successfully executed trajectory" in result - # Verify sleep was called 3 times (once after each action) - assert mock_sleep.call_count == 3 - # Verify it was called with the configured delay time - for call in mock_sleep.call_args_list: - assert call[0][0] == 0.1 + result = tool(trajectory_file=str(cache_file)) + # Verify error message + assert isinstance(result, str) + assert "Missing required placeholder values" in result + assert "current_date" in result + assert "user_name" in result -def test_execute_cached_execution_default_delay_time() -> None: - """Test that ExecuteCachedTrajectory uses default delay time of 0.5s.""" + +def test_execute_cached_execution_no_placeholders_backward_compat() -> None: + """Test backward compatibility: trajectories without placeholders work fine.""" with tempfile.TemporaryDirectory() as temp_dir: cache_file = Path(temp_dir) / "test_trajectory.json" - # Create a trajectory with 2 actions + # Create a v0.0 cache file (old format, no placeholders) trajectory: list[dict[str, Any]] = [ { "id": "tool1", "name": "click_tool", "input": {"x": 100, "y": 200}, "type": "tool_use", - }, - { - "id": "tool2", - "name": "type_tool", - "input": {"text": "hello"}, - "type": "tool_use", - }, + } ] with cache_file.open("w", encoding="utf-8") as f: json.dump(trajectory, f) - # Execute with default settings + # Create mock agent + mock_messages_api = MagicMock(spec=MessagesApi) + mock_agent = Agent(messages_api=mock_messages_api) + mock_toolbox = MagicMock(spec=ToolCollection) + mock_toolbox._tool_map = {} + mock_agent._tool_collection = mock_toolbox # noqa: SLF001 + tool = ExecuteCachedTrajectory() + tool.set_agent(mock_agent) + + result = tool(trajectory_file=str(cache_file)) + + # Verify success + assert isinstance(result, str) + assert "✓ Cache execution mode activated" in result + + +def test_continue_cached_trajectory_from_middle() -> None: + """Test continuing execution from middle of trajectory.""" + with tempfile.TemporaryDirectory() as temp_dir: + cache_file = Path(temp_dir) / "test_trajectory.json" + + # Create a trajectory with 5 steps + cache_data = { + "metadata": { + "version": "0.1", + "created_at": datetime.now(tz=timezone.utc).isoformat(), + "execution_attempts": 0, + "last_executed_at": None, + "failures": [], + "is_valid": True, + "invalidation_reason": None, + }, + "trajectory": [ + {"id": "1", "name": "tool1", "input": {}, "type": "tool_use"}, + {"id": "2", "name": "tool2", "input": {}, "type": "tool_use"}, + {"id": "3", "name": "tool3", "input": {}, "type": "tool_use"}, + {"id": "4", "name": "tool4", "input": {}, "type": "tool_use"}, + {"id": "5", "name": "tool5", "input": {}, "type": "tool_use"}, + ], + "placeholders": {}, + } + + with cache_file.open("w", encoding="utf-8") as f: + json.dump(cache_data, f) + + # Create mock agent + mock_messages_api = MagicMock(spec=MessagesApi) + mock_agent = Agent(messages_api=mock_messages_api) mock_toolbox = MagicMock(spec=ToolCollection) - tool.set_toolbox(mock_toolbox) + mock_toolbox._tool_map = {} + mock_agent._tool_collection = mock_toolbox # noqa: SLF001 - # Mock time.sleep to verify default delay is used - with patch("time.sleep") as mock_sleep: - result = tool(trajectory_file=str(cache_file)) + tool = ExecuteCachedTrajectory() + tool.set_agent(mock_agent) - # Verify success - assert "Successfully executed trajectory" in result - # Verify sleep was called with default delay of 0.5s - assert mock_sleep.call_count == 2 - for call in mock_sleep.call_args_list: - assert call[0][0] == 0.5 + result = tool(trajectory_file=str(cache_file), start_from_step_index=2) + + # Verify success message + assert isinstance(result, str) + assert "✓ Cache execution mode activated" in result + assert "resuming from step 2" in result + assert "3 remaining cached steps" in result + + +def test_continue_cached_trajectory_invalid_step_index_negative() -> None: + """Test that negative step index returns error.""" + with tempfile.TemporaryDirectory() as temp_dir: + cache_file = Path(temp_dir) / "test_trajectory.json" + + cache_data = { + "metadata": { + "version": "0.1", + "created_at": datetime.now(tz=timezone.utc).isoformat(), + "execution_attempts": 0, + "last_executed_at": None, + "failures": [], + "is_valid": True, + "invalidation_reason": None, + }, + "trajectory": [ + {"id": "1", "name": "tool1", "input": {}, "type": "tool_use"}, + ], + "placeholders": {}, + } + + with cache_file.open("w", encoding="utf-8") as f: + json.dump(cache_data, f) + + mock_messages_api = MagicMock(spec=MessagesApi) + mock_agent = Agent(messages_api=mock_messages_api) + mock_toolbox = MagicMock(spec=ToolCollection) + mock_toolbox._tool_map = {} + mock_agent._tool_collection = mock_toolbox # noqa: SLF001 + + tool = ExecuteCachedTrajectory() + tool.set_agent(mock_agent) + + result = tool(trajectory_file=str(cache_file), start_from_step_index=-1) + + # Verify error message + assert isinstance(result, str) + assert "Invalid start_from_step_index" in result + + +def test_continue_cached_trajectory_invalid_step_index_too_large() -> None: + """Test that step index beyond trajectory length returns error.""" + with tempfile.TemporaryDirectory() as temp_dir: + cache_file = Path(temp_dir) / "test_trajectory.json" + + cache_data = { + "metadata": { + "version": "0.1", + "created_at": datetime.now(tz=timezone.utc).isoformat(), + "execution_attempts": 0, + "last_executed_at": None, + "failures": [], + "is_valid": True, + "invalidation_reason": None, + }, + "trajectory": [ + {"id": "1", "name": "tool1", "input": {}, "type": "tool_use"}, + {"id": "2", "name": "tool2", "input": {}, "type": "tool_use"}, + ], + "placeholders": {}, + } + + with cache_file.open("w", encoding="utf-8") as f: + json.dump(cache_data, f) + + mock_messages_api = MagicMock(spec=MessagesApi) + mock_agent = Agent(messages_api=mock_messages_api) + mock_toolbox = MagicMock(spec=ToolCollection) + mock_toolbox._tool_map = {} + mock_agent._tool_collection = mock_toolbox # noqa: SLF001 + + tool = ExecuteCachedTrajectory() + tool.set_agent(mock_agent) + + result = tool(trajectory_file=str(cache_file), start_from_step_index=5) + + # Verify error message + assert isinstance(result, str) + assert "Invalid start_from_step_index" in result + assert "valid indices: 0-1" in result + + +def test_continue_cached_trajectory_with_placeholders() -> None: + """Test continuing execution with placeholder substitution.""" + with tempfile.TemporaryDirectory() as temp_dir: + cache_file = Path(temp_dir) / "test_trajectory.json" + + # Create a v0.1 cache file with placeholders + cache_data = { + "metadata": { + "version": "0.1", + "created_at": "2025-12-11T10:00:00Z", + "last_executed_at": None, + "execution_attempts": 0, + "failures": [], + "is_valid": True, + "invalidation_reason": None, + }, + "trajectory": [ + { + "id": "1", + "name": "tool1", + "input": {"text": "Step 1"}, + "type": "tool_use", + }, + { + "id": "2", + "name": "tool2", + "input": {"text": "Date: {{current_date}}"}, + "type": "tool_use", + }, + { + "id": "3", + "name": "tool3", + "input": {"text": "User: {{user_name}}"}, + "type": "tool_use", + }, + ], + "placeholders": { + "current_date": "Current date", + "user_name": "User name", + }, + } + + with cache_file.open("w", encoding="utf-8") as f: + json.dump(cache_data, f) + + # Create mock agent + mock_messages_api = MagicMock(spec=MessagesApi) + mock_agent = Agent(messages_api=mock_messages_api) + mock_toolbox = MagicMock(spec=ToolCollection) + mock_toolbox._tool_map = {} + mock_agent._tool_collection = mock_toolbox # noqa: SLF001 + + tool = ExecuteCachedTrajectory() + tool.set_agent(mock_agent) + + result = tool( + trajectory_file=str(cache_file), + start_from_step_index=1, + placeholder_values={"current_date": "2025-12-11", "user_name": "Alice"}, + ) + + # Verify success + assert isinstance(result, str) + assert "✓ Cache execution mode activated" in result + assert "resuming from step 1" in result + + +def test_execute_cached_trajectory_warns_if_invalid(tmp_path, caplog): + """Test that ExecuteCachedTrajectory warns when activating with invalid cache.""" + import logging + + caplog.set_level(logging.WARNING) + + cache_file = tmp_path / "test.json" + cache_data = { + "metadata": { + "version": "0.1", + "created_at": datetime.now(tz=timezone.utc).isoformat(), + "execution_attempts": 10, + "last_executed_at": None, + "failures": [], + "is_valid": False, + "invalidation_reason": "Cache marked invalid for testing", + }, + "trajectory": [ + {"id": "1", "name": "click", "input": {"x": 100}, "type": "tool_use"}, + ], + "placeholders": {}, + } + with cache_file.open("w") as f: + json.dump(cache_data, f) + + # Create mock agent + mock_messages_api = MagicMock(spec=MessagesApi) + mock_agent = Agent(messages_api=mock_messages_api) + mock_toolbox = MagicMock(spec=ToolCollection) + mock_toolbox._tool_map = {} + mock_agent._tool_collection = mock_toolbox # noqa: SLF001 + + tool = ExecuteCachedTrajectory() + tool.set_agent(mock_agent) + + result = tool(trajectory_file=str(cache_file)) + + # Should still activate but log warning + assert isinstance(result, str) + assert "✓ Cache execution mode activated" in result + + # Verify warning was logged + assert any("WARNING" in record.levelname for record in caplog.records) + assert any("invalid cache" in record.message.lower() for record in caplog.records) + + +# ============================================================================ +# InspectCacheMetadata Tests (unchanged from before) +# ============================================================================ + + +def test_inspect_cache_metadata_shows_basic_info(tmp_path): + """Test that InspectCacheMetadata displays basic cache information.""" + from askui.tools.caching_tools import InspectCacheMetadata + + cache_file = tmp_path / "test.json" + cache_data = { + "metadata": { + "version": "0.1", + "created_at": datetime.now(tz=timezone.utc).isoformat(), + "execution_attempts": 5, + "last_executed_at": datetime.now(tz=timezone.utc).isoformat(), + "failures": [], + "is_valid": True, + "invalidation_reason": None, + }, + "trajectory": [ + {"id": "1", "name": "click", "input": {"x": 100}, "type": "tool_use"}, + {"id": "2", "name": "type", "input": {"text": "test"}, "type": "tool_use"}, + ], + "placeholders": {"current_date": "{{current_date}}"}, + } + with cache_file.open("w") as f: + json.dump(cache_data, f) + + tool = InspectCacheMetadata() + result = tool(trajectory_file=str(cache_file)) + + # Verify output contains key information + assert "=== Cache Metadata ===" in result + assert "Version: 0.1" in result + assert "Total Execution Attempts: 5" in result + assert "Is Valid: True" in result + assert "Total Steps: 2" in result + assert "Placeholders: 1" in result + assert "current_date" in result + + +def test_inspect_cache_metadata_shows_failures(tmp_path): + """Test that InspectCacheMetadata displays failure history.""" + from askui.tools.caching_tools import InspectCacheMetadata + + cache_file = tmp_path / "test.json" + cache_data = { + "metadata": { + "version": "0.1", + "created_at": datetime.now(tz=timezone.utc).isoformat(), + "execution_attempts": 3, + "last_executed_at": None, + "failures": [ + { + "timestamp": datetime.now(tz=timezone.utc).isoformat(), + "step_index": 1, + "error_message": "Click failed", + "failure_count_at_step": 1, + }, + { + "timestamp": datetime.now(tz=timezone.utc).isoformat(), + "step_index": 1, + "error_message": "Click failed again", + "failure_count_at_step": 2, + }, + ], + "is_valid": False, + "invalidation_reason": "Too many failures at step 1", + }, + "trajectory": [ + {"id": "1", "name": "click", "input": {"x": 100}, "type": "tool_use"}, + ], + "placeholders": {}, + } + with cache_file.open("w") as f: + json.dump(cache_data, f) + + tool = InspectCacheMetadata() + result = tool(trajectory_file=str(cache_file)) + + # Verify failure information + assert "--- Failure History ---" in result + assert "Failure 1:" in result + assert "Failure 2:" in result + assert "Step Index: 1" in result + assert "Click failed" in result + assert "Is Valid: False" in result + assert "Invalidation Reason: Too many failures at step 1" in result + + +def test_inspect_cache_metadata_file_not_found(): + """Test that InspectCacheMetadata handles missing files.""" + from askui.tools.caching_tools import InspectCacheMetadata + + tool = InspectCacheMetadata() + result = tool(trajectory_file="/nonexistent/file.json") + + assert "Trajectory file not found" in result + + +# ============================================================================ +# RevalidateCache Tests (unchanged from before) +# ============================================================================ + + +def test_revalidate_cache_marks_invalid_as_valid(tmp_path): + """Test that RevalidateCache marks invalid cache as valid.""" + from askui.tools.caching_tools import RevalidateCache + + cache_file = tmp_path / "test.json" + cache_data = { + "metadata": { + "version": "0.1", + "created_at": datetime.now(tz=timezone.utc).isoformat(), + "execution_attempts": 3, + "last_executed_at": None, + "failures": [ + { + "timestamp": datetime.now(tz=timezone.utc).isoformat(), + "step_index": 1, + "error_message": "Error", + "failure_count_at_step": 1, + } + ], + "is_valid": False, + "invalidation_reason": "Manual invalidation", + }, + "trajectory": [ + {"id": "1", "name": "click", "input": {"x": 100}, "type": "tool_use"}, + ], + "placeholders": {}, + } + with cache_file.open("w") as f: + json.dump(cache_data, f) + + tool = RevalidateCache() + result = tool(trajectory_file=str(cache_file)) + + # Verify success message + assert "Successfully revalidated" in result + assert "Manual invalidation" in result + + # Read updated cache file + with cache_file.open("r") as f: + updated_data = json.load(f) + + # Verify cache is now valid + assert updated_data["metadata"]["is_valid"] is True + assert updated_data["metadata"]["invalidation_reason"] is None + # Failure history should still be there + assert len(updated_data["metadata"]["failures"]) == 1 + + +def test_revalidate_cache_already_valid(tmp_path): + """Test that RevalidateCache handles already valid cache.""" + from askui.tools.caching_tools import RevalidateCache + + cache_file = tmp_path / "test.json" + cache_data = { + "metadata": { + "version": "0.1", + "created_at": datetime.now(tz=timezone.utc).isoformat(), + "execution_attempts": 0, + "last_executed_at": None, + "failures": [], + "is_valid": True, + "invalidation_reason": None, + }, + "trajectory": [ + {"id": "1", "name": "click", "input": {"x": 100}, "type": "tool_use"}, + ], + "placeholders": {}, + } + with cache_file.open("w") as f: + json.dump(cache_data, f) + + tool = RevalidateCache() + result = tool(trajectory_file=str(cache_file)) + + # Verify message indicates already valid + assert "already valid" in result + assert "No changes made" in result + + +def test_revalidate_cache_file_not_found(): + """Test that RevalidateCache handles missing files.""" + from askui.tools.caching_tools import RevalidateCache + + tool = RevalidateCache() + result = tool(trajectory_file="/nonexistent/file.json") + + assert "Trajectory file not found" in result + + +# ============================================================================ +# InvalidateCache Tests (unchanged from before) +# ============================================================================ + + +def test_invalidate_cache_marks_valid_as_invalid(tmp_path): + """Test that InvalidateCache marks valid cache as invalid.""" + from askui.tools.caching_tools import InvalidateCache + + cache_file = tmp_path / "test.json" + cache_data = { + "metadata": { + "version": "0.1", + "created_at": datetime.now(tz=timezone.utc).isoformat(), + "execution_attempts": 2, + "last_executed_at": datetime.now(tz=timezone.utc).isoformat(), + "failures": [], + "is_valid": True, + "invalidation_reason": None, + }, + "trajectory": [ + {"id": "1", "name": "click", "input": {"x": 100}, "type": "tool_use"}, + ], + "placeholders": {}, + } + with cache_file.open("w") as f: + json.dump(cache_data, f) + + tool = InvalidateCache() + result = tool(trajectory_file=str(cache_file), reason="UI changed - button moved") + + # Verify success message + assert "Successfully invalidated" in result + assert "UI changed - button moved" in result + + # Read updated cache file + with cache_file.open("r") as f: + updated_data = json.load(f) + + # Verify cache is now invalid + assert updated_data["metadata"]["is_valid"] is False + assert ( + updated_data["metadata"]["invalidation_reason"] == "UI changed - button moved" + ) + # Other metadata should be preserved + assert updated_data["metadata"]["execution_attempts"] == 2 + + +def test_invalidate_cache_updates_reason_if_already_invalid(tmp_path): + """Test that InvalidateCache updates reason if already invalid.""" + from askui.tools.caching_tools import InvalidateCache + + cache_file = tmp_path / "test.json" + cache_data = { + "metadata": { + "version": "0.1", + "created_at": datetime.now(tz=timezone.utc).isoformat(), + "execution_attempts": 0, + "last_executed_at": None, + "failures": [], + "is_valid": False, + "invalidation_reason": "Old reason", + }, + "trajectory": [ + {"id": "1", "name": "click", "input": {"x": 100}, "type": "tool_use"}, + ], + "placeholders": {}, + } + with cache_file.open("w") as f: + json.dump(cache_data, f) + + tool = InvalidateCache() + result = tool(trajectory_file=str(cache_file), reason="New reason") + + # Verify message indicates update + assert "already invalid" in result + assert "Updated invalidation reason to: New reason" in result + + # Read updated cache file + with cache_file.open("r") as f: + updated_data = json.load(f) + + # Verify reason was updated + assert updated_data["metadata"]["invalidation_reason"] == "New reason" + + +def test_invalidate_cache_file_not_found(): + """Test that InvalidateCache handles missing files.""" + from askui.tools.caching_tools import InvalidateCache + + tool = InvalidateCache() + result = tool(trajectory_file="/nonexistent/file.json", reason="Test") + + assert "Trajectory file not found" in result diff --git a/tests/unit/utils/test_cache_manager.py b/tests/unit/utils/test_cache_manager.py new file mode 100644 index 00000000..2464345e --- /dev/null +++ b/tests/unit/utils/test_cache_manager.py @@ -0,0 +1,378 @@ +"""Tests for cache manager.""" + +from datetime import datetime, timedelta, timezone +from unittest.mock import MagicMock + +import pytest + +from askui.models.shared.agent_message_param import ToolUseBlockParam +from askui.models.shared.settings import CacheFailure, CacheFile, CacheMetadata +from askui.utils.cache_manager import CacheManager +from askui.utils.cache_validator import ( + CacheValidator, + CompositeCacheValidator, + StepFailureCountValidator, +) + + +@pytest.fixture +def sample_cache_file(): + """Create a sample cache file for testing.""" + return CacheFile( + metadata=CacheMetadata( + version="0.1", + created_at=datetime.now(tz=timezone.utc), + execution_attempts=0, + is_valid=True, + ), + trajectory=[ + ToolUseBlockParam( + id="1", name="click", input={"x": 100}, type="tool_use" + ), + ToolUseBlockParam(id="2", name="type", input={"text": "test"}, type="tool_use"), + ], + placeholders={}, + ) + + +# Initialization Tests + + +def test_cache_manager_default_initialization(): + """Test cache manager initializes with default validator.""" + manager = CacheManager() + assert manager.validator is not None + assert isinstance(manager.validator, CompositeCacheValidator) + assert len(manager.validator.validators) == 3 # 3 built-in validators + + +def test_cache_manager_custom_validator(): + """Test cache manager with custom validator.""" + custom_validator = StepFailureCountValidator(max_failures_per_step=5) + manager = CacheManager(validator=custom_validator) + assert manager.validator is custom_validator + + +# Record Execution Attempt Tests + + +def test_record_execution_attempt_success(sample_cache_file): + """Test recording successful execution attempt.""" + manager = CacheManager() + initial_attempts = sample_cache_file.metadata.execution_attempts + initial_last_executed = sample_cache_file.metadata.last_executed_at + + manager.record_execution_attempt(sample_cache_file, success=True) + + assert sample_cache_file.metadata.execution_attempts == initial_attempts + 1 + assert sample_cache_file.metadata.last_executed_at is not None + assert sample_cache_file.metadata.last_executed_at != initial_last_executed + + +def test_record_execution_attempt_failure_with_info(sample_cache_file): + """Test recording failed execution attempt with failure info.""" + manager = CacheManager() + initial_attempts = sample_cache_file.metadata.execution_attempts + initial_failures = len(sample_cache_file.metadata.failures) + + failure_info = CacheFailure( + timestamp=datetime.now(tz=timezone.utc), + step_index=1, + error_message="Test error", + failure_count_at_step=1, + ) + + manager.record_execution_attempt( + sample_cache_file, success=False, failure_info=failure_info + ) + + assert sample_cache_file.metadata.execution_attempts == initial_attempts + 1 + assert len(sample_cache_file.metadata.failures) == initial_failures + 1 + assert sample_cache_file.metadata.failures[-1] == failure_info + + +def test_record_execution_attempt_failure_without_info(sample_cache_file): + """Test recording failed execution attempt without failure info.""" + manager = CacheManager() + initial_attempts = sample_cache_file.metadata.execution_attempts + initial_failures = len(sample_cache_file.metadata.failures) + + manager.record_execution_attempt(sample_cache_file, success=False, failure_info=None) + + assert sample_cache_file.metadata.execution_attempts == initial_attempts + 1 + assert len(sample_cache_file.metadata.failures) == initial_failures # No new failure added + + +# Record Step Failure Tests + + +def test_record_step_failure_first_failure(sample_cache_file): + """Test recording the first failure at a step.""" + manager = CacheManager() + + manager.record_step_failure(sample_cache_file, step_index=1, error_message="First error") + + assert len(sample_cache_file.metadata.failures) == 1 + failure = sample_cache_file.metadata.failures[0] + assert failure.step_index == 1 + assert failure.error_message == "First error" + assert failure.failure_count_at_step == 1 + + +def test_record_step_failure_multiple_at_same_step(sample_cache_file): + """Test recording multiple failures at the same step.""" + manager = CacheManager() + + manager.record_step_failure(sample_cache_file, step_index=1, error_message="Error 1") + manager.record_step_failure(sample_cache_file, step_index=1, error_message="Error 2") + manager.record_step_failure(sample_cache_file, step_index=1, error_message="Error 3") + + assert len(sample_cache_file.metadata.failures) == 3 + assert sample_cache_file.metadata.failures[0].failure_count_at_step == 1 + assert sample_cache_file.metadata.failures[1].failure_count_at_step == 2 + assert sample_cache_file.metadata.failures[2].failure_count_at_step == 3 + + +def test_record_step_failure_different_steps(sample_cache_file): + """Test recording failures at different steps.""" + manager = CacheManager() + + manager.record_step_failure(sample_cache_file, step_index=1, error_message="Error at step 1") + manager.record_step_failure(sample_cache_file, step_index=2, error_message="Error at step 2") + manager.record_step_failure(sample_cache_file, step_index=1, error_message="Another at step 1") + + assert len(sample_cache_file.metadata.failures) == 3 + + step_1_failures = [f for f in sample_cache_file.metadata.failures if f.step_index == 1] + step_2_failures = [f for f in sample_cache_file.metadata.failures if f.step_index == 2] + + assert len(step_1_failures) == 2 + assert len(step_2_failures) == 1 + assert step_1_failures[1].failure_count_at_step == 2 # Second failure at step 1 + + +# Should Invalidate Tests + + +def test_should_invalidate_delegates_to_validator(sample_cache_file): + """Test that should_invalidate delegates to the validator.""" + mock_validator = MagicMock(spec=CacheValidator) + mock_validator.should_invalidate.return_value = (True, "Test reason") + + manager = CacheManager(validator=mock_validator) + should_inv, reason = manager.should_invalidate(sample_cache_file, step_index=1) + + assert should_inv is True + assert reason == "Test reason" + mock_validator.should_invalidate.assert_called_once_with(sample_cache_file, 1) + + +def test_should_invalidate_with_default_validator(sample_cache_file): + """Test should_invalidate with default built-in validators.""" + manager = CacheManager() + + # Add failures that exceed default thresholds + sample_cache_file.metadata.execution_attempts = 10 + sample_cache_file.metadata.failures = [ + CacheFailure( + timestamp=datetime.now(tz=timezone.utc), + step_index=1, + error_message=f"Error {i}", + failure_count_at_step=i + 1, + ) + for i in range(6) + ] # 6/10 = 60% failure rate (exceeds default 50%) + + should_inv, reason = manager.should_invalidate(sample_cache_file) + assert should_inv is True + assert "Failure rate" in reason + + +# Invalidate Cache Tests + + +def test_invalidate_cache(sample_cache_file): + """Test marking cache as invalid.""" + manager = CacheManager() + assert sample_cache_file.metadata.is_valid is True + assert sample_cache_file.metadata.invalidation_reason is None + + manager.invalidate_cache(sample_cache_file, reason="Test invalidation") + + assert sample_cache_file.metadata.is_valid is False + assert sample_cache_file.metadata.invalidation_reason == "Test invalidation" + + +def test_invalidate_cache_multiple_times(sample_cache_file): + """Test invalidating cache multiple times updates reason.""" + manager = CacheManager() + + manager.invalidate_cache(sample_cache_file, reason="First reason") + assert sample_cache_file.metadata.invalidation_reason == "First reason" + + manager.invalidate_cache(sample_cache_file, reason="Second reason") + assert sample_cache_file.metadata.invalidation_reason == "Second reason" + + +# Mark Cache Valid Tests + + +def test_mark_cache_valid(sample_cache_file): + """Test marking cache as valid.""" + manager = CacheManager() + + # First invalidate + sample_cache_file.metadata.is_valid = False + sample_cache_file.metadata.invalidation_reason = "Was invalid" + + # Then mark valid + manager.mark_cache_valid(sample_cache_file) + + assert sample_cache_file.metadata.is_valid is True + assert sample_cache_file.metadata.invalidation_reason is None + + +def test_mark_cache_valid_already_valid(sample_cache_file): + """Test marking already valid cache as valid.""" + manager = CacheManager() + assert sample_cache_file.metadata.is_valid is True + + manager.mark_cache_valid(sample_cache_file) + + assert sample_cache_file.metadata.is_valid is True + assert sample_cache_file.metadata.invalidation_reason is None + + +# Get Failure Count for Step Tests + + +def test_get_failure_count_for_step_no_failures(sample_cache_file): + """Test getting failure count when no failures exist.""" + manager = CacheManager() + + count = manager.get_failure_count_for_step(sample_cache_file, step_index=1) + assert count == 0 + + +def test_get_failure_count_for_step_with_failures(sample_cache_file): + """Test getting failure count for specific step.""" + manager = CacheManager() + + sample_cache_file.metadata.failures = [ + CacheFailure( + timestamp=datetime.now(tz=timezone.utc), + step_index=1, + error_message="Error 1", + failure_count_at_step=1, + ), + CacheFailure( + timestamp=datetime.now(tz=timezone.utc), + step_index=2, + error_message="Error 2", + failure_count_at_step=1, + ), + CacheFailure( + timestamp=datetime.now(tz=timezone.utc), + step_index=1, + error_message="Error 3", + failure_count_at_step=2, + ), + ] + + count_step_1 = manager.get_failure_count_for_step(sample_cache_file, step_index=1) + count_step_2 = manager.get_failure_count_for_step(sample_cache_file, step_index=2) + + assert count_step_1 == 2 + assert count_step_2 == 1 + + +def test_get_failure_count_for_step_nonexistent_step(sample_cache_file): + """Test getting failure count for step that hasn't failed.""" + manager = CacheManager() + + sample_cache_file.metadata.failures = [ + CacheFailure( + timestamp=datetime.now(tz=timezone.utc), + step_index=1, + error_message="Error", + failure_count_at_step=1, + ) + ] + + count = manager.get_failure_count_for_step(sample_cache_file, step_index=99) + assert count == 0 + + +# Integration Tests + + +def test_full_workflow_with_failure_detection(sample_cache_file): + """Test complete workflow: record failures, detect threshold, invalidate.""" + manager = CacheManager() + + # Record 3 failures at step 1 (default threshold is 3) + for i in range(3): + manager.record_step_failure( + sample_cache_file, step_index=1, error_message=f"Error {i+1}" + ) + + # Check if should invalidate + should_inv, reason = manager.should_invalidate(sample_cache_file, step_index=1) + assert should_inv is True + assert "Step 1 failed 3 times" in reason + + # Invalidate + manager.invalidate_cache(sample_cache_file, reason=reason) + assert sample_cache_file.metadata.is_valid is False + + +def test_full_workflow_below_threshold(sample_cache_file): + """Test workflow where failures don't reach threshold.""" + manager = CacheManager() + + # Record 2 failures at step 1 (below default threshold of 3) + for i in range(2): + manager.record_step_failure( + sample_cache_file, step_index=1, error_message=f"Error {i+1}" + ) + + # Check if should invalidate + should_inv, reason = manager.should_invalidate(sample_cache_file, step_index=1) + assert should_inv is False + + # Cache should still be valid + assert sample_cache_file.metadata.is_valid is True + + +def test_workflow_with_custom_validator(sample_cache_file): + """Test workflow with custom validator with lower threshold.""" + # Custom validator with lower threshold + custom_validator = CompositeCacheValidator( + [StepFailureCountValidator(max_failures_per_step=2)] + ) + manager = CacheManager(validator=custom_validator) + + # Record 2 failures (enough to trigger custom validator) + for i in range(2): + manager.record_step_failure( + sample_cache_file, step_index=1, error_message=f"Error {i+1}" + ) + + should_inv, reason = manager.should_invalidate(sample_cache_file, step_index=1) + assert should_inv is True + assert "Step 1 failed 2 times" in reason + + +def test_workflow_successful_execution_updates_timestamp(sample_cache_file): + """Test that successful execution updates last_executed_at.""" + manager = CacheManager() + + # Record some failures first + manager.record_step_failure(sample_cache_file, step_index=1, error_message="Error") + assert sample_cache_file.metadata.last_executed_at is None + + # Record successful execution + manager.record_execution_attempt(sample_cache_file, success=True) + + assert sample_cache_file.metadata.last_executed_at is not None + assert sample_cache_file.metadata.execution_attempts == 1 diff --git a/tests/unit/utils/test_cache_migration.py b/tests/unit/utils/test_cache_migration.py new file mode 100644 index 00000000..bf31b4ca --- /dev/null +++ b/tests/unit/utils/test_cache_migration.py @@ -0,0 +1,360 @@ +"""Tests for cache migration utilities.""" + +import json +from datetime import datetime, timezone +from pathlib import Path + +import pytest + +from askui.utils.cache_migration import CacheMigration, CacheMigrationError + + +@pytest.fixture +def temp_cache_dir(tmp_path): + """Create a temporary cache directory.""" + cache_dir = tmp_path / "caches" + cache_dir.mkdir() + return cache_dir + + +@pytest.fixture +def v1_cache_data(): + """Sample v0.0 cache data (just a trajectory list).""" + return [ + {"id": "1", "name": "click", "input": {"x": 100}, "type": "tool_use"}, + {"id": "2", "name": "type", "input": {"text": "test"}, "type": "tool_use"}, + ] + + +@pytest.fixture +def v2_cache_data(): + """Sample v0.1 cache data (with metadata).""" + return { + "metadata": { + "version": "0.1", + "created_at": datetime.now(tz=timezone.utc).isoformat(), + "execution_attempts": 0, + "last_executed_at": None, + "failures": [], + "is_valid": True, + "invalidation_reason": None, + }, + "trajectory": [ + {"id": "1", "name": "click", "input": {"x": 100}, "type": "tool_use"}, + ], + "placeholders": {}, + } + + +# Initialization Tests + + +def test_cache_migration_initialization(): + """Test CacheMigration initializes with correct defaults.""" + migration = CacheMigration() + assert migration.backup is False + assert migration.backup_suffix == ".v1.backup" + assert migration.migrated_count == 0 + assert migration.skipped_count == 0 + assert migration.error_count == 0 + + +def test_cache_migration_initialization_with_backup(): + """Test CacheMigration initializes with backup enabled.""" + migration = CacheMigration(backup=True, backup_suffix=".bak") + assert migration.backup is True + assert migration.backup_suffix == ".bak" + + +# Single File Migration Tests + + +def test_migrate_file_v1_to_v2(temp_cache_dir, v1_cache_data): + """Test migrating a v0.0 cache file to v0.1.""" + cache_file = temp_cache_dir / "test.json" + with cache_file.open("w") as f: + json.dump(v1_cache_data, f) + + migration = CacheMigration() + success, message = migration.migrate_file(cache_file, dry_run=False) + + assert success is True + assert "Migrated" in message + + # Verify file was updated to v0.1 + with cache_file.open("r") as f: + data = json.load(f) + + assert isinstance(data, dict) + assert "metadata" in data + assert data["metadata"]["version"] == "0.1" + assert "trajectory" in data + assert "placeholders" in data + + +def test_migrate_file_already_v2(temp_cache_dir, v2_cache_data): + """Test that v0.1 files are skipped.""" + cache_file = temp_cache_dir / "test.json" + with cache_file.open("w") as f: + json.dump(v2_cache_data, f) + + migration = CacheMigration() + success, message = migration.migrate_file(cache_file, dry_run=False) + + assert success is False + assert "Already v0.1" in message + + +def test_migrate_file_dry_run(temp_cache_dir, v1_cache_data): + """Test dry run doesn't modify files.""" + cache_file = temp_cache_dir / "test.json" + with cache_file.open("w") as f: + json.dump(v1_cache_data, f) + + # Store original content + original_content = cache_file.read_text() + + migration = CacheMigration() + success, message = migration.migrate_file(cache_file, dry_run=True) + + assert success is True + assert "Would migrate" in message + + # Verify file wasn't modified + assert cache_file.read_text() == original_content + + +def test_migrate_file_creates_backup(temp_cache_dir, v1_cache_data): + """Test that backup is created when requested.""" + cache_file = temp_cache_dir / "test.json" + with cache_file.open("w") as f: + json.dump(v1_cache_data, f) + + migration = CacheMigration(backup=True, backup_suffix=".backup") + success, message = migration.migrate_file(cache_file, dry_run=False) + + assert success is True + + # Verify backup exists + backup_file = temp_cache_dir / "test.json.backup" + assert backup_file.exists() + + # Verify backup contains original v0.0 data + with backup_file.open("r") as f: + backup_data = json.load(f) + assert backup_data == v1_cache_data + + +def test_migrate_file_not_found(temp_cache_dir): + """Test handling of missing file.""" + cache_file = temp_cache_dir / "nonexistent.json" + + migration = CacheMigration() + success, message = migration.migrate_file(cache_file, dry_run=False) + + assert success is False + assert "File not found" in message + + +def test_migrate_file_invalid_json(temp_cache_dir): + """Test handling of invalid JSON.""" + cache_file = temp_cache_dir / "invalid.json" + cache_file.write_text("not valid json{") + + migration = CacheMigration() + success, message = migration.migrate_file(cache_file, dry_run=False) + + assert success is False + assert "Error" in message + + +# Directory Migration Tests + + +def test_migrate_directory_multiple_files(temp_cache_dir, v1_cache_data): + """Test migrating multiple files in a directory.""" + # Create several v0.0 cache files + for i in range(3): + cache_file = temp_cache_dir / f"cache_{i}.json" + with cache_file.open("w") as f: + json.dump(v1_cache_data, f) + + migration = CacheMigration() + stats = migration.migrate_directory(temp_cache_dir, dry_run=False) + + assert stats["total"] == 3 + assert stats["migrated"] == 3 + assert stats["skipped"] == 0 + assert stats["errors"] == 0 + + +def test_migrate_directory_mixed_versions(temp_cache_dir, v1_cache_data, v2_cache_data): + """Test migrating directory with mixed v0.0 and v0.1 files.""" + # Create v0.0 files + for i in range(2): + cache_file = temp_cache_dir / f"v1_cache_{i}.json" + with cache_file.open("w") as f: + json.dump(v1_cache_data, f) + + # Create v0.1 files + for i in range(2): + cache_file = temp_cache_dir / f"v2_cache_{i}.json" + with cache_file.open("w") as f: + json.dump(v2_cache_data, f) + + migration = CacheMigration() + stats = migration.migrate_directory(temp_cache_dir, dry_run=False) + + assert stats["total"] == 4 + assert stats["migrated"] == 2 # Only v0.0 files migrated + assert stats["skipped"] == 2 # v0.1 files skipped + assert stats["errors"] == 0 + + +def test_migrate_directory_dry_run(temp_cache_dir, v1_cache_data): + """Test dry run on directory doesn't modify files.""" + cache_file = temp_cache_dir / "test.json" + with cache_file.open("w") as f: + json.dump(v1_cache_data, f) + + original_content = cache_file.read_text() + + migration = CacheMigration() + stats = migration.migrate_directory(temp_cache_dir, dry_run=True) + + assert stats["migrated"] == 1 + # Verify file wasn't modified + assert cache_file.read_text() == original_content + + +def test_migrate_directory_with_pattern(temp_cache_dir, v1_cache_data): + """Test migrating directory with custom file pattern.""" + # Create files with different extensions + for ext in ["json", "cache", "txt"]: + cache_file = temp_cache_dir / f"test.{ext}" + with cache_file.open("w") as f: + json.dump(v1_cache_data, f) + + migration = CacheMigration() + stats = migration.migrate_directory( + temp_cache_dir, file_pattern="*.cache", dry_run=False + ) + + # Only .cache file should be processed + assert stats["total"] == 1 + assert stats["migrated"] == 1 + + +def test_migrate_directory_not_found(): + """Test handling of non-existent directory.""" + migration = CacheMigration() + + with pytest.raises(CacheMigrationError, match="Directory not found"): + migration.migrate_directory(Path("/nonexistent/directory")) + + +def test_migrate_directory_empty(temp_cache_dir): + """Test migrating empty directory.""" + migration = CacheMigration() + stats = migration.migrate_directory(temp_cache_dir, dry_run=False) + + assert stats["total"] == 0 + assert stats["migrated"] == 0 + assert stats["skipped"] == 0 + assert stats["errors"] == 0 + + +def test_migrate_directory_with_errors(temp_cache_dir, v1_cache_data): + """Test directory migration handles errors gracefully.""" + # Create valid v0.0 file + valid_file = temp_cache_dir / "valid.json" + with valid_file.open("w") as f: + json.dump(v1_cache_data, f) + + # Create invalid file + invalid_file = temp_cache_dir / "invalid.json" + invalid_file.write_text("not valid json{") + + migration = CacheMigration() + stats = migration.migrate_directory(temp_cache_dir, dry_run=False) + + assert stats["total"] == 2 + assert stats["migrated"] == 1 # Valid file migrated + assert stats["errors"] == 1 # Invalid file failed + assert stats["skipped"] == 0 + + +def test_migrate_directory_creates_backups(temp_cache_dir, v1_cache_data): + """Test directory migration creates backups for all files.""" + # Create v0.0 files + for i in range(2): + cache_file = temp_cache_dir / f"cache_{i}.json" + with cache_file.open("w") as f: + json.dump(v1_cache_data, f) + + migration = CacheMigration(backup=True, backup_suffix=".bak") + stats = migration.migrate_directory(temp_cache_dir, dry_run=False) + + assert stats["migrated"] == 2 + + # Verify backups exist + for i in range(2): + backup_file = temp_cache_dir / f"cache_{i}.json.bak" + assert backup_file.exists() + + +# Integration Tests + + +def test_full_migration_workflow(temp_cache_dir, v1_cache_data): + """Test complete migration workflow from v0.0 to v0.1.""" + # Create v0.0 cache + cache_file = temp_cache_dir / "workflow_test.json" + with cache_file.open("w") as f: + json.dump(v1_cache_data, f) + + # Perform migration with backup + migration = CacheMigration(backup=True) + success, message = migration.migrate_file(cache_file, dry_run=False) + + assert success is True + + # Verify v0.1 structure + with cache_file.open("r") as f: + data = json.load(f) + + assert data["metadata"]["version"] == "0.1" + assert data["metadata"]["execution_attempts"] == 0 + assert data["metadata"]["is_valid"] is True + assert len(data["trajectory"]) == 2 + assert data["placeholders"] == {} + + # Verify backup + backup_file = cache_file.with_suffix(cache_file.suffix + ".v1.backup") + assert backup_file.exists() + + # Attempt to migrate again (should skip) + success, message = migration.migrate_file(cache_file, dry_run=False) + assert success is False + assert "Already v0.1" in message + + +def test_migration_preserves_trajectory_data(temp_cache_dir, v1_cache_data): + """Test that migration preserves all trajectory data.""" + cache_file = temp_cache_dir / "preserve_test.json" + with cache_file.open("w") as f: + json.dump(v1_cache_data, f) + + migration = CacheMigration() + migration.migrate_file(cache_file, dry_run=False) + + # Load migrated file + with cache_file.open("r") as f: + data = json.load(f) + + # Verify trajectory preserved + assert len(data["trajectory"]) == len(v1_cache_data) + for i, step in enumerate(data["trajectory"]): + assert step["id"] == v1_cache_data[i]["id"] + assert step["name"] == v1_cache_data[i]["name"] + assert step["input"] == v1_cache_data[i]["input"] diff --git a/tests/unit/utils/test_cache_validator.py b/tests/unit/utils/test_cache_validator.py new file mode 100644 index 00000000..d252cd31 --- /dev/null +++ b/tests/unit/utils/test_cache_validator.py @@ -0,0 +1,486 @@ +"""Tests for cache validation strategies.""" + +from datetime import datetime, timedelta, timezone + +import pytest + +from askui.models.shared.agent_message_param import ToolUseBlockParam +from askui.models.shared.settings import CacheFailure, CacheFile, CacheMetadata +from askui.utils.cache_validator import ( + CacheValidator, + CompositeCacheValidator, + StaleCacheValidator, + StepFailureCountValidator, + TotalFailureRateValidator, +) + + +@pytest.fixture +def sample_cache_file(): + """Create a sample cache file for testing.""" + return CacheFile( + metadata=CacheMetadata( + version="0.1", + created_at=datetime.now(tz=timezone.utc), + execution_attempts=0, + is_valid=True, + ), + trajectory=[ + ToolUseBlockParam( + id="1", name="click", input={"x": 100}, type="tool_use" + ), + ToolUseBlockParam(id="2", name="type", input={"text": "test"}, type="tool_use"), + ], + placeholders={}, + ) + + +# StepFailureCountValidator Tests + + +def test_step_failure_count_validator_below_threshold(sample_cache_file): + """Test validator does not invalidate when failures are below threshold.""" + validator = StepFailureCountValidator(max_failures_per_step=3) + + # Add 2 failures at step 1 + sample_cache_file.metadata.failures = [ + CacheFailure( + timestamp=datetime.now(tz=timezone.utc), + step_index=1, + error_message="Error 1", + failure_count_at_step=1, + ), + CacheFailure( + timestamp=datetime.now(tz=timezone.utc), + step_index=1, + error_message="Error 2", + failure_count_at_step=2, + ), + ] + + should_inv, reason = validator.should_invalidate(sample_cache_file, step_index=1) + assert should_inv is False + assert reason is None + + +def test_step_failure_count_validator_at_threshold(sample_cache_file): + """Test validator invalidates when failures reach threshold.""" + validator = StepFailureCountValidator(max_failures_per_step=3) + + # Add 3 failures at step 1 + sample_cache_file.metadata.failures = [ + CacheFailure( + timestamp=datetime.now(tz=timezone.utc), + step_index=1, + error_message=f"Error {i}", + failure_count_at_step=i + 1, + ) + for i in range(3) + ] + + should_inv, reason = validator.should_invalidate(sample_cache_file, step_index=1) + assert should_inv is True + assert "Step 1 failed 3 times" in reason + + +def test_step_failure_count_validator_different_steps(sample_cache_file): + """Test validator only counts failures for specific step.""" + validator = StepFailureCountValidator(max_failures_per_step=3) + + # Add failures at different steps + sample_cache_file.metadata.failures = [ + CacheFailure( + timestamp=datetime.now(tz=timezone.utc), + step_index=1, + error_message="Error at step 1", + failure_count_at_step=1, + ), + CacheFailure( + timestamp=datetime.now(tz=timezone.utc), + step_index=2, + error_message="Error at step 2", + failure_count_at_step=1, + ), + CacheFailure( + timestamp=datetime.now(tz=timezone.utc), + step_index=1, + error_message="Error at step 1 again", + failure_count_at_step=2, + ), + ] + + # Step 1 has 2 failures (below threshold) + should_inv, reason = validator.should_invalidate(sample_cache_file, step_index=1) + assert should_inv is False + + # Step 2 has 1 failure (below threshold) + should_inv, reason = validator.should_invalidate(sample_cache_file, step_index=2) + assert should_inv is False + + +def test_step_failure_count_validator_no_step_index(sample_cache_file): + """Test validator returns False when no step_index provided.""" + validator = StepFailureCountValidator(max_failures_per_step=3) + + should_inv, reason = validator.should_invalidate(sample_cache_file, step_index=None) + assert should_inv is False + assert reason is None + + +def test_step_failure_count_validator_name(): + """Test validator returns correct name.""" + validator = StepFailureCountValidator() + assert validator.get_name() == "StepFailureCount" + + +# TotalFailureRateValidator Tests + + +def test_total_failure_rate_validator_below_min_attempts(sample_cache_file): + """Test validator does not check rate below minimum attempts.""" + validator = TotalFailureRateValidator(min_attempts=10, max_failure_rate=0.5) + + sample_cache_file.metadata.execution_attempts = 5 + sample_cache_file.metadata.failures = [ + CacheFailure( + timestamp=datetime.now(tz=timezone.utc), + step_index=1, + error_message="Error", + failure_count_at_step=1, + ) + for _ in range(4) + ] # 4/5 = 80% failure rate + + should_inv, reason = validator.should_invalidate(sample_cache_file) + assert should_inv is False # Too few attempts to judge + + +def test_total_failure_rate_validator_above_threshold(sample_cache_file): + """Test validator invalidates when failure rate exceeds threshold.""" + validator = TotalFailureRateValidator(min_attempts=10, max_failure_rate=0.5) + + sample_cache_file.metadata.execution_attempts = 10 + sample_cache_file.metadata.failures = [ + CacheFailure( + timestamp=datetime.now(tz=timezone.utc), + step_index=i % 2, + error_message=f"Error {i}", + failure_count_at_step=1, + ) + for i in range(6) + ] # 6/10 = 60% failure rate + + should_inv, reason = validator.should_invalidate(sample_cache_file) + assert should_inv is True + assert "60.0%" in reason + assert "50.0%" in reason + + +def test_total_failure_rate_validator_below_threshold(sample_cache_file): + """Test validator does not invalidate when rate is acceptable.""" + validator = TotalFailureRateValidator(min_attempts=10, max_failure_rate=0.5) + + sample_cache_file.metadata.execution_attempts = 10 + sample_cache_file.metadata.failures = [ + CacheFailure( + timestamp=datetime.now(tz=timezone.utc), + step_index=1, + error_message=f"Error {i}", + failure_count_at_step=1, + ) + for i in range(4) + ] # 4/10 = 40% failure rate + + should_inv, reason = validator.should_invalidate(sample_cache_file) + assert should_inv is False + + +def test_total_failure_rate_validator_zero_attempts(sample_cache_file): + """Test validator handles zero attempts correctly.""" + validator = TotalFailureRateValidator(min_attempts=10, max_failure_rate=0.5) + + sample_cache_file.metadata.execution_attempts = 0 + sample_cache_file.metadata.failures = [] + + should_inv, reason = validator.should_invalidate(sample_cache_file) + assert should_inv is False + + +def test_total_failure_rate_validator_name(): + """Test validator returns correct name.""" + validator = TotalFailureRateValidator() + assert validator.get_name() == "TotalFailureRate" + + +# StaleCacheValidator Tests + + +def test_stale_cache_validator_not_stale(sample_cache_file): + """Test validator does not invalidate recent cache.""" + validator = StaleCacheValidator(max_age_days=30) + + sample_cache_file.metadata.last_executed_at = datetime.now(tz=timezone.utc) - timedelta(days=10) + sample_cache_file.metadata.failures = [ + CacheFailure( + timestamp=datetime.now(tz=timezone.utc), + step_index=1, + error_message="Error", + failure_count_at_step=1, + ) + ] + + should_inv, reason = validator.should_invalidate(sample_cache_file) + assert should_inv is False + + +def test_stale_cache_validator_is_stale(sample_cache_file): + """Test validator invalidates old cache with failures.""" + validator = StaleCacheValidator(max_age_days=30) + + sample_cache_file.metadata.last_executed_at = datetime.now(tz=timezone.utc) - timedelta(days=35) + sample_cache_file.metadata.failures = [ + CacheFailure( + timestamp=datetime.now(tz=timezone.utc), + step_index=1, + error_message="Error", + failure_count_at_step=1, + ) + ] + + should_inv, reason = validator.should_invalidate(sample_cache_file) + assert should_inv is True + assert "35 days" in reason + + +def test_stale_cache_validator_old_but_no_failures(sample_cache_file): + """Test validator does not invalidate old cache without failures.""" + validator = StaleCacheValidator(max_age_days=30) + + sample_cache_file.metadata.last_executed_at = datetime.now(tz=timezone.utc) - timedelta(days=100) + sample_cache_file.metadata.failures = [] + + should_inv, reason = validator.should_invalidate(sample_cache_file) + assert should_inv is False # Old but no failures = still valid + + +def test_stale_cache_validator_never_executed(sample_cache_file): + """Test validator handles cache that was never executed.""" + validator = StaleCacheValidator(max_age_days=30) + + sample_cache_file.metadata.last_executed_at = None + sample_cache_file.metadata.failures = [ + CacheFailure( + timestamp=datetime.now(tz=timezone.utc), + step_index=1, + error_message="Error", + failure_count_at_step=1, + ) + ] + + should_inv, reason = validator.should_invalidate(sample_cache_file) + assert should_inv is False # Never executed = can't be stale + + +def test_stale_cache_validator_name(): + """Test validator returns correct name.""" + validator = StaleCacheValidator() + assert validator.get_name() == "StaleCache" + + +# CompositeCacheValidator Tests + + +def test_composite_validator_empty(): + """Test composite validator with no validators.""" + validator = CompositeCacheValidator([]) + cache_file = CacheFile( + metadata=CacheMetadata( + version="0.1", + created_at=datetime.now(tz=timezone.utc), + execution_attempts=0, + is_valid=True, + ), + trajectory=[], + placeholders={}, + ) + + should_inv, reason = validator.should_invalidate(cache_file) + assert should_inv is False + assert reason is None + + +def test_composite_validator_single_validator_triggers(sample_cache_file): + """Test composite validator with one validator that triggers.""" + step_validator = StepFailureCountValidator(max_failures_per_step=2) + composite = CompositeCacheValidator([step_validator]) + + # Add 2 failures + sample_cache_file.metadata.failures = [ + CacheFailure( + timestamp=datetime.now(tz=timezone.utc), + step_index=1, + error_message=f"Error {i}", + failure_count_at_step=i + 1, + ) + for i in range(2) + ] + + should_inv, reason = composite.should_invalidate(sample_cache_file, step_index=1) + assert should_inv is True + assert "StepFailureCount" in reason + + +def test_composite_validator_multiple_validators_all_pass(sample_cache_file): + """Test composite validator when all validators pass.""" + composite = CompositeCacheValidator( + [ + StepFailureCountValidator(max_failures_per_step=3), + TotalFailureRateValidator(min_attempts=10, max_failure_rate=0.5), + ] + ) + + sample_cache_file.metadata.execution_attempts = 10 + sample_cache_file.metadata.failures = [ + CacheFailure( + timestamp=datetime.now(tz=timezone.utc), + step_index=1, + error_message="Error", + failure_count_at_step=1, + ) + ] # 1/10 = 10% rate, and only 1 failure at step 1 + + should_inv, reason = composite.should_invalidate(sample_cache_file, step_index=1) + assert should_inv is False + + +def test_composite_validator_multiple_validators_one_triggers(sample_cache_file): + """Test composite validator when one validator triggers.""" + composite = CompositeCacheValidator( + [ + StepFailureCountValidator(max_failures_per_step=2), + TotalFailureRateValidator(min_attempts=100, max_failure_rate=0.5), + ] + ) + + sample_cache_file.metadata.execution_attempts = 10 # Below min_attempts + sample_cache_file.metadata.failures = [ + CacheFailure( + timestamp=datetime.now(tz=timezone.utc), + step_index=1, + error_message=f"Error {i}", + failure_count_at_step=i + 1, + ) + for i in range(3) + ] # 3 failures at step 1 (exceeds threshold of 2) + + should_inv, reason = composite.should_invalidate(sample_cache_file, step_index=1) + assert should_inv is True + assert "StepFailureCount" in reason + assert "Step 1 failed 3 times" in reason + + +def test_composite_validator_multiple_validators_multiple_trigger(sample_cache_file): + """Test composite validator when multiple validators trigger.""" + composite = CompositeCacheValidator( + [ + StepFailureCountValidator(max_failures_per_step=2), + TotalFailureRateValidator(min_attempts=5, max_failure_rate=0.5), + ] + ) + + sample_cache_file.metadata.execution_attempts = 5 + sample_cache_file.metadata.failures = [ + CacheFailure( + timestamp=datetime.now(tz=timezone.utc), + step_index=1, + error_message=f"Error {i}", + failure_count_at_step=i + 1, + ) + for i in range(4) + ] # 4/5 = 80% rate (exceeds 50%), and 4 failures at step 1 (exceeds 2) + + should_inv, reason = composite.should_invalidate(sample_cache_file, step_index=1) + assert should_inv is True + assert "StepFailureCount" in reason + assert "TotalFailureRate" in reason + assert ";" in reason # Multiple reasons combined + + +def test_composite_validator_add_validator(sample_cache_file): + """Test adding validator to composite after initialization.""" + composite = CompositeCacheValidator([]) + assert len(composite.validators) == 0 + + composite.add_validator(StepFailureCountValidator(max_failures_per_step=1)) + assert len(composite.validators) == 1 + + # Add failure + sample_cache_file.metadata.failures = [ + CacheFailure( + timestamp=datetime.now(tz=timezone.utc), + step_index=1, + error_message="Error", + failure_count_at_step=1, + ) + ] + + should_inv, reason = composite.should_invalidate(sample_cache_file, step_index=1) + assert should_inv is True + + +def test_composite_validator_name(): + """Test composite validator returns correct name.""" + composite = CompositeCacheValidator([]) + assert composite.get_name() == "CompositeValidator" + + +# Custom Validator Tests + + +class CustomTestValidator(CacheValidator): + """Custom test validator for testing extensibility.""" + + def __init__(self, should_trigger: bool = False): + self.should_trigger = should_trigger + + def should_invalidate(self, cache_file, step_index=None): + if self.should_trigger: + return True, "Custom validation failed" + return False, None + + def get_name(self): + return "CustomTest" + + +def test_custom_validator_integration(sample_cache_file): + """Test that custom validators work with composite.""" + custom = CustomTestValidator(should_trigger=True) + composite = CompositeCacheValidator([custom]) + + should_inv, reason = composite.should_invalidate(sample_cache_file) + assert should_inv is True + assert "CustomTest" in reason + assert "Custom validation failed" in reason + + +def test_custom_validator_with_built_in(sample_cache_file): + """Test custom validator alongside built-in validators.""" + custom = CustomTestValidator(should_trigger=False) + step_validator = StepFailureCountValidator(max_failures_per_step=1) + + composite = CompositeCacheValidator([custom, step_validator]) + + # Add failure to trigger step validator + sample_cache_file.metadata.failures = [ + CacheFailure( + timestamp=datetime.now(tz=timezone.utc), + step_index=1, + error_message="Error", + failure_count_at_step=1, + ) + ] + + should_inv, reason = composite.should_invalidate(sample_cache_file, step_index=1) + assert should_inv is True + assert "StepFailureCount" in reason + assert "CustomTest" not in reason # Custom didn't trigger diff --git a/tests/unit/utils/test_cache_writer.py b/tests/unit/utils/test_cache_writer.py index 2c875ae4..333a9725 100644 --- a/tests/unit/utils/test_cache_writer.py +++ b/tests/unit/utils/test_cache_writer.py @@ -7,6 +7,7 @@ from askui.models.shared.agent_message_param import MessageParam, ToolUseBlockParam from askui.models.shared.agent_on_message_cb import OnMessageCbParam +from askui.models.shared.settings import CacheFile from askui.utils.cache_writer import CacheWriter @@ -138,7 +139,7 @@ def test_cache_writer_detects_cached_execution() -> None: def test_cache_writer_generate_writes_file() -> None: - """Test that generate() writes messages to a JSON file.""" + """Test that generate() writes messages to a JSON file in v0.1 format.""" with tempfile.TemporaryDirectory() as temp_dir: cache_dir = Path(temp_dir) cache_writer = CacheWriter(cache_dir=str(cache_dir), file_name="output.json") @@ -164,15 +165,27 @@ def test_cache_writer_generate_writes_file() -> None: cache_file = cache_dir / "output.json" assert cache_file.exists() - # Verify file content + # Verify file content (v0.1 format) with cache_file.open("r", encoding="utf-8") as f: data = json.load(f) - assert len(data) == 2 - assert data[0]["id"] == "id1" - assert data[0]["name"] == "tool1" - assert data[1]["id"] == "id2" - assert data[1]["name"] == "tool2" + # Check v0.1 structure + assert "metadata" in data + assert "trajectory" in data + assert "placeholders" in data + + # Check metadata + assert data["metadata"]["version"] == "0.1" + assert "created_at" in data["metadata"] + assert data["metadata"]["execution_attempts"] == 0 + assert data["metadata"]["is_valid"] is True + + # Check trajectory + assert len(data["trajectory"]) == 2 + assert data["trajectory"][0]["id"] == "id1" + assert data["trajectory"][0]["name"] == "tool1" + assert data["trajectory"][1]["id"] == "id2" + assert data["trajectory"][1]["name"] == "tool2" def test_cache_writer_generate_auto_names_file() -> None: @@ -243,12 +256,12 @@ def test_cache_writer_reset() -> None: assert cache_writer.was_cached_execution is False -def test_cache_writer_read_cache_file() -> None: - """Test that read_cache_file() loads ToolUseBlockParam from JSON.""" +def test_cache_writer_read_cache_file_v1() -> None: + """Test backward compatibility: read_cache_file() loads v0.0 format.""" with tempfile.TemporaryDirectory() as temp_dir: - cache_file = Path(temp_dir) / "test_cache.json" + cache_file_path = Path(temp_dir) / "test_cache.json" - # Create a cache file + # Create a v0.0 cache file (just a list) trajectory: list[dict[str, Any]] = [ { "id": "id1", @@ -264,19 +277,71 @@ def test_cache_writer_read_cache_file() -> None: }, ] - with cache_file.open("w", encoding="utf-8") as f: + with cache_file_path.open("w", encoding="utf-8") as f: json.dump(trajectory, f) # Read cache file - result = CacheWriter.read_cache_file(cache_file) + result = CacheWriter.read_cache_file(cache_file_path) + + # Should return CacheFile with migrated v0.0 data (now v0.1) + assert isinstance(result, CacheFile) + assert result.metadata.version == "0.1" # Migrated from v0.0 to v0.1 + assert len(result.trajectory) == 2 + assert isinstance(result.trajectory[0], ToolUseBlockParam) + assert result.trajectory[0].id == "id1" + assert result.trajectory[0].name == "tool1" + assert isinstance(result.trajectory[1], ToolUseBlockParam) + assert result.trajectory[1].id == "id2" + assert result.trajectory[1].name == "tool2" + + +def test_cache_writer_read_cache_file_v2() -> None: + """Test that read_cache_file() loads v0.1 format correctly.""" + with tempfile.TemporaryDirectory() as temp_dir: + cache_file_path = Path(temp_dir) / "test_cache_v2.json" + + # Create a v0.1 cache file + cache_data = { + "metadata": { + "version": "0.1", + "created_at": "2025-12-11T10:00:00Z", + "last_executed_at": None, + "execution_attempts": 0, + "failures": [], + "is_valid": True, + "invalidation_reason": None, + }, + "trajectory": [ + { + "id": "id1", + "name": "tool1", + "input": {"param": "value1"}, + "type": "tool_use", + }, + { + "id": "id2", + "name": "tool2", + "input": {"param": "value2"}, + "type": "tool_use", + }, + ], + "placeholders": {"current_date": "Current date in YYYY-MM-DD format"}, + } + + with cache_file_path.open("w", encoding="utf-8") as f: + json.dump(cache_data, f) + + # Read cache file + result = CacheWriter.read_cache_file(cache_file_path) - assert len(result) == 2 - assert isinstance(result[0], ToolUseBlockParam) - assert result[0].id == "id1" - assert result[0].name == "tool1" - assert isinstance(result[1], ToolUseBlockParam) - assert result[1].id == "id2" - assert result[1].name == "tool2" + # Should return CacheFile + assert isinstance(result, CacheFile) + assert result.metadata.version == "0.1" + assert result.metadata.is_valid is True + assert len(result.trajectory) == 2 + assert result.trajectory[0].id == "id1" + assert result.trajectory[1].id == "id2" + assert "current_date" in result.placeholders def test_cache_writer_set_file_name() -> None: @@ -310,3 +375,67 @@ def test_cache_writer_generate_resets_after_writing() -> None: # After generate, messages should be empty assert cache_writer.messages == [] + + +def test_cache_writer_detects_and_stores_placeholders() -> None: + """Test that CacheWriter detects placeholders and stores them in metadata.""" + with tempfile.TemporaryDirectory() as temp_dir: + cache_dir = Path(temp_dir) + cache_writer = CacheWriter(cache_dir=str(cache_dir), file_name="test.json") + + # Add tool use blocks with placeholders + cache_writer.messages = [ + ToolUseBlockParam( + id="id1", + name="computer", + input={"action": "type", "text": "Today is {{current_date}}"}, + type="tool_use", + ), + ToolUseBlockParam( + id="id2", + name="computer", + input={"action": "type", "text": "User: {{user_name}}"}, + type="tool_use", + ), + ] + + cache_writer.generate() + + # Read back the cache file + cache_file = cache_dir / "test.json" + with cache_file.open("r", encoding="utf-8") as f: + data = json.load(f) + + # Verify placeholders were detected and stored + assert "placeholders" in data + assert "current_date" in data["placeholders"] + assert "user_name" in data["placeholders"] + assert len(data["placeholders"]) == 2 + + +def test_cache_writer_empty_placeholders_when_none_found() -> None: + """Test that placeholders dict is empty when no placeholders exist.""" + with tempfile.TemporaryDirectory() as temp_dir: + cache_dir = Path(temp_dir) + cache_writer = CacheWriter(cache_dir=str(cache_dir), file_name="test.json") + + # Add tool use blocks without placeholders + cache_writer.messages = [ + ToolUseBlockParam( + id="id1", + name="computer", + input={"action": "click", "coordinate": [100, 200]}, + type="tool_use", + ) + ] + + cache_writer.generate() + + # Read back the cache file + cache_file = cache_dir / "test.json" + with cache_file.open("r", encoding="utf-8") as f: + data = json.load(f) + + # Verify placeholders dict is empty + assert "placeholders" in data + assert data["placeholders"] == {} diff --git a/tests/unit/utils/test_placeholder_handler.py b/tests/unit/utils/test_placeholder_handler.py new file mode 100644 index 00000000..c0acd89d --- /dev/null +++ b/tests/unit/utils/test_placeholder_handler.py @@ -0,0 +1,378 @@ +"""Unit tests for PlaceholderHandler.""" + +import pytest +from askui.models.shared.agent_message_param import ToolUseBlockParam +from askui.utils.placeholder_handler import PLACEHOLDER_PATTERN, PlaceholderHandler + + +def test_placeholder_pattern_matches_valid_placeholders() -> None: + """Test that the regex pattern matches valid placeholder syntax.""" + import re + + valid_placeholders = [ + "{{variable}}", + "{{current_date}}", + "{{user_name}}", + "{{_private}}", + "{{VAR123}}", + ] + + for placeholder in valid_placeholders: + match = re.search(PLACEHOLDER_PATTERN, placeholder) + assert match is not None, f"Should match valid placeholder: {placeholder}" + + +def test_placeholder_pattern_does_not_match_invalid() -> None: + """Test that the regex pattern rejects invalid placeholder syntax.""" + import re + + invalid_placeholders = [ + "{{123invalid}}", # Starts with number + "{{var-name}}", # Contains hyphen + "{{var name}}", # Contains space + "{single}", # Single braces + "{{}}", # Empty + ] + + for placeholder in invalid_placeholders: + match = re.search(PLACEHOLDER_PATTERN, placeholder) + if match and match.group(0) == placeholder: + pytest.fail(f"Should not match invalid placeholder: {placeholder}") + + +def test_extract_placeholders_from_simple_string() -> None: + """Test extracting placeholders from a simple string input.""" + trajectory = [ + ToolUseBlockParam( + id="1", + name="computer", + input={"action": "type", "text": "Today is {{current_date}}"}, + type="tool_use", + ) + ] + + placeholders = PlaceholderHandler.extract_placeholders(trajectory) + assert placeholders == {"current_date"} + + +def test_extract_placeholders_multiple_in_one_string() -> None: + """Test extracting multiple placeholders from one string.""" + trajectory = [ + ToolUseBlockParam( + id="1", + name="computer", + input={ + "action": "type", + "text": "Hello {{user_name}}, today is {{current_date}}", + }, + type="tool_use", + ) + ] + + placeholders = PlaceholderHandler.extract_placeholders(trajectory) + assert placeholders == {"user_name", "current_date"} + + +def test_extract_placeholders_from_nested_dict() -> None: + """Test extracting placeholders from nested dictionary structures.""" + trajectory = [ + ToolUseBlockParam( + id="1", + name="complex_tool", + input={ + "outer": {"inner": {"text": "Value is {{nested_var}}"}}, + "another": "{{another_var}}", + }, + type="tool_use", + ) + ] + + placeholders = PlaceholderHandler.extract_placeholders(trajectory) + assert placeholders == {"nested_var", "another_var"} + + +def test_extract_placeholders_from_list() -> None: + """Test extracting placeholders from lists in input.""" + trajectory = [ + ToolUseBlockParam( + id="1", + name="tool", + input={ + "items": [ + "{{item1}}", + "{{item2}}", + {"nested": "{{item3}}"}, + ] + }, + type="tool_use", + ) + ] + + placeholders = PlaceholderHandler.extract_placeholders(trajectory) + assert placeholders == {"item1", "item2", "item3"} + + +def test_extract_placeholders_no_placeholders() -> None: + """Test that extracting from trajectory without placeholders returns empty set.""" + trajectory = [ + ToolUseBlockParam( + id="1", + name="computer", + input={"action": "click", "coordinate": [100, 200]}, + type="tool_use", + ) + ] + + placeholders = PlaceholderHandler.extract_placeholders(trajectory) + assert placeholders == set() + + +def test_extract_placeholders_from_multiple_steps() -> None: + """Test extracting placeholders from multiple trajectory steps.""" + trajectory = [ + ToolUseBlockParam( + id="1", + name="tool1", + input={"text": "{{var1}}"}, + type="tool_use", + ), + ToolUseBlockParam( + id="2", + name="tool2", + input={"text": "{{var2}}"}, + type="tool_use", + ), + ToolUseBlockParam( + id="3", + name="tool3", + input={"text": "{{var1}}"}, # Duplicate + type="tool_use", + ), + ] + + placeholders = PlaceholderHandler.extract_placeholders(trajectory) + assert placeholders == {"var1", "var2"} # No duplicates + + +def test_validate_placeholders_all_provided() -> None: + """Test validation passes when all placeholders have values.""" + trajectory = [ + ToolUseBlockParam( + id="1", + name="tool", + input={"text": "{{var1}} and {{var2}}"}, + type="tool_use", + ) + ] + + is_valid, missing = PlaceholderHandler.validate_placeholders( + trajectory, {"var1": "value1", "var2": "value2"} + ) + + assert is_valid is True + assert missing == [] + + +def test_validate_placeholders_missing_some() -> None: + """Test validation fails when some placeholders are missing.""" + trajectory = [ + ToolUseBlockParam( + id="1", + name="tool", + input={"text": "{{var1}} and {{var2}} and {{var3}}"}, + type="tool_use", + ) + ] + + is_valid, missing = PlaceholderHandler.validate_placeholders( + trajectory, {"var1": "value1"} + ) + + assert is_valid is False + assert set(missing) == {"var2", "var3"} + + +def test_validate_placeholders_extra_values_ok() -> None: + """Test validation passes when extra values are provided (they're ignored).""" + trajectory = [ + ToolUseBlockParam( + id="1", + name="tool", + input={"text": "{{var1}}"}, + type="tool_use", + ) + ] + + is_valid, missing = PlaceholderHandler.validate_placeholders( + trajectory, {"var1": "value1", "extra_var": "extra_value"} + ) + + assert is_valid is True + assert missing == [] + + +def test_validate_placeholders_no_placeholders() -> None: + """Test validation passes when trajectory has no placeholders.""" + trajectory = [ + ToolUseBlockParam( + id="1", + name="tool", + input={"text": "No placeholders here"}, + type="tool_use", + ) + ] + + is_valid, missing = PlaceholderHandler.validate_placeholders(trajectory, {}) + + assert is_valid is True + assert missing == [] + + +def test_substitute_placeholders_simple_string() -> None: + """Test substituting placeholders in a simple string.""" + tool_block = ToolUseBlockParam( + id="1", + name="computer", + input={"action": "type", "text": "Today is {{current_date}}"}, + type="tool_use", + ) + + result = PlaceholderHandler.substitute_placeholders( + tool_block, {"current_date": "2025-12-11"} + ) + + assert result.input["text"] == "Today is 2025-12-11" + assert result.id == tool_block.id + assert result.name == tool_block.name + + +def test_substitute_placeholders_multiple() -> None: + """Test substituting multiple placeholders in one string.""" + tool_block = ToolUseBlockParam( + id="1", + name="computer", + input={ + "action": "type", + "text": "Hello {{user_name}}, date is {{current_date}}", + }, + type="tool_use", + ) + + result = PlaceholderHandler.substitute_placeholders( + tool_block, {"user_name": "Alice", "current_date": "2025-12-11"} + ) + + assert result.input["text"] == "Hello Alice, date is 2025-12-11" + + +def test_substitute_placeholders_nested_dict() -> None: + """Test substituting placeholders in nested dictionaries.""" + tool_block = ToolUseBlockParam( + id="1", + name="tool", + input={ + "outer": {"inner": {"text": "Value: {{var1}}"}}, + "another": "{{var2}}", + }, + type="tool_use", + ) + + result = PlaceholderHandler.substitute_placeholders( + tool_block, {"var1": "value1", "var2": "value2"} + ) + + assert result.input["outer"]["inner"]["text"] == "Value: value1" + assert result.input["another"] == "value2" + + +def test_substitute_placeholders_in_list() -> None: + """Test substituting placeholders in lists.""" + tool_block = ToolUseBlockParam( + id="1", + name="tool", + input={"items": ["{{item1}}", "static", {"nested": "{{item2}}"}]}, + type="tool_use", + ) + + result = PlaceholderHandler.substitute_placeholders( + tool_block, {"item1": "value1", "item2": "value2"} + ) + + assert result.input["items"][0] == "value1" + assert result.input["items"][1] == "static" + assert result.input["items"][2]["nested"] == "value2" + + +def test_substitute_placeholders_no_change_if_no_placeholders() -> None: + """Test that substitution doesn't change input without placeholders.""" + tool_block = ToolUseBlockParam( + id="1", + name="computer", + input={"action": "click", "coordinate": [100, 200]}, + type="tool_use", + ) + + result = PlaceholderHandler.substitute_placeholders(tool_block, {}) + + assert result.input == tool_block.input + + +def test_substitute_placeholders_partial_substitution() -> None: + """Test that only provided placeholders are substituted.""" + tool_block = ToolUseBlockParam( + id="1", + name="tool", + input={"text": "{{var1}} and {{var2}}"}, + type="tool_use", + ) + + result = PlaceholderHandler.substitute_placeholders(tool_block, {"var1": "value1"}) + + assert result.input["text"] == "value1 and {{var2}}" + + +def test_substitute_placeholders_preserves_original() -> None: + """Test that substitution creates a new object, doesn't modify original.""" + tool_block = ToolUseBlockParam( + id="1", + name="tool", + input={"text": "{{var1}}"}, + type="tool_use", + ) + + original_input = tool_block.input.copy() + PlaceholderHandler.substitute_placeholders(tool_block, {"var1": "value1"}) + + # Original should be unchanged + assert tool_block.input == original_input + + +def test_substitute_placeholders_with_special_characters() -> None: + """Test substitution with values containing special regex characters.""" + tool_block = ToolUseBlockParam( + id="1", + name="tool", + input={"text": "Pattern: {{pattern}}"}, + type="tool_use", + ) + + # Value contains regex special characters + result = PlaceholderHandler.substitute_placeholders( + tool_block, {"pattern": r".*[test]$"} + ) + + assert result.input["text"] == r"Pattern: .*[test]$" + + +def test_substitute_placeholders_same_placeholder_multiple_times() -> None: + """Test substituting the same placeholder appearing multiple times.""" + tool_block = ToolUseBlockParam( + id="1", + name="tool", + input={"text": "{{var}} is {{var}} is {{var}}"}, + type="tool_use", + ) + + result = PlaceholderHandler.substitute_placeholders(tool_block, {"var": "value"}) + + assert result.input["text"] == "value is value is value" diff --git a/tests/unit/utils/test_trajectory_executor.py b/tests/unit/utils/test_trajectory_executor.py new file mode 100644 index 00000000..9a7384e9 --- /dev/null +++ b/tests/unit/utils/test_trajectory_executor.py @@ -0,0 +1,754 @@ +"""Unit tests for TrajectoryExecutor.""" + +from unittest.mock import MagicMock + +import pytest + +from askui.models.shared.agent_message_param import ( + MessageParam, + TextBlockParam, + ToolResultBlockParam, + ToolUseBlockParam, +) +from askui.models.shared.tools import ToolCollection +from askui.utils.trajectory_executor import TrajectoryExecutor + + +def test_trajectory_executor_initialization() -> None: + """Test TrajectoryExecutor initialization.""" + trajectory = [ToolUseBlockParam(id="1", name="tool1", input={}, type="tool_use")] + toolbox = ToolCollection() + + executor = TrajectoryExecutor( + trajectory=trajectory, + toolbox=toolbox, + placeholder_values={"var": "value"}, + delay_time=0.1, + ) + + assert executor.trajectory == trajectory + assert executor.toolbox == toolbox + assert executor.placeholder_values == {"var": "value"} + assert executor.delay_time == 0.1 + assert executor.current_step_index == 0 + + +def test_trajectory_executor_execute_simple_step() -> None: + """Test executing a simple step.""" + mock_toolbox = MagicMock(spec=ToolCollection) + mock_toolbox.run.return_value = ["Tool result"] + mock_toolbox._tool_map = {} + + trajectory = [ + ToolUseBlockParam( + id="1", name="test_tool", input={"param": "value"}, type="tool_use" + ) + ] + + executor = TrajectoryExecutor( + trajectory=trajectory, toolbox=mock_toolbox, delay_time=0 + ) + + result = executor.execute_next_step() + + assert result.status == "SUCCESS" + assert result.step_index == 0 + assert result.error_message is None + assert executor.current_step_index == 1 + assert mock_toolbox.run.call_count == 1 + + +def test_trajectory_executor_execute_all_steps() -> None: + """Test executing all steps in a trajectory.""" + mock_toolbox = MagicMock(spec=ToolCollection) + mock_toolbox.run.return_value = ["Result"] + mock_toolbox._tool_map = {} + + trajectory = [ + ToolUseBlockParam(id="1", name="tool1", input={}, type="tool_use"), + ToolUseBlockParam(id="2", name="tool2", input={}, type="tool_use"), + ] + + executor = TrajectoryExecutor( + trajectory=trajectory, toolbox=mock_toolbox, delay_time=0 + ) + + results = executor.execute_all() + + # Should have 3 results: 2 successful steps + 1 completed + assert len(results) == 3 + assert results[0].status == "SUCCESS" + assert results[0].step_index == 0 + assert results[1].status == "SUCCESS" + assert results[1].step_index == 1 + assert results[2].status == "COMPLETED" + assert executor.current_step_index == 2 + assert mock_toolbox.run.call_count == 2 + + +def test_trajectory_executor_executes_screenshot_tools() -> None: + """Test that screenshot tools are executed (not skipped).""" + mock_toolbox = MagicMock(spec=ToolCollection) + mock_toolbox.run.return_value = ["Screenshot result"] + mock_toolbox._tool_map = {} + + trajectory = [ + ToolUseBlockParam(id="1", name="screenshot", input={}, type="tool_use"), + ToolUseBlockParam(id="2", name="tool1", input={}, type="tool_use"), + ] + + executor = TrajectoryExecutor( + trajectory=trajectory, toolbox=mock_toolbox, delay_time=0 + ) + + result = executor.execute_next_step() + + # Should execute screenshot tool + assert result.status == "SUCCESS" + assert result.step_index == 0 # First step executed + assert mock_toolbox.run.call_count == 1 + # Verify screenshot tool was called + assert mock_toolbox.run.call_args[0][0][0].name == "screenshot" + + +def test_trajectory_executor_executes_retrieve_trajectories_tool() -> None: + """Test that retrieve_available_trajectories_tool is executed (not skipped).""" + mock_toolbox = MagicMock(spec=ToolCollection) + mock_toolbox.run.return_value = ["Trajectory list"] + mock_toolbox._tool_map = {} + + trajectory = [ + ToolUseBlockParam( + id="1", + name="retrieve_available_trajectories_tool", + input={}, + type="tool_use", + ), + ToolUseBlockParam(id="2", name="tool1", input={}, type="tool_use"), + ] + + executor = TrajectoryExecutor( + trajectory=trajectory, toolbox=mock_toolbox, delay_time=0 + ) + + result = executor.execute_next_step() + + # Should execute retrieve tool + assert result.status == "SUCCESS" + assert result.step_index == 0 # First step executed + assert mock_toolbox.run.call_count == 1 + # Verify retrieve tool was called + assert mock_toolbox.run.call_args[0][0][0].name == "retrieve_available_trajectories_tool" + + +def test_trajectory_executor_pauses_at_non_cacheable_tool() -> None: + """Test that execution pauses at non-cacheable tools.""" + mock_toolbox = MagicMock(spec=ToolCollection) + mock_toolbox.run.return_value = ["Result"] + + # Create mock tools in the toolbox + cacheable_tool = MagicMock() + cacheable_tool.is_cacheable = True + non_cacheable_tool = MagicMock() + non_cacheable_tool.is_cacheable = False + + mock_toolbox._tool_map = { + "cacheable": cacheable_tool, + "non_cacheable": non_cacheable_tool, + } + + trajectory = [ + ToolUseBlockParam(id="1", name="cacheable", input={}, type="tool_use"), + ToolUseBlockParam(id="2", name="non_cacheable", input={}, type="tool_use"), + ToolUseBlockParam(id="3", name="cacheable", input={}, type="tool_use"), + ] + + executor = TrajectoryExecutor( + trajectory=trajectory, toolbox=mock_toolbox, delay_time=0 + ) + + results = executor.execute_all() + + # Should execute first step, then pause at non-cacheable + assert len(results) == 2 + assert results[0].status == "SUCCESS" + assert results[0].step_index == 0 + assert results[1].status == "NEEDS_AGENT" + assert results[1].step_index == 1 + assert mock_toolbox.run.call_count == 1 # Only first step executed + assert executor.current_step_index == 1 # Paused at step 1 + + +def test_trajectory_executor_handles_tool_error() -> None: + """Test that executor handles tool execution errors gracefully.""" + mock_toolbox = MagicMock(spec=ToolCollection) + mock_toolbox._tool_map = {} + mock_toolbox.run.side_effect = Exception("Tool execution failed") + + trajectory = [ + ToolUseBlockParam(id="1", name="failing_tool", input={}, type="tool_use") + ] + + executor = TrajectoryExecutor( + trajectory=trajectory, toolbox=mock_toolbox, delay_time=0 + ) + + result = executor.execute_next_step() + + assert result.status == "FAILED" + assert result.step_index == 0 + assert "Tool execution failed" in (result.error_message or "") + + +def test_trajectory_executor_substitutes_placeholders() -> None: + """Test that executor substitutes placeholders before execution.""" + captured_steps = [] + + def capture_run(steps): # type: ignore + captured_steps.extend(steps) + return ["Result"] + + mock_toolbox = MagicMock(spec=ToolCollection) + mock_toolbox.run = capture_run + mock_toolbox._tool_map = {} + + trajectory = [ + ToolUseBlockParam( + id="1", + name="test_tool", + input={"text": "Hello {{name}}"}, + type="tool_use", + ) + ] + + executor = TrajectoryExecutor( + trajectory=trajectory, + toolbox=mock_toolbox, + placeholder_values={"name": "Alice"}, + delay_time=0, + ) + + result = executor.execute_next_step() + + assert result.status == "SUCCESS" + assert len(captured_steps) == 1 + assert captured_steps[0].input["text"] == "Hello Alice" + + +def test_trajectory_executor_get_current_step_index() -> None: + """Test getting current step index.""" + toolbox = ToolCollection() + trajectory = [ToolUseBlockParam(id="1", name="tool1", input={}, type="tool_use")] + + executor = TrajectoryExecutor(trajectory=trajectory, toolbox=toolbox, delay_time=0) + + assert executor.get_current_step_index() == 0 + + executor.current_step_index = 5 + assert executor.get_current_step_index() == 5 + + +def test_trajectory_executor_get_remaining_trajectory() -> None: + """Test getting remaining trajectory steps.""" + toolbox = ToolCollection() + trajectory = [ + ToolUseBlockParam(id="1", name="tool1", input={}, type="tool_use"), + ToolUseBlockParam(id="2", name="tool2", input={}, type="tool_use"), + ToolUseBlockParam(id="3", name="tool3", input={}, type="tool_use"), + ] + + executor = TrajectoryExecutor(trajectory=trajectory, toolbox=toolbox, delay_time=0) + + # Initially, all steps remain + remaining = executor.get_remaining_trajectory() + assert len(remaining) == 3 + + # After advancing to step 1 + executor.current_step_index = 1 + remaining = executor.get_remaining_trajectory() + assert len(remaining) == 2 + assert remaining[0].id == "2" + assert remaining[1].id == "3" + + # At the end + executor.current_step_index = 3 + remaining = executor.get_remaining_trajectory() + assert len(remaining) == 0 + + +def test_trajectory_executor_skip_current_step() -> None: + """Test skipping the current step.""" + toolbox = ToolCollection() + trajectory = [ + ToolUseBlockParam(id="1", name="tool1", input={}, type="tool_use"), + ToolUseBlockParam(id="2", name="tool2", input={}, type="tool_use"), + ] + + executor = TrajectoryExecutor(trajectory=trajectory, toolbox=toolbox, delay_time=0) + + assert executor.current_step_index == 0 + executor.skip_current_step() + assert executor.current_step_index == 1 + executor.skip_current_step() + assert executor.current_step_index == 2 + + +def test_trajectory_executor_skip_at_end_does_nothing() -> None: + """Test that skipping at the end doesn't cause errors.""" + toolbox = ToolCollection() + trajectory = [ToolUseBlockParam(id="1", name="tool1", input={}, type="tool_use")] + + executor = TrajectoryExecutor(trajectory=trajectory, toolbox=toolbox, delay_time=0) + + executor.current_step_index = 1 # Already at end + executor.skip_current_step() + assert executor.current_step_index == 1 # Stays at end + + +def test_trajectory_executor_completed_status_when_done() -> None: + """Test that executor returns COMPLETED when all steps are done.""" + mock_toolbox = MagicMock(spec=ToolCollection) + mock_toolbox.run.return_value = ["Result"] + mock_toolbox._tool_map = {} + + trajectory = [ToolUseBlockParam(id="1", name="tool1", input={}, type="tool_use")] + + executor = TrajectoryExecutor( + trajectory=trajectory, toolbox=mock_toolbox, delay_time=0 + ) + + # Execute the step + result1 = executor.execute_next_step() + assert result1.status == "SUCCESS" + + # Try to execute again (no more steps) + result2 = executor.execute_next_step() + assert result2.status == "COMPLETED" + + +def test_trajectory_executor_execute_all_stops_on_failure() -> None: + """Test that execute_all stops when a step fails.""" + # Mock to fail on second call + call_count = [0] + + def mock_run(steps): # type: ignore + call_count[0] += 1 + if call_count[0] == 2: + raise Exception("Second call fails") + return ["Result"] + + mock_toolbox = MagicMock(spec=ToolCollection) + mock_toolbox.run = mock_run + mock_toolbox._tool_map = {} + + trajectory = [ + ToolUseBlockParam(id="1", name="tool1", input={}, type="tool_use"), + ToolUseBlockParam(id="2", name="tool1", input={}, type="tool_use"), + ToolUseBlockParam(id="3", name="tool1", input={}, type="tool_use"), + ] + + executor = TrajectoryExecutor( + trajectory=trajectory, toolbox=mock_toolbox, delay_time=0 + ) + + results = executor.execute_all() + + # Should have 2 results: 1 success + 1 failure (stopped) + assert len(results) == 2 + assert results[0].status == "SUCCESS" + assert results[1].status == "FAILED" + assert executor.current_step_index == 1 # Stopped at failed step + + +def test_trajectory_executor_builds_message_history() -> None: + """Test that executor builds message history during execution.""" + mock_toolbox = MagicMock(spec=ToolCollection) + mock_toolbox.run.return_value = ["Result1"] + mock_toolbox._tool_map = {} + + trajectory = [ + ToolUseBlockParam(id="tool1", name="test_tool", input={"x": 100}, type="tool_use") + ] + + executor = TrajectoryExecutor( + trajectory=trajectory, toolbox=mock_toolbox, delay_time=0 + ) + + result = executor.execute_next_step() + + # Verify message history is built + assert result.status == "SUCCESS" + assert len(result.message_history) == 2 # Assistant + User message + + # Verify assistant message (tool use) + assert isinstance(result.message_history[0], MessageParam) + assert result.message_history[0].role == "assistant" + assert isinstance(result.message_history[0].content, list) + assert len(result.message_history[0].content) == 1 + assert isinstance(result.message_history[0].content[0], ToolUseBlockParam) + + # Verify user message (tool result) + assert isinstance(result.message_history[1], MessageParam) + assert result.message_history[1].role == "user" + assert isinstance(result.message_history[1].content, list) + assert len(result.message_history[1].content) == 1 + assert isinstance(result.message_history[1].content[0], ToolResultBlockParam) + + +def test_trajectory_executor_message_history_accumulates() -> None: + """Test that message history accumulates across multiple steps.""" + mock_toolbox = MagicMock(spec=ToolCollection) + mock_toolbox.run.return_value = ["Result"] + mock_toolbox._tool_map = {} + + trajectory = [ + ToolUseBlockParam(id="1", name="tool1", input={}, type="tool_use"), + ToolUseBlockParam(id="2", name="tool2", input={}, type="tool_use"), + ToolUseBlockParam(id="3", name="tool3", input={}, type="tool_use"), + ] + + executor = TrajectoryExecutor( + trajectory=trajectory, toolbox=mock_toolbox, delay_time=0 + ) + + results = executor.execute_all() + + # Should have 4 results: 3 successful steps + 1 completed + assert len(results) == 4 + + # Check message history grows + assert len(results[0].message_history) == 2 # Step 1: assistant + user + assert len(results[1].message_history) == 4 # Step 2: + assistant + user + assert len(results[2].message_history) == 6 # Step 3: + assistant + user + assert len(results[3].message_history) == 6 # Completed: same as last step + + +def test_trajectory_executor_message_history_contains_tool_use_id() -> None: + """Test that tool result has correct tool_use_id matching tool use.""" + mock_toolbox = MagicMock(spec=ToolCollection) + mock_toolbox.run.return_value = ["Success"] + mock_toolbox._tool_map = {} + + trajectory = [ + ToolUseBlockParam(id="unique_id_123", name="tool", input={}, type="tool_use") + ] + + executor = TrajectoryExecutor( + trajectory=trajectory, toolbox=mock_toolbox, delay_time=0 + ) + + result = executor.execute_next_step() + + # Get tool use and tool result + tool_use = result.message_history[0].content[0] # type: ignore + tool_result = result.message_history[1].content[0] # type: ignore + + # Verify tool_use_id matches + assert isinstance(tool_use, ToolUseBlockParam) + assert isinstance(tool_result, ToolResultBlockParam) + assert tool_result.tool_use_id == tool_use.id + assert tool_result.tool_use_id == "unique_id_123" + + +def test_trajectory_executor_message_history_includes_text_result() -> None: + """Test that tool results include text content.""" + mock_toolbox = MagicMock(spec=ToolCollection) + mock_toolbox.run.return_value = ["Tool executed successfully"] + mock_toolbox._tool_map = {} + + trajectory = [ + ToolUseBlockParam(id="1", name="tool", input={}, type="tool_use") + ] + + executor = TrajectoryExecutor( + trajectory=trajectory, toolbox=mock_toolbox, delay_time=0 + ) + + result = executor.execute_next_step() + + # Get tool result + tool_result_block = result.message_history[1].content[0] # type: ignore + assert isinstance(tool_result_block, ToolResultBlockParam) + + # Verify text content is included + assert isinstance(tool_result_block.content, list) + assert len(tool_result_block.content) == 1 + assert isinstance(tool_result_block.content[0], TextBlockParam) + assert tool_result_block.content[0].text == "Tool executed successfully" + + +def test_trajectory_executor_message_history_on_failure() -> None: + """Test that message history is included even when execution fails.""" + mock_toolbox = MagicMock(spec=ToolCollection) + mock_toolbox._tool_map = {} + mock_toolbox.run.side_effect = Exception("Tool failed") + + trajectory = [ + ToolUseBlockParam(id="1", name="failing_tool", input={}, type="tool_use") + ] + + executor = TrajectoryExecutor( + trajectory=trajectory, toolbox=mock_toolbox, delay_time=0 + ) + + result = executor.execute_next_step() + + assert result.status == "FAILED" + # Message history should include the assistant message (tool use) + # but not the user message (since execution failed) + assert len(result.message_history) == 1 + assert result.message_history[0].role == "assistant" + + +def test_trajectory_executor_message_history_on_pause() -> None: + """Test that message history is included when execution pauses.""" + mock_toolbox = MagicMock(spec=ToolCollection) + mock_toolbox.run.return_value = ["Result"] + + cacheable_tool = MagicMock() + cacheable_tool.is_cacheable = True + non_cacheable_tool = MagicMock() + non_cacheable_tool.is_cacheable = False + + mock_toolbox._tool_map = { + "cacheable": cacheable_tool, + "non_cacheable": non_cacheable_tool, + } + + trajectory = [ + ToolUseBlockParam(id="1", name="cacheable", input={}, type="tool_use"), + ToolUseBlockParam(id="2", name="non_cacheable", input={}, type="tool_use"), + ] + + executor = TrajectoryExecutor( + trajectory=trajectory, toolbox=mock_toolbox, delay_time=0 + ) + + results = executor.execute_all() + + # Should pause at non-cacheable step + assert len(results) == 2 + assert results[0].status == "SUCCESS" + assert results[1].status == "NEEDS_AGENT" + + # Message history should include first step's execution + assert len(results[1].message_history) == 2 # First step: assistant + user + + +def test_trajectory_executor_message_history_order() -> None: + """Test that message history maintains correct order.""" + mock_toolbox = MagicMock(spec=ToolCollection) + mock_toolbox.run.return_value = ["Result"] + mock_toolbox._tool_map = {} + + trajectory = [ + ToolUseBlockParam(id="1", name="tool1", input={"step": 1}, type="tool_use"), + ToolUseBlockParam(id="2", name="tool2", input={"step": 2}, type="tool_use"), + ] + + executor = TrajectoryExecutor( + trajectory=trajectory, toolbox=mock_toolbox, delay_time=0 + ) + + results = executor.execute_all() + + # Get final message history + final_history = results[-1].message_history + + # Should have 4 messages: assistant, user, assistant, user + assert len(final_history) == 4 + assert final_history[0].role == "assistant" # Step 1 tool use + assert final_history[1].role == "user" # Step 1 result + assert final_history[2].role == "assistant" # Step 2 tool use + assert final_history[3].role == "user" # Step 2 result + + # Verify step order in tool use + tool_use_1 = final_history[0].content[0] # type: ignore + tool_use_2 = final_history[2].content[0] # type: ignore + assert isinstance(tool_use_1, ToolUseBlockParam) + assert isinstance(tool_use_2, ToolUseBlockParam) + assert tool_use_1.input == {"step": 1} # type: ignore + assert tool_use_2.input == {"step": 2} # type: ignore + + +# Visual Validation Extension Point Tests + + +def test_visual_validation_disabled_by_default() -> None: + """Test that visual validation is disabled by default.""" + mock_toolbox = MagicMock(spec=ToolCollection) + mock_toolbox.run.return_value = ["Result"] + mock_toolbox._tool_map = {} + + trajectory = [ + ToolUseBlockParam( + id="1", + name="click", + input={"x": 100}, + type="tool_use", + visual_hash="abc123", + visual_validation_required=True, + ), + ] + + executor = TrajectoryExecutor( + trajectory=trajectory, toolbox=mock_toolbox, delay_time=0 + ) + + # visual_validation_enabled should be False by default + assert executor.visual_validation_enabled is False + + # Should execute successfully without validation + results = executor.execute_all() + assert results[0].status == "SUCCESS" + + +def test_visual_validation_enabled_flag() -> None: + """Test that visual validation flag can be enabled.""" + mock_toolbox = MagicMock(spec=ToolCollection) + mock_toolbox.run.return_value = ["Result"] + mock_toolbox._tool_map = {} + + trajectory = [ + ToolUseBlockParam(id="1", name="click", input={"x": 100}, type="tool_use"), + ] + + executor = TrajectoryExecutor( + trajectory=trajectory, + toolbox=mock_toolbox, + delay_time=0, + visual_validation_enabled=True, + ) + + # visual_validation_enabled should be True + assert executor.visual_validation_enabled is True + + +def test_validate_step_visually_hook_exists() -> None: + """Test that validate_step_visually hook exists and returns correct signature.""" + mock_toolbox = MagicMock(spec=ToolCollection) + mock_toolbox._tool_map = {} + + trajectory = [ + ToolUseBlockParam(id="1", name="click", input={"x": 100}, type="tool_use"), + ] + + executor = TrajectoryExecutor( + trajectory=trajectory, toolbox=mock_toolbox, delay_time=0 + ) + + # Hook should exist + assert hasattr(executor, "validate_step_visually") + assert callable(executor.validate_step_visually) + + # Hook should return correct signature + step = trajectory[0] + is_valid, error_msg = executor.validate_step_visually(step) + + assert isinstance(is_valid, bool) + assert is_valid is True # Currently always returns True + assert error_msg is None + + +def test_validate_step_visually_always_passes_when_disabled() -> None: + """Test that validation always passes when disabled (default behavior).""" + mock_toolbox = MagicMock(spec=ToolCollection) + mock_toolbox.run.return_value = ["Result"] + mock_toolbox._tool_map = {} + + # Create step with visual validation fields + trajectory = [ + ToolUseBlockParam( + id="1", + name="click", + input={"x": 100}, + type="tool_use", + visual_hash="abc123", + visual_validation_required=True, + ), + ] + + # Validation disabled (default) + executor = TrajectoryExecutor( + trajectory=trajectory, + toolbox=mock_toolbox, + delay_time=0, + visual_validation_enabled=False, + ) + + # Should execute without calling validate_step_visually + results = executor.execute_all() + assert results[0].status == "SUCCESS" + + +def test_validate_step_visually_hook_called_when_enabled() -> None: + """Test that validate_step_visually is called when enabled.""" + mock_toolbox = MagicMock(spec=ToolCollection) + mock_toolbox.run.return_value = ["Result"] + mock_toolbox._tool_map = {} + + trajectory = [ + ToolUseBlockParam( + id="1", + name="click", + input={"x": 100}, + type="tool_use", + visual_hash="abc123", + visual_validation_required=True, + ), + ] + + executor = TrajectoryExecutor( + trajectory=trajectory, + toolbox=mock_toolbox, + delay_time=0, + visual_validation_enabled=True, + ) + + # Mock the validation hook to track calls + original_validate = executor.validate_step_visually + validation_called = [] + + def mock_validate(step, screenshot=None): + validation_called.append(step) + return original_validate(step, screenshot) + + executor.validate_step_visually = mock_validate + + # Execute trajectory + results = executor.execute_all() + + # Validation should have been called + assert len(validation_called) == 1 + assert results[0].status == "SUCCESS" + + +@pytest.mark.skip(reason="Visual validation fields not yet implemented - future feature") +def test_visual_validation_fields_on_tool_use_block() -> None: + """Test that ToolUseBlockParam supports visual validation fields. + + Note: This test is for future functionality. Visual validation fields + (visual_hash, visual_validation_required) are planned but not yet + implemented in the ToolUseBlockParam model. + """ + # Create step with visual validation fields + step = ToolUseBlockParam( + id="1", + name="click", + input={"x": 100, "y": 200}, + type="tool_use", + visual_hash="a8f3c9e14b7d2056", + visual_validation_required=True, + ) + + # Fields should be accessible + assert step.visual_hash == "a8f3c9e14b7d2056" + assert step.visual_validation_required is True + + # Default values should work + step_default = ToolUseBlockParam( + id="2", name="type", input={"text": "hello"}, type="tool_use" + ) + + assert step_default.visual_hash is None + assert step_default.visual_validation_required is False From 93564dc8f69faac859b51dfd35c3629dd15c3725 Mon Sep 17 00:00:00 2001 From: philipph-askui Date: Fri, 12 Dec 2025 16:02:49 +0100 Subject: [PATCH 02/30] chore(caching): add more configurability and fixes some tests --- docs/caching.md | 23 +++++--- src/askui/agent_base.py | 2 +- src/askui/models/shared/settings.py | 9 +++- src/askui/tools/caching_tools.py | 8 +-- src/askui/utils/cache_writer.py | 32 +++++++---- tests/unit/tools/test_caching_tools.py | 74 +++++++++++++------------- 6 files changed, 85 insertions(+), 63 deletions(-) diff --git a/docs/caching.md b/docs/caching.md index d5bafe86..2fb78fdb 100644 --- a/docs/caching.md +++ b/docs/caching.md @@ -24,13 +24,15 @@ The caching mechanism supports four strategies, configured via the `caching_sett Caching is configured using the `CachingSettings` class: ```python -from askui.models.shared.settings import CachingSettings, CachedExecutionToolSettings +from askui.models.shared.settings import CachingSettings, CachedExecutionToolSettings, CacheWriterSettings caching_settings = CachingSettings( strategy="write", # One of: "read", "write", "both", "no" cache_dir=".cache", # Directory to store cache files filename="my_test.json", # Filename for the cache file (optional for write mode) - auto_identify_placeholders=True, # Auto-detect dynamic values (default: True) + cache_writer_settings=CacheWriterSettings( + placeholder_identification_strategy="llm", + ) # Auto-detect dynamic values (default: "llm") execute_cached_trajectory_tool_settings=CachedExecutionToolSettings( delay_time_between_action=0.5 # Delay in seconds between each cached action ) @@ -42,9 +44,14 @@ caching_settings = CachingSettings( - **`strategy`**: The caching strategy to use (`"read"`, `"write"`, `"both"`, or `"no"`). - **`cache_dir`**: Directory where cache files are stored. Defaults to `".cache"`. - **`filename`**: Name of the cache file to write to or read from. If not specified in write mode, a timestamped filename will be generated automatically (format: `cached_trajectory_YYYYMMDDHHMMSSffffff.json`). -- **`auto_identify_placeholders`**: **New in v0.1!** When `True` (default), uses AI to automatically identify and parameterize dynamic values like dates, usernames, and IDs during cache recording. When `False`, only manually specified placeholders (using `{{...}}` syntax) are detected. See [Automatic Placeholder Identification](#automatic-placeholder-identification). +- **`CacheWriterSettings`**: **New in v0.1!** Configuration for the Cache Writer See [CacheWriter Settings](#cachewriter-settings) below. - **`execute_cached_trajectory_tool_settings`**: Configuration for the trajectory execution tool (optional). See [Execution Settings](#execution-settings) below. +### CacheWriter Settings + +- `placeholder_identification_strategy`: When `llm` (default), uses AI to automatically identify and parameterize dynamic values like dates, usernames, and IDs during cache recording. When `preset`, only manually specified placeholders (using `{{...}}` syntax) are detected. See [Automatic Placeholder Identification](#automatic-placeholder-identification). +- `llm_placeholder_id_api_provider`: The provider of that will be used for for the llm in the placeholder identification (will only be used if `placeholder_identification_strategy`is set to `llm`). Defaults to `askui`. + ### Execution Settings The `CachedExecutionToolSettings` class allows you to configure how cached trajectories are executed: @@ -466,7 +473,7 @@ In write mode, the `CacheWriter` class: 2. Extracts tool use blocks from the messages 3. Stores tool blocks in memory during execution 4. When agent finishes (on `stop_reason="end_turn"`): - - **Automatically identifies placeholders** using AI (if `auto_identify_placeholders=True`) + - **Automatically identifies placeholders** using AI (if `placeholder_identification_strategy=llm`) - Analyzes trajectory to find dynamic values (dates, usernames, IDs, etc.) - Generates descriptive placeholder definitions - Replaces identified values with `{{placeholder_name}}` syntax in trajectory @@ -636,7 +643,7 @@ Valid placeholder names: #### How It Works -When `auto_identify_placeholders=True` (the default), the system: +When `placeholder_identification_strategy=llm` (the default), the system: 1. **Records the trajectory** as normal during agent execution 2. **Analyzes the trajectory** using an LLM to identify dynamic values such as: @@ -693,11 +700,13 @@ If you prefer manual placeholder control: ```python caching_settings = CachingSettings( strategy="write", - auto_identify_placeholders=False # Only detect {{...}} syntax + cache_writer_settings = CacheWriterSettings( + placeholder_identification_strategy="default" # Only detect {{...}} syntax + ) ) ``` -With `auto_identify_placeholders=False`, only manually specified placeholders using the `{{...}}` syntax will be detected. +With `placeholder_identification_strategy=default`, only manually specified placeholders using the `{{...}}` syntax will be detected. #### Logging diff --git a/src/askui/agent_base.py b/src/askui/agent_base.py index 4ba50170..db5b5d72 100644 --- a/src/askui/agent_base.py +++ b/src/askui/agent_base.py @@ -395,7 +395,7 @@ def _patch_act_with_cache( cache_writer = CacheWriter( cache_dir=caching_settings.cache_dir, file_name=caching_settings.filename, - caching_settings=caching_settings, + cache_writer_settings=caching_settings.cache_writer_settings, toolbox=toolbox, goal=goal, ) diff --git a/src/askui/models/shared/settings.py b/src/askui/models/shared/settings.py index 6265fa9a..fc8f3695 100644 --- a/src/askui/models/shared/settings.py +++ b/src/askui/models/shared/settings.py @@ -11,12 +11,14 @@ from pydantic import BaseModel, ConfigDict, Field from typing_extensions import Literal +from askui.models.anthropic.factory import AnthropicApiProvider from askui.models.shared.agent_message_param import ToolUseBlockParam COMPUTER_USE_20250124_BETA_FLAG = "computer-use-2025-01-24" COMPUTER_USE_20251124_BETA_FLAG = "computer-use-2025-11-24" CACHING_STRATEGY = Literal["read", "write", "both", "no"] +PLACEHOLDER_IDENTIFICATION_STRATEGY = Literal["llm", "preset"] class MessageSettings(BaseModel): @@ -40,14 +42,19 @@ class CachedExecutionToolSettings(BaseModel): delay_time_between_action: float = 0.5 +class CacheWriterSettings(BaseModel): + placeholder_identification_strategy: PLACEHOLDER_IDENTIFICATION_STRATEGY = "llm" + llm_placeholder_id_api_provider: AnthropicApiProvider = "askui" + + class CachingSettings(BaseModel): strategy: CACHING_STRATEGY = "no" cache_dir: str = ".cache" filename: str = "" - auto_identify_placeholders: bool = True execute_cached_trajectory_tool_settings: CachedExecutionToolSettings = ( CachedExecutionToolSettings() ) + cache_writer_settings: CacheWriterSettings = CacheWriterSettings() class CacheFailure(BaseModel): diff --git a/src/askui/tools/caching_tools.py b/src/askui/tools/caching_tools.py index 764a09b2..c6867e69 100644 --- a/src/askui/tools/caching_tools.py +++ b/src/askui/tools/caching_tools.py @@ -850,13 +850,7 @@ def __call__(self, success: bool, verification_notes: str) -> str: # Update cache metadata based on verification result if success: self._agent.update_cache_metadata_on_completion(success=True) - result_msg = ( - f"✓ Cache verification successful: {verification_notes}\n\n" - "The cached trajectory execution achieved the target " - "system state correctly. " - "You may now proceed with any additional tasks or " - "conclude the execution." - ) + result_msg = f"✓ Cache verification successful: {verification_notes}" logger.info(result_msg) else: error_msg = ( diff --git a/src/askui/utils/cache_writer.py b/src/askui/utils/cache_writer.py index 835f5878..905caccf 100644 --- a/src/askui/utils/cache_writer.py +++ b/src/askui/utils/cache_writer.py @@ -8,7 +8,11 @@ from askui.models.model_router import create_api_client from askui.models.shared.agent_message_param import MessageParam, ToolUseBlockParam from askui.models.shared.agent_on_message_cb import OnMessageCbParam -from askui.models.shared.settings import CacheFile, CacheMetadata, CachingSettings +from askui.models.shared.settings import ( + CacheFile, + CacheMetadata, + CacheWriterSettings, +) from askui.models.shared.tools import ToolCollection from askui.utils.placeholder_handler import PlaceholderHandler from askui.utils.placeholder_identifier import identify_placeholders @@ -21,7 +25,7 @@ def __init__( self, cache_dir: str = ".cache", file_name: str = "", - caching_settings: CachingSettings | None = None, + cache_writer_settings: CacheWriterSettings | None = None, toolbox: ToolCollection | None = None, goal: str | None = None, ) -> None: @@ -32,14 +36,10 @@ def __init__( file_name += ".json" self.file_name = file_name self.was_cached_execution = False - self._caching_settings = caching_settings or CachingSettings() + self._cache_writer_settings = cache_writer_settings or CacheWriterSettings() self._goal = goal self._toolbox: ToolCollection | None = None - # Get messages_api for placeholder identification - self._messages_api = AnthropicMessagesApi( - client=create_api_client(api_provider="askui"), - locator_serializer=VlmLocatorSerializer(), - ) + # Set toolbox for cache writer so it can check which tools are cacheable self._toolbox = toolbox @@ -106,10 +106,20 @@ def _replace_placeholders( goal_to_save = self._goal placeholders_dict: dict[str, str] = {} - if self._caching_settings.auto_identify_placeholders and self.messages: + if ( + self._cache_writer_settings.placeholder_identification_strategy == "llm" + and self.messages + ): + # Get messages_api for placeholder identification + messages_api = AnthropicMessagesApi( + client=create_api_client( + self._cache_writer_settings.llm_placeholder_id_api_provider + ), + locator_serializer=VlmLocatorSerializer(), + ) placeholders_dict, placeholder_definitions = identify_placeholders( trajectory=self.messages, - messages_api=self._messages_api, + messages_api=messages_api, ) n_placeholders = len(placeholder_definitions) # Replace actual values with {{placeholder_name}} syntax in trajectory @@ -169,7 +179,7 @@ def _blank_non_cacheable_tool_inputs( result: list[ToolUseBlockParam] = [] for tool_block in trajectory: # Check if this tool is cacheable - tool = self._toolbox._tool_map.get(tool_block.name) + tool = self._toolbox.get_tools().get(tool_block.name) # If tool is not cacheable, blank out its input if tool is not None and not tool.is_cacheable: diff --git a/tests/unit/tools/test_caching_tools.py b/tests/unit/tools/test_caching_tools.py index 11fbcfe5..c105a95f 100644 --- a/tests/unit/tools/test_caching_tools.py +++ b/tests/unit/tools/test_caching_tools.py @@ -213,16 +213,17 @@ def test_retrieve_caches_includes_invalid_when_requested(tmp_path): # ============================================================================ -def test_execute_cached_execution_initializes_without_toolbox() -> None: - """Test that ExecuteCachedTrajectory can be initialized without toolbox.""" - tool = ExecuteCachedTrajectory() +def test_execute_cached_execution_initializes_with_toolbox() -> None: + """Test that ExecuteCachedTrajectory can be initialized with toolbox.""" + mock_toolbox = MagicMock(spec=ToolCollection) + tool = ExecuteCachedTrajectory(toolbox=mock_toolbox) assert tool.name == "execute_cached_executions_tool" - assert tool._toolbox is None # noqa: SLF001 + assert tool._toolbox is mock_toolbox # noqa: SLF001 assert tool._agent is None # noqa: SLF001 -def test_execute_cached_execution_raises_error_without_toolbox_or_agent() -> None: - """Test that ExecuteCachedTrajectory raises error when neither toolbox nor agent set.""" +def test_execute_cached_execution_raises_error_without_agent() -> None: + """Test that ExecuteCachedTrajectory raises error when agent not set.""" with tempfile.TemporaryDirectory() as temp_dir: cache_file = Path(temp_dir) / "test.json" cache_data = { @@ -240,7 +241,8 @@ def test_execute_cached_execution_raises_error_without_toolbox_or_agent() -> Non } cache_file.write_text(json.dumps(cache_data), encoding="utf-8") - tool = ExecuteCachedTrajectory() + mock_toolbox = MagicMock(spec=ToolCollection) + tool = ExecuteCachedTrajectory(toolbox=mock_toolbox) with pytest.raises(RuntimeError, match="Agent not set"): tool(trajectory_file=str(cache_file)) @@ -248,9 +250,10 @@ def test_execute_cached_execution_raises_error_without_toolbox_or_agent() -> Non def test_execute_cached_execution_returns_error_when_file_not_found() -> None: """Test that ExecuteCachedTrajectory returns error message if file doesn't exist.""" - tool = ExecuteCachedTrajectory() + mock_toolbox = MagicMock(spec=ToolCollection) + tool = ExecuteCachedTrajectory(toolbox=mock_toolbox) mock_agent = MagicMock(spec=Agent) - mock_agent._tool_collection = MagicMock(spec=ToolCollection) + mock_agent._tool_collection = mock_toolbox tool.set_agent(mock_agent) result = tool(trajectory_file="/non/existent/file.json") @@ -304,7 +307,7 @@ def test_execute_cached_execution_activates_cache_mode() -> None: mock_agent._tool_collection = mock_toolbox # noqa: SLF001 # Create and configure tool - tool = ExecuteCachedTrajectory() + tool = ExecuteCachedTrajectory(toolbox=mock_toolbox) tool.set_agent(mock_agent) # Call the tool @@ -315,15 +318,14 @@ def test_execute_cached_execution_activates_cache_mode() -> None: assert "✓ Cache execution mode activated" in result assert "2 cached steps" in result - # Verify agent state was set - assert mock_agent._executing_from_cache is True # noqa: SLF001 - assert mock_agent._cache_executor is not None # noqa: SLF001 - assert mock_agent._cache_file is not None # noqa: SLF001 - assert mock_agent._cache_file_path == str(cache_file) # noqa: SLF001 + # Verify cache info was set using public API + cache_file_obj, cache_file_path = mock_agent.get_cache_info() + assert cache_file_obj is not None + assert cache_file_path == str(cache_file) -def test_execute_cached_execution_works_with_set_toolbox() -> None: - """Test that ExecuteCachedTrajectory works with set_toolbox (legacy approach).""" +def test_execute_cached_execution_works_with_toolbox() -> None: + """Test that ExecuteCachedTrajectory works with toolbox provided.""" with tempfile.TemporaryDirectory() as temp_dir: cache_file = Path(temp_dir) / "test_trajectory.json" @@ -351,15 +353,14 @@ def test_execute_cached_execution_works_with_set_toolbox() -> None: with cache_file.open("w", encoding="utf-8") as f: json.dump(cache_data, f) - # Create mock agent without toolbox + # Create mock agent mock_messages_api = MagicMock(spec=MessagesApi) mock_agent = Agent(messages_api=mock_messages_api) - # Create tool and set toolbox directly - tool = ExecuteCachedTrajectory() + # Create tool with toolbox mock_toolbox = MagicMock(spec=ToolCollection) mock_toolbox._tool_map = {} - tool.set_toolbox(mock_toolbox) + tool = ExecuteCachedTrajectory(toolbox=mock_toolbox) tool.set_agent(mock_agent) result = tool(trajectory_file=str(cache_file)) @@ -369,14 +370,13 @@ def test_execute_cached_execution_works_with_set_toolbox() -> None: assert "✓ Cache execution mode activated" in result -def test_execute_cached_execution_set_agent_and_toolbox() -> None: - """Test that set_agent and set_toolbox properly set references.""" - tool = ExecuteCachedTrajectory() - mock_agent = MagicMock(spec=Agent) +def test_execute_cached_execution_set_agent() -> None: + """Test that set_agent properly sets reference.""" mock_toolbox = MagicMock(spec=ToolCollection) + tool = ExecuteCachedTrajectory(toolbox=mock_toolbox) + mock_agent = MagicMock(spec=Agent) tool.set_agent(mock_agent) - tool.set_toolbox(mock_toolbox) assert tool._agent == mock_agent # noqa: SLF001 assert tool._toolbox == mock_toolbox # noqa: SLF001 @@ -384,7 +384,8 @@ def test_execute_cached_execution_set_agent_and_toolbox() -> None: def test_execute_cached_execution_initializes_with_default_settings() -> None: """Test that ExecuteCachedTrajectory uses default settings when none provided.""" - tool = ExecuteCachedTrajectory() + mock_toolbox = MagicMock(spec=ToolCollection) + tool = ExecuteCachedTrajectory(toolbox=mock_toolbox) # Should have default settings initialized assert hasattr(tool, "_settings") @@ -393,8 +394,9 @@ def test_execute_cached_execution_initializes_with_default_settings() -> None: def test_execute_cached_execution_initializes_with_custom_settings() -> None: """Test that ExecuteCachedTrajectory accepts custom settings.""" + mock_toolbox = MagicMock(spec=ToolCollection) custom_settings = CachedExecutionToolSettings(delay_time_between_action=1.0) - tool = ExecuteCachedTrajectory(settings=custom_settings) + tool = ExecuteCachedTrajectory(toolbox=mock_toolbox, settings=custom_settings) # Should have custom settings initialized assert hasattr(tool, "_settings") @@ -440,7 +442,7 @@ def test_execute_cached_execution_with_placeholders() -> None: mock_toolbox._tool_map = {} mock_agent._tool_collection = mock_toolbox # noqa: SLF001 - tool = ExecuteCachedTrajectory() + tool = ExecuteCachedTrajectory(toolbox=mock_toolbox) tool.set_agent(mock_agent) result = tool( @@ -493,7 +495,7 @@ def test_execute_cached_execution_missing_placeholders() -> None: mock_toolbox = MagicMock(spec=ToolCollection) mock_agent._tool_collection = mock_toolbox # noqa: SLF001 - tool = ExecuteCachedTrajectory() + tool = ExecuteCachedTrajectory(toolbox=mock_toolbox) tool.set_agent(mock_agent) result = tool(trajectory_file=str(cache_file)) @@ -530,7 +532,7 @@ def test_execute_cached_execution_no_placeholders_backward_compat() -> None: mock_toolbox._tool_map = {} mock_agent._tool_collection = mock_toolbox # noqa: SLF001 - tool = ExecuteCachedTrajectory() + tool = ExecuteCachedTrajectory(toolbox=mock_toolbox) tool.set_agent(mock_agent) result = tool(trajectory_file=str(cache_file)) @@ -576,7 +578,7 @@ def test_continue_cached_trajectory_from_middle() -> None: mock_toolbox._tool_map = {} mock_agent._tool_collection = mock_toolbox # noqa: SLF001 - tool = ExecuteCachedTrajectory() + tool = ExecuteCachedTrajectory(toolbox=mock_toolbox) tool.set_agent(mock_agent) result = tool(trajectory_file=str(cache_file), start_from_step_index=2) @@ -618,7 +620,7 @@ def test_continue_cached_trajectory_invalid_step_index_negative() -> None: mock_toolbox._tool_map = {} mock_agent._tool_collection = mock_toolbox # noqa: SLF001 - tool = ExecuteCachedTrajectory() + tool = ExecuteCachedTrajectory(toolbox=mock_toolbox) tool.set_agent(mock_agent) result = tool(trajectory_file=str(cache_file), start_from_step_index=-1) @@ -659,7 +661,7 @@ def test_continue_cached_trajectory_invalid_step_index_too_large() -> None: mock_toolbox._tool_map = {} mock_agent._tool_collection = mock_toolbox # noqa: SLF001 - tool = ExecuteCachedTrajectory() + tool = ExecuteCachedTrajectory(toolbox=mock_toolbox) tool.set_agent(mock_agent) result = tool(trajectory_file=str(cache_file), start_from_step_index=5) @@ -722,7 +724,7 @@ def test_continue_cached_trajectory_with_placeholders() -> None: mock_toolbox._tool_map = {} mock_agent._tool_collection = mock_toolbox # noqa: SLF001 - tool = ExecuteCachedTrajectory() + tool = ExecuteCachedTrajectory(toolbox=mock_toolbox) tool.set_agent(mock_agent) result = tool( @@ -769,7 +771,7 @@ def test_execute_cached_trajectory_warns_if_invalid(tmp_path, caplog): mock_toolbox._tool_map = {} mock_agent._tool_collection = mock_toolbox # noqa: SLF001 - tool = ExecuteCachedTrajectory() + tool = ExecuteCachedTrajectory(toolbox=mock_toolbox) tool.set_agent(mock_agent) result = tool(trajectory_file=str(cache_file)) From a56cf2f0dc782f695bc2aaa97dfa9a562385d48e Mon Sep 17 00:00:00 2001 From: philipph-askui Date: Fri, 12 Dec 2025 16:10:12 +0100 Subject: [PATCH 03/30] fix(caching): fix unit tests for cache writer --- tests/unit/utils/test_cache_writer.py | 42 +++++++++++++++++++++++---- 1 file changed, 36 insertions(+), 6 deletions(-) diff --git a/tests/unit/utils/test_cache_writer.py b/tests/unit/utils/test_cache_writer.py index 333a9725..f418550c 100644 --- a/tests/unit/utils/test_cache_writer.py +++ b/tests/unit/utils/test_cache_writer.py @@ -7,7 +7,7 @@ from askui.models.shared.agent_message_param import MessageParam, ToolUseBlockParam from askui.models.shared.agent_on_message_cb import OnMessageCbParam -from askui.models.shared.settings import CacheFile +from askui.models.shared.settings import CacheFile, CacheWriterSettings from askui.utils.cache_writer import CacheWriter @@ -142,7 +142,13 @@ def test_cache_writer_generate_writes_file() -> None: """Test that generate() writes messages to a JSON file in v0.1 format.""" with tempfile.TemporaryDirectory() as temp_dir: cache_dir = Path(temp_dir) - cache_writer = CacheWriter(cache_dir=str(cache_dir), file_name="output.json") + cache_writer = CacheWriter( + cache_dir=str(cache_dir), + file_name="output.json", + cache_writer_settings=CacheWriterSettings( + placeholder_identification_strategy="preset" + ), + ) # Add some tool use blocks tool_use1 = ToolUseBlockParam( @@ -192,7 +198,13 @@ def test_cache_writer_generate_auto_names_file() -> None: """Test that generate() auto-generates filename if not provided.""" with tempfile.TemporaryDirectory() as temp_dir: cache_dir = Path(temp_dir) - cache_writer = CacheWriter(cache_dir=str(cache_dir), file_name="") + cache_writer = CacheWriter( + cache_dir=str(cache_dir), + file_name="", + cache_writer_settings=CacheWriterSettings( + placeholder_identification_strategy="preset" + ), + ) tool_use = ToolUseBlockParam( id="id1", @@ -360,7 +372,13 @@ def test_cache_writer_generate_resets_after_writing() -> None: """Test that generate() calls reset() after writing the file.""" with tempfile.TemporaryDirectory() as temp_dir: cache_dir = Path(temp_dir) - cache_writer = CacheWriter(cache_dir=str(cache_dir), file_name="test.json") + cache_writer = CacheWriter( + cache_dir=str(cache_dir), + file_name="test.json", + cache_writer_settings=CacheWriterSettings( + placeholder_identification_strategy="preset" + ), + ) cache_writer.messages = [ ToolUseBlockParam( @@ -381,7 +399,13 @@ def test_cache_writer_detects_and_stores_placeholders() -> None: """Test that CacheWriter detects placeholders and stores them in metadata.""" with tempfile.TemporaryDirectory() as temp_dir: cache_dir = Path(temp_dir) - cache_writer = CacheWriter(cache_dir=str(cache_dir), file_name="test.json") + cache_writer = CacheWriter( + cache_dir=str(cache_dir), + file_name="test.json", + cache_writer_settings=CacheWriterSettings( + placeholder_identification_strategy="preset" + ), + ) # Add tool use blocks with placeholders cache_writer.messages = [ @@ -417,7 +441,13 @@ def test_cache_writer_empty_placeholders_when_none_found() -> None: """Test that placeholders dict is empty when no placeholders exist.""" with tempfile.TemporaryDirectory() as temp_dir: cache_dir = Path(temp_dir) - cache_writer = CacheWriter(cache_dir=str(cache_dir), file_name="test.json") + cache_writer = CacheWriter( + cache_dir=str(cache_dir), + file_name="test.json", + cache_writer_settings=CacheWriterSettings( + placeholder_identification_strategy="preset" + ), + ) # Add tool use blocks without placeholders cache_writer.messages = [ From 11c67bb32e3771eda8827b58b853474c2b0168c6 Mon Sep 17 00:00:00 2001 From: philipph-askui Date: Mon, 15 Dec 2025 06:29:39 +0100 Subject: [PATCH 04/30] fix(caching): fix unit tests for trajectory executor --- tests/unit/utils/test_trajectory_executor.py | 130 ++++++++++++++++--- 1 file changed, 110 insertions(+), 20 deletions(-) diff --git a/tests/unit/utils/test_trajectory_executor.py b/tests/unit/utils/test_trajectory_executor.py index 9a7384e9..9bef68b6 100644 --- a/tests/unit/utils/test_trajectory_executor.py +++ b/tests/unit/utils/test_trajectory_executor.py @@ -36,7 +36,12 @@ def test_trajectory_executor_initialization() -> None: def test_trajectory_executor_execute_simple_step() -> None: """Test executing a simple step.""" mock_toolbox = MagicMock(spec=ToolCollection) - mock_toolbox.run.return_value = ["Tool result"] + mock_toolbox.run.return_value = [ + ToolResultBlockParam( + tool_use_id="1", + content=[TextBlockParam(type="text", text="Tool result")], + ) + ] mock_toolbox._tool_map = {} trajectory = [ @@ -61,7 +66,12 @@ def test_trajectory_executor_execute_simple_step() -> None: def test_trajectory_executor_execute_all_steps() -> None: """Test executing all steps in a trajectory.""" mock_toolbox = MagicMock(spec=ToolCollection) - mock_toolbox.run.return_value = ["Result"] + mock_toolbox.run.return_value = [ + ToolResultBlockParam( + tool_use_id="1", + content=[TextBlockParam(type="text", text="Result")], + ) + ] mock_toolbox._tool_map = {} trajectory = [ @@ -89,7 +99,12 @@ def test_trajectory_executor_execute_all_steps() -> None: def test_trajectory_executor_executes_screenshot_tools() -> None: """Test that screenshot tools are executed (not skipped).""" mock_toolbox = MagicMock(spec=ToolCollection) - mock_toolbox.run.return_value = ["Screenshot result"] + mock_toolbox.run.return_value = [ + ToolResultBlockParam( + tool_use_id="1", + content=[TextBlockParam(type="text", text="Screenshot result")], + ) + ] mock_toolbox._tool_map = {} trajectory = [ @@ -114,7 +129,12 @@ def test_trajectory_executor_executes_screenshot_tools() -> None: def test_trajectory_executor_executes_retrieve_trajectories_tool() -> None: """Test that retrieve_available_trajectories_tool is executed (not skipped).""" mock_toolbox = MagicMock(spec=ToolCollection) - mock_toolbox.run.return_value = ["Trajectory list"] + mock_toolbox.run.return_value = [ + ToolResultBlockParam( + tool_use_id="1", + content=[TextBlockParam(type="text", text="Trajectory list")], + ) + ] mock_toolbox._tool_map = {} trajectory = [ @@ -144,7 +164,12 @@ def test_trajectory_executor_executes_retrieve_trajectories_tool() -> None: def test_trajectory_executor_pauses_at_non_cacheable_tool() -> None: """Test that execution pauses at non-cacheable tools.""" mock_toolbox = MagicMock(spec=ToolCollection) - mock_toolbox.run.return_value = ["Result"] + mock_toolbox.run.return_value = [ + ToolResultBlockParam( + tool_use_id="1", + content=[TextBlockParam(type="text", text="Result")], + ) + ] # Create mock tools in the toolbox cacheable_tool = MagicMock() @@ -206,7 +231,12 @@ def test_trajectory_executor_substitutes_placeholders() -> None: def capture_run(steps): # type: ignore captured_steps.extend(steps) - return ["Result"] + return [ + ToolResultBlockParam( + tool_use_id="1", + content=[TextBlockParam(type="text", text="Result")], + ) + ] mock_toolbox = MagicMock(spec=ToolCollection) mock_toolbox.run = capture_run @@ -308,7 +338,12 @@ def test_trajectory_executor_skip_at_end_does_nothing() -> None: def test_trajectory_executor_completed_status_when_done() -> None: """Test that executor returns COMPLETED when all steps are done.""" mock_toolbox = MagicMock(spec=ToolCollection) - mock_toolbox.run.return_value = ["Result"] + mock_toolbox.run.return_value = [ + ToolResultBlockParam( + tool_use_id="1", + content=[TextBlockParam(type="text", text="Result")], + ) + ] mock_toolbox._tool_map = {} trajectory = [ToolUseBlockParam(id="1", name="tool1", input={}, type="tool_use")] @@ -335,7 +370,12 @@ def mock_run(steps): # type: ignore call_count[0] += 1 if call_count[0] == 2: raise Exception("Second call fails") - return ["Result"] + return [ + ToolResultBlockParam( + tool_use_id="1", + content=[TextBlockParam(type="text", text="Result")], + ) + ] mock_toolbox = MagicMock(spec=ToolCollection) mock_toolbox.run = mock_run @@ -363,7 +403,12 @@ def mock_run(steps): # type: ignore def test_trajectory_executor_builds_message_history() -> None: """Test that executor builds message history during execution.""" mock_toolbox = MagicMock(spec=ToolCollection) - mock_toolbox.run.return_value = ["Result1"] + mock_toolbox.run.return_value = [ + ToolResultBlockParam( + tool_use_id="1", + content=[TextBlockParam(type="text", text="Result1")], + ) + ] mock_toolbox._tool_map = {} trajectory = [ @@ -398,7 +443,12 @@ def test_trajectory_executor_builds_message_history() -> None: def test_trajectory_executor_message_history_accumulates() -> None: """Test that message history accumulates across multiple steps.""" mock_toolbox = MagicMock(spec=ToolCollection) - mock_toolbox.run.return_value = ["Result"] + mock_toolbox.run.return_value = [ + ToolResultBlockParam( + tool_use_id="1", + content=[TextBlockParam(type="text", text="Result")], + ) + ] mock_toolbox._tool_map = {} trajectory = [ @@ -426,11 +476,16 @@ def test_trajectory_executor_message_history_accumulates() -> None: def test_trajectory_executor_message_history_contains_tool_use_id() -> None: """Test that tool result has correct tool_use_id matching tool use.""" mock_toolbox = MagicMock(spec=ToolCollection) - mock_toolbox.run.return_value = ["Success"] + mock_toolbox.run.return_value = [ + ToolResultBlockParam( + tool_use_id="1", + content=[TextBlockParam(type="text", text="Success")], + ) + ] mock_toolbox._tool_map = {} trajectory = [ - ToolUseBlockParam(id="unique_id_123", name="tool", input={}, type="tool_use") + ToolUseBlockParam(id="1", name="tool", input={}, type="tool_use") ] executor = TrajectoryExecutor( @@ -447,13 +502,18 @@ def test_trajectory_executor_message_history_contains_tool_use_id() -> None: assert isinstance(tool_use, ToolUseBlockParam) assert isinstance(tool_result, ToolResultBlockParam) assert tool_result.tool_use_id == tool_use.id - assert tool_result.tool_use_id == "unique_id_123" + assert tool_result.tool_use_id == "1" def test_trajectory_executor_message_history_includes_text_result() -> None: """Test that tool results include text content.""" mock_toolbox = MagicMock(spec=ToolCollection) - mock_toolbox.run.return_value = ["Tool executed successfully"] + mock_toolbox.run.return_value = [ + ToolResultBlockParam( + tool_use_id="1", + content=[TextBlockParam(type="text", text="Tool executed successfully")], + ) + ] mock_toolbox._tool_map = {} trajectory = [ @@ -503,7 +563,12 @@ def test_trajectory_executor_message_history_on_failure() -> None: def test_trajectory_executor_message_history_on_pause() -> None: """Test that message history is included when execution pauses.""" mock_toolbox = MagicMock(spec=ToolCollection) - mock_toolbox.run.return_value = ["Result"] + mock_toolbox.run.return_value = [ + ToolResultBlockParam( + tool_use_id="1", + content=[TextBlockParam(type="text", text="Result")], + ) + ] cacheable_tool = MagicMock() cacheable_tool.is_cacheable = True @@ -538,7 +603,12 @@ def test_trajectory_executor_message_history_on_pause() -> None: def test_trajectory_executor_message_history_order() -> None: """Test that message history maintains correct order.""" mock_toolbox = MagicMock(spec=ToolCollection) - mock_toolbox.run.return_value = ["Result"] + mock_toolbox.run.return_value = [ + ToolResultBlockParam( + tool_use_id="1", + content=[TextBlockParam(type="text", text="Result")], + ) + ] mock_toolbox._tool_map = {} trajectory = [ @@ -577,7 +647,12 @@ def test_trajectory_executor_message_history_order() -> None: def test_visual_validation_disabled_by_default() -> None: """Test that visual validation is disabled by default.""" mock_toolbox = MagicMock(spec=ToolCollection) - mock_toolbox.run.return_value = ["Result"] + mock_toolbox.run.return_value = [ + ToolResultBlockParam( + tool_use_id="1", + content=[TextBlockParam(type="text", text="Result")], + ) + ] mock_toolbox._tool_map = {} trajectory = [ @@ -606,7 +681,12 @@ def test_visual_validation_disabled_by_default() -> None: def test_visual_validation_enabled_flag() -> None: """Test that visual validation flag can be enabled.""" mock_toolbox = MagicMock(spec=ToolCollection) - mock_toolbox.run.return_value = ["Result"] + mock_toolbox.run.return_value = [ + ToolResultBlockParam( + tool_use_id="1", + content=[TextBlockParam(type="text", text="Result")], + ) + ] mock_toolbox._tool_map = {} trajectory = [ @@ -653,7 +733,12 @@ def test_validate_step_visually_hook_exists() -> None: def test_validate_step_visually_always_passes_when_disabled() -> None: """Test that validation always passes when disabled (default behavior).""" mock_toolbox = MagicMock(spec=ToolCollection) - mock_toolbox.run.return_value = ["Result"] + mock_toolbox.run.return_value = [ + ToolResultBlockParam( + tool_use_id="1", + content=[TextBlockParam(type="text", text="Result")], + ) + ] mock_toolbox._tool_map = {} # Create step with visual validation fields @@ -684,7 +769,12 @@ def test_validate_step_visually_always_passes_when_disabled() -> None: def test_validate_step_visually_hook_called_when_enabled() -> None: """Test that validate_step_visually is called when enabled.""" mock_toolbox = MagicMock(spec=ToolCollection) - mock_toolbox.run.return_value = ["Result"] + mock_toolbox.run.return_value = [ + ToolResultBlockParam( + tool_use_id="1", + content=[TextBlockParam(type="text", text="Result")], + ) + ] mock_toolbox._tool_map = {} trajectory = [ From 7e969c8734a792f0c963f0ee28c6d4d85fcfc8e9 Mon Sep 17 00:00:00 2001 From: philipph-askui Date: Mon, 15 Dec 2025 07:05:47 +0100 Subject: [PATCH 05/30] chore(caching): pass CacheExecutionManager to caching tools i/o the Agent, to make the flow simpler --- src/askui/models/shared/agent.py | 65 +----------------- src/askui/tools/caching_tools.py | 47 +++++++------ tests/unit/tools/test_caching_tools.py | 94 ++++++++++---------------- 3 files changed, 66 insertions(+), 140 deletions(-) diff --git a/src/askui/models/shared/agent.py b/src/askui/models/shared/agent.py index 4a0bebd1..1082053f 100644 --- a/src/askui/models/shared/agent.py +++ b/src/askui/models/shared/agent.py @@ -1,5 +1,4 @@ import logging -from typing import TYPE_CHECKING from typing_extensions import override @@ -22,10 +21,6 @@ from askui.reporting import NULL_REPORTER, Reporter from askui.utils.cache_execution_manager import CacheExecutionManager -if TYPE_CHECKING: - from askui.models.shared.settings import CacheFile - from askui.utils.trajectory_executor import TrajectoryExecutor - logger = logging.getLogger(__name__) @@ -61,7 +56,6 @@ def __init__( # Store current tool collection for cache executor access self._tool_collection: ToolCollection | None = None - def _get_agent_response( self, model: str, @@ -241,7 +235,7 @@ def _setup_cache_tools(self, tool_collection: ToolCollection) -> None: # Iterate through tools and set agent on caching tools for tool_name, tool in tool_collection.get_tools().items(): if isinstance(tool, (ExecuteCachedTrajectory, VerifyCacheExecution)): - tool.set_agent(self) + tool.set_cache_execution_manager(self._cache_manager) logger.debug("Set agent reference on %s", tool_name) @override @@ -316,60 +310,3 @@ def _handle_stop_reason(self, message: MessageParam, max_tokens: int) -> None: raise MaxTokensExceededError(max_tokens) if message.stop_reason == "refusal": raise ModelRefusalError - - # Public methods for cache management (used by caching tools) - # These delegate to the CacheExecutionManager - def activate_cache_execution( - self, - executor: "TrajectoryExecutor", - cache_file: "CacheFile", - cache_file_path: str, - ) -> None: - """Activate cache execution mode. - - Args: - executor: The trajectory executor to use - cache_file: The cache file being executed - cache_file_path: Path to the cache file - """ - self._cache_manager.activate_execution(executor, cache_file, cache_file_path) - - def get_cache_info(self) -> tuple["CacheFile | None", str | None]: - """Get current cache file and path. - - Returns: - Tuple of (cache_file, cache_file_path) - """ - return self._cache_manager.get_cache_info() - - def is_cache_verification_pending(self) -> bool: - """Check if cache verification is pending. - - Returns: - True if verification is pending - """ - return self._cache_manager.is_cache_verification_pending() - - def update_cache_metadata_on_completion(self, success: bool) -> None: - """Update cache metadata after execution completion (public API). - - Args: - success: Whether the execution was successful - """ - self._cache_manager.update_metadata_on_completion(success) - - def update_cache_metadata_on_failure( - self, step_index: int, error_message: str - ) -> None: - """Update cache metadata after execution failure (public API). - - Args: - step_index: The step index where failure occurred - error_message: The error message - """ - self._cache_manager.update_metadata_on_failure(step_index, error_message) - - def clear_cache_state(self) -> None: - """Clear cache execution state.""" - self._cache_manager.clear_cache_state() - diff --git a/src/askui/tools/caching_tools.py b/src/askui/tools/caching_tools.py index c6867e69..07aa85e2 100644 --- a/src/askui/tools/caching_tools.py +++ b/src/askui/tools/caching_tools.py @@ -8,12 +8,12 @@ from ..models.shared.settings import CachedExecutionToolSettings from ..models.shared.tools import Tool, ToolCollection +from ..utils.cache_execution_manager import CacheExecutionManager from ..utils.cache_manager import CacheManager from ..utils.cache_writer import CacheWriter from ..utils.placeholder_handler import PlaceholderHandler if TYPE_CHECKING: - from ..models.shared.agent import Agent from ..models.shared.agent_message_param import ToolUseBlockParam from ..models.shared.settings import CacheFile from ..utils.trajectory_executor import TrajectoryExecutor @@ -198,16 +198,18 @@ def __init__( if not settings: settings = CachedExecutionToolSettings() self._settings = settings - self._agent: "Agent | None" = None # Will be set by set_agent() + self._cache_execution_manager: CacheExecutionManager | None = None self._toolbox = toolbox - def set_agent(self, agent: "Agent") -> None: + def set_cache_execution_manager( + self, cache_execution_manager: CacheExecutionManager + ) -> None: """Set the agent reference for cache execution mode activation. Args: agent: The Agent instance that will execute the cached trajectory """ - self._agent = agent + self._cache_execution_manager = cache_execution_manager def _validate_trajectory_file(self, trajectory_file: str) -> str | None: """Validate that trajectory file exists. @@ -390,8 +392,8 @@ def __call__( ) # Validate agent is set - if not self._agent: - error_msg = "Agent not set. Call set_agent() first." + if not self._cache_execution_manager: + error_msg = "Cache Execution Manager not set. Call set_cache_execution_manager() first." logger.error(error_msg) raise RuntimeError(error_msg) @@ -437,7 +439,7 @@ def __call__( ) # Store executor and cache info in agent state - self._agent.activate_cache_execution( + self._cache_execution_manager.activate_execution( executor=executor, cache_file=cache_file, cache_file_path=trajectory_file, @@ -799,11 +801,17 @@ def __init__(self) -> None: }, ) self.is_cacheable = False # Verification is not cacheable - self._agent: "Agent | None" = None + self._cache_execution_manager: CacheExecutionManager | None = None - def set_agent(self, agent: "Agent") -> None: - """Set agent reference for metadata updates.""" - self._agent = agent + def set_cache_execution_manager( + self, cache_execution_manager: CacheExecutionManager + ) -> None: + """Set the agent reference for cache execution mode activation. + + Args: + agent: The Agent instance that will execute the cached trajectory + """ + self._cache_execution_manager = cache_execution_manager @override @validate_call @@ -822,14 +830,15 @@ def __call__(self, success: bool, verification_notes: str) -> str: success, verification_notes, ) - - if not self._agent: - error_msg = "Agent not set. Cannot record verification result." + if not self._cache_execution_manager: + error_msg = ( + "Cache Execution Manager not set. Cannot record verification result." + ) logger.error(error_msg) return error_msg # Check if there's a cache file to update (more reliable than checking flag) - cache_file, cache_file_path = self._agent.get_cache_info() + cache_file, cache_file_path = self._cache_execution_manager.get_cache_info() if not (cache_file and cache_file_path): warning_msg = ( "No cache file to update. " @@ -841,7 +850,7 @@ def __call__(self, success: bool, verification_notes: str) -> str: # Debug log if verification flag wasn't explicitly set # (This can happen if verification is called directly without the flag, # but we still proceed since we have the cache file) - if not self._agent.is_cache_verification_pending(): + if not self._cache_execution_manager.is_cache_verification_pending(): logger.debug( "Verification flag not set, but cache file exists. " "This is normal for direct verification calls." @@ -849,7 +858,7 @@ def __call__(self, success: bool, verification_notes: str) -> str: # Update cache metadata based on verification result if success: - self._agent.update_cache_metadata_on_completion(success=True) + self._cache_execution_manager.update_metadata_on_completion(success=True) result_msg = f"✓ Cache verification successful: {verification_notes}" logger.info(result_msg) else: @@ -857,7 +866,7 @@ def __call__(self, success: bool, verification_notes: str) -> str: f"Cache execution did not lead to target system state: " f"{verification_notes}" ) - self._agent.update_cache_metadata_on_failure( + self._cache_execution_manager.update_metadata_on_failure( step_index=-1, # -1 indicates verification failure error_message=error_msg, ) @@ -874,6 +883,6 @@ def __call__(self, success: bool, verification_notes: str) -> str: logger.warning(result_msg) # Clear verification flag and cache references after verification - self._agent.clear_cache_state() + self._cache_execution_manager.clear_cache_state() return result_msg diff --git a/tests/unit/tools/test_caching_tools.py b/tests/unit/tools/test_caching_tools.py index c105a95f..620f5072 100644 --- a/tests/unit/tools/test_caching_tools.py +++ b/tests/unit/tools/test_caching_tools.py @@ -9,14 +9,13 @@ import pytest -from askui.models.shared.agent import Agent -from askui.models.shared.messages_api import MessagesApi from askui.models.shared.settings import CachedExecutionToolSettings from askui.models.shared.tools import ToolCollection from askui.tools.caching_tools import ( ExecuteCachedTrajectory, RetrieveCachedTestExecutions, ) +from askui.utils.cache_execution_manager import CacheExecutionManager # ============================================================================ @@ -219,11 +218,11 @@ def test_execute_cached_execution_initializes_with_toolbox() -> None: tool = ExecuteCachedTrajectory(toolbox=mock_toolbox) assert tool.name == "execute_cached_executions_tool" assert tool._toolbox is mock_toolbox # noqa: SLF001 - assert tool._agent is None # noqa: SLF001 + assert tool._cache_execution_manager is None # noqa: SLF001 -def test_execute_cached_execution_raises_error_without_agent() -> None: - """Test that ExecuteCachedTrajectory raises error when agent not set.""" +def test_execute_cached_execution_raises_error_without_cache_manager() -> None: + """Test that ExecuteCachedTrajectory raises error when cache manager not set.""" with tempfile.TemporaryDirectory() as temp_dir: cache_file = Path(temp_dir) / "test.json" cache_data = { @@ -244,7 +243,7 @@ def test_execute_cached_execution_raises_error_without_agent() -> None: mock_toolbox = MagicMock(spec=ToolCollection) tool = ExecuteCachedTrajectory(toolbox=mock_toolbox) - with pytest.raises(RuntimeError, match="Agent not set"): + with pytest.raises(RuntimeError, match="Cache Execution Manager not set"): tool(trajectory_file=str(cache_file)) @@ -252,9 +251,8 @@ def test_execute_cached_execution_returns_error_when_file_not_found() -> None: """Test that ExecuteCachedTrajectory returns error message if file doesn't exist.""" mock_toolbox = MagicMock(spec=ToolCollection) tool = ExecuteCachedTrajectory(toolbox=mock_toolbox) - mock_agent = MagicMock(spec=Agent) - mock_agent._tool_collection = mock_toolbox - tool.set_agent(mock_agent) + mock_cache_manager = MagicMock(spec=CacheExecutionManager) + tool.set_cache_execution_manager(mock_cache_manager) result = tool(trajectory_file="/non/existent/file.json") @@ -300,15 +298,13 @@ def test_execute_cached_execution_activates_cache_mode() -> None: json.dump(cache_data, f) # Create mock agent with toolbox - mock_messages_api = MagicMock(spec=MessagesApi) - mock_agent = Agent(messages_api=mock_messages_api) + mock_cache_manager = MagicMock(spec=CacheExecutionManager) mock_toolbox = MagicMock(spec=ToolCollection) mock_toolbox._tool_map = {} - mock_agent._tool_collection = mock_toolbox # noqa: SLF001 # Create and configure tool tool = ExecuteCachedTrajectory(toolbox=mock_toolbox) - tool.set_agent(mock_agent) + tool.set_cache_execution_manager(mock_cache_manager) # Call the tool result = tool(trajectory_file=str(cache_file)) @@ -318,10 +314,11 @@ def test_execute_cached_execution_activates_cache_mode() -> None: assert "✓ Cache execution mode activated" in result assert "2 cached steps" in result - # Verify cache info was set using public API - cache_file_obj, cache_file_path = mock_agent.get_cache_info() - assert cache_file_obj is not None - assert cache_file_path == str(cache_file) + # Verify activate_execution was called on the cache manager + mock_cache_manager.activate_execution.assert_called_once() + # Verify the cache file path was passed correctly + call_args = mock_cache_manager.activate_execution.call_args + assert call_args.kwargs["cache_file_path"] == str(cache_file) def test_execute_cached_execution_works_with_toolbox() -> None: @@ -354,14 +351,13 @@ def test_execute_cached_execution_works_with_toolbox() -> None: json.dump(cache_data, f) # Create mock agent - mock_messages_api = MagicMock(spec=MessagesApi) - mock_agent = Agent(messages_api=mock_messages_api) + mock_cache_manager = MagicMock(spec=CacheExecutionManager) # Create tool with toolbox mock_toolbox = MagicMock(spec=ToolCollection) mock_toolbox._tool_map = {} tool = ExecuteCachedTrajectory(toolbox=mock_toolbox) - tool.set_agent(mock_agent) + tool.set_cache_execution_manager(mock_cache_manager) result = tool(trajectory_file=str(cache_file)) @@ -370,15 +366,15 @@ def test_execute_cached_execution_works_with_toolbox() -> None: assert "✓ Cache execution mode activated" in result -def test_execute_cached_execution_set_agent() -> None: - """Test that set_agent properly sets reference.""" +def test_execute_cached_execution_set_cache_manager() -> None: + """Test that set_cache_execution_manager properly sets reference.""" mock_toolbox = MagicMock(spec=ToolCollection) tool = ExecuteCachedTrajectory(toolbox=mock_toolbox) - mock_agent = MagicMock(spec=Agent) + mock_cache_manager = MagicMock(spec=CacheExecutionManager) - tool.set_agent(mock_agent) + tool.set_cache_execution_manager(mock_cache_manager) - assert tool._agent == mock_agent # noqa: SLF001 + assert tool._cache_execution_manager == mock_cache_manager # noqa: SLF001 assert tool._toolbox == mock_toolbox # noqa: SLF001 @@ -436,14 +432,12 @@ def test_execute_cached_execution_with_placeholders() -> None: json.dump(cache_data, f) # Create mock agent - mock_messages_api = MagicMock(spec=MessagesApi) - mock_agent = Agent(messages_api=mock_messages_api) + mock_cache_manager = MagicMock(spec=CacheExecutionManager) mock_toolbox = MagicMock(spec=ToolCollection) mock_toolbox._tool_map = {} - mock_agent._tool_collection = mock_toolbox # noqa: SLF001 tool = ExecuteCachedTrajectory(toolbox=mock_toolbox) - tool.set_agent(mock_agent) + tool.set_cache_execution_manager(mock_cache_manager) result = tool( trajectory_file=str(cache_file), @@ -490,13 +484,11 @@ def test_execute_cached_execution_missing_placeholders() -> None: json.dump(cache_data, f) # Create mock agent - mock_messages_api = MagicMock(spec=MessagesApi) - mock_agent = Agent(messages_api=mock_messages_api) + mock_cache_manager = MagicMock(spec=CacheExecutionManager) mock_toolbox = MagicMock(spec=ToolCollection) - mock_agent._tool_collection = mock_toolbox # noqa: SLF001 tool = ExecuteCachedTrajectory(toolbox=mock_toolbox) - tool.set_agent(mock_agent) + tool.set_cache_execution_manager(mock_cache_manager) result = tool(trajectory_file=str(cache_file)) @@ -526,14 +518,12 @@ def test_execute_cached_execution_no_placeholders_backward_compat() -> None: json.dump(trajectory, f) # Create mock agent - mock_messages_api = MagicMock(spec=MessagesApi) - mock_agent = Agent(messages_api=mock_messages_api) + mock_cache_manager = MagicMock(spec=CacheExecutionManager) mock_toolbox = MagicMock(spec=ToolCollection) mock_toolbox._tool_map = {} - mock_agent._tool_collection = mock_toolbox # noqa: SLF001 tool = ExecuteCachedTrajectory(toolbox=mock_toolbox) - tool.set_agent(mock_agent) + tool.set_cache_execution_manager(mock_cache_manager) result = tool(trajectory_file=str(cache_file)) @@ -572,14 +562,12 @@ def test_continue_cached_trajectory_from_middle() -> None: json.dump(cache_data, f) # Create mock agent - mock_messages_api = MagicMock(spec=MessagesApi) - mock_agent = Agent(messages_api=mock_messages_api) + mock_cache_manager = MagicMock(spec=CacheExecutionManager) mock_toolbox = MagicMock(spec=ToolCollection) mock_toolbox._tool_map = {} - mock_agent._tool_collection = mock_toolbox # noqa: SLF001 tool = ExecuteCachedTrajectory(toolbox=mock_toolbox) - tool.set_agent(mock_agent) + tool.set_cache_execution_manager(mock_cache_manager) result = tool(trajectory_file=str(cache_file), start_from_step_index=2) @@ -614,14 +602,12 @@ def test_continue_cached_trajectory_invalid_step_index_negative() -> None: with cache_file.open("w", encoding="utf-8") as f: json.dump(cache_data, f) - mock_messages_api = MagicMock(spec=MessagesApi) - mock_agent = Agent(messages_api=mock_messages_api) + mock_cache_manager = MagicMock(spec=CacheExecutionManager) mock_toolbox = MagicMock(spec=ToolCollection) mock_toolbox._tool_map = {} - mock_agent._tool_collection = mock_toolbox # noqa: SLF001 tool = ExecuteCachedTrajectory(toolbox=mock_toolbox) - tool.set_agent(mock_agent) + tool.set_cache_execution_manager(mock_cache_manager) result = tool(trajectory_file=str(cache_file), start_from_step_index=-1) @@ -655,14 +641,12 @@ def test_continue_cached_trajectory_invalid_step_index_too_large() -> None: with cache_file.open("w", encoding="utf-8") as f: json.dump(cache_data, f) - mock_messages_api = MagicMock(spec=MessagesApi) - mock_agent = Agent(messages_api=mock_messages_api) + mock_cache_manager = MagicMock(spec=CacheExecutionManager) mock_toolbox = MagicMock(spec=ToolCollection) mock_toolbox._tool_map = {} - mock_agent._tool_collection = mock_toolbox # noqa: SLF001 tool = ExecuteCachedTrajectory(toolbox=mock_toolbox) - tool.set_agent(mock_agent) + tool.set_cache_execution_manager(mock_cache_manager) result = tool(trajectory_file=str(cache_file), start_from_step_index=5) @@ -718,14 +702,12 @@ def test_continue_cached_trajectory_with_placeholders() -> None: json.dump(cache_data, f) # Create mock agent - mock_messages_api = MagicMock(spec=MessagesApi) - mock_agent = Agent(messages_api=mock_messages_api) + mock_cache_manager = MagicMock(spec=CacheExecutionManager) mock_toolbox = MagicMock(spec=ToolCollection) mock_toolbox._tool_map = {} - mock_agent._tool_collection = mock_toolbox # noqa: SLF001 tool = ExecuteCachedTrajectory(toolbox=mock_toolbox) - tool.set_agent(mock_agent) + tool.set_cache_execution_manager(mock_cache_manager) result = tool( trajectory_file=str(cache_file), @@ -765,14 +747,12 @@ def test_execute_cached_trajectory_warns_if_invalid(tmp_path, caplog): json.dump(cache_data, f) # Create mock agent - mock_messages_api = MagicMock(spec=MessagesApi) - mock_agent = Agent(messages_api=mock_messages_api) + mock_cache_manager = MagicMock(spec=CacheExecutionManager) mock_toolbox = MagicMock(spec=ToolCollection) mock_toolbox._tool_map = {} - mock_agent._tool_collection = mock_toolbox # noqa: SLF001 tool = ExecuteCachedTrajectory(toolbox=mock_toolbox) - tool.set_agent(mock_agent) + tool.set_cache_execution_manager(mock_cache_manager) result = tool(trajectory_file=str(cache_file)) From 35b4b8c64fd251a7d64d87cd46807b5fd13e5af1 Mon Sep 17 00:00:00 2001 From: philipph-askui Date: Mon, 15 Dec 2025 07:20:28 +0100 Subject: [PATCH 06/30] chore(caching): cleanup _step method --- src/askui/models/shared/agent.py | 58 ++++++++++++---------- src/askui/utils/cache_execution_manager.py | 43 ++++------------ 2 files changed, 41 insertions(+), 60 deletions(-) diff --git a/src/askui/models/shared/agent.py b/src/askui/models/shared/agent.py index 1082053f..1c6ddea0 100644 --- a/src/askui/models/shared/agent.py +++ b/src/askui/models/shared/agent.py @@ -107,41 +107,35 @@ def _process_tool_execution( tool_collection: ToolCollection, on_message: OnMessageCb, truncation_strategy: TruncationStrategy, - model: str, - settings: ActSettings, - ) -> None: - """Process tool execution and continue if needed. + ) -> bool: + """Process tool execution and return whether to continue. Args: message_by_assistant: Assistant message with potential tool uses tool_collection: Available tools on_message: Callback for messages truncation_strategy: Message truncation strategy - model: Model to use - settings: Agent settings + + Returns: + True if tool results were added and caller should recurse, + False otherwise """ tool_result_message = self._use_tools(message_by_assistant, tool_collection) if not tool_result_message: - return + return False tool_result_message = self._call_on_message( on_message, tool_result_message, truncation_strategy.messages ) if not tool_result_message: - return + return False tool_result_message_dict = tool_result_message.model_dump(mode="json") logger.debug(tool_result_message_dict) truncation_strategy.append_message(tool_result_message) - # Continue with next step recursively - self._step( - model=model, - tool_collection=tool_collection, - on_message=on_message, - settings=settings, - truncation_strategy=truncation_strategy, - ) + # Return True to indicate caller should recurse + return True def _step( self, @@ -175,16 +169,21 @@ def _step( # Get or generate assistant message if truncation_strategy.messages[-1].role == "user": # Try to execute from cache first - if self._cache_manager.handle_execution_step( + should_recurse = self._cache_manager.handle_execution_step( on_message, truncation_strategy, - model, - tool_collection, - settings, self.__class__.__name__, - self._step, - ): - return # Cache step handled and recursion occurred + ) + if should_recurse: + # Cache step handled, recurse to continue + self._step( + model=model, + on_message=on_message, + settings=settings, + tool_collection=tool_collection, + truncation_strategy=truncation_strategy, + ) + return # Normal flow: get agent response message_by_assistant = self._get_agent_response( @@ -198,14 +197,21 @@ def _step( # Check stop reason and process tools self._handle_stop_reason(message_by_assistant, settings.messages.max_tokens) - self._process_tool_execution( + should_recurse = self._process_tool_execution( message_by_assistant, tool_collection, on_message, truncation_strategy, - model, - settings, ) + if should_recurse: + # Tool results added, recurse to continue + self._step( + model=model, + on_message=on_message, + settings=settings, + tool_collection=tool_collection, + truncation_strategy=truncation_strategy, + ) def _call_on_message( self, diff --git a/src/askui/utils/cache_execution_manager.py b/src/askui/utils/cache_execution_manager.py index 869b3b37..d17a5888 100644 --- a/src/askui/utils/cache_execution_manager.py +++ b/src/askui/utils/cache_execution_manager.py @@ -3,12 +3,10 @@ import json import logging from pathlib import Path -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING from askui.models.shared.agent_message_param import MessageParam, TextBlockParam from askui.models.shared.agent_on_message_cb import OnMessageCb -from askui.models.shared.settings import ActSettings -from askui.models.shared.tools import ToolCollection from askui.models.shared.truncation_strategies import TruncationStrategy from askui.reporting import Reporter from askui.utils.trajectory_executor import ExecutionResult @@ -17,11 +15,6 @@ from askui.models.shared.settings import CacheFile from askui.utils.trajectory_executor import TrajectoryExecutor -# Type for the step callback function (matches Agent._step signature) -StepCallback = Callable[ - [str, OnMessageCb, ActSettings, ToolCollection, TruncationStrategy], None -] - logger = logging.getLogger(__name__) @@ -100,25 +93,17 @@ def handle_execution_step( self, on_message: OnMessageCb, truncation_strategy: TruncationStrategy, - model: str, - tool_collection: ToolCollection, - settings: ActSettings, agent_class_name: str, - step_callback: StepCallback, ) -> bool: """Handle cache execution step. Args: on_message: Callback for messages truncation_strategy: Message truncation strategy - model: Model to use - tool_collection: Available tools - settings: Agent settings agent_class_name: Name of agent class for reporting - step_callback: Callback to continue agent step Returns: - True if cache step was handled and recursion occurred, + True if cache step was handled and caller should recurse, False if should continue with normal flow """ if not (self._executing_from_cache and self._cache_executor): @@ -132,11 +117,7 @@ def handle_execution_step( result, on_message, truncation_strategy, - model, - tool_collection, - settings, agent_class_name, - step_callback, ) if result.status == "NEEDS_AGENT": return self._handle_cache_needs_agent(result) @@ -150,13 +131,14 @@ def _handle_cache_success( result: ExecutionResult, on_message: OnMessageCb, truncation_strategy: TruncationStrategy, - model: str, - tool_collection: ToolCollection, - settings: ActSettings, agent_class_name: str, - step_callback: StepCallback, ) -> bool: - """Handle successful cache step execution.""" + """Handle successful cache step execution. + + Returns: + True if messages were added and caller should recurse, + False otherwise + """ if len(result.message_history) < 2: return False @@ -184,14 +166,7 @@ def _handle_cache_success( truncation_strategy.append_message(user_msg_processed) - # Continue with next step recursively - step_callback( - model, - on_message, - settings, - tool_collection, - truncation_strategy, - ) + # Return True to indicate caller should recurse return True def _handle_cache_needs_agent(self, result: ExecutionResult) -> bool: From bd85a7367347864841d217b4cafeaa2b88efa4d7 Mon Sep 17 00:00:00 2001 From: philipph-askui Date: Mon, 15 Dec 2025 07:23:02 +0100 Subject: [PATCH 07/30] chore(caching): rename `cache_manager` to `cache_execution_manager` to avoid confusion with CacheManager class --- src/askui/models/shared/agent.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/askui/models/shared/agent.py b/src/askui/models/shared/agent.py index 1c6ddea0..da068118 100644 --- a/src/askui/models/shared/agent.py +++ b/src/askui/models/shared/agent.py @@ -52,7 +52,7 @@ def __init__( truncation_strategy_factory or SimpleTruncationStrategyFactory() ) # Cache execution manager handles all cache-related logic - self._cache_manager = CacheExecutionManager(reporter) + self._cache_execution_manager = CacheExecutionManager(reporter) # Store current tool collection for cache executor access self._tool_collection: ToolCollection | None = None @@ -169,7 +169,7 @@ def _step( # Get or generate assistant message if truncation_strategy.messages[-1].role == "user": # Try to execute from cache first - should_recurse = self._cache_manager.handle_execution_step( + should_recurse = self._cache_execution_manager.handle_execution_step( on_message, truncation_strategy, self.__class__.__name__, @@ -241,7 +241,7 @@ def _setup_cache_tools(self, tool_collection: ToolCollection) -> None: # Iterate through tools and set agent on caching tools for tool_name, tool in tool_collection.get_tools().items(): if isinstance(tool, (ExecuteCachedTrajectory, VerifyCacheExecution)): - tool.set_cache_execution_manager(self._cache_manager) + tool.set_cache_execution_manager(self._cache_execution_manager) logger.debug("Set agent reference on %s", tool_name) @override @@ -254,7 +254,7 @@ def act( settings: ActSettings | None = None, ) -> None: # Reset cache execution state at the start of each act() call - self._cache_manager.reset_state() + self._cache_execution_manager.reset_state() _settings = settings or ActSettings() _tool_collection = tools or ToolCollection() From 41ef97be825fcf03f1c302742bc23b2073d99340 Mon Sep 17 00:00:00 2001 From: philipph-askui Date: Mon, 15 Dec 2025 07:43:47 +0100 Subject: [PATCH 08/30] fix(caching): fixes bug when delegating execution step back to agent --- src/askui/utils/cache_execution_manager.py | 64 ++++++++++++++++---- src/askui/utils/trajectory_executor.py | 6 +- tests/unit/utils/test_trajectory_executor.py | 13 +++- 3 files changed, 69 insertions(+), 14 deletions(-) diff --git a/src/askui/utils/cache_execution_manager.py b/src/askui/utils/cache_execution_manager.py index d17a5888..b62007af 100644 --- a/src/askui/utils/cache_execution_manager.py +++ b/src/askui/utils/cache_execution_manager.py @@ -120,7 +120,12 @@ def handle_execution_step( agent_class_name, ) if result.status == "NEEDS_AGENT": - return self._handle_cache_needs_agent(result) + return self._handle_cache_needs_agent( + result, + on_message, + truncation_strategy, + agent_class_name, + ) if result.status == "COMPLETED": return self._handle_cache_completed(truncation_strategy) # result.status == "FAILED" @@ -169,23 +174,62 @@ def _handle_cache_success( # Return True to indicate caller should recurse return True - def _handle_cache_needs_agent(self, result: ExecutionResult) -> bool: - """Handle cache execution pausing for non-cacheable tool.""" + def _handle_cache_needs_agent( + self, + result: ExecutionResult, + on_message: OnMessageCb, + truncation_strategy: TruncationStrategy, + agent_class_name: str, + ) -> bool: + """Handle cache execution pausing for non-cacheable tool. + + Injects a user message explaining that cache execution paused and + what the agent needs to execute next. + + Returns: + False to indicate normal agent flow should continue + """ logger.info( "Paused cache execution at step %d " "(non-cacheable tool - agent will handle this step)", result.step_index, ) self._executing_from_cache = False + + # Get the tool that needs to be executed + tool_to_execute = result.tool_result # This is the ToolUseBlockParam + + # Create a user message explaining what needs to be done + if tool_to_execute: + instruction_message = MessageParam( + role="user", + content=[ + TextBlockParam( + type="text", + text=( + f"Cache execution paused at step {result.step_index}. " + f"The previous steps were executed successfully from cache. " + f"The next step requires the '{tool_to_execute.name}' tool, " + f"which cannot be executed from cache. " + f"Please execute this tool with the necessary parameters." + ), + ) + ], + ) + + # Add the instruction message to truncation strategy + instruction_msg = self._call_on_message( + on_message, instruction_message, truncation_strategy.messages + ) + if instruction_msg: + truncation_strategy.append_message(instruction_msg) + return False # Fall through to normal agent API call - def _handle_cache_completed( - self, truncation_strategy: TruncationStrategy - ) -> bool: + def _handle_cache_completed(self, truncation_strategy: TruncationStrategy) -> bool: """Handle cache execution completion.""" logger.info( - "✓ Cache trajectory execution completed - " - "requesting agent verification" + "✓ Cache trajectory execution completed - requesting agent verification" ) self._executing_from_cache = False self._cache_verification_pending = True @@ -269,9 +313,7 @@ def update_metadata_on_completion(self, success: bool) -> None: except Exception: logger.exception("Failed to update cache metadata") - def update_metadata_on_failure( - self, step_index: int, error_message: str - ) -> None: + def update_metadata_on_failure(self, step_index: int, error_message: str) -> None: """Update cache metadata after execution failure. Args: diff --git a/src/askui/utils/trajectory_executor.py b/src/askui/utils/trajectory_executor.py index 61917f78..9d33bd98 100644 --- a/src/askui/utils/trajectory_executor.py +++ b/src/askui/utils/trajectory_executor.py @@ -120,10 +120,14 @@ def execute_next_step(self) -> ExecutionResult: logger.info( f"Pausing at step {step_index}: {step.name} (non-cacheable tool)" ) + # Return result with current tool step info for the agent to handle + # Note: We don't add any messages here - the cache manager will + # inject a user message explaining what needs to be done return ExecutionResult( status="NEEDS_AGENT", step_index=step_index, - message_history=self.message_history, + message_history=self.message_history.copy(), + tool_result=step, # Pass the tool use block for reference ) # Visual validation (future feature - currently always passes) diff --git a/tests/unit/utils/test_trajectory_executor.py b/tests/unit/utils/test_trajectory_executor.py index 9bef68b6..8bacb28c 100644 --- a/tests/unit/utils/test_trajectory_executor.py +++ b/tests/unit/utils/test_trajectory_executor.py @@ -596,8 +596,17 @@ def test_trajectory_executor_message_history_on_pause() -> None: assert results[0].status == "SUCCESS" assert results[1].status == "NEEDS_AGENT" - # Message history should include first step's execution - assert len(results[1].message_history) == 2 # First step: assistant + user + # Message history should include only the successfully executed cacheable step: + # 1. First step: assistant message (tool use) + # 2. First step: user message (tool result) + # The non-cacheable tool is NOT in message history - instead it's in tool_result + assert len(results[1].message_history) == 2 + assert results[1].message_history[0].role == "assistant" # First cacheable tool use + assert results[1].message_history[1].role == "user" # First tool result + + # The non-cacheable tool should be in tool_result for reference + assert results[1].tool_result is not None + assert results[1].tool_result.name == "non_cacheable" def test_trajectory_executor_message_history_order() -> None: From e894073f57f4b783a830e864049be44aedb40110 Mon Sep 17 00:00:00 2001 From: philipph-askui Date: Mon, 15 Dec 2025 07:54:04 +0100 Subject: [PATCH 09/30] chore(caching): cleanup imports --- src/askui/utils/cache_execution_manager.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/askui/utils/cache_execution_manager.py b/src/askui/utils/cache_execution_manager.py index b62007af..180abaaf 100644 --- a/src/askui/utils/cache_execution_manager.py +++ b/src/askui/utils/cache_execution_manager.py @@ -6,9 +6,10 @@ from typing import TYPE_CHECKING from askui.models.shared.agent_message_param import MessageParam, TextBlockParam -from askui.models.shared.agent_on_message_cb import OnMessageCb +from askui.models.shared.agent_on_message_cb import OnMessageCb, OnMessageCbParam from askui.models.shared.truncation_strategies import TruncationStrategy from askui.reporting import Reporter +from askui.utils.cache_manager import CacheManager from askui.utils.trajectory_executor import ExecutionResult if TYPE_CHECKING: @@ -281,7 +282,6 @@ def _call_on_message( """Call on_message callback if provided.""" if on_message is None: return message - from askui.models.shared.agent_on_message_cb import OnMessageCbParam return on_message(OnMessageCbParam(message=message, messages=messages)) @@ -295,8 +295,6 @@ def update_metadata_on_completion(self, success: bool) -> None: return try: - from askui.utils.cache_manager import CacheManager - cache_manager = CacheManager() cache_manager.record_execution_attempt(self._cache_file, success=success) @@ -309,7 +307,7 @@ def update_metadata_on_completion(self, success: bool) -> None: indent=2, default=str, ) - logger.debug("Updated cache metadata: %s", cache_path.name) + logger.info("Updated cache metadata: %s", cache_path.name) except Exception: logger.exception("Failed to update cache metadata") @@ -324,8 +322,6 @@ def update_metadata_on_failure(self, step_index: int, error_message: str) -> Non return try: - from askui.utils.cache_manager import CacheManager - cache_manager = CacheManager() cache_manager.record_execution_attempt(self._cache_file, success=False) cache_manager.record_step_failure( From 7b844db0e39d9eb71337b85aac8300eb48a5cf8e Mon Sep 17 00:00:00 2001 From: philipph-askui Date: Mon, 15 Dec 2025 08:30:20 +0100 Subject: [PATCH 10/30] fix(caching): make cached steps appear under the correct role in reports (`CachedExecutionManager` i/o `Agent`) --- src/askui/models/shared/agent.py | 1 - src/askui/utils/cache_execution_manager.py | 7 +------ 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/src/askui/models/shared/agent.py b/src/askui/models/shared/agent.py index da068118..2e11c83c 100644 --- a/src/askui/models/shared/agent.py +++ b/src/askui/models/shared/agent.py @@ -172,7 +172,6 @@ def _step( should_recurse = self._cache_execution_manager.handle_execution_step( on_message, truncation_strategy, - self.__class__.__name__, ) if should_recurse: # Cache step handled, recurse to continue diff --git a/src/askui/utils/cache_execution_manager.py b/src/askui/utils/cache_execution_manager.py index 180abaaf..a3f339ee 100644 --- a/src/askui/utils/cache_execution_manager.py +++ b/src/askui/utils/cache_execution_manager.py @@ -94,7 +94,6 @@ def handle_execution_step( self, on_message: OnMessageCb, truncation_strategy: TruncationStrategy, - agent_class_name: str, ) -> bool: """Handle cache execution step. @@ -118,14 +117,12 @@ def handle_execution_step( result, on_message, truncation_strategy, - agent_class_name, ) if result.status == "NEEDS_AGENT": return self._handle_cache_needs_agent( result, on_message, truncation_strategy, - agent_class_name, ) if result.status == "COMPLETED": return self._handle_cache_completed(truncation_strategy) @@ -137,7 +134,6 @@ def _handle_cache_success( result: ExecutionResult, on_message: OnMessageCb, truncation_strategy: TruncationStrategy, - agent_class_name: str, ) -> bool: """Handle successful cache step execution. @@ -160,7 +156,7 @@ def _handle_cache_success( truncation_strategy.append_message(message_by_assistant) self._reporter.add_message( - agent_class_name, message_by_assistant.model_dump(mode="json") + self.__class__.__name__, message_by_assistant.model_dump(mode="json") ) # Add user message (tool result) @@ -180,7 +176,6 @@ def _handle_cache_needs_agent( result: ExecutionResult, on_message: OnMessageCb, truncation_strategy: TruncationStrategy, - agent_class_name: str, ) -> bool: """Handle cache execution pausing for non-cacheable tool. From 34effa9139462e568b3a2904cd4bc80fe5af2ecf Mon Sep 17 00:00:00 2001 From: philipph-askui Date: Mon, 15 Dec 2025 08:33:06 +0100 Subject: [PATCH 11/30] chore(caching): pdm format --- src/askui/utils/cache_manager.py | 4 +- src/askui/utils/placeholder_handler.py | 8 +-- src/askui/utils/placeholder_identifier.py | 8 +-- tests/unit/utils/test_cache_manager.py | 54 ++++++++++++++------ tests/unit/utils/test_cache_validator.py | 16 ++++-- tests/unit/utils/test_trajectory_executor.py | 21 ++++---- 6 files changed, 68 insertions(+), 43 deletions(-) diff --git a/src/askui/utils/cache_manager.py b/src/askui/utils/cache_manager.py index ed94a0cb..5a288ed1 100644 --- a/src/askui/utils/cache_manager.py +++ b/src/askui/utils/cache_manager.py @@ -128,9 +128,7 @@ def mark_cache_valid(self, cache_file: CacheFile) -> None: cache_file.metadata.is_valid = True cache_file.metadata.invalidation_reason = None - def get_failure_count_for_step( - self, cache_file: CacheFile, step_index: int - ) -> int: + def get_failure_count_for_step(self, cache_file: CacheFile, step_index: int) -> int: """Get number of failures for a specific step. Args: diff --git a/src/askui/utils/placeholder_handler.py b/src/askui/utils/placeholder_handler.py index 62915db0..5bfd04b0 100644 --- a/src/askui/utils/placeholder_handler.py +++ b/src/askui/utils/placeholder_handler.py @@ -47,9 +47,7 @@ def extract_placeholders(trajectory: list[ToolUseBlockParam]) -> set[str]: for step in trajectory: # Recursively find placeholders in the input object - placeholders.update( - PlaceholderHandler._extract_from_value(step.input) - ) + placeholders.update(PlaceholderHandler._extract_from_value(step.input)) return placeholders @@ -178,9 +176,7 @@ def replace_values_with_placeholders( return templated_trajectory @staticmethod - def _replace_values_in_value( - value: Any, replacements: dict[str, str] - ) -> Any: + def _replace_values_in_value(value: Any, replacements: dict[str, str]) -> Any: """Recursively replace actual values with placeholder syntax. Args: diff --git a/src/askui/utils/placeholder_identifier.py b/src/askui/utils/placeholder_identifier.py index 49c9569d..cc6abfd5 100644 --- a/src/askui/utils/placeholder_identifier.py +++ b/src/askui/utils/placeholder_identifier.py @@ -93,7 +93,9 @@ def identify_placeholders( response_text = response_text.split("```")[1].split("```")[0].strip() placeholder_data = json.loads(response_text) - logger.debug(f"Successfully parsed JSON response with {len(placeholder_data.get('placeholders', []))} placeholders") + logger.debug( + f"Successfully parsed JSON response with {len(placeholder_data.get('placeholders', []))} placeholders" + ) # Convert to our data structures placeholder_definitions = [ @@ -103,9 +105,7 @@ def identify_placeholders( for p in placeholder_data.get("placeholders", []) ] - placeholder_dict = { - p.name: p.description for p in placeholder_definitions - } + placeholder_dict = {p.name: p.description for p in placeholder_definitions} if placeholder_definitions: logger.info( diff --git a/tests/unit/utils/test_cache_manager.py b/tests/unit/utils/test_cache_manager.py index 2464345e..e4bfdd91 100644 --- a/tests/unit/utils/test_cache_manager.py +++ b/tests/unit/utils/test_cache_manager.py @@ -26,10 +26,10 @@ def sample_cache_file(): is_valid=True, ), trajectory=[ + ToolUseBlockParam(id="1", name="click", input={"x": 100}, type="tool_use"), ToolUseBlockParam( - id="1", name="click", input={"x": 100}, type="tool_use" + id="2", name="type", input={"text": "test"}, type="tool_use" ), - ToolUseBlockParam(id="2", name="type", input={"text": "test"}, type="tool_use"), ], placeholders={}, ) @@ -97,10 +97,14 @@ def test_record_execution_attempt_failure_without_info(sample_cache_file): initial_attempts = sample_cache_file.metadata.execution_attempts initial_failures = len(sample_cache_file.metadata.failures) - manager.record_execution_attempt(sample_cache_file, success=False, failure_info=None) + manager.record_execution_attempt( + sample_cache_file, success=False, failure_info=None + ) assert sample_cache_file.metadata.execution_attempts == initial_attempts + 1 - assert len(sample_cache_file.metadata.failures) == initial_failures # No new failure added + assert ( + len(sample_cache_file.metadata.failures) == initial_failures + ) # No new failure added # Record Step Failure Tests @@ -110,7 +114,9 @@ def test_record_step_failure_first_failure(sample_cache_file): """Test recording the first failure at a step.""" manager = CacheManager() - manager.record_step_failure(sample_cache_file, step_index=1, error_message="First error") + manager.record_step_failure( + sample_cache_file, step_index=1, error_message="First error" + ) assert len(sample_cache_file.metadata.failures) == 1 failure = sample_cache_file.metadata.failures[0] @@ -123,9 +129,15 @@ def test_record_step_failure_multiple_at_same_step(sample_cache_file): """Test recording multiple failures at the same step.""" manager = CacheManager() - manager.record_step_failure(sample_cache_file, step_index=1, error_message="Error 1") - manager.record_step_failure(sample_cache_file, step_index=1, error_message="Error 2") - manager.record_step_failure(sample_cache_file, step_index=1, error_message="Error 3") + manager.record_step_failure( + sample_cache_file, step_index=1, error_message="Error 1" + ) + manager.record_step_failure( + sample_cache_file, step_index=1, error_message="Error 2" + ) + manager.record_step_failure( + sample_cache_file, step_index=1, error_message="Error 3" + ) assert len(sample_cache_file.metadata.failures) == 3 assert sample_cache_file.metadata.failures[0].failure_count_at_step == 1 @@ -137,14 +149,24 @@ def test_record_step_failure_different_steps(sample_cache_file): """Test recording failures at different steps.""" manager = CacheManager() - manager.record_step_failure(sample_cache_file, step_index=1, error_message="Error at step 1") - manager.record_step_failure(sample_cache_file, step_index=2, error_message="Error at step 2") - manager.record_step_failure(sample_cache_file, step_index=1, error_message="Another at step 1") + manager.record_step_failure( + sample_cache_file, step_index=1, error_message="Error at step 1" + ) + manager.record_step_failure( + sample_cache_file, step_index=2, error_message="Error at step 2" + ) + manager.record_step_failure( + sample_cache_file, step_index=1, error_message="Another at step 1" + ) assert len(sample_cache_file.metadata.failures) == 3 - step_1_failures = [f for f in sample_cache_file.metadata.failures if f.step_index == 1] - step_2_failures = [f for f in sample_cache_file.metadata.failures if f.step_index == 2] + step_1_failures = [ + f for f in sample_cache_file.metadata.failures if f.step_index == 1 + ] + step_2_failures = [ + f for f in sample_cache_file.metadata.failures if f.step_index == 2 + ] assert len(step_1_failures) == 2 assert len(step_2_failures) == 1 @@ -313,7 +335,7 @@ def test_full_workflow_with_failure_detection(sample_cache_file): # Record 3 failures at step 1 (default threshold is 3) for i in range(3): manager.record_step_failure( - sample_cache_file, step_index=1, error_message=f"Error {i+1}" + sample_cache_file, step_index=1, error_message=f"Error {i + 1}" ) # Check if should invalidate @@ -333,7 +355,7 @@ def test_full_workflow_below_threshold(sample_cache_file): # Record 2 failures at step 1 (below default threshold of 3) for i in range(2): manager.record_step_failure( - sample_cache_file, step_index=1, error_message=f"Error {i+1}" + sample_cache_file, step_index=1, error_message=f"Error {i + 1}" ) # Check if should invalidate @@ -355,7 +377,7 @@ def test_workflow_with_custom_validator(sample_cache_file): # Record 2 failures (enough to trigger custom validator) for i in range(2): manager.record_step_failure( - sample_cache_file, step_index=1, error_message=f"Error {i+1}" + sample_cache_file, step_index=1, error_message=f"Error {i + 1}" ) should_inv, reason = manager.should_invalidate(sample_cache_file, step_index=1) diff --git a/tests/unit/utils/test_cache_validator.py b/tests/unit/utils/test_cache_validator.py index d252cd31..e55d0acb 100644 --- a/tests/unit/utils/test_cache_validator.py +++ b/tests/unit/utils/test_cache_validator.py @@ -26,10 +26,10 @@ def sample_cache_file(): is_valid=True, ), trajectory=[ + ToolUseBlockParam(id="1", name="click", input={"x": 100}, type="tool_use"), ToolUseBlockParam( - id="1", name="click", input={"x": 100}, type="tool_use" + id="2", name="type", input={"text": "test"}, type="tool_use" ), - ToolUseBlockParam(id="2", name="type", input={"text": "test"}, type="tool_use"), ], placeholders={}, ) @@ -219,7 +219,9 @@ def test_stale_cache_validator_not_stale(sample_cache_file): """Test validator does not invalidate recent cache.""" validator = StaleCacheValidator(max_age_days=30) - sample_cache_file.metadata.last_executed_at = datetime.now(tz=timezone.utc) - timedelta(days=10) + sample_cache_file.metadata.last_executed_at = datetime.now( + tz=timezone.utc + ) - timedelta(days=10) sample_cache_file.metadata.failures = [ CacheFailure( timestamp=datetime.now(tz=timezone.utc), @@ -237,7 +239,9 @@ def test_stale_cache_validator_is_stale(sample_cache_file): """Test validator invalidates old cache with failures.""" validator = StaleCacheValidator(max_age_days=30) - sample_cache_file.metadata.last_executed_at = datetime.now(tz=timezone.utc) - timedelta(days=35) + sample_cache_file.metadata.last_executed_at = datetime.now( + tz=timezone.utc + ) - timedelta(days=35) sample_cache_file.metadata.failures = [ CacheFailure( timestamp=datetime.now(tz=timezone.utc), @@ -256,7 +260,9 @@ def test_stale_cache_validator_old_but_no_failures(sample_cache_file): """Test validator does not invalidate old cache without failures.""" validator = StaleCacheValidator(max_age_days=30) - sample_cache_file.metadata.last_executed_at = datetime.now(tz=timezone.utc) - timedelta(days=100) + sample_cache_file.metadata.last_executed_at = datetime.now( + tz=timezone.utc + ) - timedelta(days=100) sample_cache_file.metadata.failures = [] should_inv, reason = validator.should_invalidate(sample_cache_file) diff --git a/tests/unit/utils/test_trajectory_executor.py b/tests/unit/utils/test_trajectory_executor.py index 8bacb28c..46428305 100644 --- a/tests/unit/utils/test_trajectory_executor.py +++ b/tests/unit/utils/test_trajectory_executor.py @@ -158,7 +158,10 @@ def test_trajectory_executor_executes_retrieve_trajectories_tool() -> None: assert result.step_index == 0 # First step executed assert mock_toolbox.run.call_count == 1 # Verify retrieve tool was called - assert mock_toolbox.run.call_args[0][0][0].name == "retrieve_available_trajectories_tool" + assert ( + mock_toolbox.run.call_args[0][0][0].name + == "retrieve_available_trajectories_tool" + ) def test_trajectory_executor_pauses_at_non_cacheable_tool() -> None: @@ -412,7 +415,9 @@ def test_trajectory_executor_builds_message_history() -> None: mock_toolbox._tool_map = {} trajectory = [ - ToolUseBlockParam(id="tool1", name="test_tool", input={"x": 100}, type="tool_use") + ToolUseBlockParam( + id="tool1", name="test_tool", input={"x": 100}, type="tool_use" + ) ] executor = TrajectoryExecutor( @@ -484,9 +489,7 @@ def test_trajectory_executor_message_history_contains_tool_use_id() -> None: ] mock_toolbox._tool_map = {} - trajectory = [ - ToolUseBlockParam(id="1", name="tool", input={}, type="tool_use") - ] + trajectory = [ToolUseBlockParam(id="1", name="tool", input={}, type="tool_use")] executor = TrajectoryExecutor( trajectory=trajectory, toolbox=mock_toolbox, delay_time=0 @@ -516,9 +519,7 @@ def test_trajectory_executor_message_history_includes_text_result() -> None: ] mock_toolbox._tool_map = {} - trajectory = [ - ToolUseBlockParam(id="1", name="tool", input={}, type="tool_use") - ] + trajectory = [ToolUseBlockParam(id="1", name="tool", input={}, type="tool_use")] executor = TrajectoryExecutor( trajectory=trajectory, toolbox=mock_toolbox, delay_time=0 @@ -822,7 +823,9 @@ def mock_validate(step, screenshot=None): assert results[0].status == "SUCCESS" -@pytest.mark.skip(reason="Visual validation fields not yet implemented - future feature") +@pytest.mark.skip( + reason="Visual validation fields not yet implemented - future feature" +) def test_visual_validation_fields_on_tool_use_block() -> None: """Test that ToolUseBlockParam supports visual validation fields. From 406855253803f6a8e5261279860f98272986ca4e Mon Sep 17 00:00:00 2001 From: philipph-askui Date: Mon, 15 Dec 2025 08:45:32 +0100 Subject: [PATCH 12/30] chore(caching): make CacheManager API for validators consistent with VisionAgent API for reporters --- src/askui/utils/cache_manager.py | 10 +++++----- tests/unit/utils/test_cache_manager.py | 23 +++++++++++------------ 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/src/askui/utils/cache_manager.py b/src/askui/utils/cache_manager.py index 5a288ed1..1b5878ae 100644 --- a/src/askui/utils/cache_manager.py +++ b/src/askui/utils/cache_manager.py @@ -35,15 +35,15 @@ class CacheManager: manager = CacheManager(validator=custom_validator) """ - def __init__(self, validator: Optional[CacheValidator] = None): + def __init__(self, validators: Optional[list[CacheValidator]] = None): """Initialize cache manager. Args: validator: Custom validator or None to use default composite validator """ - if validator is None: + if validators is None: # Default validator with built-in strategies - self.validator = CompositeCacheValidator( + self.validators = CompositeCacheValidator( [ StepFailureCountValidator(max_failures_per_step=3), TotalFailureRateValidator(min_attempts=10, max_failure_rate=0.5), @@ -51,7 +51,7 @@ def __init__(self, validator: Optional[CacheValidator] = None): ] ) else: - self.validator = validator + self.validators = CompositeCacheValidator(validators) def record_execution_attempt( self, @@ -107,7 +107,7 @@ def should_invalidate( Returns: Tuple of (should_invalidate: bool, reason: Optional[str]) """ - return self.validator.should_invalidate(cache_file, step_index) + return self.validators.should_invalidate(cache_file, step_index) def invalidate_cache(self, cache_file: CacheFile, reason: str) -> None: """Mark cache as invalid. diff --git a/tests/unit/utils/test_cache_manager.py b/tests/unit/utils/test_cache_manager.py index e4bfdd91..f33a76ed 100644 --- a/tests/unit/utils/test_cache_manager.py +++ b/tests/unit/utils/test_cache_manager.py @@ -1,6 +1,6 @@ """Tests for cache manager.""" -from datetime import datetime, timedelta, timezone +from datetime import datetime, timezone from unittest.mock import MagicMock import pytest @@ -41,16 +41,16 @@ def sample_cache_file(): def test_cache_manager_default_initialization(): """Test cache manager initializes with default validator.""" manager = CacheManager() - assert manager.validator is not None - assert isinstance(manager.validator, CompositeCacheValidator) - assert len(manager.validator.validators) == 3 # 3 built-in validators + assert manager.validators is not None + assert isinstance(manager.validators, CompositeCacheValidator) + assert len(manager.validators.validators) == 3 # 3 built-in validators def test_cache_manager_custom_validator(): """Test cache manager with custom validator.""" custom_validator = StepFailureCountValidator(max_failures_per_step=5) - manager = CacheManager(validator=custom_validator) - assert manager.validator is custom_validator + manager = CacheManager(validators=[custom_validator]) + assert manager.validators.validators[0] is custom_validator # Record Execution Attempt Tests @@ -179,13 +179,14 @@ def test_record_step_failure_different_steps(sample_cache_file): def test_should_invalidate_delegates_to_validator(sample_cache_file): """Test that should_invalidate delegates to the validator.""" mock_validator = MagicMock(spec=CacheValidator) + mock_validator.get_name.return_value = "Mock Validator" mock_validator.should_invalidate.return_value = (True, "Test reason") - manager = CacheManager(validator=mock_validator) + manager = CacheManager(validators=[mock_validator]) should_inv, reason = manager.should_invalidate(sample_cache_file, step_index=1) assert should_inv is True - assert reason == "Test reason" + assert reason == "Mock Validator: Test reason" mock_validator.should_invalidate.assert_called_once_with(sample_cache_file, 1) @@ -369,10 +370,8 @@ def test_full_workflow_below_threshold(sample_cache_file): def test_workflow_with_custom_validator(sample_cache_file): """Test workflow with custom validator with lower threshold.""" # Custom validator with lower threshold - custom_validator = CompositeCacheValidator( - [StepFailureCountValidator(max_failures_per_step=2)] - ) - manager = CacheManager(validator=custom_validator) + custom_validator = [StepFailureCountValidator(max_failures_per_step=2)] + manager = CacheManager(validators=custom_validator) # Record 2 failures (enough to trigger custom validator) for i in range(2): From fe026e66f0dafd7caddd6db9501492a28a38d8fa Mon Sep 17 00:00:00 2001 From: philipph-askui Date: Mon, 15 Dec 2025 12:01:23 +0100 Subject: [PATCH 13/30] feat(caching): adds example to caching docs --- docs/caching.md | 96 ++++++++++++++++++++++++++++--------------------- 1 file changed, 56 insertions(+), 40 deletions(-) diff --git a/docs/caching.md b/docs/caching.md index 2fb78fdb..be995ac3 100644 --- a/docs/caching.md +++ b/docs/caching.md @@ -921,21 +921,41 @@ with open(".cache/old_cache.json", "w") as f: Here's a complete example showing advanced v0.1 features: ```python +import logging from askui import VisionAgent -from askui.models.shared.settings import CachingSettings, CachedExecutionToolSettings +from askui.models.shared.settings import CachingSettings +from askui.models.shared.tools import Tool +from askui.reporting import SimpleHtmlReporter -# Step 1: Record a workflow with dynamic values -print("Recording user registration flow...") -with VisionAgent() as agent: - agent.act( - goal="Register a new user with email 'john@example.com' and today's date", - caching_settings=CachingSettings( - strategy="write", - cache_dir="test_cache", - filename="user_registration.json" +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger() + + +class PrintTool(Tool): + def __init__(self) -> None: + super().__init__( + name="print_tool", + description=""" + Print something to the console + """, + input_schema={ + "type": "object", + "properties": { + "text": { + "type": "string", + "description": """ + The text that should be printed to the console + """, + }, + }, + "required": ["text"], + }, ) - ) -# Cache file now contains placeholders for email and date + self.is_cacheable = False + + # Agent will detect placeholders and provide new values: + def __call__(self, text: str) -> None: + print(text) # Step 2: Replay with different values print("\nReplaying registration with new user...") @@ -954,36 +974,32 @@ with VisionAgent() as agent: ) ) ) -# Agent will detect placeholders and provide new values: -# - email: "jane@example.com" -# - date: "2025-12-11" -# Step 3: Handle partial failure and resume -print("\nTesting with non-cacheable debug step...") -with VisionAgent() as agent: - agent.act( - goal="Register user and debug if issues occur", - caching_settings=CachingSettings( - strategy="read", - cache_dir="test_cache" - ) + +if __name__ == "__main__": + goal = """Please open a new window in google chrome by right clicking on the icon in the Dock at the bottom of the screen. + Then, navigate to www.askui.com and print a brief summary all the screens that you have seen during the execution. + Describe them one by one, e.g. 1. Screen: Lorem Ipsum, 2. Screen: .... + One sentence per screen is sufficient. + Do not scroll on the screens for that! + Just summarize the content that is or was visible on the screen. + If available, you can use cache file at caching_demo.json + """ + caching_settings = CachingSettings( + strategy="both", cache_dir=".askui_cache", filename="caching_demo.json" ) -# If trajectory includes a non-cacheable debug tool: -# 1. Execution pauses with NEEDS_AGENT status -# 2. Agent manually executes debug tool -# 3. Agent uses ExecuteCachedTrajectory with start_from_step_index to resume -# 4. Remaining steps execute successfully - -# Step 4: Monitor cache health -print("\nChecking cache metadata...") -cache_file = CacheWriter.read_cache_file(Path("test_cache/user_registration.json")) -print(f"Execution attempts: {cache_file.metadata.execution_attempts}") -print(f"Failures: {len(cache_file.metadata.failures)}") -print(f"Valid: {cache_file.metadata.is_valid}") -if cache_file.metadata.failures: - print("Recent failures:") - for failure in cache_file.metadata.failures[-3:]: - print(f" - Step {failure.step_index}: {failure.error_message}") + # first act will create the cache file + with VisionAgent( + display=1, reporters=[SimpleHtmlReporter()], act_tools=[PrintTool()] + ) as agent: + agent.act(goal, caching_settings=caching_settings) + + # second act will read and execute the cached file + goal = goal.replace("www.askui.com", "www.caesr.ai") + with VisionAgent( + display=1, reporters=[SimpleHtmlReporter()], act_tools=[PrintTool()] + ) as agent: + agent.act(goal, caching_settings=caching_settings) ``` ## Future Enhancements From f1d2d36fb9b5acedc79d391fe54d5945067c8be6 Mon Sep 17 00:00:00 2001 From: philipph-askui Date: Mon, 15 Dec 2025 14:20:56 +0100 Subject: [PATCH 14/30] chore(caching): explicitly mention that caching is experimental in the docs --- docs/caching.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/caching.md b/docs/caching.md index be995ac3..9560221c 100644 --- a/docs/caching.md +++ b/docs/caching.md @@ -1,6 +1,10 @@ # Caching (Experimental) +<<<<<<< HEAD **CAUTION: The Caching feature is still in alpha state and subject to change! Use it at your own risk. In case you run into issues, you can disable caching by removing the caching_settings parameter or by explicitly setting the caching_strategy to `no`.** +======= +**CAUTION: THIS FEATURE IS STILL IN ALPHA STATE AND SUBJECT TO CHANGE! USE AT YOUR OWN RISK. IN CASE YOU RUN INTO ISSUES YOU CAN JUST DISABLE THE FEATURE BY REMOVING THE CACHING_SETTINGS OR SETTING THE CACHING STRATEGY TO `no`.** +>>>>>>> bda523e (chore(caching): explicitly mention that caching is experimental in the docs) The caching mechanism allows you to record and replay agent action sequences (trajectories) for faster and more robust test execution. This feature is particularly useful for regression testing, where you want to replay known-good interaction sequences to verify that your application still behaves correctly. From 00795ad0e3e4a3e7b7eccd42e97a722f0cdee3bb Mon Sep 17 00:00:00 2001 From: philipph-askui Date: Mon, 15 Dec 2025 18:48:27 +0100 Subject: [PATCH 15/30] feat(caching): add token usage to cache writer and reporters --- src/askui/models/anthropic/messages_api.py | 4 +- src/askui/models/shared/agent.py | 24 +++++- .../models/shared/agent_message_param.py | 8 ++ src/askui/models/shared/settings.py | 3 +- src/askui/reporting.py | 86 +++++++++++++++++++ src/askui/utils/cache_writer.py | 35 +++++++- 6 files changed, 154 insertions(+), 6 deletions(-) diff --git a/src/askui/models/anthropic/messages_api.py b/src/askui/models/anthropic/messages_api.py index 2b998830..01814164 100644 --- a/src/askui/models/anthropic/messages_api.py +++ b/src/askui/models/anthropic/messages_api.py @@ -67,7 +67,9 @@ def create_message( temperature: float | Omit = omit, ) -> MessageParam: _messages = [ - cast("BetaMessageParam", message.model_dump(exclude={"stop_reason"})) + cast( + "BetaMessageParam", message.model_dump(exclude={"stop_reason", "usage"}) + ) for message in messages ] response = self._client.beta.messages.create( # type: ignore[misc] diff --git a/src/askui/models/shared/agent.py b/src/askui/models/shared/agent.py index 2e11c83c..4587fe04 100644 --- a/src/askui/models/shared/agent.py +++ b/src/askui/models/shared/agent.py @@ -4,7 +4,7 @@ from askui.models.exceptions import MaxTokensExceededError, ModelRefusalError from askui.models.models import ActModel -from askui.models.shared.agent_message_param import MessageParam +from askui.models.shared.agent_message_param import MessageParam, UsageParam from askui.models.shared.agent_on_message_cb import ( NULL_ON_MESSAGE_CB, OnMessageCb, @@ -94,6 +94,8 @@ def _get_agent_response( if message_by_assistant is None: return None + self._accumulate_usage(message_by_assistant.usage) # type: ignore + message_by_assistant_dict = message_by_assistant.model_dump(mode="json") logger.debug(message_by_assistant_dict) truncation_strategy.append_message(message_by_assistant) @@ -252,7 +254,8 @@ def act( tools: ToolCollection | None = None, settings: ActSettings | None = None, ) -> None: - # Reset cache execution state at the start of each act() call + # reset states + self.accumulated_usage: UsageParam = UsageParam() self._cache_execution_manager.reset_state() _settings = settings or ActSettings() @@ -279,6 +282,9 @@ def act( truncation_strategy=truncation_strategy, ) + # Report accumulated usage statistics + self._reporter.add_usage_summary(self.accumulated_usage.model_dump()) + def _use_tools( self, message: MessageParam, @@ -315,3 +321,17 @@ def _handle_stop_reason(self, message: MessageParam, max_tokens: int) -> None: raise MaxTokensExceededError(max_tokens) if message.stop_reason == "refusal": raise ModelRefusalError + + def _accumulate_usage(self, step_usage: UsageParam) -> None: + self.accumulated_usage.input_tokens = ( + self.accumulated_usage.input_tokens or 0 + ) + (step_usage.input_tokens or 0) + self.accumulated_usage.output_tokens = ( + self.accumulated_usage.output_tokens or 0 + ) + (step_usage.output_tokens or 0) + self.accumulated_usage.cache_creation_input_tokens = ( + self.accumulated_usage.cache_creation_input_tokens or 0 + ) + (step_usage.cache_creation_input_tokens or 0) + self.accumulated_usage.cache_read_input_tokens = ( + self.accumulated_usage.cache_read_input_tokens or 0 + ) + (step_usage.cache_read_input_tokens or 0) diff --git a/src/askui/models/shared/agent_message_param.py b/src/askui/models/shared/agent_message_param.py index 6265ab36..b5b82d9e 100644 --- a/src/askui/models/shared/agent_message_param.py +++ b/src/askui/models/shared/agent_message_param.py @@ -105,10 +105,18 @@ class BetaRedactedThinkingBlock(BaseModel): ] +class UsageParam(BaseModel): + input_tokens: int | None = None + output_tokens: int | None = None + cache_creation_input_tokens: int | None = None + cache_read_input_tokens: int | None = None + + class MessageParam(BaseModel): role: Literal["user", "assistant"] content: str | list[ContentBlockParam] stop_reason: StopReason | None = None + usage: UsageParam | None = None __all__ = [ diff --git a/src/askui/models/shared/settings.py b/src/askui/models/shared/settings.py index fc8f3695..b22b1cca 100644 --- a/src/askui/models/shared/settings.py +++ b/src/askui/models/shared/settings.py @@ -12,7 +12,7 @@ from typing_extensions import Literal from askui.models.anthropic.factory import AnthropicApiProvider -from askui.models.shared.agent_message_param import ToolUseBlockParam +from askui.models.shared.agent_message_param import ToolUseBlockParam, UsageParam COMPUTER_USE_20250124_BETA_FLAG = "computer-use-2025-01-24" COMPUTER_USE_20251124_BETA_FLAG = "computer-use-2025-11-24" @@ -69,6 +69,7 @@ class CacheMetadata(BaseModel): created_at: datetime goal: Optional[str] = None last_executed_at: Optional[datetime] = None + token_usage: UsageParam | None = None execution_attempts: int = 0 failures: list[CacheFailure] = Field(default_factory=list) is_valid: bool = True diff --git a/src/askui/reporting.py b/src/askui/reporting.py index 1116f009..99620468 100644 --- a/src/askui/reporting.py +++ b/src/askui/reporting.py @@ -56,6 +56,21 @@ def add_message( """ raise NotImplementedError + @abstractmethod + def add_usage_summary(self, usage: dict[str, int | None]) -> None: + """Add usage statistics summary to the report. + + Called at the end of an act() execution with accumulated token usage. + + Args: + usage (dict[str, int | None]): Accumulated usage statistics containing: + - input_tokens: Total input tokens sent to API + - output_tokens: Total output tokens generated + - cache_creation_input_tokens: Tokens written to prompt cache + - cache_read_input_tokens: Tokens read from prompt cache + """ + raise NotImplementedError + @abstractmethod def generate(self) -> None: """Generates the final report. @@ -81,6 +96,10 @@ def add_message( ) -> None: pass + @override + def add_usage_summary(self, usage: dict[str, int | None]) -> None: + pass + @override def generate(self) -> None: pass @@ -120,6 +139,12 @@ def generate(self) -> None: for report in self._reporters: report.generate() + @override + def add_usage_summary(self, usage: dict[str, int | None]) -> None: + """Add usage summary to all reporters.""" + for reporter in self._reporters: + reporter.add_usage_summary(usage) + class SystemInfo(TypedDict): platform: str @@ -139,6 +164,7 @@ def __init__(self, report_dir: str = "reports") -> None: self.report_dir = Path(report_dir) self.messages: list[dict[str, Any]] = [] self.system_info = self._collect_system_info() + self.usage_summary: dict[str, int | None] | None = None def _collect_system_info(self) -> SystemInfo: """Collect system and Python information""" @@ -179,6 +205,11 @@ def add_message( } self.messages.append(message) + @override + def add_usage_summary(self, usage: dict[str, int | None]) -> None: + """Store usage summary for inclusion in the report.""" + self.usage_summary = usage + @override def generate(self) -> None: """Generate an HTML report file. @@ -684,6 +715,7 @@ def generate(self) -> None: +<<<<<<< HEAD

Conversation Log

@@ -720,6 +752,55 @@ def generate(self) -> None:
+======= + {% if usage_summary %} +

Token Usage

+ + {% if usage_summary.get('input_tokens') is not none %} + + + + + {% endif %} + {% if usage_summary.get('output_tokens') is not none %} + + + + + {% endif %} +
Input Tokens{{ "{:,}".format(usage_summary.get('input_tokens')) }}
Output Tokens{{ "{:,}".format(usage_summary.get('output_tokens')) }}
+ {% endif %} + +

Conversation Log

+ + + + + + + {% for msg in messages %} + + + + + + {% endfor %} +
TimeRoleContent
{{ msg.timestamp.strftime('%H:%M:%S') }}{{ msg.role }} + {% if msg.is_json %} +
+
{{ msg.content }}
+
+ {% else %} + {{ msg.content }} + {% endif %} + {% for image in msg.images %} +
+ Message image + {% endfor %} +
+>>>>>>> c3fbf84 (feat(caching): add token usage to cache writer and reporters) """ @@ -729,6 +810,7 @@ def generate(self) -> None: timestamp=datetime.now(tz=timezone.utc), messages=self.messages, system_info=self.system_info, + usage_summary=self.usage_summary, ) report_path = ( @@ -811,6 +893,10 @@ def add_message( attachment_type=self.allure.attachment_type.PNG, ) + @override + def add_usage_summary(self, usage: dict[str, int | None]) -> None: + """No-op for AllureReporter as usage is not part of Allure reports.""" + @override def generate(self) -> None: """No-op for AllureReporter as reports are generated in real-time.""" diff --git a/src/askui/utils/cache_writer.py b/src/askui/utils/cache_writer.py index 905caccf..8aff0d52 100644 --- a/src/askui/utils/cache_writer.py +++ b/src/askui/utils/cache_writer.py @@ -6,7 +6,11 @@ from askui.locators.serializers import VlmLocatorSerializer from askui.models.anthropic.messages_api import AnthropicMessagesApi from askui.models.model_router import create_api_client -from askui.models.shared.agent_message_param import MessageParam, ToolUseBlockParam +from askui.models.shared.agent_message_param import ( + MessageParam, + ToolUseBlockParam, + UsageParam, +) from askui.models.shared.agent_on_message_cb import OnMessageCbParam from askui.models.shared.settings import ( CacheFile, @@ -39,12 +43,13 @@ def __init__( self._cache_writer_settings = cache_writer_settings or CacheWriterSettings() self._goal = goal self._toolbox: ToolCollection | None = None + self._accumulated_usage = UsageParam() # Set toolbox for cache writer so it can check which tools are cacheable self._toolbox = toolbox def add_message_cb(self, param: OnMessageCbParam) -> MessageParam: - """Add a message to cache.""" + """Add a message to cache and accumulate usage statistics.""" if param.message.role == "assistant": contents = param.message.content if isinstance(contents, list): @@ -53,6 +58,11 @@ def add_message_cb(self, param: OnMessageCbParam) -> MessageParam: self.messages.append(content) if content.name == "execute_cached_executions_tool": self.was_cached_execution = True + + # Accumulate usage from assistant messages + if param.message.usage: + self._accumulate_usage(param.message.usage) + if param.message.stop_reason == "end_turn": self.generate() @@ -69,6 +79,7 @@ def reset(self, file_name: str = "") -> None: file_name += ".json" self.file_name = file_name self.was_cached_execution = False + self._accumulated_usage = UsageParam() def generate(self) -> None: if self.was_cached_execution: @@ -219,6 +230,7 @@ def _generate_cache_file( version="0.1", created_at=datetime.now(tz=timezone.utc), goal=goal_to_save, + token_usage=self._accumulated_usage, ), trajectory=trajectory_to_save, placeholders=placeholders_dict, @@ -228,6 +240,25 @@ def _generate_cache_file( json.dump(cache_file.model_dump(mode="json"), f, indent=4) logger.info(f"Cache file successfully written: {cache_file_path} ") + def _accumulate_usage(self, step_usage: UsageParam) -> None: + """Accumulate usage statistics from a single API call. + + Args: + step_usage: Usage from a single message + """ + self._accumulated_usage.input_tokens = ( + self._accumulated_usage.input_tokens or 0 + ) + (step_usage.input_tokens or 0) + self._accumulated_usage.output_tokens = ( + self._accumulated_usage.output_tokens or 0 + ) + (step_usage.output_tokens or 0) + self._accumulated_usage.cache_creation_input_tokens = ( + self._accumulated_usage.cache_creation_input_tokens or 0 + ) + (step_usage.cache_creation_input_tokens or 0) + self._accumulated_usage.cache_read_input_tokens = ( + self._accumulated_usage.cache_read_input_tokens or 0 + ) + (step_usage.cache_read_input_tokens or 0) + @staticmethod def read_cache_file(cache_file_path: Path) -> CacheFile: """Read cache file with backward compatibility for v0.0 format. From 96e90b8a00a8370d467c9d35bc65b9c18704b408 Mon Sep 17 00:00:00 2001 From: philipph-askui Date: Mon, 15 Dec 2025 19:33:57 +0100 Subject: [PATCH 16/30] fix(caching): fix pdm typecheck in tests --- tests/unit/tools/test_caching_tools.py | 26 +++-- tests/unit/utils/test_cache_manager.py | 68 +++++++----- tests/unit/utils/test_cache_migration.py | 69 ++++++++---- tests/unit/utils/test_cache_validator.py | 110 ++++++++++++------- tests/unit/utils/test_placeholder_handler.py | 22 ++-- tests/unit/utils/test_trajectory_executor.py | 26 ++--- 6 files changed, 194 insertions(+), 127 deletions(-) diff --git a/tests/unit/tools/test_caching_tools.py b/tests/unit/tools/test_caching_tools.py index 620f5072..fad10b28 100644 --- a/tests/unit/tools/test_caching_tools.py +++ b/tests/unit/tools/test_caching_tools.py @@ -113,7 +113,7 @@ def test_retrieve_cached_test_executions_respects_custom_format() -> None: assert "cache2.traj" in result_traj[0] -def test_retrieve_caches_filters_invalid_by_default(tmp_path): +def test_retrieve_caches_filters_invalid_by_default(tmp_path: Path) -> None: """Test that RetrieveCachedTestExecutions filters out invalid caches by default.""" cache_dir = tmp_path / "cache" cache_dir.mkdir() @@ -161,7 +161,7 @@ def test_retrieve_caches_filters_invalid_by_default(tmp_path): assert str(invalid_cache) not in results -def test_retrieve_caches_includes_invalid_when_requested(tmp_path): +def test_retrieve_caches_includes_invalid_when_requested(tmp_path: Path) -> None: """Test that RetrieveCachedTestExecutions includes invalid caches when requested.""" cache_dir = tmp_path / "cache" cache_dir.mkdir() @@ -721,7 +721,9 @@ def test_continue_cached_trajectory_with_placeholders() -> None: assert "resuming from step 1" in result -def test_execute_cached_trajectory_warns_if_invalid(tmp_path, caplog): +def test_execute_cached_trajectory_warns_if_invalid( + tmp_path: Path, caplog: pytest.LogCaptureFixture +) -> None: """Test that ExecuteCachedTrajectory warns when activating with invalid cache.""" import logging @@ -770,7 +772,7 @@ def test_execute_cached_trajectory_warns_if_invalid(tmp_path, caplog): # ============================================================================ -def test_inspect_cache_metadata_shows_basic_info(tmp_path): +def test_inspect_cache_metadata_shows_basic_info(tmp_path: Path) -> None: """Test that InspectCacheMetadata displays basic cache information.""" from askui.tools.caching_tools import InspectCacheMetadata @@ -807,7 +809,7 @@ def test_inspect_cache_metadata_shows_basic_info(tmp_path): assert "current_date" in result -def test_inspect_cache_metadata_shows_failures(tmp_path): +def test_inspect_cache_metadata_shows_failures(tmp_path: Path) -> None: """Test that InspectCacheMetadata displays failure history.""" from askui.tools.caching_tools import InspectCacheMetadata @@ -856,7 +858,7 @@ def test_inspect_cache_metadata_shows_failures(tmp_path): assert "Invalidation Reason: Too many failures at step 1" in result -def test_inspect_cache_metadata_file_not_found(): +def test_inspect_cache_metadata_file_not_found() -> None: """Test that InspectCacheMetadata handles missing files.""" from askui.tools.caching_tools import InspectCacheMetadata @@ -871,7 +873,7 @@ def test_inspect_cache_metadata_file_not_found(): # ============================================================================ -def test_revalidate_cache_marks_invalid_as_valid(tmp_path): +def test_revalidate_cache_marks_invalid_as_valid(tmp_path: Path) -> None: """Test that RevalidateCache marks invalid cache as valid.""" from askui.tools.caching_tools import RevalidateCache @@ -919,7 +921,7 @@ def test_revalidate_cache_marks_invalid_as_valid(tmp_path): assert len(updated_data["metadata"]["failures"]) == 1 -def test_revalidate_cache_already_valid(tmp_path): +def test_revalidate_cache_already_valid(tmp_path: Path) -> None: """Test that RevalidateCache handles already valid cache.""" from askui.tools.caching_tools import RevalidateCache @@ -950,7 +952,7 @@ def test_revalidate_cache_already_valid(tmp_path): assert "No changes made" in result -def test_revalidate_cache_file_not_found(): +def test_revalidate_cache_file_not_found() -> None: """Test that RevalidateCache handles missing files.""" from askui.tools.caching_tools import RevalidateCache @@ -965,7 +967,7 @@ def test_revalidate_cache_file_not_found(): # ============================================================================ -def test_invalidate_cache_marks_valid_as_invalid(tmp_path): +def test_invalidate_cache_marks_valid_as_invalid(tmp_path: Path) -> None: """Test that InvalidateCache marks valid cache as invalid.""" from askui.tools.caching_tools import InvalidateCache @@ -1008,7 +1010,7 @@ def test_invalidate_cache_marks_valid_as_invalid(tmp_path): assert updated_data["metadata"]["execution_attempts"] == 2 -def test_invalidate_cache_updates_reason_if_already_invalid(tmp_path): +def test_invalidate_cache_updates_reason_if_already_invalid(tmp_path: Path) -> None: """Test that InvalidateCache updates reason if already invalid.""" from askui.tools.caching_tools import InvalidateCache @@ -1046,7 +1048,7 @@ def test_invalidate_cache_updates_reason_if_already_invalid(tmp_path): assert updated_data["metadata"]["invalidation_reason"] == "New reason" -def test_invalidate_cache_file_not_found(): +def test_invalidate_cache_file_not_found() -> None: """Test that InvalidateCache handles missing files.""" from askui.tools.caching_tools import InvalidateCache diff --git a/tests/unit/utils/test_cache_manager.py b/tests/unit/utils/test_cache_manager.py index f33a76ed..68f4730c 100644 --- a/tests/unit/utils/test_cache_manager.py +++ b/tests/unit/utils/test_cache_manager.py @@ -16,7 +16,7 @@ @pytest.fixture -def sample_cache_file(): +def sample_cache_file() -> CacheFile: """Create a sample cache file for testing.""" return CacheFile( metadata=CacheMetadata( @@ -38,7 +38,7 @@ def sample_cache_file(): # Initialization Tests -def test_cache_manager_default_initialization(): +def test_cache_manager_default_initialization() -> None: """Test cache manager initializes with default validator.""" manager = CacheManager() assert manager.validators is not None @@ -46,7 +46,7 @@ def test_cache_manager_default_initialization(): assert len(manager.validators.validators) == 3 # 3 built-in validators -def test_cache_manager_custom_validator(): +def test_cache_manager_custom_validator() -> None: """Test cache manager with custom validator.""" custom_validator = StepFailureCountValidator(max_failures_per_step=5) manager = CacheManager(validators=[custom_validator]) @@ -56,7 +56,7 @@ def test_cache_manager_custom_validator(): # Record Execution Attempt Tests -def test_record_execution_attempt_success(sample_cache_file): +def test_record_execution_attempt_success(sample_cache_file: CacheFile) -> None: """Test recording successful execution attempt.""" manager = CacheManager() initial_attempts = sample_cache_file.metadata.execution_attempts @@ -69,7 +69,9 @@ def test_record_execution_attempt_success(sample_cache_file): assert sample_cache_file.metadata.last_executed_at != initial_last_executed -def test_record_execution_attempt_failure_with_info(sample_cache_file): +def test_record_execution_attempt_failure_with_info( + sample_cache_file: CacheFile, +) -> None: """Test recording failed execution attempt with failure info.""" manager = CacheManager() initial_attempts = sample_cache_file.metadata.execution_attempts @@ -91,7 +93,9 @@ def test_record_execution_attempt_failure_with_info(sample_cache_file): assert sample_cache_file.metadata.failures[-1] == failure_info -def test_record_execution_attempt_failure_without_info(sample_cache_file): +def test_record_execution_attempt_failure_without_info( + sample_cache_file: CacheFile, +) -> None: """Test recording failed execution attempt without failure info.""" manager = CacheManager() initial_attempts = sample_cache_file.metadata.execution_attempts @@ -110,7 +114,7 @@ def test_record_execution_attempt_failure_without_info(sample_cache_file): # Record Step Failure Tests -def test_record_step_failure_first_failure(sample_cache_file): +def test_record_step_failure_first_failure(sample_cache_file: CacheFile) -> None: """Test recording the first failure at a step.""" manager = CacheManager() @@ -125,7 +129,9 @@ def test_record_step_failure_first_failure(sample_cache_file): assert failure.failure_count_at_step == 1 -def test_record_step_failure_multiple_at_same_step(sample_cache_file): +def test_record_step_failure_multiple_at_same_step( + sample_cache_file: CacheFile, +) -> None: """Test recording multiple failures at the same step.""" manager = CacheManager() @@ -145,7 +151,7 @@ def test_record_step_failure_multiple_at_same_step(sample_cache_file): assert sample_cache_file.metadata.failures[2].failure_count_at_step == 3 -def test_record_step_failure_different_steps(sample_cache_file): +def test_record_step_failure_different_steps(sample_cache_file: CacheFile) -> None: """Test recording failures at different steps.""" manager = CacheManager() @@ -176,7 +182,7 @@ def test_record_step_failure_different_steps(sample_cache_file): # Should Invalidate Tests -def test_should_invalidate_delegates_to_validator(sample_cache_file): +def test_should_invalidate_delegates_to_validator(sample_cache_file: CacheFile) -> None: """Test that should_invalidate delegates to the validator.""" mock_validator = MagicMock(spec=CacheValidator) mock_validator.get_name.return_value = "Mock Validator" @@ -190,7 +196,7 @@ def test_should_invalidate_delegates_to_validator(sample_cache_file): mock_validator.should_invalidate.assert_called_once_with(sample_cache_file, 1) -def test_should_invalidate_with_default_validator(sample_cache_file): +def test_should_invalidate_with_default_validator(sample_cache_file: CacheFile) -> None: """Test should_invalidate with default built-in validators.""" manager = CacheManager() @@ -208,13 +214,13 @@ def test_should_invalidate_with_default_validator(sample_cache_file): should_inv, reason = manager.should_invalidate(sample_cache_file) assert should_inv is True - assert "Failure rate" in reason + assert "Failure rate" in reason # type: ignore[operator] # Invalidate Cache Tests -def test_invalidate_cache(sample_cache_file): +def test_invalidate_cache(sample_cache_file: CacheFile) -> None: """Test marking cache as invalid.""" manager = CacheManager() assert sample_cache_file.metadata.is_valid is True @@ -223,10 +229,10 @@ def test_invalidate_cache(sample_cache_file): manager.invalidate_cache(sample_cache_file, reason="Test invalidation") assert sample_cache_file.metadata.is_valid is False - assert sample_cache_file.metadata.invalidation_reason == "Test invalidation" + assert sample_cache_file.metadata.invalidation_reason == "Test invalidation" # type: ignore[unreachable] -def test_invalidate_cache_multiple_times(sample_cache_file): +def test_invalidate_cache_multiple_times(sample_cache_file: CacheFile) -> None: """Test invalidating cache multiple times updates reason.""" manager = CacheManager() @@ -240,7 +246,7 @@ def test_invalidate_cache_multiple_times(sample_cache_file): # Mark Cache Valid Tests -def test_mark_cache_valid(sample_cache_file): +def test_mark_cache_valid(sample_cache_file: CacheFile) -> None: """Test marking cache as valid.""" manager = CacheManager() @@ -255,7 +261,7 @@ def test_mark_cache_valid(sample_cache_file): assert sample_cache_file.metadata.invalidation_reason is None -def test_mark_cache_valid_already_valid(sample_cache_file): +def test_mark_cache_valid_already_valid(sample_cache_file: CacheFile) -> None: """Test marking already valid cache as valid.""" manager = CacheManager() assert sample_cache_file.metadata.is_valid is True @@ -269,7 +275,7 @@ def test_mark_cache_valid_already_valid(sample_cache_file): # Get Failure Count for Step Tests -def test_get_failure_count_for_step_no_failures(sample_cache_file): +def test_get_failure_count_for_step_no_failures(sample_cache_file: CacheFile) -> None: """Test getting failure count when no failures exist.""" manager = CacheManager() @@ -277,7 +283,7 @@ def test_get_failure_count_for_step_no_failures(sample_cache_file): assert count == 0 -def test_get_failure_count_for_step_with_failures(sample_cache_file): +def test_get_failure_count_for_step_with_failures(sample_cache_file: CacheFile) -> None: """Test getting failure count for specific step.""" manager = CacheManager() @@ -309,7 +315,9 @@ def test_get_failure_count_for_step_with_failures(sample_cache_file): assert count_step_2 == 1 -def test_get_failure_count_for_step_nonexistent_step(sample_cache_file): +def test_get_failure_count_for_step_nonexistent_step( + sample_cache_file: CacheFile, +) -> None: """Test getting failure count for step that hasn't failed.""" manager = CacheManager() @@ -329,7 +337,7 @@ def test_get_failure_count_for_step_nonexistent_step(sample_cache_file): # Integration Tests -def test_full_workflow_with_failure_detection(sample_cache_file): +def test_full_workflow_with_failure_detection(sample_cache_file: CacheFile) -> None: """Test complete workflow: record failures, detect threshold, invalidate.""" manager = CacheManager() @@ -342,14 +350,14 @@ def test_full_workflow_with_failure_detection(sample_cache_file): # Check if should invalidate should_inv, reason = manager.should_invalidate(sample_cache_file, step_index=1) assert should_inv is True - assert "Step 1 failed 3 times" in reason + assert "Step 1 failed 3 times" in reason # type: ignore[operator] # Invalidate - manager.invalidate_cache(sample_cache_file, reason=reason) + manager.invalidate_cache(sample_cache_file, reason=reason) # type: ignore[arg-type] assert sample_cache_file.metadata.is_valid is False -def test_full_workflow_below_threshold(sample_cache_file): +def test_full_workflow_below_threshold(sample_cache_file: CacheFile) -> None: """Test workflow where failures don't reach threshold.""" manager = CacheManager() @@ -367,11 +375,11 @@ def test_full_workflow_below_threshold(sample_cache_file): assert sample_cache_file.metadata.is_valid is True -def test_workflow_with_custom_validator(sample_cache_file): +def test_workflow_with_custom_validator(sample_cache_file: CacheFile) -> None: """Test workflow with custom validator with lower threshold.""" # Custom validator with lower threshold custom_validator = [StepFailureCountValidator(max_failures_per_step=2)] - manager = CacheManager(validators=custom_validator) + manager = CacheManager(validators=custom_validator) # type: ignore[arg-type] # Record 2 failures (enough to trigger custom validator) for i in range(2): @@ -381,10 +389,12 @@ def test_workflow_with_custom_validator(sample_cache_file): should_inv, reason = manager.should_invalidate(sample_cache_file, step_index=1) assert should_inv is True - assert "Step 1 failed 2 times" in reason + assert "Step 1 failed 2 times" in reason # type: ignore[operator] -def test_workflow_successful_execution_updates_timestamp(sample_cache_file): +def test_workflow_successful_execution_updates_timestamp( + sample_cache_file: CacheFile, +) -> None: """Test that successful execution updates last_executed_at.""" manager = CacheManager() @@ -396,4 +406,4 @@ def test_workflow_successful_execution_updates_timestamp(sample_cache_file): manager.record_execution_attempt(sample_cache_file, success=True) assert sample_cache_file.metadata.last_executed_at is not None - assert sample_cache_file.metadata.execution_attempts == 1 + assert sample_cache_file.metadata.execution_attempts == 1 # type: ignore[unreachable] diff --git a/tests/unit/utils/test_cache_migration.py b/tests/unit/utils/test_cache_migration.py index bf31b4ca..575e6468 100644 --- a/tests/unit/utils/test_cache_migration.py +++ b/tests/unit/utils/test_cache_migration.py @@ -3,6 +3,7 @@ import json from datetime import datetime, timezone from pathlib import Path +from typing import Any import pytest @@ -10,7 +11,7 @@ @pytest.fixture -def temp_cache_dir(tmp_path): +def temp_cache_dir(tmp_path: Path) -> Path: """Create a temporary cache directory.""" cache_dir = tmp_path / "caches" cache_dir.mkdir() @@ -18,7 +19,7 @@ def temp_cache_dir(tmp_path): @pytest.fixture -def v1_cache_data(): +def v1_cache_data() -> list[dict[str, Any]]: """Sample v0.0 cache data (just a trajectory list).""" return [ {"id": "1", "name": "click", "input": {"x": 100}, "type": "tool_use"}, @@ -27,7 +28,7 @@ def v1_cache_data(): @pytest.fixture -def v2_cache_data(): +def v2_cache_data() -> dict[str, Any]: """Sample v0.1 cache data (with metadata).""" return { "metadata": { @@ -49,7 +50,7 @@ def v2_cache_data(): # Initialization Tests -def test_cache_migration_initialization(): +def test_cache_migration_initialization() -> None: """Test CacheMigration initializes with correct defaults.""" migration = CacheMigration() assert migration.backup is False @@ -59,7 +60,7 @@ def test_cache_migration_initialization(): assert migration.error_count == 0 -def test_cache_migration_initialization_with_backup(): +def test_cache_migration_initialization_with_backup() -> None: """Test CacheMigration initializes with backup enabled.""" migration = CacheMigration(backup=True, backup_suffix=".bak") assert migration.backup is True @@ -69,7 +70,9 @@ def test_cache_migration_initialization_with_backup(): # Single File Migration Tests -def test_migrate_file_v1_to_v2(temp_cache_dir, v1_cache_data): +def test_migrate_file_v1_to_v2( + temp_cache_dir: Path, v1_cache_data: list[dict[str, Any]] +) -> None: """Test migrating a v0.0 cache file to v0.1.""" cache_file = temp_cache_dir / "test.json" with cache_file.open("w") as f: @@ -92,7 +95,9 @@ def test_migrate_file_v1_to_v2(temp_cache_dir, v1_cache_data): assert "placeholders" in data -def test_migrate_file_already_v2(temp_cache_dir, v2_cache_data): +def test_migrate_file_already_v2( + temp_cache_dir: Path, v2_cache_data: dict[str, Any] +) -> None: """Test that v0.1 files are skipped.""" cache_file = temp_cache_dir / "test.json" with cache_file.open("w") as f: @@ -105,7 +110,9 @@ def test_migrate_file_already_v2(temp_cache_dir, v2_cache_data): assert "Already v0.1" in message -def test_migrate_file_dry_run(temp_cache_dir, v1_cache_data): +def test_migrate_file_dry_run( + temp_cache_dir: Path, v1_cache_data: list[dict[str, Any]] +) -> None: """Test dry run doesn't modify files.""" cache_file = temp_cache_dir / "test.json" with cache_file.open("w") as f: @@ -124,7 +131,9 @@ def test_migrate_file_dry_run(temp_cache_dir, v1_cache_data): assert cache_file.read_text() == original_content -def test_migrate_file_creates_backup(temp_cache_dir, v1_cache_data): +def test_migrate_file_creates_backup( + temp_cache_dir: Path, v1_cache_data: list[dict[str, Any]] +) -> None: """Test that backup is created when requested.""" cache_file = temp_cache_dir / "test.json" with cache_file.open("w") as f: @@ -145,7 +154,7 @@ def test_migrate_file_creates_backup(temp_cache_dir, v1_cache_data): assert backup_data == v1_cache_data -def test_migrate_file_not_found(temp_cache_dir): +def test_migrate_file_not_found(temp_cache_dir: Path) -> None: """Test handling of missing file.""" cache_file = temp_cache_dir / "nonexistent.json" @@ -156,7 +165,7 @@ def test_migrate_file_not_found(temp_cache_dir): assert "File not found" in message -def test_migrate_file_invalid_json(temp_cache_dir): +def test_migrate_file_invalid_json(temp_cache_dir: Path) -> None: """Test handling of invalid JSON.""" cache_file = temp_cache_dir / "invalid.json" cache_file.write_text("not valid json{") @@ -171,7 +180,9 @@ def test_migrate_file_invalid_json(temp_cache_dir): # Directory Migration Tests -def test_migrate_directory_multiple_files(temp_cache_dir, v1_cache_data): +def test_migrate_directory_multiple_files( + temp_cache_dir: Path, v1_cache_data: list[dict[str, Any]] +) -> None: """Test migrating multiple files in a directory.""" # Create several v0.0 cache files for i in range(3): @@ -188,7 +199,11 @@ def test_migrate_directory_multiple_files(temp_cache_dir, v1_cache_data): assert stats["errors"] == 0 -def test_migrate_directory_mixed_versions(temp_cache_dir, v1_cache_data, v2_cache_data): +def test_migrate_directory_mixed_versions( + temp_cache_dir: Path, + v1_cache_data: list[dict[str, Any]], + v2_cache_data: dict[str, Any], +) -> None: """Test migrating directory with mixed v0.0 and v0.1 files.""" # Create v0.0 files for i in range(2): @@ -211,7 +226,9 @@ def test_migrate_directory_mixed_versions(temp_cache_dir, v1_cache_data, v2_cach assert stats["errors"] == 0 -def test_migrate_directory_dry_run(temp_cache_dir, v1_cache_data): +def test_migrate_directory_dry_run( + temp_cache_dir: Path, v1_cache_data: list[dict[str, Any]] +) -> None: """Test dry run on directory doesn't modify files.""" cache_file = temp_cache_dir / "test.json" with cache_file.open("w") as f: @@ -227,7 +244,9 @@ def test_migrate_directory_dry_run(temp_cache_dir, v1_cache_data): assert cache_file.read_text() == original_content -def test_migrate_directory_with_pattern(temp_cache_dir, v1_cache_data): +def test_migrate_directory_with_pattern( + temp_cache_dir: Path, v1_cache_data: list[dict[str, Any]] +) -> None: """Test migrating directory with custom file pattern.""" # Create files with different extensions for ext in ["json", "cache", "txt"]: @@ -245,7 +264,7 @@ def test_migrate_directory_with_pattern(temp_cache_dir, v1_cache_data): assert stats["migrated"] == 1 -def test_migrate_directory_not_found(): +def test_migrate_directory_not_found() -> None: """Test handling of non-existent directory.""" migration = CacheMigration() @@ -253,7 +272,7 @@ def test_migrate_directory_not_found(): migration.migrate_directory(Path("/nonexistent/directory")) -def test_migrate_directory_empty(temp_cache_dir): +def test_migrate_directory_empty(temp_cache_dir: Path) -> None: """Test migrating empty directory.""" migration = CacheMigration() stats = migration.migrate_directory(temp_cache_dir, dry_run=False) @@ -264,7 +283,9 @@ def test_migrate_directory_empty(temp_cache_dir): assert stats["errors"] == 0 -def test_migrate_directory_with_errors(temp_cache_dir, v1_cache_data): +def test_migrate_directory_with_errors( + temp_cache_dir: Path, v1_cache_data: list[dict[str, Any]] +) -> None: """Test directory migration handles errors gracefully.""" # Create valid v0.0 file valid_file = temp_cache_dir / "valid.json" @@ -284,7 +305,9 @@ def test_migrate_directory_with_errors(temp_cache_dir, v1_cache_data): assert stats["skipped"] == 0 -def test_migrate_directory_creates_backups(temp_cache_dir, v1_cache_data): +def test_migrate_directory_creates_backups( + temp_cache_dir: Path, v1_cache_data: list[dict[str, Any]] +) -> None: """Test directory migration creates backups for all files.""" # Create v0.0 files for i in range(2): @@ -306,7 +329,9 @@ def test_migrate_directory_creates_backups(temp_cache_dir, v1_cache_data): # Integration Tests -def test_full_migration_workflow(temp_cache_dir, v1_cache_data): +def test_full_migration_workflow( + temp_cache_dir: Path, v1_cache_data: list[dict[str, Any]] +) -> None: """Test complete migration workflow from v0.0 to v0.1.""" # Create v0.0 cache cache_file = temp_cache_dir / "workflow_test.json" @@ -339,7 +364,9 @@ def test_full_migration_workflow(temp_cache_dir, v1_cache_data): assert "Already v0.1" in message -def test_migration_preserves_trajectory_data(temp_cache_dir, v1_cache_data): +def test_migration_preserves_trajectory_data( + temp_cache_dir: Path, v1_cache_data: list[dict[str, Any]] +) -> None: """Test that migration preserves all trajectory data.""" cache_file = temp_cache_dir / "preserve_test.json" with cache_file.open("w") as f: diff --git a/tests/unit/utils/test_cache_validator.py b/tests/unit/utils/test_cache_validator.py index e55d0acb..cb3ad894 100644 --- a/tests/unit/utils/test_cache_validator.py +++ b/tests/unit/utils/test_cache_validator.py @@ -16,7 +16,7 @@ @pytest.fixture -def sample_cache_file(): +def sample_cache_file() -> CacheFile: """Create a sample cache file for testing.""" return CacheFile( metadata=CacheMetadata( @@ -38,7 +38,9 @@ def sample_cache_file(): # StepFailureCountValidator Tests -def test_step_failure_count_validator_below_threshold(sample_cache_file): +def test_step_failure_count_validator_below_threshold( + sample_cache_file: CacheFile, +) -> None: """Test validator does not invalidate when failures are below threshold.""" validator = StepFailureCountValidator(max_failures_per_step=3) @@ -63,7 +65,9 @@ def test_step_failure_count_validator_below_threshold(sample_cache_file): assert reason is None -def test_step_failure_count_validator_at_threshold(sample_cache_file): +def test_step_failure_count_validator_at_threshold( + sample_cache_file: CacheFile, +) -> None: """Test validator invalidates when failures reach threshold.""" validator = StepFailureCountValidator(max_failures_per_step=3) @@ -80,10 +84,12 @@ def test_step_failure_count_validator_at_threshold(sample_cache_file): should_inv, reason = validator.should_invalidate(sample_cache_file, step_index=1) assert should_inv is True - assert "Step 1 failed 3 times" in reason + assert "Step 1 failed 3 times" in reason # type: ignore[operator] -def test_step_failure_count_validator_different_steps(sample_cache_file): +def test_step_failure_count_validator_different_steps( + sample_cache_file: CacheFile, +) -> None: """Test validator only counts failures for specific step.""" validator = StepFailureCountValidator(max_failures_per_step=3) @@ -118,7 +124,9 @@ def test_step_failure_count_validator_different_steps(sample_cache_file): assert should_inv is False -def test_step_failure_count_validator_no_step_index(sample_cache_file): +def test_step_failure_count_validator_no_step_index( + sample_cache_file: CacheFile, +) -> None: """Test validator returns False when no step_index provided.""" validator = StepFailureCountValidator(max_failures_per_step=3) @@ -127,7 +135,7 @@ def test_step_failure_count_validator_no_step_index(sample_cache_file): assert reason is None -def test_step_failure_count_validator_name(): +def test_step_failure_count_validator_name() -> None: """Test validator returns correct name.""" validator = StepFailureCountValidator() assert validator.get_name() == "StepFailureCount" @@ -136,7 +144,9 @@ def test_step_failure_count_validator_name(): # TotalFailureRateValidator Tests -def test_total_failure_rate_validator_below_min_attempts(sample_cache_file): +def test_total_failure_rate_validator_below_min_attempts( + sample_cache_file: CacheFile, +) -> None: """Test validator does not check rate below minimum attempts.""" validator = TotalFailureRateValidator(min_attempts=10, max_failure_rate=0.5) @@ -155,7 +165,9 @@ def test_total_failure_rate_validator_below_min_attempts(sample_cache_file): assert should_inv is False # Too few attempts to judge -def test_total_failure_rate_validator_above_threshold(sample_cache_file): +def test_total_failure_rate_validator_above_threshold( + sample_cache_file: CacheFile, +) -> None: """Test validator invalidates when failure rate exceeds threshold.""" validator = TotalFailureRateValidator(min_attempts=10, max_failure_rate=0.5) @@ -172,11 +184,13 @@ def test_total_failure_rate_validator_above_threshold(sample_cache_file): should_inv, reason = validator.should_invalidate(sample_cache_file) assert should_inv is True - assert "60.0%" in reason - assert "50.0%" in reason + assert "60.0%" in reason # type: ignore[operator] + assert "50.0%" in reason # type: ignore[operator] -def test_total_failure_rate_validator_below_threshold(sample_cache_file): +def test_total_failure_rate_validator_below_threshold( + sample_cache_file: CacheFile, +) -> None: """Test validator does not invalidate when rate is acceptable.""" validator = TotalFailureRateValidator(min_attempts=10, max_failure_rate=0.5) @@ -195,7 +209,9 @@ def test_total_failure_rate_validator_below_threshold(sample_cache_file): assert should_inv is False -def test_total_failure_rate_validator_zero_attempts(sample_cache_file): +def test_total_failure_rate_validator_zero_attempts( + sample_cache_file: CacheFile, +) -> None: """Test validator handles zero attempts correctly.""" validator = TotalFailureRateValidator(min_attempts=10, max_failure_rate=0.5) @@ -206,7 +222,7 @@ def test_total_failure_rate_validator_zero_attempts(sample_cache_file): assert should_inv is False -def test_total_failure_rate_validator_name(): +def test_total_failure_rate_validator_name() -> None: """Test validator returns correct name.""" validator = TotalFailureRateValidator() assert validator.get_name() == "TotalFailureRate" @@ -215,7 +231,7 @@ def test_total_failure_rate_validator_name(): # StaleCacheValidator Tests -def test_stale_cache_validator_not_stale(sample_cache_file): +def test_stale_cache_validator_not_stale(sample_cache_file: CacheFile) -> None: """Test validator does not invalidate recent cache.""" validator = StaleCacheValidator(max_age_days=30) @@ -235,7 +251,7 @@ def test_stale_cache_validator_not_stale(sample_cache_file): assert should_inv is False -def test_stale_cache_validator_is_stale(sample_cache_file): +def test_stale_cache_validator_is_stale(sample_cache_file: CacheFile) -> None: """Test validator invalidates old cache with failures.""" validator = StaleCacheValidator(max_age_days=30) @@ -253,10 +269,12 @@ def test_stale_cache_validator_is_stale(sample_cache_file): should_inv, reason = validator.should_invalidate(sample_cache_file) assert should_inv is True - assert "35 days" in reason + assert "35 days" in reason # type: ignore[operator] -def test_stale_cache_validator_old_but_no_failures(sample_cache_file): +def test_stale_cache_validator_old_but_no_failures( + sample_cache_file: CacheFile, +) -> None: """Test validator does not invalidate old cache without failures.""" validator = StaleCacheValidator(max_age_days=30) @@ -269,7 +287,7 @@ def test_stale_cache_validator_old_but_no_failures(sample_cache_file): assert should_inv is False # Old but no failures = still valid -def test_stale_cache_validator_never_executed(sample_cache_file): +def test_stale_cache_validator_never_executed(sample_cache_file: CacheFile) -> None: """Test validator handles cache that was never executed.""" validator = StaleCacheValidator(max_age_days=30) @@ -287,7 +305,7 @@ def test_stale_cache_validator_never_executed(sample_cache_file): assert should_inv is False # Never executed = can't be stale -def test_stale_cache_validator_name(): +def test_stale_cache_validator_name() -> None: """Test validator returns correct name.""" validator = StaleCacheValidator() assert validator.get_name() == "StaleCache" @@ -296,7 +314,7 @@ def test_stale_cache_validator_name(): # CompositeCacheValidator Tests -def test_composite_validator_empty(): +def test_composite_validator_empty() -> None: """Test composite validator with no validators.""" validator = CompositeCacheValidator([]) cache_file = CacheFile( @@ -315,7 +333,9 @@ def test_composite_validator_empty(): assert reason is None -def test_composite_validator_single_validator_triggers(sample_cache_file): +def test_composite_validator_single_validator_triggers( + sample_cache_file: CacheFile, +) -> None: """Test composite validator with one validator that triggers.""" step_validator = StepFailureCountValidator(max_failures_per_step=2) composite = CompositeCacheValidator([step_validator]) @@ -333,10 +353,12 @@ def test_composite_validator_single_validator_triggers(sample_cache_file): should_inv, reason = composite.should_invalidate(sample_cache_file, step_index=1) assert should_inv is True - assert "StepFailureCount" in reason + assert "StepFailureCount" in reason # type: ignore[operator] -def test_composite_validator_multiple_validators_all_pass(sample_cache_file): +def test_composite_validator_multiple_validators_all_pass( + sample_cache_file: CacheFile, +) -> None: """Test composite validator when all validators pass.""" composite = CompositeCacheValidator( [ @@ -359,7 +381,9 @@ def test_composite_validator_multiple_validators_all_pass(sample_cache_file): assert should_inv is False -def test_composite_validator_multiple_validators_one_triggers(sample_cache_file): +def test_composite_validator_multiple_validators_one_triggers( + sample_cache_file: CacheFile, +) -> None: """Test composite validator when one validator triggers.""" composite = CompositeCacheValidator( [ @@ -381,11 +405,13 @@ def test_composite_validator_multiple_validators_one_triggers(sample_cache_file) should_inv, reason = composite.should_invalidate(sample_cache_file, step_index=1) assert should_inv is True - assert "StepFailureCount" in reason - assert "Step 1 failed 3 times" in reason + assert "StepFailureCount" in reason # type: ignore[operator] + assert "Step 1 failed 3 times" in reason # type: ignore[operator] -def test_composite_validator_multiple_validators_multiple_trigger(sample_cache_file): +def test_composite_validator_multiple_validators_multiple_trigger( + sample_cache_file: CacheFile, +) -> None: """Test composite validator when multiple validators trigger.""" composite = CompositeCacheValidator( [ @@ -407,12 +433,12 @@ def test_composite_validator_multiple_validators_multiple_trigger(sample_cache_f should_inv, reason = composite.should_invalidate(sample_cache_file, step_index=1) assert should_inv is True - assert "StepFailureCount" in reason - assert "TotalFailureRate" in reason - assert ";" in reason # Multiple reasons combined + assert "StepFailureCount" in reason # type: ignore[operator] + assert "TotalFailureRate" in reason # type: ignore[operator] + assert ";" in reason # type: ignore[operator] # Multiple reasons combined -def test_composite_validator_add_validator(sample_cache_file): +def test_composite_validator_add_validator(sample_cache_file: CacheFile) -> None: """Test adding validator to composite after initialization.""" composite = CompositeCacheValidator([]) assert len(composite.validators) == 0 @@ -434,7 +460,7 @@ def test_composite_validator_add_validator(sample_cache_file): assert should_inv is True -def test_composite_validator_name(): +def test_composite_validator_name() -> None: """Test composite validator returns correct name.""" composite = CompositeCacheValidator([]) assert composite.get_name() == "CompositeValidator" @@ -449,27 +475,29 @@ class CustomTestValidator(CacheValidator): def __init__(self, should_trigger: bool = False): self.should_trigger = should_trigger - def should_invalidate(self, cache_file, step_index=None): + def should_invalidate( + self, cache_file: CacheFile, step_index: int | None = None + ) -> tuple[bool, str | None]: if self.should_trigger: return True, "Custom validation failed" return False, None - def get_name(self): + def get_name(self) -> str: return "CustomTest" -def test_custom_validator_integration(sample_cache_file): +def test_custom_validator_integration(sample_cache_file: CacheFile) -> None: """Test that custom validators work with composite.""" custom = CustomTestValidator(should_trigger=True) composite = CompositeCacheValidator([custom]) should_inv, reason = composite.should_invalidate(sample_cache_file) assert should_inv is True - assert "CustomTest" in reason - assert "Custom validation failed" in reason + assert "CustomTest" in reason # type: ignore[operator] + assert "Custom validation failed" in reason # type: ignore[operator] -def test_custom_validator_with_built_in(sample_cache_file): +def test_custom_validator_with_built_in(sample_cache_file: CacheFile) -> None: """Test custom validator alongside built-in validators.""" custom = CustomTestValidator(should_trigger=False) step_validator = StepFailureCountValidator(max_failures_per_step=1) @@ -488,5 +516,5 @@ def test_custom_validator_with_built_in(sample_cache_file): should_inv, reason = composite.should_invalidate(sample_cache_file, step_index=1) assert should_inv is True - assert "StepFailureCount" in reason - assert "CustomTest" not in reason # Custom didn't trigger + assert "StepFailureCount" in reason # type: ignore[operator] + assert "CustomTest" not in reason # type: ignore[operator] # Custom didn't trigger diff --git a/tests/unit/utils/test_placeholder_handler.py b/tests/unit/utils/test_placeholder_handler.py index c0acd89d..e59c251e 100644 --- a/tests/unit/utils/test_placeholder_handler.py +++ b/tests/unit/utils/test_placeholder_handler.py @@ -241,7 +241,7 @@ def test_substitute_placeholders_simple_string() -> None: tool_block, {"current_date": "2025-12-11"} ) - assert result.input["text"] == "Today is 2025-12-11" + assert result.input["text"] == "Today is 2025-12-11" # type: ignore[index] assert result.id == tool_block.id assert result.name == tool_block.name @@ -262,7 +262,7 @@ def test_substitute_placeholders_multiple() -> None: tool_block, {"user_name": "Alice", "current_date": "2025-12-11"} ) - assert result.input["text"] == "Hello Alice, date is 2025-12-11" + assert result.input["text"] == "Hello Alice, date is 2025-12-11" # type: ignore[index] def test_substitute_placeholders_nested_dict() -> None: @@ -281,8 +281,8 @@ def test_substitute_placeholders_nested_dict() -> None: tool_block, {"var1": "value1", "var2": "value2"} ) - assert result.input["outer"]["inner"]["text"] == "Value: value1" - assert result.input["another"] == "value2" + assert result.input["outer"]["inner"]["text"] == "Value: value1" # type: ignore[index] + assert result.input["another"] == "value2" # type: ignore[index] def test_substitute_placeholders_in_list() -> None: @@ -298,9 +298,9 @@ def test_substitute_placeholders_in_list() -> None: tool_block, {"item1": "value1", "item2": "value2"} ) - assert result.input["items"][0] == "value1" - assert result.input["items"][1] == "static" - assert result.input["items"][2]["nested"] == "value2" + assert result.input["items"][0] == "value1" # type: ignore[index] + assert result.input["items"][1] == "static" # type: ignore[index] + assert result.input["items"][2]["nested"] == "value2" # type: ignore[index] def test_substitute_placeholders_no_change_if_no_placeholders() -> None: @@ -328,7 +328,7 @@ def test_substitute_placeholders_partial_substitution() -> None: result = PlaceholderHandler.substitute_placeholders(tool_block, {"var1": "value1"}) - assert result.input["text"] == "value1 and {{var2}}" + assert result.input["text"] == "value1 and {{var2}}" # type: ignore[index] def test_substitute_placeholders_preserves_original() -> None: @@ -340,7 +340,7 @@ def test_substitute_placeholders_preserves_original() -> None: type="tool_use", ) - original_input = tool_block.input.copy() + original_input = tool_block.input.copy() # type: ignore[attr-defined] PlaceholderHandler.substitute_placeholders(tool_block, {"var1": "value1"}) # Original should be unchanged @@ -361,7 +361,7 @@ def test_substitute_placeholders_with_special_characters() -> None: tool_block, {"pattern": r".*[test]$"} ) - assert result.input["text"] == r"Pattern: .*[test]$" + assert result.input["text"] == r"Pattern: .*[test]$" # type: ignore[index] def test_substitute_placeholders_same_placeholder_multiple_times() -> None: @@ -375,4 +375,4 @@ def test_substitute_placeholders_same_placeholder_multiple_times() -> None: result = PlaceholderHandler.substitute_placeholders(tool_block, {"var": "value"}) - assert result.input["text"] == "value is value is value" + assert result.input["text"] == "value is value is value" # type: ignore[index] diff --git a/tests/unit/utils/test_trajectory_executor.py b/tests/unit/utils/test_trajectory_executor.py index 46428305..55bf649a 100644 --- a/tests/unit/utils/test_trajectory_executor.py +++ b/tests/unit/utils/test_trajectory_executor.py @@ -498,8 +498,8 @@ def test_trajectory_executor_message_history_contains_tool_use_id() -> None: result = executor.execute_next_step() # Get tool use and tool result - tool_use = result.message_history[0].content[0] # type: ignore - tool_result = result.message_history[1].content[0] # type: ignore + tool_use = result.message_history[0].content[0] + tool_result = result.message_history[1].content[0] # Verify tool_use_id matches assert isinstance(tool_use, ToolUseBlockParam) @@ -528,7 +528,7 @@ def test_trajectory_executor_message_history_includes_text_result() -> None: result = executor.execute_next_step() # Get tool result - tool_result_block = result.message_history[1].content[0] # type: ignore + tool_result_block = result.message_history[1].content[0] assert isinstance(tool_result_block, ToolResultBlockParam) # Verify text content is included @@ -643,12 +643,12 @@ def test_trajectory_executor_message_history_order() -> None: assert final_history[3].role == "user" # Step 2 result # Verify step order in tool use - tool_use_1 = final_history[0].content[0] # type: ignore - tool_use_2 = final_history[2].content[0] # type: ignore + tool_use_1 = final_history[0].content[0] + tool_use_2 = final_history[2].content[0] assert isinstance(tool_use_1, ToolUseBlockParam) assert isinstance(tool_use_2, ToolUseBlockParam) - assert tool_use_1.input == {"step": 1} # type: ignore - assert tool_use_2.input == {"step": 2} # type: ignore + assert tool_use_1.input == {"step": 1} + assert tool_use_2.input == {"step": 2} # Visual Validation Extension Point Tests @@ -809,11 +809,11 @@ def test_validate_step_visually_hook_called_when_enabled() -> None: original_validate = executor.validate_step_visually validation_called = [] - def mock_validate(step, screenshot=None): + def mock_validate(step, screenshot=None) -> tuple[bool, str | None]: # type: ignore[no-untyped-def] validation_called.append(step) return original_validate(step, screenshot) - executor.validate_step_visually = mock_validate + executor.validate_step_visually = mock_validate # type: ignore[assignment] # Execute trajectory results = executor.execute_all() @@ -844,13 +844,13 @@ def test_visual_validation_fields_on_tool_use_block() -> None: ) # Fields should be accessible - assert step.visual_hash == "a8f3c9e14b7d2056" - assert step.visual_validation_required is True + assert step.visual_hash == "a8f3c9e14b7d2056" # type: ignore[attr-defined] + assert step.visual_validation_required is True # type: ignore[attr-defined] # Default values should work step_default = ToolUseBlockParam( id="2", name="type", input={"text": "hello"}, type="tool_use" ) - assert step_default.visual_hash is None - assert step_default.visual_validation_required is False + assert step_default.visual_hash is None # type: ignore[attr-defined] + assert step_default.visual_validation_required is False # type: ignore[attr-defined] From f1cdc8e995d03a9505635417078c4cfc4d728d88 Mon Sep 17 00:00:00 2001 From: philipph-askui Date: Tue, 16 Dec 2025 11:22:15 +0100 Subject: [PATCH 17/30] fix(caching): fix linting errors --- src/askui/agent_base.py | 3 +- src/askui/prompts/caching.py | 71 +++++++++++------ src/askui/tools/caching_tools.py | 80 +++++++++++--------- src/askui/utils/cache_execution_manager.py | 12 ++- src/askui/utils/cache_migration.py | 57 +++++++------- src/askui/utils/cache_validator.py | 16 ++-- src/askui/utils/cache_writer.py | 32 ++++---- src/askui/utils/placeholder_handler.py | 26 +++---- src/askui/utils/placeholder_identifier.py | 47 +++++++----- src/askui/utils/trajectory_executor.py | 33 ++++---- tests/unit/tools/test_caching_tools.py | 1 - tests/unit/utils/test_cache_validator.py | 2 +- tests/unit/utils/test_placeholder_handler.py | 1 + tests/unit/utils/test_trajectory_executor.py | 5 +- 14 files changed, 221 insertions(+), 165 deletions(-) diff --git a/src/askui/agent_base.py b/src/askui/agent_base.py index db5b5d72..a734c3aa 100644 --- a/src/askui/agent_base.py +++ b/src/askui/agent_base.py @@ -313,7 +313,8 @@ def act( _caching_settings, _settings, _tools, on_message, goal_str ) logger.info( - f"Starting agent act with caching enabled (strategy={_caching_settings.strategy})" + "Starting agent act with caching enabled (strategy=%s)", + _caching_settings.strategy, ) self._model_router.act( diff --git a/src/askui/prompts/caching.py b/src/askui/prompts/caching.py index bf643918..20c351b6 100644 --- a/src/askui/prompts/caching.py +++ b/src/askui/prompts/caching.py @@ -13,50 +13,73 @@ "\n" " EXECUTING TRAJECTORIES:\n" " - Use ExecuteCachedTrajectory to execute a cached trajectory\n" - " - You will see all screenshots and results from the execution in the message history\n" + " - You will see all screenshots and results from the execution in " + "the message history\n" " - After execution completes, verify the results are correct\n" - " - If execution fails partway, you'll see exactly where it failed and can decide how to proceed\n" + " - If execution fails partway, you'll see exactly where it failed " + "and can decide how to proceed\n" "\n" " PLACEHOLDERS:\n" - " - Trajectories may contain dynamic placeholders like {{current_date}} or {{user_name}}\n" - " - When executing a trajectory, check if it requires placeholder values\n" - " - Provide placeholder values using the placeholder_values parameter as a dictionary\n" - " - Example: ExecuteCachedTrajectory(trajectory_file='test.json', placeholder_values={'current_date': '2025-12-11'})\n" - " - If required placeholders are missing, execution will fail with a clear error message\n" + " - Trajectories may contain dynamic placeholders like " + "{{current_date}} or {{user_name}}\n" + " - When executing a trajectory, check if it requires " + "placeholder values\n" + " - Provide placeholder values using the placeholder_values " + "parameter as a dictionary\n" + " - Example: ExecuteCachedTrajectory(trajectory_file='test.json', " + "placeholder_values={'current_date': '2025-12-11'})\n" + " - If required placeholders are missing, execution will fail with " + "a clear error message\n" "\n" " NON-CACHEABLE STEPS:\n" - " - Some tools cannot be cached and require your direct execution (e.g., print_debug, contextual decisions)\n" - " - When trajectory execution reaches a non-cacheable step, it will pause and return control to you\n" - " - You'll receive a NEEDS_AGENT status with the current step index\n" - " - Execute the non-cacheable step manually using your regular tools\n" - " - After completing the non-cacheable step, continue the trajectory using ExecuteCachedTrajectory with start_from_step_index\n" + " - Some tools cannot be cached and require your direct execution " + "(e.g., print_debug, contextual decisions)\n" + " - When trajectory execution reaches a non-cacheable step, it will " + "pause and return control to you\n" + " - You'll receive a NEEDS_AGENT status with the current " + "step index\n" + " - Execute the non-cacheable step manually using your " + "regular tools\n" + " - After completing the non-cacheable step, continue the trajectory " + "using ExecuteCachedTrajectory with start_from_step_index\n" "\n" " CONTINUING TRAJECTORIES:\n" - " - Use ExecuteCachedTrajectory with start_from_step_index to resume execution after handling a non-cacheable step\n" - " - Provide the same trajectory file and the step index where execution should continue\n" - " - Example: ExecuteCachedTrajectory(trajectory_file='test.json', start_from_step_index=5, placeholder_values={...})\n" + " - Use ExecuteCachedTrajectory with start_from_step_index to resume " + "execution after handling a non-cacheable step\n" + " - Provide the same trajectory file and the step index where " + "execution should continue\n" + " - Example: ExecuteCachedTrajectory(trajectory_file='test.json', " + "start_from_step_index=5, placeholder_values={...})\n" " - The tool will execute remaining steps from that index onwards\n" "\n" " FAILURE HANDLING:\n" - " - If a trajectory fails during execution, you'll see the error message and the step where it failed\n" - " - Analyze the failure: Was it due to UI changes, timing issues, or incorrect state?\n" + " - If a trajectory fails during execution, you'll see the error " + "message and the step where it failed\n" + " - Analyze the failure: Was it due to UI changes, timing issues, " + "or incorrect state?\n" " - Options for handling failures:\n" " 1. Execute the remaining steps manually\n" - " 2. Fix the issue and retry from a specific step using ExecuteCachedTrajectory with start_from_step_index\n" - " 3. Report that the cached trajectory is outdated and needs re-recording\n" + " 2. Fix the issue and retry from a specific step using " + "ExecuteCachedTrajectory with start_from_step_index\n" + " 3. Report that the cached trajectory is outdated and needs " + "re-recording\n" "\n" " BEST PRACTICES:\n" " - Always verify results after trajectory execution completes\n" - " - While trajectories work most of the time, occasionally execution can be partly incorrect\n" + " - While trajectories work most of the time, occasionally " + "execution can be partly incorrect\n" " - Make corrections where necessary after cached execution\n" - " - if you need to make any corrections after a trajectory execution, please mark the cached execution as failed\n" - " - If a trajectory consistently fails, it may be invalid and should be re-recorded\n" + " - if you need to make any corrections after a trajectory " + "execution, please mark the cached execution as failed\n" + " - If a trajectory consistently fails, it may be invalid and " + "should be re-recorded\n" " \n" " \n" " There are several trajectories available to you.\n" " Their filename is a unique testID.\n" - " If executed using the ExecuteCachedTrajectory tool, a trajectory will " - "automatically execute all necessary steps for the test with that id.\n" + " If executed using the ExecuteCachedTrajectory tool, a trajectory " + "will automatically execute all necessary steps for the test with " + "that id.\n" " \n" ) diff --git a/src/askui/tools/caching_tools.py b/src/askui/tools/caching_tools.py index 07aa85e2..9c5655d9 100644 --- a/src/askui/tools/caching_tools.py +++ b/src/askui/tools/caching_tools.py @@ -47,8 +47,9 @@ def __init__(self, cache_dir: str, trajectories_format: str = ".json") -> None: "include_invalid": { "type": "boolean", "description": ( - "Whether to include invalid/invalidated caches in the results. " - "Default is False (only show valid caches)." + "Whether to include invalid/invalidated caches in " + "the results. Default is False (only show valid " + "caches)." ), "default": False, }, @@ -97,7 +98,7 @@ def __call__(self, include_invalid: bool = False) -> list[str]: # type: ignore f.name, cache_file.metadata.invalidation_reason, ) - except Exception: + except Exception: # noqa: PERF203 unreadable_count += 1 logger.exception("Failed to read cache file %s", f.name) # If we can't read it, exclude it @@ -140,25 +141,30 @@ def __init__( super().__init__( name="execute_cached_executions_tool", description=( - "Activate cache execution mode to replay a pre-recorded trajectory. " - "This tool sets up the agent to execute cached UI interactions step-by-step.\n\n" + "Activate cache execution mode to replay a pre-recorded " + "trajectory. This tool sets up the agent to execute cached UI " + "interactions step-by-step.\n\n" "Before using this tool:\n" "1. Use retrieve_available_trajectories_tool to see which " "trajectory files are available\n" "2. Select the appropriate trajectory file path from the " "returned list\n" - "3. If the trajectory contains placeholders (e.g., {{current_date}}), " - "provide values for them in the placeholder_values parameter\n" + "3. If the trajectory contains placeholders (e.g., " + "{{current_date}}), provide values for them in the " + "placeholder_values parameter\n" "4. Pass the full file path to this tool\n\n" - "Placeholders allow dynamic values to be injected during execution. " - "For example, if a trajectory types '{{current_date}}', you must " - "provide placeholder_values={'current_date': '2025-12-11'}.\n\n" - "To continue from a specific step (e.g., after manually handling a " - "non-cacheable step), use the start_from_step_index parameter. " - "By default, execution starts from the beginning (step 0).\n\n" - "Once activated, the agent will execute cached steps automatically. " - "If a non-cacheable step is encountered, the agent will be asked to " - "handle it manually before resuming cache execution." + "Placeholders allow dynamic values to be injected during " + "execution. For example, if a trajectory types " + "'{{current_date}}', you must provide " + "placeholder_values={'current_date': '2025-12-11'}.\n\n" + "To continue from a specific step (e.g., after manually " + "handling a non-cacheable step), use the start_from_step_index " + "parameter. By default, execution starts from the beginning " + "(step 0).\n\n" + "Once activated, the agent will execute cached steps " + "automatically. If a non-cacheable step is encountered, the " + "agent will be asked to handle it manually before resuming " + "cache execution." ), input_schema={ "type": "object", @@ -174,9 +180,10 @@ def __init__( "start_from_step_index": { "type": "integer", "description": ( - "Optional: The step index to start or resume execution from (0-based). " - "Use 0 (default) to start from the beginning. Use a higher index " - "to continue from a specific step, e.g., after manually handling " + "Optional: The step index to start or resume " + "execution from (0-based). Use 0 (default) to start " + "from the beginning. Use a higher index to continue " + "from a specific step, e.g., after manually handling " "a non-cacheable step." ), "default": 0, @@ -184,9 +191,10 @@ def __init__( "placeholder_values": { "type": "object", "description": ( - "Optional dictionary mapping placeholder names to their values. " - "Required if the trajectory contains placeholders like {{variable}}. " - "Example: {'current_date': '2025-12-11', 'user_name': 'Alice'}" + "Optional dictionary mapping placeholder names to " + "their values. Required if the trajectory contains " + "placeholders like {{variable}}. Example: " + "{'current_date': '2025-12-11', 'user_name': 'Alice'}" ), "additionalProperties": {"type": "string"}, "default": {}, @@ -393,7 +401,10 @@ def __call__( # Validate agent is set if not self._cache_execution_manager: - error_msg = "Cache Execution Manager not set. Call set_cache_execution_manager() first." + error_msg = ( + "Cache Execution Manager not set. Call " + "set_cache_execution_manager() first." + ) logger.error(error_msg) raise RuntimeError(error_msg) @@ -415,7 +426,8 @@ def __call__( # Warn if cache is invalid if not cache_file.metadata.is_valid: warning_msg = ( - f"WARNING: Using invalid cache from {Path(trajectory_file).name}. " + f"WARNING: Using invalid cache from " + f"{Path(trajectory_file).name}. " f"Reason: {cache_file.metadata.invalidation_reason}. " "This cache may not work correctly." ) @@ -764,9 +776,11 @@ def __init__(self) -> None: super().__init__( name="verify_cache_execution", description=( - "IMPORTANT: Call this tool immediately after reviewing a cached trajectory execution.\n\n" - "Report whether the cached execution successfully achieved the target system state.\n" - "You MUST call this tool to complete the cache verification process.\n\n" + "IMPORTANT: Call this tool immediately after reviewing a " + "cached trajectory execution.\n\n" + "Report whether the cached execution successfully achieved " + "the target system state. You MUST call this tool to complete " + "the cache verification process.\n\n" "Set success=True if:\n" "- The cached execution correctly achieved the intended goal\n" "- The final state matches what was expected\n" @@ -872,13 +886,11 @@ def __call__(self, success: bool, verification_notes: str) -> str: ) result_msg = ( f"✗ Cache verification failed: {verification_notes}\n\n" - "The cached trajectory did not achieve the target " - "system state correctly. " - "You should now continue to complete the task manually " - "from the current state. " - "Use your tools to finish achieving the goal, taking into " - "account what the cache attempted to do and what " - "corrections are needed." + "The cached trajectory did not achieve the target system " + "state correctly. You should now continue to complete the " + "task manually from the current state. Use your tools to " + "finish achieving the goal, taking into account what the " + "cache attempted to do and what corrections are needed." ) logger.warning(result_msg) diff --git a/src/askui/utils/cache_execution_manager.py b/src/askui/utils/cache_execution_manager.py index a3f339ee..f5b7f262 100644 --- a/src/askui/utils/cache_execution_manager.py +++ b/src/askui/utils/cache_execution_manager.py @@ -203,11 +203,15 @@ def _handle_cache_needs_agent( TextBlockParam( type="text", text=( - f"Cache execution paused at step {result.step_index}. " - f"The previous steps were executed successfully from cache. " - f"The next step requires the '{tool_to_execute.name}' tool, " + f"Cache execution paused at step " + f"{result.step_index}. " + f"The previous steps were executed successfully " + f"from cache. " + f"The next step requires the " + f"'{tool_to_execute.name}' tool, " f"which cannot be executed from cache. " - f"Please execute this tool with the necessary parameters." + f"Please execute this tool with the necessary " + f"parameters." ), ) ], diff --git a/src/askui/utils/cache_migration.py b/src/askui/utils/cache_migration.py index 5be27edf..f3be0ac3 100644 --- a/src/askui/utils/cache_migration.py +++ b/src/askui/utils/cache_migration.py @@ -71,7 +71,7 @@ def migrate_file(self, file_path: Path, dry_run: bool = False) -> tuple[bool, st try: # Read the file - with open(file_path, "r") as f: + with file_path.open("r") as f: data = json.load(f) # Check if already v0.1 @@ -83,7 +83,7 @@ def migrate_file(self, file_path: Path, dry_run: bool = False) -> tuple[bool, st # Use CacheWriter to read (automatically migrates) try: cache_file = CacheWriter.read_cache_file(file_path) - except Exception as e: + except Exception as e: # noqa: BLE001 return False, f"Failed to read cache: {str(e)}" # Verify it's now v0.1 @@ -102,17 +102,17 @@ def migrate_file(self, file_path: Path, dry_run: bool = False) -> tuple[bool, st file_path.suffix + self.backup_suffix ) shutil.copy2(file_path, backup_path) - logger.debug(f"Created backup: {backup_path}") + logger.debug("Created backup: %s", backup_path) # Write migrated version back - with open(file_path, "w") as f: + with file_path.open("w") as f: json.dump(cache_file.model_dump(mode="json"), f, indent=2, default=str) - return True, f"Migrated: {file_path.name}" - - except Exception as e: - logger.error(f"Error migrating {file_path}: {e}", exc_info=True) + except Exception as e: # noqa: BLE001 + logger.exception("Error migrating %s", file_path) return False, f"Error: {str(e)}" + else: + return True, f"Migrated: {file_path.name}" def migrate_directory( self, @@ -131,13 +131,14 @@ def migrate_directory( Dictionary with migration statistics """ if not cache_dir.is_dir(): - raise CacheMigrationError(f"Directory not found: {cache_dir}") + msg = f"Directory not found: {cache_dir}" + raise CacheMigrationError(msg) # Find all cache files cache_files = list(cache_dir.glob(file_pattern)) if not cache_files: - logger.warning(f"No cache files found in {cache_dir}") + logger.warning("No cache files found in %s", cache_dir) return { "migrated": 0, "skipped": 0, @@ -145,7 +146,7 @@ def migrate_directory( "total": 0, } - logger.info(f"Found {len(cache_files)} cache files in {cache_dir}") + logger.info("Found %s cache files in %s", len(cache_files), cache_dir) # Reset counters self.migrated_count = 0 @@ -159,26 +160,26 @@ def migrate_directory( if success: self.migrated_count += 1 - logger.info(f"✓ {message}") + logger.info("✓ %s", message) elif "Already v0.1" in message: self.skipped_count += 1 - logger.debug(f"⊘ {message}") + logger.debug("⊘ %s", message) else: self.error_count += 1 - logger.error(f"✗ {message}") + logger.error("✗ %s", message) results.append( {"file": file_path.name, "success": success, "message": message} ) # Log summary - logger.info(f"\n{'=' * 60}") + logger.info("\n%s", "=" * 60) logger.info("Migration Summary:") - logger.info(f" Total files: {len(cache_files)}") - logger.info(f" Migrated: {self.migrated_count}") - logger.info(f" Already v0.1: {self.skipped_count}") - logger.info(f" Errors: {self.error_count}") - logger.info(f"{'=' * 60}\n") + logger.info(" Total files: %s", len(cache_files)) + logger.info(" Migrated: %s", self.migrated_count) + logger.info(" Already v0.1: %s", self.skipped_count) + logger.info(" Errors: %s", self.error_count) + logger.info("%s\n", "=" * 60) return { "migrated": self.migrated_count, @@ -284,19 +285,21 @@ def main() -> int: # Return success if no errors if stats["errors"] == 0: logger.info("✓ Migration completed successfully!") - return 0 - logger.error(f"✗ Migration completed with {stats['errors']} errors") - return 1 + else: + logger.error("✗ Migration completed with %s errors", stats["errors"]) + return 1 - except CacheMigrationError as e: - logger.error(f"Migration failed: {e}") + except CacheMigrationError: + logger.exception("Migration failed") return 1 except KeyboardInterrupt: logger.info("\nMigration cancelled by user") return 1 - except Exception as e: - logger.error(f"Unexpected error: {e}", exc_info=True) + except Exception: # noqa: BLE001 + logger.exception("Unexpected error") return 1 + else: + return 0 if __name__ == "__main__": diff --git a/src/askui/utils/cache_validator.py b/src/askui/utils/cache_validator.py index 7cc811ac..5b113792 100644 --- a/src/askui/utils/cache_validator.py +++ b/src/askui/utils/cache_validator.py @@ -32,12 +32,10 @@ def should_invalidate( Returns: Tuple of (should_invalidate: bool, reason: Optional[str]) """ - pass @abstractmethod def get_name(self) -> str: """Return validator name for logging/debugging.""" - pass class CompositeCacheValidator(CacheValidator): @@ -132,7 +130,8 @@ def should_invalidate( if failures_at_step >= self.max_failures_per_step: return ( True, - f"Step {step_index} failed {failures_at_step} times (max: {self.max_failures_per_step})", + f"Step {step_index} failed {failures_at_step} times " + f"(max: {self.max_failures_per_step})", ) return False, None @@ -160,13 +159,13 @@ def __init__(self, min_attempts: int = 10, max_failure_rate: float = 0.5): self.max_failure_rate = max_failure_rate def should_invalidate( - self, cache_file: CacheFile, step_index: Optional[int] = None + self, cache_file: CacheFile, _step_index: Optional[int] = None ) -> tuple[bool, Optional[str]]: """Check if overall failure rate is too high. Args: cache_file: The cache file with metadata and trajectory - step_index: Unused for this validator + _step_index: Unused for this validator Returns: Tuple of (should_invalidate: bool, reason: Optional[str]) @@ -181,7 +180,8 @@ def should_invalidate( if failure_rate > self.max_failure_rate: return ( True, - f"Failure rate {failure_rate:.1%} exceeds {self.max_failure_rate:.1%} after {attempts} attempts", + f"Failure rate {failure_rate:.1%} exceeds " + f"{self.max_failure_rate:.1%} after {attempts} attempts", ) return False, None @@ -207,13 +207,13 @@ def __init__(self, max_age_days: int = 30): self.max_age_days = max_age_days def should_invalidate( - self, cache_file: CacheFile, step_index: Optional[int] = None + self, cache_file: CacheFile, _step_index: Optional[int] = None ) -> tuple[bool, Optional[str]]: """Check if cache is stale (old + has failures). Args: cache_file: The cache file with metadata and trajectory - step_index: Unused for this validator + _step_index: Unused for this validator Returns: Tuple of (should_invalidate: bool, reason: Optional[str]) diff --git a/src/askui/utils/cache_writer.py b/src/askui/utils/cache_writer.py index 8aff0d52..2e7a1b53 100644 --- a/src/askui/utils/cache_writer.py +++ b/src/askui/utils/cache_writer.py @@ -166,7 +166,7 @@ def _replace_placeholders( for name in placeholder_names } n_placeholders = len(placeholder_names) - logger.info(f"Replaced {n_placeholders} placeholder values in trajectory") + logger.info("Replaced %s placeholder values in trajectory", n_placeholders) return goal_to_save, trajectory_to_save, placeholders_dict def _blank_non_cacheable_tool_inputs( @@ -195,7 +195,7 @@ def _blank_non_cacheable_tool_inputs( # If tool is not cacheable, blank out its input if tool is not None and not tool.is_cacheable: logger.debug( - f"Blanking input for non-cacheable tool: {tool_block.name}" + "Blanking input for non-cacheable tool: %s", tool_block.name ) blanked_count += 1 result.append( @@ -213,7 +213,8 @@ def _blank_non_cacheable_tool_inputs( if blanked_count > 0: logger.info( - f"Blanked inputs for {blanked_count} non-cacheable tool(s) to save space" + "Blanked inputs for %s non-cacheable tool(s) to save space", + blanked_count, ) return result @@ -238,7 +239,7 @@ def _generate_cache_file( with cache_file_path.open("w", encoding="utf-8") as f: json.dump(cache_file.model_dump(mode="json"), f, indent=4) - logger.info(f"Cache file successfully written: {cache_file_path} ") + logger.info("Cache file successfully written: %s ", cache_file_path) def _accumulate_usage(self, step_usage: UsageParam) -> None: """Accumulate usage statistics from a single API call. @@ -266,7 +267,7 @@ def read_cache_file(cache_file_path: Path) -> CacheFile: Returns: CacheFile object with metadata and trajectory """ - logger.debug(f"Reading cache file: {cache_file_path}") + logger.debug("Reading cache file: %s", cache_file_path) with cache_file_path.open("r", encoding="utf-8") as f: raw_data = json.load(f) @@ -274,7 +275,8 @@ def read_cache_file(cache_file_path: Path) -> CacheFile: if isinstance(raw_data, list): # v0.0 format: just a list of tool use blocks logger.info( - f"Detected v0.0 cache format in {cache_file_path.name}, migrating to v0.1" + "Detected v0.0 cache format in %s, migrating to v0.1", + cache_file_path.name, ) trajectory = [ToolUseBlockParam(**step) for step in raw_data] # Create default metadata for v0.0 files (migrated to v0.1 format) @@ -289,24 +291,28 @@ def read_cache_file(cache_file_path: Path) -> CacheFile: placeholders={}, ) logger.info( - f"Successfully loaded and migrated v0.0 cache: {len(trajectory)} steps, 0 placeholders" + "Successfully loaded and migrated v0.0 cache: %s steps, 0 placeholders", + len(trajectory), ) return cache_file if isinstance(raw_data, dict) and "metadata" in raw_data: # v0.1 format: structured with metadata cache_file = CacheFile(**raw_data) logger.info( - f"Successfully loaded v0.1 cache: {len(cache_file.trajectory)} steps, " - f"{len(cache_file.placeholders)} placeholders" + "Successfully loaded v0.1 cache: %s steps, %s placeholders", + len(cache_file.trajectory), + len(cache_file.placeholders), ) if cache_file.metadata.goal: - logger.debug(f"Cache goal: {cache_file.metadata.goal}") + logger.debug("Cache goal: %s", cache_file.metadata.goal) return cache_file logger.error( - f"Unknown cache file format in {cache_file_path.name}. " - "Expected either a list (v0.0) or dict with 'metadata' key (v0.1)." + "Unknown cache file format in %s. " + "Expected either a list (v0.0) or dict with 'metadata' key (v0.1).", + cache_file_path.name, ) - raise ValueError( + msg = ( f"Unknown cache file format in {cache_file_path}. " "Expected either a list (v0.0) or dict with 'metadata' key (v0.1)." ) + raise ValueError(msg) diff --git a/src/askui/utils/placeholder_handler.py b/src/askui/utils/placeholder_handler.py index 5bfd04b0..d491eeff 100644 --- a/src/askui/utils/placeholder_handler.py +++ b/src/askui/utils/placeholder_handler.py @@ -195,25 +195,24 @@ def _replace_values_in_value(value: Any, replacements: dict[str, str]) -> Any: if actual_value in result: result = result.replace(actual_value, replacements[actual_value]) return result - elif isinstance(value, dict): + if isinstance(value, dict): # Recursively replace in dict values return { k: PlaceholderHandler._replace_values_in_value(v, replacements) for k, v in value.items() } - elif isinstance(value, list): + if isinstance(value, list): # Recursively replace in list items return [ PlaceholderHandler._replace_values_in_value(item, replacements) for item in value ] - else: - # For non-string types, check if the value matches exactly - str_value = str(value) - if str_value in replacements: - # Return the placeholder as a string - return replacements[str_value] - return value + # For non-string types, check if the value matches exactly + str_value = str(value) + if str_value in replacements: + # Return the placeholder as a string + return replacements[str_value] + return value @staticmethod def substitute_placeholders( @@ -277,18 +276,17 @@ def _substitute_in_value(value: Any, placeholder_values: dict[str, str]) -> Any: pattern = r"\{\{" + re.escape(name) + r"\}\}" result = re.sub(pattern, replacement, result) return result - elif isinstance(value, dict): + if isinstance(value, dict): # Recursively substitute in dict values return { k: PlaceholderHandler._substitute_in_value(v, placeholder_values) for k, v in value.items() } - elif isinstance(value, list): + if isinstance(value, list): # Recursively substitute in list items return [ PlaceholderHandler._substitute_in_value(item, placeholder_values) for item in value ] - else: - # Return other types as-is - return value + # Return other types as-is + return value diff --git a/src/askui/utils/placeholder_identifier.py b/src/askui/utils/placeholder_identifier.py index cc6abfd5..a54ccee2 100644 --- a/src/askui/utils/placeholder_identifier.py +++ b/src/askui/utils/placeholder_identifier.py @@ -45,25 +45,27 @@ def identify_placeholders( return {}, [] logger.info( - f"Starting placeholder identification for trajectory with {len(trajectory)} steps" + "Starting placeholder identification for trajectory with %s steps", + len(trajectory), ) # Convert trajectory to serializable format for analysis trajectory_data = [tool.model_dump(mode="json") for tool in trajectory] - logger.debug(f"Converted {len(trajectory_data)} tool blocks to JSON format") - - user_message = f"""Analyze this UI automation trajectory and identify all values that should be placeholders: - -```json -{json.dumps(trajectory_data, indent=2)} -``` - -Return only the JSON object with identified placeholders. Be thorough but conservative - only mark values that are clearly dynamic or time-sensitive.""" + logger.debug("Converted %s tool blocks to JSON format", len(trajectory_data)) + + user_message = ( + "Analyze this UI automation trajectory and identify all values that " + "should be placeholders:\n\n" + f"```json\n{json.dumps(trajectory_data, indent=2)}\n```\n\n" + "Return only the JSON object with identified placeholders. " + "Be thorough but conservative - only mark values that are clearly " + "dynamic or time-sensitive." + ) response_text = "" # Initialize for error logging try: # Make single API call - logger.debug(f"Calling LLM ({model}) to analyze trajectory for placeholders") + logger.debug("Calling LLM (%s) to analyze trajectory for placeholders", model) response = messages_api.create_message( messages=[MessageParam(role="user", content=user_message)], model=model, @@ -94,7 +96,8 @@ def identify_placeholders( placeholder_data = json.loads(response_text) logger.debug( - f"Successfully parsed JSON response with {len(placeholder_data.get('placeholders', []))} placeholders" + "Successfully parsed JSON response with %s placeholders", + len(placeholder_data.get("placeholders", [])), ) # Convert to our data structures @@ -109,26 +112,32 @@ def identify_placeholders( if placeholder_definitions: logger.info( - f"Successfully identified {len(placeholder_definitions)} placeholders in trajectory" + "Successfully identified %s placeholders in trajectory", + len(placeholder_definitions), ) for p in placeholder_definitions: - logger.debug(f" - {p.name}: {p.value} ({p.description})") + logger.debug(" - %s: %s (%s)", p.name, p.value, p.description) else: logger.info( - "No placeholders identified in trajectory (this is normal for trajectories with only static values)" + "No placeholders identified in trajectory " + "(this is normal for trajectories with only static values)" ) - return placeholder_dict, placeholder_definitions - except json.JSONDecodeError as e: logger.warning( - f"Failed to parse LLM response as JSON: {e}. Falling back to empty placeholder list.", + "Failed to parse LLM response as JSON: %s. " + "Falling back to empty placeholder list.", + e, extra={"response_text": response_text[:500]}, # Log first 500 chars ) return {}, [] except Exception as e: # noqa: BLE001 logger.warning( - f"Failed to identify placeholders with LLM: {e}. Falling back to empty placeholder list.", + "Failed to identify placeholders with LLM: %s. " + "Falling back to empty placeholder list.", + e, exc_info=True, ) return {}, [] + else: + return placeholder_dict, placeholder_definitions diff --git a/src/askui/utils/trajectory_executor.py b/src/askui/utils/trajectory_executor.py index 9d33bd98..d953ff09 100644 --- a/src/askui/utils/trajectory_executor.py +++ b/src/askui/utils/trajectory_executor.py @@ -13,10 +13,7 @@ from typing_extensions import Literal from askui.models.shared.agent_message_param import ( - ImageBlockParam, MessageParam, - TextBlockParam, - ToolResultBlockParam, ToolUseBlockParam, ) from askui.models.shared.tools import ToolCollection @@ -110,7 +107,7 @@ def execute_next_step(self) -> ExecutionResult: # Check if step should be skipped if self._should_skip_step(step): - logger.debug(f"Skipping step {step_index}: {step.name}") + logger.debug("Skipping step %d: %s", step_index, step.name) self.current_step_index += 1 # Recursively execute next step return self.execute_next_step() @@ -118,7 +115,9 @@ def execute_next_step(self) -> ExecutionResult: # Check if step needs agent intervention (non-cacheable) if self.should_pause_for_agent(step): logger.info( - f"Pausing at step {step_index}: {step.name} (non-cacheable tool)" + "Pausing at step %d: %s (non-cacheable tool)", + step_index, + step.name, ) # Return result with current tool step info for the agent to handle # Note: We don't add any messages here - the cache manager will @@ -136,7 +135,9 @@ def execute_next_step(self) -> ExecutionResult: is_valid, error_msg = self.validate_step_visually(step) if not is_valid: logger.warning( - f"Visual validation failed at step {step_index}: {error_msg}" + "Visual validation failed at step %d: %s", + step_index, + error_msg, ) return ExecutionResult( status="FAILED", @@ -152,7 +153,7 @@ def execute_next_step(self) -> ExecutionResult: # Execute the tool try: - logger.debug(f"Executing step {step_index}: {step.name}") + logger.debug("Executing step %d: %s", step_index, step.name) # Add assistant message (tool use) to history assistant_message = MessageParam( @@ -164,9 +165,10 @@ def execute_next_step(self) -> ExecutionResult: # Execute the tool tool_results = self.toolbox.run([substituted_step]) - # toolbox.run() returns a list of content blocks (ToolResultBlockParam, etc.) - # We use these directly without converting to strings - this preserves - # proper data types like ImageBlockParam + # toolbox.run() returns a list of content blocks + # (ToolResultBlockParam, etc.) We use these directly without + # converting to strings - this preserves proper data types like + # ImageBlockParam # Add user message (tool result) to history user_message = MessageParam( @@ -190,10 +192,7 @@ def execute_next_step(self) -> ExecutionResult: ) except Exception as e: - logger.error( - f"Error executing step {step_index}: {step.name}", - exc_info=True, - ) + logger.exception("Error executing step %d: %s", step_index, step.name) return ExecutionResult( status="FAILED", step_index=step_index, @@ -236,7 +235,7 @@ def should_pause_for_agent(self, step: ToolUseBlockParam) -> bool: Currently checks if the tool is marked as non-cacheable. """ # Get the tool from toolbox - tool = self.toolbox._tool_map.get(step.name) + tool = self.toolbox._tool_map.get(step.name) # noqa: SLF001 if tool is None: # Tool not found in regular tools, might be MCP tool @@ -270,7 +269,7 @@ def skip_current_step(self) -> None: if self.current_step_index < len(self.trajectory): self.current_step_index += 1 - def _should_skip_step(self, step: ToolUseBlockParam) -> bool: + def _should_skip_step(self, _step: ToolUseBlockParam) -> bool: """Check if a step should be skipped during execution. Args: @@ -285,7 +284,7 @@ def _should_skip_step(self, step: ToolUseBlockParam) -> bool: return False def validate_step_visually( - self, step: ToolUseBlockParam, current_screenshot: Any = None + self, _step: ToolUseBlockParam, _current_screenshot: Any = None ) -> tuple[bool, str | None]: """Hook for visual validation of cached steps using aHash comparison. diff --git a/tests/unit/tools/test_caching_tools.py b/tests/unit/tools/test_caching_tools.py index fad10b28..eab5a370 100644 --- a/tests/unit/tools/test_caching_tools.py +++ b/tests/unit/tools/test_caching_tools.py @@ -17,7 +17,6 @@ ) from askui.utils.cache_execution_manager import CacheExecutionManager - # ============================================================================ # RetrieveCachedTestExecutions Tests (unchanged from before) # ============================================================================ diff --git a/tests/unit/utils/test_cache_validator.py b/tests/unit/utils/test_cache_validator.py index cb3ad894..a20cabab 100644 --- a/tests/unit/utils/test_cache_validator.py +++ b/tests/unit/utils/test_cache_validator.py @@ -476,7 +476,7 @@ def __init__(self, should_trigger: bool = False): self.should_trigger = should_trigger def should_invalidate( - self, cache_file: CacheFile, step_index: int | None = None + self, _cache_file: CacheFile, _step_index: int | None = None ) -> tuple[bool, str | None]: if self.should_trigger: return True, "Custom validation failed" diff --git a/tests/unit/utils/test_placeholder_handler.py b/tests/unit/utils/test_placeholder_handler.py index e59c251e..d022f5bc 100644 --- a/tests/unit/utils/test_placeholder_handler.py +++ b/tests/unit/utils/test_placeholder_handler.py @@ -1,6 +1,7 @@ """Unit tests for PlaceholderHandler.""" import pytest + from askui.models.shared.agent_message_param import ToolUseBlockParam from askui.utils.placeholder_handler import PLACEHOLDER_PATTERN, PlaceholderHandler diff --git a/tests/unit/utils/test_trajectory_executor.py b/tests/unit/utils/test_trajectory_executor.py index 55bf649a..d7f2aa6b 100644 --- a/tests/unit/utils/test_trajectory_executor.py +++ b/tests/unit/utils/test_trajectory_executor.py @@ -369,10 +369,11 @@ def test_trajectory_executor_execute_all_stops_on_failure() -> None: # Mock to fail on second call call_count = [0] - def mock_run(steps): # type: ignore + def mock_run(_steps): # type: ignore call_count[0] += 1 if call_count[0] == 2: - raise Exception("Second call fails") + msg = "Second call fails" + raise Exception(msg) # noqa: TRY002 return [ ToolResultBlockParam( tool_use_id="1", From cacceb5b5d2d0b487605b1840ea9379057a4037d Mon Sep 17 00:00:00 2001 From: philipph-askui Date: Tue, 16 Dec 2025 11:37:59 +0100 Subject: [PATCH 18/30] fix(caching): add missing usage statistics and fix typos in docs --- docs/caching.md | 4 ---- src/askui/reporting.py | 39 --------------------------------------- 2 files changed, 43 deletions(-) diff --git a/docs/caching.md b/docs/caching.md index 9560221c..be995ac3 100644 --- a/docs/caching.md +++ b/docs/caching.md @@ -1,10 +1,6 @@ # Caching (Experimental) -<<<<<<< HEAD **CAUTION: The Caching feature is still in alpha state and subject to change! Use it at your own risk. In case you run into issues, you can disable caching by removing the caching_settings parameter or by explicitly setting the caching_strategy to `no`.** -======= -**CAUTION: THIS FEATURE IS STILL IN ALPHA STATE AND SUBJECT TO CHANGE! USE AT YOUR OWN RISK. IN CASE YOU RUN INTO ISSUES YOU CAN JUST DISABLE THE FEATURE BY REMOVING THE CACHING_SETTINGS OR SETTING THE CACHING STRATEGY TO `no`.** ->>>>>>> bda523e (chore(caching): explicitly mention that caching is experimental in the docs) The caching mechanism allows you to record and replay agent action sequences (trajectories) for faster and more robust test execution. This feature is particularly useful for regression testing, where you want to replay known-good interaction sequences to verify that your application still behaves correctly. diff --git a/src/askui/reporting.py b/src/askui/reporting.py index 99620468..96408061 100644 --- a/src/askui/reporting.py +++ b/src/askui/reporting.py @@ -715,44 +715,6 @@ def generate(self) -> None: -<<<<<<< HEAD -
-

Conversation Log

- - - - - - - {% for msg in messages %} - - - - - - {% endfor %} -
TimeRoleContent
{{ msg.timestamp.strftime('%H:%M:%S.%f')[:-3] }} UTC - - {{ msg.role }} - - - {% if msg.is_json %} -
-
{{ msg.content }}
-
- {% else %} - {{ msg.content }} - {% endif %} - {% for image in msg.images %} -
- Message image - {% endfor %} -
-
- -======= {% if usage_summary %}

Token Usage

@@ -800,7 +762,6 @@ def generate(self) -> None: {% endfor %}
->>>>>>> c3fbf84 (feat(caching): add token usage to cache writer and reporters) """ From 1c89c6fa89b019154ceb3ab85da28c60ad36277e Mon Sep 17 00:00:00 2001 From: philipph-askui Date: Tue, 16 Dec 2025 13:19:11 +0100 Subject: [PATCH 19/30] chore(caching): move caching utils to caching sub folder --- src/askui/agent_base.py | 2 +- src/askui/models/shared/agent.py | 2 +- src/askui/tools/caching_tools.py | 6 +++--- src/askui/utils/caching/__init__.py | 1 + src/askui/utils/{ => caching}/cache_execution_manager.py | 2 +- src/askui/utils/{ => caching}/cache_manager.py | 2 +- src/askui/utils/{ => caching}/cache_migration.py | 2 +- src/askui/utils/{ => caching}/cache_validator.py | 0 src/askui/utils/{ => caching}/cache_writer.py | 0 tests/unit/tools/test_caching_tools.py | 2 +- tests/unit/utils/test_cache_manager.py | 4 ++-- tests/unit/utils/test_cache_migration.py | 2 +- tests/unit/utils/test_cache_validator.py | 2 +- tests/unit/utils/test_cache_writer.py | 2 +- 14 files changed, 15 insertions(+), 14 deletions(-) create mode 100644 src/askui/utils/caching/__init__.py rename src/askui/utils/{ => caching}/cache_execution_manager.py (99%) rename src/askui/utils/{ => caching}/cache_manager.py (98%) rename src/askui/utils/{ => caching}/cache_migration.py (99%) rename src/askui/utils/{ => caching}/cache_validator.py (100%) rename src/askui/utils/{ => caching}/cache_writer.py (100%) diff --git a/src/askui/agent_base.py b/src/askui/agent_base.py index a734c3aa..3e2d7f4d 100644 --- a/src/askui/agent_base.py +++ b/src/askui/agent_base.py @@ -25,7 +25,7 @@ RetrieveCachedTestExecutions, ) from askui.utils.annotation_writer import AnnotationWriter -from askui.utils.cache_writer import CacheWriter +from askui.utils.caching.cache_writer import CacheWriter from askui.utils.image_utils import ImageSource from askui.utils.source_utils import InputSource, load_image_source diff --git a/src/askui/models/shared/agent.py b/src/askui/models/shared/agent.py index 4587fe04..adde725a 100644 --- a/src/askui/models/shared/agent.py +++ b/src/askui/models/shared/agent.py @@ -19,7 +19,7 @@ TruncationStrategyFactory, ) from askui.reporting import NULL_REPORTER, Reporter -from askui.utils.cache_execution_manager import CacheExecutionManager +from askui.utils.caching.cache_execution_manager import CacheExecutionManager logger = logging.getLogger(__name__) diff --git a/src/askui/tools/caching_tools.py b/src/askui/tools/caching_tools.py index 9c5655d9..7714b53e 100644 --- a/src/askui/tools/caching_tools.py +++ b/src/askui/tools/caching_tools.py @@ -8,9 +8,9 @@ from ..models.shared.settings import CachedExecutionToolSettings from ..models.shared.tools import Tool, ToolCollection -from ..utils.cache_execution_manager import CacheExecutionManager -from ..utils.cache_manager import CacheManager -from ..utils.cache_writer import CacheWriter +from ..utils.caching.cache_execution_manager import CacheExecutionManager +from ..utils.caching.cache_manager import CacheManager +from ..utils.caching.cache_writer import CacheWriter from ..utils.placeholder_handler import PlaceholderHandler if TYPE_CHECKING: diff --git a/src/askui/utils/caching/__init__.py b/src/askui/utils/caching/__init__.py new file mode 100644 index 00000000..be6e4b17 --- /dev/null +++ b/src/askui/utils/caching/__init__.py @@ -0,0 +1 @@ +"""Caching utilities for trajectory execution.""" diff --git a/src/askui/utils/cache_execution_manager.py b/src/askui/utils/caching/cache_execution_manager.py similarity index 99% rename from src/askui/utils/cache_execution_manager.py rename to src/askui/utils/caching/cache_execution_manager.py index f5b7f262..b9bedde4 100644 --- a/src/askui/utils/cache_execution_manager.py +++ b/src/askui/utils/caching/cache_execution_manager.py @@ -9,7 +9,7 @@ from askui.models.shared.agent_on_message_cb import OnMessageCb, OnMessageCbParam from askui.models.shared.truncation_strategies import TruncationStrategy from askui.reporting import Reporter -from askui.utils.cache_manager import CacheManager +from askui.utils.caching.cache_manager import CacheManager from askui.utils.trajectory_executor import ExecutionResult if TYPE_CHECKING: diff --git a/src/askui/utils/cache_manager.py b/src/askui/utils/caching/cache_manager.py similarity index 98% rename from src/askui/utils/cache_manager.py rename to src/askui/utils/caching/cache_manager.py index 1b5878ae..a3ca9356 100644 --- a/src/askui/utils/cache_manager.py +++ b/src/askui/utils/caching/cache_manager.py @@ -8,7 +8,7 @@ from typing import Optional from askui.models.shared.settings import CacheFailure, CacheFile -from askui.utils.cache_validator import ( +from askui.utils.caching.cache_validator import ( CacheValidator, CompositeCacheValidator, StaleCacheValidator, diff --git a/src/askui/utils/cache_migration.py b/src/askui/utils/caching/cache_migration.py similarity index 99% rename from src/askui/utils/cache_migration.py rename to src/askui/utils/caching/cache_migration.py index f3be0ac3..a8f98025 100644 --- a/src/askui/utils/cache_migration.py +++ b/src/askui/utils/caching/cache_migration.py @@ -27,7 +27,7 @@ from pathlib import Path from typing import Any -from askui.utils.cache_writer import CacheWriter +from askui.utils.caching.cache_writer import CacheWriter logger = logging.getLogger(__name__) diff --git a/src/askui/utils/cache_validator.py b/src/askui/utils/caching/cache_validator.py similarity index 100% rename from src/askui/utils/cache_validator.py rename to src/askui/utils/caching/cache_validator.py diff --git a/src/askui/utils/cache_writer.py b/src/askui/utils/caching/cache_writer.py similarity index 100% rename from src/askui/utils/cache_writer.py rename to src/askui/utils/caching/cache_writer.py diff --git a/tests/unit/tools/test_caching_tools.py b/tests/unit/tools/test_caching_tools.py index eab5a370..605eabed 100644 --- a/tests/unit/tools/test_caching_tools.py +++ b/tests/unit/tools/test_caching_tools.py @@ -15,7 +15,7 @@ ExecuteCachedTrajectory, RetrieveCachedTestExecutions, ) -from askui.utils.cache_execution_manager import CacheExecutionManager +from askui.utils.caching.cache_execution_manager import CacheExecutionManager # ============================================================================ # RetrieveCachedTestExecutions Tests (unchanged from before) diff --git a/tests/unit/utils/test_cache_manager.py b/tests/unit/utils/test_cache_manager.py index 68f4730c..f7c41ed9 100644 --- a/tests/unit/utils/test_cache_manager.py +++ b/tests/unit/utils/test_cache_manager.py @@ -7,8 +7,8 @@ from askui.models.shared.agent_message_param import ToolUseBlockParam from askui.models.shared.settings import CacheFailure, CacheFile, CacheMetadata -from askui.utils.cache_manager import CacheManager -from askui.utils.cache_validator import ( +from askui.utils.caching.cache_manager import CacheManager +from askui.utils.caching.cache_validator import ( CacheValidator, CompositeCacheValidator, StepFailureCountValidator, diff --git a/tests/unit/utils/test_cache_migration.py b/tests/unit/utils/test_cache_migration.py index 575e6468..780cd1a7 100644 --- a/tests/unit/utils/test_cache_migration.py +++ b/tests/unit/utils/test_cache_migration.py @@ -7,7 +7,7 @@ import pytest -from askui.utils.cache_migration import CacheMigration, CacheMigrationError +from askui.utils.caching.cache_migration import CacheMigration, CacheMigrationError @pytest.fixture diff --git a/tests/unit/utils/test_cache_validator.py b/tests/unit/utils/test_cache_validator.py index a20cabab..c7af8b12 100644 --- a/tests/unit/utils/test_cache_validator.py +++ b/tests/unit/utils/test_cache_validator.py @@ -6,7 +6,7 @@ from askui.models.shared.agent_message_param import ToolUseBlockParam from askui.models.shared.settings import CacheFailure, CacheFile, CacheMetadata -from askui.utils.cache_validator import ( +from askui.utils.caching.cache_validator import ( CacheValidator, CompositeCacheValidator, StaleCacheValidator, diff --git a/tests/unit/utils/test_cache_writer.py b/tests/unit/utils/test_cache_writer.py index f418550c..a6e1238f 100644 --- a/tests/unit/utils/test_cache_writer.py +++ b/tests/unit/utils/test_cache_writer.py @@ -8,7 +8,7 @@ from askui.models.shared.agent_message_param import MessageParam, ToolUseBlockParam from askui.models.shared.agent_on_message_cb import OnMessageCbParam from askui.models.shared.settings import CacheFile, CacheWriterSettings -from askui.utils.cache_writer import CacheWriter +from askui.utils.caching.cache_writer import CacheWriter def test_cache_writer_initialization() -> None: From e1cbadfd472e36410cac6909080eb9dc3f6b77b7 Mon Sep 17 00:00:00 2001 From: philipph-askui Date: Tue, 16 Dec 2025 13:20:08 +0100 Subject: [PATCH 20/30] chore(caching): removes unnecessary docstrings from CacheManager class to improve readability of the code --- src/askui/utils/caching/cache_manager.py | 74 +----------------------- 1 file changed, 2 insertions(+), 72 deletions(-) diff --git a/src/askui/utils/caching/cache_manager.py b/src/askui/utils/caching/cache_manager.py index a3ca9356..e874bde0 100644 --- a/src/askui/utils/caching/cache_manager.py +++ b/src/askui/utils/caching/cache_manager.py @@ -1,8 +1,4 @@ -"""Cache management utilities for tracking execution and invalidation. - -This module provides the CacheManager class that handles cache metadata updates, -failure tracking, and cache invalidation using configurable validation strategies. -""" +"""Cache management for tracking execution and invalidation.""" from datetime import datetime, timezone from typing import Optional @@ -18,31 +14,10 @@ class CacheManager: - """Manages cache metadata updates and validation. - - Uses a CompositeCacheValidator for extensible invalidation logic. - Users can provide custom validators via the validator parameter. - - Example: - # Using default validators - manager = CacheManager() - - # Using custom validator - custom_validator = CompositeCacheValidator([ - StepFailureCountValidator(max_failures_per_step=5), - MyCustomValidator() - ]) - manager = CacheManager(validator=custom_validator) - """ + """Manages cache metadata updates and validation using configurable validators.""" def __init__(self, validators: Optional[list[CacheValidator]] = None): - """Initialize cache manager. - - Args: - validator: Custom validator or None to use default composite validator - """ if validators is None: - # Default validator with built-in strategies self.validators = CompositeCacheValidator( [ StepFailureCountValidator(max_failures_per_step=3), @@ -59,31 +34,15 @@ def record_execution_attempt( success: bool, failure_info: Optional[CacheFailure] = None, ) -> None: - """Record an execution attempt and update metadata. - - Args: - cache_file: The cache file to update - success: Whether the execution was successful - failure_info: Optional failure information if execution failed - """ cache_file.metadata.execution_attempts += 1 - if success: cache_file.metadata.last_executed_at = datetime.now(tz=timezone.utc) - # Successful execution - metadata updated elif failure_info: cache_file.metadata.failures.append(failure_info) def record_step_failure( self, cache_file: CacheFile, step_index: int, error_message: str ) -> None: - """Record a failure at specific step. - - Args: - cache_file: The cache file to update - step_index: Index of the step that failed - error_message: Description of the error - """ failure = CacheFailure( timestamp=datetime.now(tz=timezone.utc), step_index=step_index, @@ -98,46 +57,17 @@ def record_step_failure( def should_invalidate( self, cache_file: CacheFile, step_index: Optional[int] = None ) -> tuple[bool, Optional[str]]: - """Check if cache should be invalidated using the validator. - - Args: - cache_file: The cache file to check - step_index: Optional step index where failure occurred - - Returns: - Tuple of (should_invalidate: bool, reason: Optional[str]) - """ return self.validators.should_invalidate(cache_file, step_index) def invalidate_cache(self, cache_file: CacheFile, reason: str) -> None: - """Mark cache as invalid. - - Args: - cache_file: The cache file to invalidate - reason: Reason for invalidation - """ cache_file.metadata.is_valid = False cache_file.metadata.invalidation_reason = reason def mark_cache_valid(self, cache_file: CacheFile) -> None: - """Mark cache as valid. - - Args: - cache_file: The cache file to mark as valid - """ cache_file.metadata.is_valid = True cache_file.metadata.invalidation_reason = None def get_failure_count_for_step(self, cache_file: CacheFile, step_index: int) -> int: - """Get number of failures for a specific step. - - Args: - cache_file: The cache file to check - step_index: Index of the step to count failures for - - Returns: - Number of failures at this step - """ return sum( 1 for f in cache_file.metadata.failures if f.step_index == step_index ) From 42928c6e962a14b0136a3ecba8e11be6baeab76a Mon Sep 17 00:00:00 2001 From: philipph-askui Date: Tue, 16 Dec 2025 13:31:23 +0100 Subject: [PATCH 21/30] chore(caching): removes CacheMigration class. Cache migration will still be done automatically when reading v0.0 cache files --- docs/caching.md | 65 +--- src/askui/utils/caching/cache_migration.py | 306 ---------------- tests/unit/utils/test_cache_migration.py | 387 --------------------- 3 files changed, 7 insertions(+), 751 deletions(-) delete mode 100644 src/askui/utils/caching/cache_migration.py delete mode 100644 tests/unit/utils/test_cache_migration.py diff --git a/docs/caching.md b/docs/caching.md index be995ac3..8f9d8ead 100644 --- a/docs/caching.md +++ b/docs/caching.md @@ -845,57 +845,9 @@ When a v0.0 cache file (simple JSON array) is read: - Original files are not modified on disk (unless re-written) - v0.1 system can read both formats seamlessly -### Batch Migration with CLI Tool +### Programmatic Migration (Optional) -For batch migration of existing v0.0 caches to v0.1 format, use the migration CLI utility: - -```bash -# Migrate all caches in a directory -python -m askui.utils.cache_migration --cache-dir .cache - -# Dry run (preview what would be migrated) -python -m askui.utils.cache_migration --cache-dir .cache --dry-run - -# Create backups before migration -python -m askui.utils.cache_migration --cache-dir .cache --backup - -# Migrate specific file pattern -python -m askui.utils.cache_migration --cache-dir .cache --pattern "test_*.json" - -# Verbose output -python -m askui.utils.cache_migration --cache-dir .cache --verbose -``` - -The migration tool will: -- Find all cache files matching the pattern -- Check if each file is v0.0 or already v0.1 -- Migrate v0.0 files to v0.1 format -- Optionally create backups with `.v1.backup` suffix -- Report detailed statistics about the migration - -Example output: -``` -INFO: Found 5 cache files in .cache -INFO: ✓ Migrated: login_test.json -INFO: ✓ Migrated: checkout_flow.json -INFO: ⊘ Already v0.1: user_registration.json -INFO: ✓ Migrated: search_test.json -INFO: ✗ Error: invalid.json (invalid JSON) - -============================================================ -Migration Summary: - Total files: 5 - Migrated: 3 - Already v0.1: 1 - Errors: 1 -============================================================ - -INFO: ✓ Migration completed successfully! -``` - -### Manual Migration (Programmatic) - -To upgrade individual v0.0 caches to v0.1 format programmatically: +If you prefer to upgrade v0.0 cache files to v0.1 format on disk (rather than letting the system migrate them on-the-fly during read), you can do so programmatically: ```python from pathlib import Path @@ -903,18 +855,15 @@ from askui.utils.cache_writer import CacheWriter import json # Read v0.0 file (auto-migrates to v0.1 in memory) -cache_file = CacheWriter.read_cache_file(Path(".cache/old_cache.json")) +cache_path = Path(".cache/old_cache.json") +cached_trajectory = CacheWriter.read_cache_file(cache_path) # Write back to disk in v0.1 format -with open(".cache/old_cache.json", "w") as f: - json.dump(cache_file.model_dump(mode="json"), f, indent=2, default=str) +with cache_path.open("w", encoding="utf-8") as f: + json.dump(cached_trajectory.model_dump(mode="json"), f, indent=2, default=str) ``` -**Note:** Batch migration is optional - all v0.0 caches are automatically migrated during read operations. Use the migration tool if you prefer to: -- Pre-migrate all caches at once -- Create backups before migration -- Verify migration success across all files -- Audit which caches need migration +**Note:** Programmatic migration is optional - all v0.0 caches are automatically migrated during read operations. You only need to manually upgrade cache files if you want them in v0.1 format on disk immediately. ## Example: Complete Test Workflow with v0.1 Features diff --git a/src/askui/utils/caching/cache_migration.py b/src/askui/utils/caching/cache_migration.py deleted file mode 100644 index a8f98025..00000000 --- a/src/askui/utils/caching/cache_migration.py +++ /dev/null @@ -1,306 +0,0 @@ -"""Cache migration utilities for converting v0.0 caches to v0.1 format. - -This module provides tools to batch migrate existing v0.0 cache files to the new -v0.1 format with metadata support. Individual files are automatically migrated on -first read by CacheWriter, but this utility is useful for: - -1. Batch migration of all caches in a directory -2. Pre-migration without executing caches -3. Verification of migration success -4. Backup creation before migration - -Usage: - # Migrate all caches in a directory - python -m askui.utils.cache_migration --cache-dir .cache - - # Dry run (don't modify files) - python -m askui.utils.cache_migration --cache-dir .cache --dry-run - - # Create backups before migration - python -m askui.utils.cache_migration --cache-dir .cache --backup -""" - -import argparse -import json -import logging -import shutil -from pathlib import Path -from typing import Any - -from askui.utils.caching.cache_writer import CacheWriter - -logger = logging.getLogger(__name__) - - -class CacheMigrationError(Exception): - """Raised when cache migration fails.""" - - -class CacheMigration: - """Handles migration of cache files from v0.0 to v0.1 format.""" - - def __init__( - self, - backup: bool = False, - backup_suffix: str = ".v1.backup", - ): - """Initialize cache migration utility. - - Args: - backup: Whether to create backup files before migration - backup_suffix: Suffix to add to backup files - """ - self.backup = backup - self.backup_suffix = backup_suffix - self.migrated_count = 0 - self.skipped_count = 0 - self.error_count = 0 - - def migrate_file(self, file_path: Path, dry_run: bool = False) -> tuple[bool, str]: - """Migrate a single cache file from v0.0 to v0.1. - - Args: - file_path: Path to the cache file - dry_run: If True, don't modify the file - - Returns: - Tuple of (success: bool, message: str) - """ - if not file_path.is_file(): - return False, f"File not found: {file_path}" - - try: - # Read the file - with file_path.open("r") as f: - data = json.load(f) - - # Check if already v0.1 - if isinstance(data, dict) and "metadata" in data: - version = data.get("metadata", {}).get("version") - if version == "0.1": - return False, f"Already v0.1: {file_path.name}" - - # Use CacheWriter to read (automatically migrates) - try: - cache_file = CacheWriter.read_cache_file(file_path) - except Exception as e: # noqa: BLE001 - return False, f"Failed to read cache: {str(e)}" - - # Verify it's now v0.1 - if cache_file.metadata.version != "0.1": - return ( - False, - f"Migration failed: Version is {cache_file.metadata.version}", - ) - - if dry_run: - return True, f"Would migrate: {file_path.name}" - - # Create backup if requested - if self.backup: - backup_path = file_path.with_suffix( - file_path.suffix + self.backup_suffix - ) - shutil.copy2(file_path, backup_path) - logger.debug("Created backup: %s", backup_path) - - # Write migrated version back - with file_path.open("w") as f: - json.dump(cache_file.model_dump(mode="json"), f, indent=2, default=str) - - except Exception as e: # noqa: BLE001 - logger.exception("Error migrating %s", file_path) - return False, f"Error: {str(e)}" - else: - return True, f"Migrated: {file_path.name}" - - def migrate_directory( - self, - cache_dir: Path, - file_pattern: str = "*.json", - dry_run: bool = False, - ) -> dict[str, Any]: - """Migrate all cache files in a directory. - - Args: - cache_dir: Directory containing cache files - file_pattern: Glob pattern for cache files - dry_run: If True, don't modify files - - Returns: - Dictionary with migration statistics - """ - if not cache_dir.is_dir(): - msg = f"Directory not found: {cache_dir}" - raise CacheMigrationError(msg) - - # Find all cache files - cache_files = list(cache_dir.glob(file_pattern)) - - if not cache_files: - logger.warning("No cache files found in %s", cache_dir) - return { - "migrated": 0, - "skipped": 0, - "errors": 0, - "total": 0, - } - - logger.info("Found %s cache files in %s", len(cache_files), cache_dir) - - # Reset counters - self.migrated_count = 0 - self.skipped_count = 0 - self.error_count = 0 - - # Migrate each file - results = [] - for file_path in cache_files: - success, message = self.migrate_file(file_path, dry_run=dry_run) - - if success: - self.migrated_count += 1 - logger.info("✓ %s", message) - elif "Already v0.1" in message: - self.skipped_count += 1 - logger.debug("⊘ %s", message) - else: - self.error_count += 1 - logger.error("✗ %s", message) - - results.append( - {"file": file_path.name, "success": success, "message": message} - ) - - # Log summary - logger.info("\n%s", "=" * 60) - logger.info("Migration Summary:") - logger.info(" Total files: %s", len(cache_files)) - logger.info(" Migrated: %s", self.migrated_count) - logger.info(" Already v0.1: %s", self.skipped_count) - logger.info(" Errors: %s", self.error_count) - logger.info("%s\n", "=" * 60) - - return { - "migrated": self.migrated_count, - "skipped": self.skipped_count, - "errors": self.error_count, - "total": len(cache_files), - "results": results, - } - - -def main() -> int: - """CLI entry point for cache migration. - - Returns: - Exit code (0 for success, 1 for failure) - """ - parser = argparse.ArgumentParser( - description="Migrate cache files from v0.0 to v0.1 format", - formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=""" -Examples: - # Migrate all caches in .cache directory - python -m askui.utils.cache_migration --cache-dir .cache - - # Dry run (show what would be migrated) - python -m askui.utils.cache_migration --cache-dir .cache --dry-run - - # Create backups before migration - python -m askui.utils.cache_migration --cache-dir .cache --backup - - # Custom file pattern - python -m askui.utils.cache_migration --cache-dir .cache --pattern "test_*.json" - """, - ) - - parser.add_argument( - "--cache-dir", - type=str, - required=True, - help="Directory containing cache files to migrate", - ) - - parser.add_argument( - "--pattern", - type=str, - default="*.json", - help="Glob pattern for cache files (default: *.json)", - ) - - parser.add_argument( - "--dry-run", - action="store_true", - help="Show what would be migrated without modifying files", - ) - - parser.add_argument( - "--backup", - action="store_true", - help="Create backup files before migration (adds .v1.backup suffix)", - ) - - parser.add_argument( - "--backup-suffix", - type=str, - default=".v1.backup", - help="Suffix for backup files (default: .v1.backup)", - ) - - parser.add_argument( - "-v", - "--verbose", - action="store_true", - help="Enable verbose logging", - ) - - args = parser.parse_args() - - # Configure logging - logging.basicConfig( - level=logging.DEBUG if args.verbose else logging.INFO, - format="%(levelname)s: %(message)s", - ) - - try: - cache_dir = Path(args.cache_dir) - - if args.dry_run: - logger.info("DRY RUN MODE - No files will be modified\n") - - # Create migration utility - migration = CacheMigration( - backup=args.backup, - backup_suffix=args.backup_suffix, - ) - - # Perform migration - stats = migration.migrate_directory( - cache_dir=cache_dir, - file_pattern=args.pattern, - dry_run=args.dry_run, - ) - - # Return success if no errors - if stats["errors"] == 0: - logger.info("✓ Migration completed successfully!") - else: - logger.error("✗ Migration completed with %s errors", stats["errors"]) - return 1 - - except CacheMigrationError: - logger.exception("Migration failed") - return 1 - except KeyboardInterrupt: - logger.info("\nMigration cancelled by user") - return 1 - except Exception: # noqa: BLE001 - logger.exception("Unexpected error") - return 1 - else: - return 0 - - -if __name__ == "__main__": - exit(main()) diff --git a/tests/unit/utils/test_cache_migration.py b/tests/unit/utils/test_cache_migration.py deleted file mode 100644 index 780cd1a7..00000000 --- a/tests/unit/utils/test_cache_migration.py +++ /dev/null @@ -1,387 +0,0 @@ -"""Tests for cache migration utilities.""" - -import json -from datetime import datetime, timezone -from pathlib import Path -from typing import Any - -import pytest - -from askui.utils.caching.cache_migration import CacheMigration, CacheMigrationError - - -@pytest.fixture -def temp_cache_dir(tmp_path: Path) -> Path: - """Create a temporary cache directory.""" - cache_dir = tmp_path / "caches" - cache_dir.mkdir() - return cache_dir - - -@pytest.fixture -def v1_cache_data() -> list[dict[str, Any]]: - """Sample v0.0 cache data (just a trajectory list).""" - return [ - {"id": "1", "name": "click", "input": {"x": 100}, "type": "tool_use"}, - {"id": "2", "name": "type", "input": {"text": "test"}, "type": "tool_use"}, - ] - - -@pytest.fixture -def v2_cache_data() -> dict[str, Any]: - """Sample v0.1 cache data (with metadata).""" - return { - "metadata": { - "version": "0.1", - "created_at": datetime.now(tz=timezone.utc).isoformat(), - "execution_attempts": 0, - "last_executed_at": None, - "failures": [], - "is_valid": True, - "invalidation_reason": None, - }, - "trajectory": [ - {"id": "1", "name": "click", "input": {"x": 100}, "type": "tool_use"}, - ], - "placeholders": {}, - } - - -# Initialization Tests - - -def test_cache_migration_initialization() -> None: - """Test CacheMigration initializes with correct defaults.""" - migration = CacheMigration() - assert migration.backup is False - assert migration.backup_suffix == ".v1.backup" - assert migration.migrated_count == 0 - assert migration.skipped_count == 0 - assert migration.error_count == 0 - - -def test_cache_migration_initialization_with_backup() -> None: - """Test CacheMigration initializes with backup enabled.""" - migration = CacheMigration(backup=True, backup_suffix=".bak") - assert migration.backup is True - assert migration.backup_suffix == ".bak" - - -# Single File Migration Tests - - -def test_migrate_file_v1_to_v2( - temp_cache_dir: Path, v1_cache_data: list[dict[str, Any]] -) -> None: - """Test migrating a v0.0 cache file to v0.1.""" - cache_file = temp_cache_dir / "test.json" - with cache_file.open("w") as f: - json.dump(v1_cache_data, f) - - migration = CacheMigration() - success, message = migration.migrate_file(cache_file, dry_run=False) - - assert success is True - assert "Migrated" in message - - # Verify file was updated to v0.1 - with cache_file.open("r") as f: - data = json.load(f) - - assert isinstance(data, dict) - assert "metadata" in data - assert data["metadata"]["version"] == "0.1" - assert "trajectory" in data - assert "placeholders" in data - - -def test_migrate_file_already_v2( - temp_cache_dir: Path, v2_cache_data: dict[str, Any] -) -> None: - """Test that v0.1 files are skipped.""" - cache_file = temp_cache_dir / "test.json" - with cache_file.open("w") as f: - json.dump(v2_cache_data, f) - - migration = CacheMigration() - success, message = migration.migrate_file(cache_file, dry_run=False) - - assert success is False - assert "Already v0.1" in message - - -def test_migrate_file_dry_run( - temp_cache_dir: Path, v1_cache_data: list[dict[str, Any]] -) -> None: - """Test dry run doesn't modify files.""" - cache_file = temp_cache_dir / "test.json" - with cache_file.open("w") as f: - json.dump(v1_cache_data, f) - - # Store original content - original_content = cache_file.read_text() - - migration = CacheMigration() - success, message = migration.migrate_file(cache_file, dry_run=True) - - assert success is True - assert "Would migrate" in message - - # Verify file wasn't modified - assert cache_file.read_text() == original_content - - -def test_migrate_file_creates_backup( - temp_cache_dir: Path, v1_cache_data: list[dict[str, Any]] -) -> None: - """Test that backup is created when requested.""" - cache_file = temp_cache_dir / "test.json" - with cache_file.open("w") as f: - json.dump(v1_cache_data, f) - - migration = CacheMigration(backup=True, backup_suffix=".backup") - success, message = migration.migrate_file(cache_file, dry_run=False) - - assert success is True - - # Verify backup exists - backup_file = temp_cache_dir / "test.json.backup" - assert backup_file.exists() - - # Verify backup contains original v0.0 data - with backup_file.open("r") as f: - backup_data = json.load(f) - assert backup_data == v1_cache_data - - -def test_migrate_file_not_found(temp_cache_dir: Path) -> None: - """Test handling of missing file.""" - cache_file = temp_cache_dir / "nonexistent.json" - - migration = CacheMigration() - success, message = migration.migrate_file(cache_file, dry_run=False) - - assert success is False - assert "File not found" in message - - -def test_migrate_file_invalid_json(temp_cache_dir: Path) -> None: - """Test handling of invalid JSON.""" - cache_file = temp_cache_dir / "invalid.json" - cache_file.write_text("not valid json{") - - migration = CacheMigration() - success, message = migration.migrate_file(cache_file, dry_run=False) - - assert success is False - assert "Error" in message - - -# Directory Migration Tests - - -def test_migrate_directory_multiple_files( - temp_cache_dir: Path, v1_cache_data: list[dict[str, Any]] -) -> None: - """Test migrating multiple files in a directory.""" - # Create several v0.0 cache files - for i in range(3): - cache_file = temp_cache_dir / f"cache_{i}.json" - with cache_file.open("w") as f: - json.dump(v1_cache_data, f) - - migration = CacheMigration() - stats = migration.migrate_directory(temp_cache_dir, dry_run=False) - - assert stats["total"] == 3 - assert stats["migrated"] == 3 - assert stats["skipped"] == 0 - assert stats["errors"] == 0 - - -def test_migrate_directory_mixed_versions( - temp_cache_dir: Path, - v1_cache_data: list[dict[str, Any]], - v2_cache_data: dict[str, Any], -) -> None: - """Test migrating directory with mixed v0.0 and v0.1 files.""" - # Create v0.0 files - for i in range(2): - cache_file = temp_cache_dir / f"v1_cache_{i}.json" - with cache_file.open("w") as f: - json.dump(v1_cache_data, f) - - # Create v0.1 files - for i in range(2): - cache_file = temp_cache_dir / f"v2_cache_{i}.json" - with cache_file.open("w") as f: - json.dump(v2_cache_data, f) - - migration = CacheMigration() - stats = migration.migrate_directory(temp_cache_dir, dry_run=False) - - assert stats["total"] == 4 - assert stats["migrated"] == 2 # Only v0.0 files migrated - assert stats["skipped"] == 2 # v0.1 files skipped - assert stats["errors"] == 0 - - -def test_migrate_directory_dry_run( - temp_cache_dir: Path, v1_cache_data: list[dict[str, Any]] -) -> None: - """Test dry run on directory doesn't modify files.""" - cache_file = temp_cache_dir / "test.json" - with cache_file.open("w") as f: - json.dump(v1_cache_data, f) - - original_content = cache_file.read_text() - - migration = CacheMigration() - stats = migration.migrate_directory(temp_cache_dir, dry_run=True) - - assert stats["migrated"] == 1 - # Verify file wasn't modified - assert cache_file.read_text() == original_content - - -def test_migrate_directory_with_pattern( - temp_cache_dir: Path, v1_cache_data: list[dict[str, Any]] -) -> None: - """Test migrating directory with custom file pattern.""" - # Create files with different extensions - for ext in ["json", "cache", "txt"]: - cache_file = temp_cache_dir / f"test.{ext}" - with cache_file.open("w") as f: - json.dump(v1_cache_data, f) - - migration = CacheMigration() - stats = migration.migrate_directory( - temp_cache_dir, file_pattern="*.cache", dry_run=False - ) - - # Only .cache file should be processed - assert stats["total"] == 1 - assert stats["migrated"] == 1 - - -def test_migrate_directory_not_found() -> None: - """Test handling of non-existent directory.""" - migration = CacheMigration() - - with pytest.raises(CacheMigrationError, match="Directory not found"): - migration.migrate_directory(Path("/nonexistent/directory")) - - -def test_migrate_directory_empty(temp_cache_dir: Path) -> None: - """Test migrating empty directory.""" - migration = CacheMigration() - stats = migration.migrate_directory(temp_cache_dir, dry_run=False) - - assert stats["total"] == 0 - assert stats["migrated"] == 0 - assert stats["skipped"] == 0 - assert stats["errors"] == 0 - - -def test_migrate_directory_with_errors( - temp_cache_dir: Path, v1_cache_data: list[dict[str, Any]] -) -> None: - """Test directory migration handles errors gracefully.""" - # Create valid v0.0 file - valid_file = temp_cache_dir / "valid.json" - with valid_file.open("w") as f: - json.dump(v1_cache_data, f) - - # Create invalid file - invalid_file = temp_cache_dir / "invalid.json" - invalid_file.write_text("not valid json{") - - migration = CacheMigration() - stats = migration.migrate_directory(temp_cache_dir, dry_run=False) - - assert stats["total"] == 2 - assert stats["migrated"] == 1 # Valid file migrated - assert stats["errors"] == 1 # Invalid file failed - assert stats["skipped"] == 0 - - -def test_migrate_directory_creates_backups( - temp_cache_dir: Path, v1_cache_data: list[dict[str, Any]] -) -> None: - """Test directory migration creates backups for all files.""" - # Create v0.0 files - for i in range(2): - cache_file = temp_cache_dir / f"cache_{i}.json" - with cache_file.open("w") as f: - json.dump(v1_cache_data, f) - - migration = CacheMigration(backup=True, backup_suffix=".bak") - stats = migration.migrate_directory(temp_cache_dir, dry_run=False) - - assert stats["migrated"] == 2 - - # Verify backups exist - for i in range(2): - backup_file = temp_cache_dir / f"cache_{i}.json.bak" - assert backup_file.exists() - - -# Integration Tests - - -def test_full_migration_workflow( - temp_cache_dir: Path, v1_cache_data: list[dict[str, Any]] -) -> None: - """Test complete migration workflow from v0.0 to v0.1.""" - # Create v0.0 cache - cache_file = temp_cache_dir / "workflow_test.json" - with cache_file.open("w") as f: - json.dump(v1_cache_data, f) - - # Perform migration with backup - migration = CacheMigration(backup=True) - success, message = migration.migrate_file(cache_file, dry_run=False) - - assert success is True - - # Verify v0.1 structure - with cache_file.open("r") as f: - data = json.load(f) - - assert data["metadata"]["version"] == "0.1" - assert data["metadata"]["execution_attempts"] == 0 - assert data["metadata"]["is_valid"] is True - assert len(data["trajectory"]) == 2 - assert data["placeholders"] == {} - - # Verify backup - backup_file = cache_file.with_suffix(cache_file.suffix + ".v1.backup") - assert backup_file.exists() - - # Attempt to migrate again (should skip) - success, message = migration.migrate_file(cache_file, dry_run=False) - assert success is False - assert "Already v0.1" in message - - -def test_migration_preserves_trajectory_data( - temp_cache_dir: Path, v1_cache_data: list[dict[str, Any]] -) -> None: - """Test that migration preserves all trajectory data.""" - cache_file = temp_cache_dir / "preserve_test.json" - with cache_file.open("w") as f: - json.dump(v1_cache_data, f) - - migration = CacheMigration() - migration.migrate_file(cache_file, dry_run=False) - - # Load migrated file - with cache_file.open("r") as f: - data = json.load(f) - - # Verify trajectory preserved - assert len(data["trajectory"]) == len(v1_cache_data) - for i, step in enumerate(data["trajectory"]): - assert step["id"] == v1_cache_data[i]["id"] - assert step["name"] == v1_cache_data[i]["name"] - assert step["input"] == v1_cache_data[i]["input"] From 9d94a7b5d2db2d5022c950bbfb244806514a4fd7 Mon Sep 17 00:00:00 2001 From: philipph-askui Date: Tue, 16 Dec 2025 14:55:47 +0100 Subject: [PATCH 22/30] chore(caching): distributes responsibilities between CacheWriter and ParameterHandler more consistently and renames `placeholder` to `caching_parameter`. --- docs/caching.md | 126 ++--- src/askui/models/shared/settings.py | 8 +- src/askui/prompts/caching.py | 26 +- src/askui/tools/caching_tools.py | 86 +-- src/askui/utils/cache_parameter_handler.py | 489 ++++++++++++++++++ src/askui/utils/caching/cache_writer.py | 89 +--- src/askui/utils/placeholder_handler.py | 292 ----------- src/askui/utils/placeholder_identifier.py | 143 ----- src/askui/utils/trajectory_executor.py | 18 +- tests/unit/tools/test_caching_tools.py | 78 +-- tests/unit/utils/test_cache_manager.py | 2 +- ...ler.py => test_cache_parameter_handler.py} | 155 +++--- tests/unit/utils/test_cache_validator.py | 4 +- tests/unit/utils/test_cache_writer.py | 44 +- tests/unit/utils/test_trajectory_executor.py | 10 +- 15 files changed, 787 insertions(+), 783 deletions(-) create mode 100644 src/askui/utils/cache_parameter_handler.py delete mode 100644 src/askui/utils/placeholder_handler.py delete mode 100644 src/askui/utils/placeholder_identifier.py rename tests/unit/utils/{test_placeholder_handler.py => test_cache_parameter_handler.py} (58%) diff --git a/docs/caching.md b/docs/caching.md index 8f9d8ead..025c639a 100644 --- a/docs/caching.md +++ b/docs/caching.md @@ -8,7 +8,7 @@ The caching mechanism allows you to record and replay agent action sequences (tr The caching system works by recording all tool use actions (mouse movements, clicks, typing, etc.) performed by the agent during an `act()` execution. These recorded sequences can then be replayed in subsequent executions, allowing the agent to skip the decision-making process and execute the actions directly. -**New in v0.1:** The caching system now includes advanced features like placeholder support for dynamic values, smart handling of non-cacheable tools that require agent intervention, comprehensive message history tracking, and automatic failure detection with recovery capabilities. +**New in v0.1:** The caching system now includes advanced features like parameter support for dynamic values, smart handling of non-cacheable tools that require agent intervention, comprehensive message history tracking, and automatic failure detection with recovery capabilities. ## Caching Strategies @@ -31,7 +31,7 @@ caching_settings = CachingSettings( cache_dir=".cache", # Directory to store cache files filename="my_test.json", # Filename for the cache file (optional for write mode) cache_writer_settings=CacheWriterSettings( - placeholder_identification_strategy="llm", + parameter_identification_strategy="llm", ) # Auto-detect dynamic values (default: "llm") execute_cached_trajectory_tool_settings=CachedExecutionToolSettings( delay_time_between_action=0.5 # Delay in seconds between each cached action @@ -49,8 +49,8 @@ caching_settings = CachingSettings( ### CacheWriter Settings -- `placeholder_identification_strategy`: When `llm` (default), uses AI to automatically identify and parameterize dynamic values like dates, usernames, and IDs during cache recording. When `preset`, only manually specified placeholders (using `{{...}}` syntax) are detected. See [Automatic Placeholder Identification](#automatic-placeholder-identification). -- `llm_placeholder_id_api_provider`: The provider of that will be used for for the llm in the placeholder identification (will only be used if `placeholder_identification_strategy`is set to `llm`). Defaults to `askui`. +- `parameter_identification_strategy`: When `llm` (default), uses AI to automatically identify and parameterize dynamic values like dates, usernames, and IDs during cache recording. When `preset`, only manually specified cache_parameters (using `{{...}}` syntax) are detected. See [Automatic Cache Parameter Identification](#automatic-parameter-identification). +- `llm_parameter_id_api_provider`: The provider of that will be used for for the llm in the parameter identification (will only be used if `parameter_identification_strategy`is set to `llm`). Defaults to `askui`. ### Execution Settings @@ -120,9 +120,9 @@ When using `strategy="read"`, the agent receives two tools: The agent will automatically check if a relevant cached trajectory exists and use it if appropriate. During execution, the agent can see all screenshots and results in the message history. After executing a cached trajectory, the agent will verify the results and make corrections if needed. -### Using Placeholders for Dynamic Values +### Using Cache Parameters for Dynamic Values -**New in v0.1:** Trajectories can contain placeholders for dynamic values that change between executions: +**New in v0.1:** Trajectories can contain cache_parameters for dynamic values that change between executions: ```python from askui import VisionAgent @@ -140,7 +140,7 @@ with VisionAgent() as agent: ) ) -# Later, when replaying, the agent can provide placeholder values +# Later, when replaying, the agent can provide parameter values # If the cache file contains {{current_date}} or {{task_title}}, provide them: with VisionAgent() as agent: agent.act( @@ -150,11 +150,11 @@ with VisionAgent() as agent: cache_dir=".cache" ) ) - # The agent will automatically detect required placeholders and can provide them - # via the placeholder_values parameter when calling ExecuteCachedTrajectory + # The agent will automatically detect required cache_parameters and can provide them + # via the parameter_values parameter when calling ExecuteCachedTrajectory ``` -Placeholders use the syntax `{{variable_name}}` and are automatically detected during cache file creation. When executing a trajectory with placeholders, the agent must provide values for all required placeholders. +Cache Parameters use the syntax `{{variable_name}}` and are automatically detected during cache file creation. When executing a trajectory with cache_parameters, the agent must provide values for all required cache_parameters. ### Handling Non-Cacheable Steps @@ -190,7 +190,7 @@ Tools can be marked as non-cacheable by setting `is_cacheable=False` in their de result = execute_cached_trajectory_tool( trajectory_file=".cache/my_test.json", start_from_step_index=5, # Continue from step 5 - placeholder_values={"date": "2025-12-11"} # Provide any required placeholders + parameter_values={"date": "2025-12-11"} # Provide any required cache_parameters ) ``` @@ -351,7 +351,7 @@ In this mode: ## Cache File Format -**New in v0.1:** Cache files now use an enhanced format with metadata tracking, placeholder support, and execution history. +**New in v0.1:** Cache files now use an enhanced format with metadata tracking, parameter support, and execution history. ### v0.1 Format (Current) @@ -393,7 +393,7 @@ Cache files are JSON objects with the following structure: "input": {} } ], - "placeholders": { + "cache_parameters": { "user_name": "Name of the user to greet" } } @@ -405,16 +405,16 @@ Cache files are JSON objects with the following structure: - **`version`**: Cache file format version (currently "0.1") - **`created_at`**: ISO 8601 timestamp when the cache was created -- **`goal`**: **New!** The original goal/instruction given to the agent when recording this trajectory. Placeholders are applied to the goal text just like in the trajectory, making it easy to understand what the cache was designed to accomplish. +- **`goal`**: **New!** The original goal/instruction given to the agent when recording this trajectory. Cache Parameters are applied to the goal text just like in the trajectory, making it easy to understand what the cache was designed to accomplish. - **`last_executed_at`**: ISO 8601 timestamp of the last execution (null if never executed) - **`execution_attempts`**: Number of times this trajectory has been executed - **`failures`**: List of failures encountered during execution (see [Failure Tracking](#failure-tracking)) - **`is_valid`**: Boolean indicating if the cache is still considered valid - **`invalidation_reason`**: Optional string explaining why the cache was invalidated -#### Placeholders +#### Cache Parameters -The `placeholders` object maps placeholder names to their descriptions. Placeholders in the trajectory use the syntax `{{placeholder_name}}` and must be substituted with actual values during execution. +The `cache_parameters` object maps parameter names to their descriptions. Cache Parameters in the trajectory use the syntax `{{parameter_name}}` and must be substituted with actual values during execution. #### Failure Tracking @@ -473,16 +473,16 @@ In write mode, the `CacheWriter` class: 2. Extracts tool use blocks from the messages 3. Stores tool blocks in memory during execution 4. When agent finishes (on `stop_reason="end_turn"`): - - **Automatically identifies placeholders** using AI (if `placeholder_identification_strategy=llm`) + - **Automatically identifies cache_parameters** using AI (if `parameter_identification_strategy=llm`) - Analyzes trajectory to find dynamic values (dates, usernames, IDs, etc.) - - Generates descriptive placeholder definitions - - Replaces identified values with `{{placeholder_name}}` syntax in trajectory + - Generates descriptive parameter definitions + - Replaces identified values with `{{parameter_name}}` syntax in trajectory - Applies same replacements to the goal text - **Blanks non-cacheable tool inputs** by setting `input: {}` for tools with `is_cacheable=False` (saves space and privacy) - **Writes to JSON file** with: - - v0.1 metadata (version, timestamps, goal with placeholders) - - Trajectory of tool use blocks (with placeholders and blanked inputs) - - Placeholder definitions with descriptions + - v0.1 metadata (version, timestamps, goal with cache_parameters) + - Trajectory of tool use blocks (with cache_parameters and blanked inputs) + - Parameter definitions with descriptions 5. Automatically skips writing if a cached execution was used (to avoid recording replays) ### Read Mode @@ -494,7 +494,7 @@ In read mode: - `ExecuteCachedTrajectory`: Executes from the beginning or continues from a specific step using `start_from_step_index` 2. A special system prompt (`CACHE_USE_PROMPT`) instructs the agent on: - How to use trajectories - - Placeholder handling + - Parameter handling - Non-cacheable step management - Failure recovery strategies 3. The agent can list available cache files and choose appropriate ones @@ -502,7 +502,7 @@ In read mode: - Each step is executed sequentially with configurable delays - All tools in the trajectory are executed, including screenshots and retrieval tools - Non-cacheable tools trigger a pause with `NEEDS_AGENT` status - - Placeholders are validated and substituted before execution + - Cache Parameters are validated and substituted before execution - Message history is built with assistant (tool use) and user (tool result) messages - Agent sees all screenshots and results in the message history 5. Execution can pause for agent intervention: @@ -624,26 +624,26 @@ Agent calls ExecuteCachedTrajectory(start_from_step_index=6) Execution continues successfully ``` -## Placeholders +## Cache Parameters -**New in v0.1:** Placeholders enable dynamic value substitution in cached trajectories. +**New in v0.1:** Cache Parameters enable dynamic value substitution in cached trajectories. -### Placeholder Syntax +### Parameter Syntax -Placeholders use double curly braces: `{{placeholder_name}}` +Cache Parameters use double curly braces: `{{parameter_name}}` -Valid placeholder names: +Valid parameter names: - Must start with a letter or underscore - Can contain letters, numbers, and underscores - Examples: `{{date}}`, `{{user_name}}`, `{{order_id_123}}` -### Automatic Placeholder Identification +### Automatic Cache Parameter Identification **New in v0.1!** The caching system uses AI to automatically identify and parameterize dynamic values when recording trajectories. #### How It Works -When `placeholder_identification_strategy=llm` (the default), the system: +When `parameter_identification_strategy=llm` (the default), the system: 1. **Records the trajectory** as normal during agent execution 2. **Analyzes the trajectory** using an LLM to identify dynamic values such as: @@ -653,7 +653,7 @@ When `placeholder_identification_strategy=llm` (the default), the system: - Dynamic text referencing current state or time - File paths with user-specific or time-specific components - Temporary or generated identifiers -3. **Generates placeholder definitions** with descriptive names and documentation: +3. **Generates parameter definitions** with descriptive names and documentation: ```json { "name": "current_date", @@ -661,7 +661,7 @@ When `placeholder_identification_strategy=llm` (the default), the system: "description": "Current date in YYYY-MM-DD format" } ``` -4. **Replaces values with placeholders** in both the trajectory AND the goal: +4. **Replaces values with cache_parameters** in both the trajectory AND the goal: - Original: `"text": "Login as john.doe"` - Result: `"text": "Login as {{username}}"` 5. **Saves the templated trajectory** to the cache file @@ -670,14 +670,14 @@ When `placeholder_identification_strategy=llm` (the default), the system: ✅ **No manual work** - Automatically identifies dynamic values ✅ **Smart detection** - LLM understands semantic meaning (dates vs coordinates) -✅ **Descriptive** - Generates helpful descriptions for each placeholder -✅ **Applies to goal** - Goal text also gets placeholder replacement +✅ **Descriptive** - Generates helpful descriptions for each parameter +✅ **Applies to goal** - Goal text also gets parameter replacement #### What Gets Detected The AI identifies values that are likely to change between executions: -**Will be detected as placeholders:** +**Will be detected as cache_parameters:** - Dates: "2025-12-11", "Dec 11, 2025", "12/11/2025" - Times: "10:30 AM", "14:45:00", "2025-12-11T10:30:00Z" - Usernames: "john.doe", "admin_user", "test_account" @@ -686,7 +686,7 @@ The AI identifies values that are likely to change between executions: - Names: "John Smith", "Jane Doe" - Dynamic text: "Today is 2025-12-11", "Logged in as john.doe" -**Will NOT be detected as placeholders:** +**Will NOT be detected as cache_parameters:** - UI coordinates: `{"x": 100, "y": 200}` - Fixed button labels: "Submit", "Cancel", "OK" - Configuration values: `{"timeout": 30, "retries": 3}` @@ -695,22 +695,22 @@ The AI identifies values that are likely to change between executions: #### Disabling Auto-Identification -If you prefer manual placeholder control: +If you prefer manual parameter control: ```python caching_settings = CachingSettings( strategy="write", cache_writer_settings = CacheWriterSettings( - placeholder_identification_strategy="default" # Only detect {{...}} syntax + parameter_identification_strategy="default" # Only detect {{...}} syntax ) ) ``` -With `placeholder_identification_strategy=default`, only manually specified placeholders using the `{{...}}` syntax will be detected. +With `parameter_identification_strategy=default`, only manually specified cache_parameters using the `{{...}}` syntax will be detected. #### Logging -To see what placeholders are being identified, enable INFO-level logging: +To see what cache_parameters are being identified, enable INFO-level logging: ```python import logging @@ -719,28 +719,28 @@ logging.basicConfig(level=logging.INFO) You'll see output like: ``` -INFO: Using LLM to identify placeholders in trajectory -INFO: Identified 3 placeholders in trajectory +INFO: Using LLM to identify cache_parameters in trajectory +INFO: Identified 3 cache_parameters in trajectory DEBUG: - current_date: 2025-12-11 (Current date in YYYY-MM-DD format) DEBUG: - username: john.doe (Username for login) DEBUG: - session_id: abc123 (Session identifier) -INFO: Replaced 3 placeholder values in trajectory -INFO: Applied placeholder replacement to goal: Login as john.doe -> Login as {{username}} +INFO: Replaced 3 parameter values in trajectory +INFO: Applied parameter replacement to goal: Login as john.doe -> Login as {{username}} ``` -### Manual Placeholders +### Manual Cache Parameters -You can also manually create placeholders when recording by using the syntax in your goal description. The system will preserve `{{...}}` patterns in tool inputs. +You can also manually create cache_parameters when recording by using the syntax in your goal description. The system will preserve `{{...}}` patterns in tool inputs. -### Providing Placeholder Values +### Providing Parameter Values -When executing a trajectory with placeholders, the agent must provide values: +When executing a trajectory with cache_parameters, the agent must provide values: ```python # Via ExecuteCachedTrajectory result = execute_cached_trajectory_tool( trajectory_file=".cache/my_test.json", - placeholder_values={ + parameter_values={ "current_date": "2025-12-11", "user_email": "test@example.com" } @@ -750,24 +750,24 @@ result = execute_cached_trajectory_tool( result = execute_cached_trajectory_tool( trajectory_file=".cache/my_test.json", start_from_step_index=3, # Continue from step 3 - placeholder_values={ + parameter_values={ "current_date": "2025-12-11", "user_email": "test@example.com" } ) ``` -### Placeholder Validation +### Parameter Validation Before execution, the system validates that: -- All required placeholders have values provided -- No required placeholders are missing +- All required cache_parameters have values provided +- No required cache_parameters are missing -If validation fails, execution is aborted with a clear error message listing missing placeholders. +If validation fails, execution is aborted with a clear error message listing missing cache_parameters. ### Use Cases -Placeholders are particularly useful for: +Cache Parameters are particularly useful for: - **Date-dependent workflows**: Testing with current/future dates - **User-specific actions**: Different users, emails, names - **Order/transaction IDs**: Testing with different identifiers @@ -797,7 +797,7 @@ Example: 1. **Always Verify Results**: After cached execution, verify the outcome matches expectations 2. **Handle Failures Gracefully**: Provide clear recovery paths when trajectories fail -3. **Use Placeholders Wisely**: Identify dynamic values that should be parameterized +3. **Use Cache Parameters Wisely**: Identify dynamic values that should be parameterized 4. **Mark Non-Cacheable Tools**: Properly mark tools that require agent intervention 5. **Monitor Cache Validity**: Track execution attempts and failures to identify stale caches 6. **Test Cache Replay**: Periodically test that cached trajectories still work @@ -835,7 +835,7 @@ When a v0.0 cache file (simple JSON array) is read: "invalidation_reason": null } ``` -4. Extracts any placeholders found in trajectory +4. Extracts any cache_parameters found in trajectory 5. Returns fully-formed `CacheFile` object ### Compatibility Guarantees @@ -902,7 +902,7 @@ class PrintTool(Tool): ) self.is_cacheable = False - # Agent will detect placeholders and provide new values: + # Agent will detect cache_parameters and provide new values: def __call__(self, text: str) -> None: print(text) @@ -970,9 +970,9 @@ Planned features for future versions: - **Cause**: UI has changed since recording - **Solution**: Take a screenshot to compare, re-record the trajectory, or manually execute failing steps -**Issue**: "Missing required placeholders" error -- **Cause**: Trajectory contains placeholders but values weren't provided -- **Solution**: Check cache metadata for required placeholders and provide values via `placeholder_values` parameter +**Issue**: "Missing required cache_parameters" error +- **Cause**: Trajectory contains cache_parameters but values weren't provided +- **Solution**: Check cache metadata for required cache_parameters and provide values via `parameter_values` parameter **Issue**: Execution pauses unexpectedly - **Cause**: Trajectory contains non-cacheable tool @@ -991,7 +991,7 @@ Planned features for future versions: 1. **Check message history**: After execution, review `message_history` in the result to see exactly what happened 2. **Monitor failure metadata**: Track `execution_attempts` and `failures` in cache metadata 3. **Test incrementally**: Use `ExecuteCachedTrajectory` with `start_from_step_index` to test specific sections of a trajectory -4. **Verify placeholders**: Print cache metadata to see what placeholders are expected +4. **Verify cache_parameters**: Print cache metadata to see what cache_parameters are expected 5. **Adjust delays**: If timing issues occur, increase `delay_time_between_action` incrementally For more help, see the [GitHub Issues](https://github.com/askui/vision-agent/issues) or contact support. diff --git a/src/askui/models/shared/settings.py b/src/askui/models/shared/settings.py index b22b1cca..e2f1480a 100644 --- a/src/askui/models/shared/settings.py +++ b/src/askui/models/shared/settings.py @@ -18,7 +18,7 @@ COMPUTER_USE_20251124_BETA_FLAG = "computer-use-2025-11-24" CACHING_STRATEGY = Literal["read", "write", "both", "no"] -PLACEHOLDER_IDENTIFICATION_STRATEGY = Literal["llm", "preset"] +CACHE_PARAMETER_IDENTIFICATION_STRATEGY = Literal["llm", "preset"] class MessageSettings(BaseModel): @@ -43,8 +43,8 @@ class CachedExecutionToolSettings(BaseModel): class CacheWriterSettings(BaseModel): - placeholder_identification_strategy: PLACEHOLDER_IDENTIFICATION_STRATEGY = "llm" - llm_placeholder_id_api_provider: AnthropicApiProvider = "askui" + parameter_identification_strategy: CACHE_PARAMETER_IDENTIFICATION_STRATEGY = "llm" + llm_parameter_id_api_provider: AnthropicApiProvider = "askui" class CachingSettings(BaseModel): @@ -81,4 +81,4 @@ class CacheFile(BaseModel): metadata: CacheMetadata trajectory: list[ToolUseBlockParam] - placeholders: dict[str, str] = Field(default_factory=dict) + cache_parameters: dict[str, str] = Field(default_factory=dict) diff --git a/src/askui/prompts/caching.py b/src/askui/prompts/caching.py index 20c351b6..7be63b5b 100644 --- a/src/askui/prompts/caching.py +++ b/src/askui/prompts/caching.py @@ -19,16 +19,16 @@ " - If execution fails partway, you'll see exactly where it failed " "and can decide how to proceed\n" "\n" - " PLACEHOLDERS:\n" - " - Trajectories may contain dynamic placeholders like " + " CACHING_PARAMETERS:\n" + " - Trajectories may contain dynamic parameters like " "{{current_date}} or {{user_name}}\n" " - When executing a trajectory, check if it requires " - "placeholder values\n" - " - Provide placeholder values using the placeholder_values " + "parameter values\n" + " - Provide parameter values using the parameter_values " "parameter as a dictionary\n" " - Example: ExecuteCachedTrajectory(trajectory_file='test.json', " - "placeholder_values={'current_date': '2025-12-11'})\n" - " - If required placeholders are missing, execution will fail with " + "parameter_values={'current_date': '2025-12-11'})\n" + " - If required parameters are missing, execution will fail with " "a clear error message\n" "\n" " NON-CACHEABLE STEPS:\n" @@ -49,7 +49,7 @@ " - Provide the same trajectory file and the step index where " "execution should continue\n" " - Example: ExecuteCachedTrajectory(trajectory_file='test.json', " - "start_from_step_index=5, placeholder_values={...})\n" + "start_from_step_index=5, parameter_values={...})\n" " - The tool will execute remaining steps from that index onwards\n" "\n" " FAILURE HANDLING:\n" @@ -83,8 +83,8 @@ " \n" ) -PLACEHOLDER_IDENTIFIER_SYSTEM_PROMPT = """You are analyzing UI automation trajectories \ -to identify values that should be parameterized as placeholders. +CACHING_PARAMETER_IDENTIFIER_SYSTEM_PROMPT = """You are analyzing UI automation \ +trajectories to identify values that should be parameterized as parameters. Identify values that are likely to change between executions, such as: - Dates and timestamps (e.g., "2025-12-11", "10:30 AM", "2025-12-11T14:30:00Z") @@ -94,7 +94,7 @@ - File paths with user-specific or time-specific components - Temporary or generated identifiers -DO NOT mark as placeholders: +DO NOT mark as parameters: - UI element coordinates (x, y positions) - Fixed button labels or static UI text - Configuration values that don't change (e.g., timeouts, retry counts) @@ -102,14 +102,14 @@ - Tool names - Boolean values or common constants -For each placeholder, provide: +For each parameter, provide: 1. A descriptive name in snake_case (e.g., "current_date", "user_email") 2. The actual value found in the trajectory 3. A brief description of what it represents Return your analysis as a JSON object with this structure: { - "placeholders": [ + "parameters": [ { "name": "current_date", "value": "2025-12-11", @@ -118,4 +118,4 @@ ] } -If no placeholders are found, return an empty placeholders array.""" +If no parameters are found, return an empty parameters array.""" diff --git a/src/askui/tools/caching_tools.py b/src/askui/tools/caching_tools.py index 7714b53e..8705f203 100644 --- a/src/askui/tools/caching_tools.py +++ b/src/askui/tools/caching_tools.py @@ -8,10 +8,10 @@ from ..models.shared.settings import CachedExecutionToolSettings from ..models.shared.tools import Tool, ToolCollection +from ..utils.cache_parameter_handler import CacheParameterHandler from ..utils.caching.cache_execution_manager import CacheExecutionManager from ..utils.caching.cache_manager import CacheManager from ..utils.caching.cache_writer import CacheWriter -from ..utils.placeholder_handler import PlaceholderHandler if TYPE_CHECKING: from ..models.shared.agent_message_param import ToolUseBlockParam @@ -149,14 +149,14 @@ def __init__( "trajectory files are available\n" "2. Select the appropriate trajectory file path from the " "returned list\n" - "3. If the trajectory contains placeholders (e.g., " + "3. If the trajectory contains parameters (e.g., " "{{current_date}}), provide values for them in the " - "placeholder_values parameter\n" + "parameter_values parameter\n" "4. Pass the full file path to this tool\n\n" - "Placeholders allow dynamic values to be injected during " + "Cache parameters allow dynamic values to be injected during " "execution. For example, if a trajectory types " "'{{current_date}}', you must provide " - "placeholder_values={'current_date': '2025-12-11'}.\n\n" + "parameter_values={'current_date': '2025-12-11'}.\n\n" "To continue from a specific step (e.g., after manually " "handling a non-cacheable step), use the start_from_step_index " "parameter. By default, execution starts from the beginning " @@ -188,12 +188,12 @@ def __init__( ), "default": 0, }, - "placeholder_values": { + "parameter_values": { "type": "object", "description": ( - "Optional dictionary mapping placeholder names to " + "Optional dictionary mapping parameter names to " "their values. Required if the trajectory contains " - "placeholders like {{variable}}. Example: " + "parameters like {{variable}}. Example: " "{'current_date': '2025-12-11', 'user_name': 'Alice'}" ), "additionalProperties": {"type": "string"}, @@ -265,33 +265,33 @@ def _validate_step_index( return error_msg return None - def _validate_placeholders( + def _validate_parameters( self, trajectory: list["ToolUseBlockParam"], - placeholder_values: dict[str, str], - cache_placeholders: dict[str, str], + parameter_values: dict[str, str], + cache_parameters: dict[str, str], ) -> str | None: - """Validate placeholder values. + """Validate parameter values. Args: trajectory: The cached trajectory - placeholder_values: User-provided placeholder values - cache_placeholders: Placeholders defined in cache file + parameter_values: User-provided parameter values + cache_parameters: Parameters defined in cache file Returns: Error message if validation fails, None otherwise """ - logger.debug("Validating placeholder values") - is_valid, missing = PlaceholderHandler.validate_placeholders( - trajectory, placeholder_values + logger.debug("Validating parameter values") + is_valid, missing = CacheParameterHandler.validate_parameters( + trajectory, parameter_values ) if not is_valid: error_msg = ( - f"Missing required placeholder values: {', '.join(missing)}\n" - f"The trajectory contains the following placeholders: " - f"{', '.join(cache_placeholders.keys())}\n" - f"Please provide values for all placeholders in the " - f"placeholder_values parameter." + f"Missing required parameter values: {', '.join(missing)}\n" + f"The trajectory contains the following parameters: " + f"{', '.join(cache_parameters.keys())}\n" + f"Please provide values for all parameters in the " + f"parameter_values parameter." ) logger.error(error_msg) return error_msg @@ -300,14 +300,14 @@ def _validate_placeholders( def _create_executor( self, cache_file: "CacheFile", - placeholder_values: dict[str, str], + parameter_values: dict[str, str], start_from_step_index: int, ) -> "TrajectoryExecutor": """Create and configure trajectory executor. Args: cache_file: The cache file to execute - placeholder_values: Placeholder values to use + parameter_values: Parameter values to use start_from_step_index: Index to start execution from Returns: @@ -324,7 +324,7 @@ def _create_executor( executor = TrajectoryExecutor( trajectory=cache_file.trajectory, toolbox=self._toolbox, - placeholder_values=placeholder_values, + parameter_values=parameter_values, delay_time=self._settings.delay_time_between_action, ) @@ -342,7 +342,7 @@ def _format_success_message( trajectory_file: str, trajectory_length: int, start_from_step_index: int, - placeholder_count: int, + parameter_count: int, ) -> str: """Format success message. @@ -350,7 +350,7 @@ def _format_success_message( trajectory_file: Path to trajectory file trajectory_length: Total steps in trajectory start_from_step_index: Starting step index - placeholder_count: Number of placeholders used + parameter_count: Number of parameters used Returns: Formatted success message @@ -369,8 +369,8 @@ def _format_success_message( f"Will execute {remaining_steps} remaining cached steps." ) - if placeholder_count > 0: - success_msg += f" Using {placeholder_count} placeholder value(s)." + if parameter_count > 0: + success_msg += f" Using {parameter_count} parameter value(s)." return success_msg @@ -380,7 +380,7 @@ def __call__( self, trajectory_file: str, start_from_step_index: int = 0, - placeholder_values: dict[str, str] | None = None, + parameter_values: dict[str, str] | None = None, ) -> str: """Activate cache execution mode for the agent. @@ -390,8 +390,8 @@ def __call__( Returns: Success message indicating cache mode has been activated """ - if placeholder_values is None: - placeholder_values = {} + if parameter_values is None: + parameter_values = {} logger.info( "Activating cache execution mode: %s (start_from_step=%d)", @@ -417,9 +417,9 @@ def __call__( cache_file = CacheWriter.read_cache_file(Path(trajectory_file)) logger.debug( - "Cache loaded: %d steps, %d placeholders, valid=%s", + "Cache loaded: %d steps, %d parameters, valid=%s", len(cache_file.trajectory), - len(cache_file.placeholders), + len(cache_file.cache_parameters), cache_file.metadata.is_valid, ) @@ -439,15 +439,15 @@ def __call__( ): return error - # Validate placeholders - if error := self._validate_placeholders( - cache_file.trajectory, placeholder_values, cache_file.placeholders + # Validate parameters + if error := self._validate_parameters( + cache_file.trajectory, parameter_values, cache_file.cache_parameters ): return error # Create and configure executor executor = self._create_executor( - cache_file, placeholder_values, start_from_step_index + cache_file, parameter_values, start_from_step_index ) # Store executor and cache info in agent state @@ -462,7 +462,7 @@ def __call__( trajectory_file, len(cache_file.trajectory), start_from_step_index, - len(placeholder_values), + len(parameter_values), ) logger.info(success_msg) return success_msg @@ -483,7 +483,7 @@ def __init__(self) -> None: "- Execution statistics (attempts, last execution time)\n" "- Validity status and invalidation reason (if invalid)\n" "- Failure history with timestamps and error messages\n" - "- Placeholders and trajectory step count\n\n" + "- Parameters and trajectory step count\n\n" "Use this tool to debug cache issues or understand why a cache " "might be failing or invalidated." ), @@ -556,10 +556,10 @@ def __call__(self, trajectory_file: str) -> str: lines.append("") lines.append("--- Trajectory Info ---") lines.append(f"Total Steps: {len(cache_file.trajectory)}") - lines.append(f"Placeholders: {len(cache_file.placeholders)}") - if cache_file.placeholders: + lines.append(f"Parameters: {len(cache_file.cache_parameters)}") + if cache_file.cache_parameters: lines.append( - f"Placeholder Names: {', '.join(cache_file.placeholders.keys())}" + f"Parameter Names: {', '.join(cache_file.cache_parameters.keys())}" ) if metadata.failures: diff --git a/src/askui/utils/cache_parameter_handler.py b/src/askui/utils/cache_parameter_handler.py new file mode 100644 index 00000000..f036757f --- /dev/null +++ b/src/askui/utils/cache_parameter_handler.py @@ -0,0 +1,489 @@ +"""Cache parameter handling for trajectory recording and execution. + +This module provides utilities for: +- Identifying dynamic values that should become parameters (recording phase) +- Validating and substituting parameter values (execution phase) + +Cache parameters use the {{parameter_name}} syntax and allow dynamic values +to be injected during cache execution. +""" + +import json +import logging +import re +from typing import Any + +from askui.locators.serializers import VlmLocatorSerializer +from askui.models.anthropic.factory import AnthropicApiProvider +from askui.models.anthropic.messages_api import AnthropicMessagesApi +from askui.models.model_router import create_api_client +from askui.models.shared.agent_message_param import MessageParam, ToolUseBlockParam +from askui.models.shared.messages_api import MessagesApi +from askui.prompts.caching import CACHING_PARAMETER_IDENTIFIER_SYSTEM_PROMPT + +logger = logging.getLogger(__name__) + +# Regex pattern for matching parameters: {{parameter_name}} +# Allows alphanumeric characters and underscores, must start with letter/underscore +CACHE_PARAMETER_PATTERN = r"\{\{([a-zA-Z_][a-zA-Z0-9_]*)\}\}" + + +class CacheParameterDefinition: + """Represents a cache parameter identified in a trajectory.""" + + def __init__(self, name: str, value: Any, description: str) -> None: + self.name = name + self.value = value + self.description = description + + def __repr__(self) -> str: + return f"CacheParameterDefinition(name={self.name}, value={self.value})" + + +class CacheParameterHandler: + """Handles all cache parameter operations for trajectory recording and execution.""" + + # ======================================================================== + # RECORDING PHASE: Parameter identification and templatization + # ======================================================================== + + @staticmethod + def identify_and_parameterize( + trajectory: list[ToolUseBlockParam], + goal: str | None, + identification_strategy: str, + api_provider: AnthropicApiProvider = "askui", + model: str = "claude-sonnet-4-5-20250929", + ) -> tuple[str | None, list[ToolUseBlockParam], dict[str, str]]: + """Identify parameters and return parameterized trajectory + goal. + + This is the main entry point for the recording phase. It orchestrates + parameter identification and templatization of both trajectory and goal. + + Args: + trajectory: The trajectory to analyze and parameterize + goal: The goal text to parameterize (optional) + identification_strategy: "llm" for AI-based or "preset" for manual + api_provider: API provider for LLM calls (only used for "llm" strategy) + model: Model to use for LLM-based identification + + Returns: + Tuple of: + - Parameterized goal text (or None if no goal) + - Parameterized trajectory (with {{param}} syntax) + - Dict mapping parameter names to descriptions + """ + if identification_strategy == "llm" and trajectory: + # Create messages_api for LLM-based identification + messages_api = AnthropicMessagesApi( + client=create_api_client(api_provider), + locator_serializer=VlmLocatorSerializer(), + ) + + # Use LLM to identify parameters + parameters_dict, parameter_definitions = ( + CacheParameterHandler._identify_parameters_with_llm( + trajectory, messages_api, model + ) + ) + + if parameter_definitions: + # Replace values with {{parameter}} syntax in trajectory + parameterized_trajectory = ( + CacheParameterHandler._replace_values_with_parameters( + trajectory, parameter_definitions + ) + ) + + # Apply same replacement to goal text + parameterized_goal = goal + if goal: + parameterized_goal = ( + CacheParameterHandler._apply_parameters_to_text( + goal, parameter_definitions + ) + ) + + n_parameters = len(parameter_definitions) + logger.info("Replaced %s parameter values in trajectory", n_parameters) + return parameterized_goal, parameterized_trajectory, parameters_dict + + # No parameters identified + logger.info("No parameters identified in trajectory") + return goal, trajectory, {} + + # Manual extraction (preset strategy) + parameter_names = CacheParameterHandler.extract_parameters(trajectory) + parameters_dict = { + name: f"Parameter for {name}" + for name in parameter_names # Generic desc + } + n_parameters = len(parameter_names) + logger.info("Extracted %s manual parameters from trajectory", n_parameters) + return goal, trajectory, parameters_dict + + @staticmethod + def _identify_parameters_with_llm( + trajectory: list[ToolUseBlockParam], + messages_api: MessagesApi, + model: str = "claude-sonnet-4-5-20250929", + ) -> tuple[dict[str, str], list[CacheParameterDefinition]]: + """Identify parameters in a trajectory using LLM analysis. + + Args: + trajectory: The trajectory to analyze (list of tool use blocks) + messages_api: Messages API instance for LLM calls + model: Model to use for analysis + + Returns: + Tuple of: + - Dict mapping parameter names to descriptions + - List of CacheParameterDefinition objects with name, value, and description + """ + if not trajectory: + logger.debug("Empty trajectory provided, skipping parameter identification") + return {}, [] + + logger.info( + "Starting parameter identification for trajectory with %s steps", + len(trajectory), + ) + + # Convert trajectory to serializable format for analysis + trajectory_data = [tool.model_dump(mode="json") for tool in trajectory] + logger.debug("Converted %s tool blocks to JSON format", len(trajectory_data)) + + user_message = ( + "Analyze this UI automation trajectory and identify all values that " + "should be parameters:\n\n" + f"```json\n{json.dumps(trajectory_data, indent=2)}\n```\n\n" + "Return only the JSON object with identified parameters. " + "Be thorough but conservative - only mark values that are clearly " + "dynamic or time-sensitive." + ) + + response_text = "" # Initialize for error logging + try: + # Make single API call + logger.debug("Calling LLM (%s) to analyze trajectory for parameters", model) + response = messages_api.create_message( + messages=[MessageParam(role="user", content=user_message)], + model=model, + system=CACHING_PARAMETER_IDENTIFIER_SYSTEM_PROMPT, + max_tokens=4096, + temperature=0.0, # Deterministic for analysis + ) + logger.debug("Received response from LLM") + + # Extract text from response + if isinstance(response.content, list): + response_text = next( + ( + block.text + for block in response.content + if hasattr(block, "text") + ), + "", + ) + else: + response_text = str(response.content) + + # Parse the JSON response + logger.debug("Parsing LLM response to extract parameter definitions") + # Handle markdown code blocks if present + if "```json" in response_text: + logger.debug("Removing JSON markdown code block wrapper from response") + response_text = ( + response_text.split("```json")[1].split("```")[0].strip() + ) + elif "```" in response_text: + logger.debug("Removing code block wrapper from response") + response_text = response_text.split("```")[1].split("```")[0].strip() + + parameter_data = json.loads(response_text) + logger.debug( + "Successfully parsed JSON response with %s parameters", + len(parameter_data.get("parameters", [])), + ) + + # Convert to our data structures + parameter_definitions = [ + CacheParameterDefinition( + name=p["name"], value=p["value"], description=p["description"] + ) + for p in parameter_data.get("parameters", []) + ] + + parameters_dict = {p.name: p.description for p in parameter_definitions} + + if parameter_definitions: + logger.info( + "Successfully identified %s parameters in trajectory", + len(parameter_definitions), + ) + for p in parameter_definitions: + logger.debug(" - %s: %s (%s)", p.name, p.value, p.description) + else: + logger.info( + "No parameters identified in trajectory " + "(this is normal for trajectories with only static values)" + ) + + except json.JSONDecodeError as e: + logger.warning( + "Failed to parse LLM response as JSON: %s. " + "Falling back to empty parameter list.", + e, + extra={"response_text": response_text[:500]}, # Log first 500 chars + ) + return {}, [] + except Exception as e: # noqa: BLE001 + logger.warning( + "Failed to identify parameters with LLM: %s. " + "Falling back to empty parameter list.", + e, + exc_info=True, + ) + return {}, [] + else: + return parameters_dict, parameter_definitions + + @staticmethod + def _replace_values_with_parameters( + trajectory: list[ToolUseBlockParam], + parameter_definitions: list[CacheParameterDefinition], + ) -> list[ToolUseBlockParam]: + """Replace actual values in trajectory with {{parameter_name}} syntax. + + This is the reverse of substitute_parameters - it takes identified values + and replaces them with parameter syntax for saving to cache. + + Args: + trajectory: The trajectory to templatize + parameter_definitions: List of CacheParameterDefinition objects with + name and value attributes + + Returns: + New trajectory with values replaced by parameters + """ + # Build replacement map: value -> parameter name + replacements = { + str(p.value): f"{{{{{p.name}}}}}" for p in parameter_definitions + } + + # Apply replacements to each tool block + parameterized_trajectory = [] + for tool_block in trajectory: + parameterized_input = CacheParameterHandler._replace_values_in_value( + tool_block.input, replacements + ) + + parameterized_trajectory.append( + ToolUseBlockParam( + id=tool_block.id, + name=tool_block.name, + input=parameterized_input, + type=tool_block.type, + cache_control=tool_block.cache_control, + ) + ) + + return parameterized_trajectory + + @staticmethod + def _apply_parameters_to_text( + text: str, parameter_definitions: list[CacheParameterDefinition] + ) -> str: + """Apply parameter replacement to a text string (e.g., goal). + + Args: + text: The text to parameterize + parameter_definitions: List of parameter definitions + + Returns: + Text with values replaced by {{parameter}} syntax + """ + # Build replacement map: value -> parameter syntax + replacements = { + str(p.value): f"{{{{{p.name}}}}}" for p in parameter_definitions + } + # Sort by length descending to replace longer matches first + result = text + for actual_value in sorted(replacements.keys(), key=len, reverse=True): + if actual_value in result: + result = result.replace(actual_value, replacements[actual_value]) + return result + + @staticmethod + def _replace_values_in_value(value: Any, replacements: dict[str, str]) -> Any: + """Recursively replace actual values with parameter syntax. + + Args: + value: Any value (str, dict, list, etc.) to process + replacements: Dict mapping actual values to parameter syntax + + Returns: + New value with replacements applied + """ + if isinstance(value, str): + # Replace exact matches and substring matches + result = value + # Sort by length descending to replace longer matches first + # This prevents partial replacements + for actual_value in sorted(replacements.keys(), key=len, reverse=True): + if actual_value in result: + result = result.replace(actual_value, replacements[actual_value]) + return result + if isinstance(value, dict): + # Recursively replace in dict values + return { + k: CacheParameterHandler._replace_values_in_value(v, replacements) + for k, v in value.items() + } + if isinstance(value, list): + # Recursively replace in list items + return [ + CacheParameterHandler._replace_values_in_value(item, replacements) + for item in value + ] + # For non-string types, check if the value matches exactly + str_value = str(value) + if str_value in replacements: + # Return the parameter as a string + return replacements[str_value] + return value + + # ======================================================================== + # EXECUTION PHASE: Parameter extraction, validation, and substitution + # ======================================================================== + + @staticmethod + def extract_parameters(trajectory: list[ToolUseBlockParam]) -> set[str]: + """Extract all parameter names from a trajectory. + + Scans all tool inputs for {{parameter_name}} patterns and returns + a set of unique parameter names. + + Args: + trajectory: List of tool use blocks to scan + + Returns: + Set of unique parameter names found in the trajectory + """ + parameters: set[str] = set() + + for step in trajectory: + # Recursively find parameters in the input object + parameters.update(CacheParameterHandler._extract_from_value(step.input)) + + return parameters + + @staticmethod + def _extract_from_value(value: Any) -> set[str]: + """Recursively extract parameters from a value. + + Args: + value: Any value (str, dict, list, etc.) to search for parameters + + Returns: + Set of parameter names found + """ + parameters: set[str] = set() + + if isinstance(value, str): + # Find all matches in the string + matches = re.finditer(CACHE_PARAMETER_PATTERN, value) + parameters.update(match.group(1) for match in matches) + elif isinstance(value, dict): + # Recursively search dict values + for v in value.values(): + parameters.update(CacheParameterHandler._extract_from_value(v)) + elif isinstance(value, list): + # Recursively search list items + for item in value: + parameters.update(CacheParameterHandler._extract_from_value(item)) + + return parameters + + @staticmethod + def validate_parameters( + trajectory: list[ToolUseBlockParam], provided_values: dict[str, str] + ) -> tuple[bool, list[str]]: + """Validate that all required parameters have values. + + Args: + trajectory: List of tool use blocks containing parameters + provided_values: Dict of parameter names to their values + + Returns: + Tuple of (is_valid, missing_parameters) + - is_valid: True if all parameters have values, False otherwise + - missing_parameters: List of parameter names that are missing values + """ + required_parameters = CacheParameterHandler.extract_parameters(trajectory) + missing = [name for name in required_parameters if name not in provided_values] + + return len(missing) == 0, missing + + @staticmethod + def substitute_parameters( + tool_block: ToolUseBlockParam, parameter_values: dict[str, str] + ) -> ToolUseBlockParam: + """Replace parameters in a tool block with actual values. + + Creates a new ToolUseBlockParam with all {{parameter}} occurrences + replaced with their corresponding values from parameter_values. + + Args: + tool_block: The tool use block containing parameters + parameter_values: Dict mapping parameter names to replacement values + + Returns: + New ToolUseBlockParam with parameters substituted + """ + # Deep copy the input and substitute parameters + substituted_input = CacheParameterHandler._substitute_in_value( + tool_block.input, parameter_values + ) + + # Create new ToolUseBlockParam with substituted values + return ToolUseBlockParam( + id=tool_block.id, + name=tool_block.name, + input=substituted_input, + type=tool_block.type, + cache_control=tool_block.cache_control, + ) + + @staticmethod + def _substitute_in_value(value: Any, parameter_values: dict[str, str]) -> Any: + """Recursively substitute parameters in a value. + + Args: + value: Any value (str, dict, list, etc.) containing parameters + parameter_values: Dict of parameter names to replacement values + + Returns: + New value with parameters substituted + """ + if isinstance(value, str): + # Replace all parameters in the string + result = value + for name, replacement in parameter_values.items(): + pattern = r"\{\{" + re.escape(name) + r"\}\}" + result = re.sub(pattern, replacement, result) + return result + if isinstance(value, dict): + # Recursively substitute in dict values + return { + k: CacheParameterHandler._substitute_in_value(v, parameter_values) + for k, v in value.items() + } + if isinstance(value, list): + # Recursively substitute in list items + return [ + CacheParameterHandler._substitute_in_value(item, parameter_values) + for item in value + ] + # Return other types as-is + return value diff --git a/src/askui/utils/caching/cache_writer.py b/src/askui/utils/caching/cache_writer.py index 2e7a1b53..673f27ed 100644 --- a/src/askui/utils/caching/cache_writer.py +++ b/src/askui/utils/caching/cache_writer.py @@ -3,9 +3,6 @@ from datetime import datetime, timezone from pathlib import Path -from askui.locators.serializers import VlmLocatorSerializer -from askui.models.anthropic.messages_api import AnthropicMessagesApi -from askui.models.model_router import create_api_client from askui.models.shared.agent_message_param import ( MessageParam, ToolUseBlockParam, @@ -18,8 +15,7 @@ CacheWriterSettings, ) from askui.models.shared.tools import ToolCollection -from askui.utils.placeholder_handler import PlaceholderHandler -from askui.utils.placeholder_identifier import identify_placeholders +from askui.utils.cache_parameter_handler import CacheParameterHandler logger = logging.getLogger(__name__) @@ -93,8 +89,8 @@ def generate(self) -> None: cache_file_path = self.cache_dir / self.file_name - goal_to_save, trajectory_to_save, placeholders_dict = ( - self._replace_placeholders() + goal_to_save, trajectory_to_save, parameters_dict = ( + self._parameterize_trajectory() ) if self._toolbox is not None: @@ -105,69 +101,20 @@ def generate(self) -> None: logger.info("No toolbox set, skipping non-cacheable tool input blanking") self._generate_cache_file( - goal_to_save, trajectory_to_save, placeholders_dict, cache_file_path + goal_to_save, trajectory_to_save, parameters_dict, cache_file_path ) self.reset() - def _replace_placeholders( + def _parameterize_trajectory( self, ) -> tuple[str | None, list[ToolUseBlockParam], dict[str, str]]: - # Determine which trajectory and placeholders to use - trajectory_to_save = self.messages - goal_to_save = self._goal - placeholders_dict: dict[str, str] = {} - - if ( - self._cache_writer_settings.placeholder_identification_strategy == "llm" - and self.messages - ): - # Get messages_api for placeholder identification - messages_api = AnthropicMessagesApi( - client=create_api_client( - self._cache_writer_settings.llm_placeholder_id_api_provider - ), - locator_serializer=VlmLocatorSerializer(), - ) - placeholders_dict, placeholder_definitions = identify_placeholders( - trajectory=self.messages, - messages_api=messages_api, - ) - n_placeholders = len(placeholder_definitions) - # Replace actual values with {{placeholder_name}} syntax in trajectory - if placeholder_definitions: - trajectory_to_save = ( - PlaceholderHandler.replace_values_with_placeholders( - trajectory=self.messages, - placeholder_definitions=placeholder_definitions, - ) - ) - - # Also apply placeholder replacement to the goal - if self._goal: - goal_to_save = self._goal - # Build replacement map: value -> placeholder syntax - replacements = { - str(p.value): f"{{{{{p.name}}}}}" - for p in placeholder_definitions - } - # Sort by length descending to replace longer matches first - for actual_value in sorted( - replacements.keys(), key=len, reverse=True - ): - if actual_value in goal_to_save: - goal_to_save = goal_to_save.replace( - actual_value, replacements[actual_value] - ) - else: - # Manual placeholder extraction - placeholder_names = PlaceholderHandler.extract_placeholders(self.messages) - placeholders_dict = { - name: f"Placeholder for {name}" # Generic description - for name in placeholder_names - } - n_placeholders = len(placeholder_names) - logger.info("Replaced %s placeholder values in trajectory", n_placeholders) - return goal_to_save, trajectory_to_save, placeholders_dict + """Identify parameters and return parameterized trajectory + goal.""" + return CacheParameterHandler.identify_and_parameterize( + trajectory=self.messages, + goal=self._goal, + identification_strategy=self._cache_writer_settings.parameter_identification_strategy, + api_provider=self._cache_writer_settings.llm_parameter_id_api_provider, + ) def _blank_non_cacheable_tool_inputs( self, trajectory: list[ToolUseBlockParam] @@ -223,7 +170,7 @@ def _generate_cache_file( self, goal_to_save: str | None, trajectory_to_save: list[ToolUseBlockParam], - placeholders_dict: dict[str, str], + parameters_dict: dict[str, str], cache_file_path: Path, ) -> None: cache_file = CacheFile( @@ -234,7 +181,7 @@ def _generate_cache_file( token_usage=self._accumulated_usage, ), trajectory=trajectory_to_save, - placeholders=placeholders_dict, + cache_parameters=parameters_dict, ) with cache_file_path.open("w", encoding="utf-8") as f: @@ -288,10 +235,10 @@ def read_cache_file(cache_file_path: Path) -> CacheFile: ), ), trajectory=trajectory, - placeholders={}, + cache_parameters={}, ) logger.info( - "Successfully loaded and migrated v0.0 cache: %s steps, 0 placeholders", + "Successfully loaded and migrated v0.0 cache: %s steps, 0 parameters", len(trajectory), ) return cache_file @@ -299,9 +246,9 @@ def read_cache_file(cache_file_path: Path) -> CacheFile: # v0.1 format: structured with metadata cache_file = CacheFile(**raw_data) logger.info( - "Successfully loaded v0.1 cache: %s steps, %s placeholders", + "Successfully loaded v0.1 cache: %s steps, %s parameters", len(cache_file.trajectory), - len(cache_file.placeholders), + len(cache_file.cache_parameters), ) if cache_file.metadata.goal: logger.debug("Cache goal: %s", cache_file.metadata.goal) diff --git a/src/askui/utils/placeholder_handler.py b/src/askui/utils/placeholder_handler.py deleted file mode 100644 index d491eeff..00000000 --- a/src/askui/utils/placeholder_handler.py +++ /dev/null @@ -1,292 +0,0 @@ -"""Placeholder handling for cache trajectories. - -This module provides utilities for detecting, validating, and substituting -placeholders in cached trajectories. Placeholders use the {{variable_name}} -syntax and allow dynamic values to be injected during cache execution. -""" - -import re -from typing import Any - -from askui.models.shared.agent_message_param import ToolUseBlockParam - -# Regex pattern for matching placeholders: {{variable_name}} -# Allows alphanumeric characters and underscores, must start with letter/underscore -PLACEHOLDER_PATTERN = r"\{\{([a-zA-Z_][a-zA-Z0-9_]*)\}\}" - - -class PlaceholderHandler: - """Handler for placeholder detection, validation, and substitution.""" - - @staticmethod - def extract_placeholders(trajectory: list[ToolUseBlockParam]) -> set[str]: - """Extract all placeholder names from a trajectory. - - Scans all tool inputs for {{placeholder_name}} patterns and returns - a set of unique placeholder names. - - Args: - trajectory: List of tool use blocks to scan - - Returns: - Set of unique placeholder names found in the trajectory - - Example: - >>> trajectory = [ - ... ToolUseBlockParam( - ... id="1", - ... name="computer", - ... input={"action": "type", "text": "Today is {{current_date}}"}, - ... type="tool_use" - ... ) - ... ] - >>> PlaceholderHandler.extract_placeholders(trajectory) - {'current_date'} - """ - placeholders: set[str] = set() - - for step in trajectory: - # Recursively find placeholders in the input object - placeholders.update(PlaceholderHandler._extract_from_value(step.input)) - - return placeholders - - @staticmethod - def _extract_from_value(value: Any) -> set[str]: - """Recursively extract placeholders from a value. - - Args: - value: Any value (str, dict, list, etc.) to search for placeholders - - Returns: - Set of placeholder names found - """ - placeholders: set[str] = set() - - if isinstance(value, str): - # Find all matches in the string - matches = re.finditer(PLACEHOLDER_PATTERN, value) - placeholders.update(match.group(1) for match in matches) - elif isinstance(value, dict): - # Recursively search dict values - for v in value.values(): - placeholders.update(PlaceholderHandler._extract_from_value(v)) - elif isinstance(value, list): - # Recursively search list items - for item in value: - placeholders.update(PlaceholderHandler._extract_from_value(item)) - - return placeholders - - @staticmethod - def validate_placeholders( - trajectory: list[ToolUseBlockParam], provided_values: dict[str, str] - ) -> tuple[bool, list[str]]: - """Validate that all required placeholders have values. - - Args: - trajectory: List of tool use blocks containing placeholders - provided_values: Dict of placeholder names to their values - - Returns: - Tuple of (is_valid, missing_placeholders) - - is_valid: True if all placeholders have values, False otherwise - - missing_placeholders: List of placeholder names that are missing values - - Example: - >>> trajectory = [...] # Contains {{current_date}} and {{user_name}} - >>> is_valid, missing = PlaceholderHandler.validate_placeholders( - ... trajectory, - ... {"current_date": "2025-12-11"} - ... ) - >>> is_valid - False - >>> missing - ['user_name'] - """ - required_placeholders = PlaceholderHandler.extract_placeholders(trajectory) - missing = [ - name for name in required_placeholders if name not in provided_values - ] - - return len(missing) == 0, missing - - @staticmethod - def replace_values_with_placeholders( - trajectory: list[ToolUseBlockParam], - placeholder_definitions: list[Any], # list[PlaceholderDefinition] - ) -> list[ToolUseBlockParam]: - """Replace actual values in trajectory with {{placeholder_name}} syntax. - - This is the reverse of substitute_placeholders - it takes identified values - and replaces them with placeholder syntax for saving to cache. - - Args: - trajectory: The trajectory to templatize - placeholder_definitions: List of PlaceholderDefinition objects with - name and value attributes - - Returns: - New trajectory with values replaced by placeholders - - Example: - >>> trajectory = [ - ... ToolUseBlockParam( - ... id="1", - ... name="computer", - ... input={"action": "type", "text": "Date: 2025-12-11"}, - ... type="tool_use" - ... ) - ... ] - >>> placeholders = [ - ... PlaceholderDefinition( - ... name="current_date", - ... value="2025-12-11", - ... description="Current date" - ... ) - ... ] - >>> result = PlaceholderHandler.replace_values_with_placeholders( - ... trajectory, placeholders - ... ) - >>> result[0].input["text"] - 'Date: {{current_date}}' - """ - # Build replacement map: value -> placeholder name - replacements = { - str(p.value): f"{{{{{p.name}}}}}" for p in placeholder_definitions - } - - # Apply replacements to each tool block - templated_trajectory = [] - for tool_block in trajectory: - templated_input = PlaceholderHandler._replace_values_in_value( - tool_block.input, replacements - ) - - templated_trajectory.append( - ToolUseBlockParam( - id=tool_block.id, - name=tool_block.name, - input=templated_input, - type=tool_block.type, - cache_control=tool_block.cache_control, - ) - ) - - return templated_trajectory - - @staticmethod - def _replace_values_in_value(value: Any, replacements: dict[str, str]) -> Any: - """Recursively replace actual values with placeholder syntax. - - Args: - value: Any value (str, dict, list, etc.) to process - replacements: Dict mapping actual values to placeholder syntax - - Returns: - New value with replacements applied - """ - if isinstance(value, str): - # Replace exact matches and substring matches - result = value - # Sort by length descending to replace longer matches first - # This prevents partial replacements - for actual_value in sorted(replacements.keys(), key=len, reverse=True): - if actual_value in result: - result = result.replace(actual_value, replacements[actual_value]) - return result - if isinstance(value, dict): - # Recursively replace in dict values - return { - k: PlaceholderHandler._replace_values_in_value(v, replacements) - for k, v in value.items() - } - if isinstance(value, list): - # Recursively replace in list items - return [ - PlaceholderHandler._replace_values_in_value(item, replacements) - for item in value - ] - # For non-string types, check if the value matches exactly - str_value = str(value) - if str_value in replacements: - # Return the placeholder as a string - return replacements[str_value] - return value - - @staticmethod - def substitute_placeholders( - tool_block: ToolUseBlockParam, placeholder_values: dict[str, str] - ) -> ToolUseBlockParam: - """Replace placeholders in a tool block with actual values. - - Creates a new ToolUseBlockParam with all {{placeholder}} occurrences - replaced with their corresponding values from placeholder_values. - - Args: - tool_block: The tool use block containing placeholders - placeholder_values: Dict mapping placeholder names to replacement values - - Returns: - New ToolUseBlockParam with placeholders substituted - - Example: - >>> tool_block = ToolUseBlockParam( - ... id="1", - ... name="computer", - ... input={"action": "type", "text": "Date: {{current_date}}"}, - ... type="tool_use" - ... ) - >>> result = PlaceholderHandler.substitute_placeholders( - ... tool_block, - ... {"current_date": "2025-12-11"} - ... ) - >>> result.input["text"] - 'Date: 2025-12-11' - """ - # Deep copy the input and substitute placeholders - substituted_input = PlaceholderHandler._substitute_in_value( - tool_block.input, placeholder_values - ) - - # Create new ToolUseBlockParam with substituted values - return ToolUseBlockParam( - id=tool_block.id, - name=tool_block.name, - input=substituted_input, - type=tool_block.type, - cache_control=tool_block.cache_control, - ) - - @staticmethod - def _substitute_in_value(value: Any, placeholder_values: dict[str, str]) -> Any: - """Recursively substitute placeholders in a value. - - Args: - value: Any value (str, dict, list, etc.) containing placeholders - placeholder_values: Dict of placeholder names to replacement values - - Returns: - New value with placeholders substituted - """ - if isinstance(value, str): - # Replace all placeholders in the string - result = value - for name, replacement in placeholder_values.items(): - pattern = r"\{\{" + re.escape(name) + r"\}\}" - result = re.sub(pattern, replacement, result) - return result - if isinstance(value, dict): - # Recursively substitute in dict values - return { - k: PlaceholderHandler._substitute_in_value(v, placeholder_values) - for k, v in value.items() - } - if isinstance(value, list): - # Recursively substitute in list items - return [ - PlaceholderHandler._substitute_in_value(item, placeholder_values) - for item in value - ] - # Return other types as-is - return value diff --git a/src/askui/utils/placeholder_identifier.py b/src/askui/utils/placeholder_identifier.py deleted file mode 100644 index a54ccee2..00000000 --- a/src/askui/utils/placeholder_identifier.py +++ /dev/null @@ -1,143 +0,0 @@ -"""Module for identifying placeholders in trajectories using LLM analysis.""" - -import json -import logging -from typing import Any - -from askui.models.shared.agent_message_param import MessageParam, ToolUseBlockParam -from askui.models.shared.messages_api import MessagesApi -from askui.prompts.caching import PLACEHOLDER_IDENTIFIER_SYSTEM_PROMPT - -logger = logging.getLogger(__name__) - - -class PlaceholderDefinition: - """Represents a placeholder identified in a trajectory.""" - - def __init__(self, name: str, value: Any, description: str) -> None: - self.name = name - self.value = value - self.description = description - - def __repr__(self) -> str: - return f"PlaceholderDefinition(name={self.name}, value={self.value})" - - -def identify_placeholders( - trajectory: list[ToolUseBlockParam], - messages_api: MessagesApi, - model: str = "claude-sonnet-4-5-20250929", -) -> tuple[dict[str, str], list[PlaceholderDefinition]]: - """Identify placeholders in a trajectory using LLM analysis. - - Args: - trajectory: The trajectory to analyze (list of tool use blocks) - messages_api: Messages API instance for LLM calls - model: Model to use for analysis - - Returns: - Tuple of: - - Dict mapping placeholder names to descriptions - - List of PlaceholderDefinition objects with name, value, and description - """ - if not trajectory: - logger.debug("Empty trajectory provided, skipping placeholder identification") - return {}, [] - - logger.info( - "Starting placeholder identification for trajectory with %s steps", - len(trajectory), - ) - - # Convert trajectory to serializable format for analysis - trajectory_data = [tool.model_dump(mode="json") for tool in trajectory] - logger.debug("Converted %s tool blocks to JSON format", len(trajectory_data)) - - user_message = ( - "Analyze this UI automation trajectory and identify all values that " - "should be placeholders:\n\n" - f"```json\n{json.dumps(trajectory_data, indent=2)}\n```\n\n" - "Return only the JSON object with identified placeholders. " - "Be thorough but conservative - only mark values that are clearly " - "dynamic or time-sensitive." - ) - - response_text = "" # Initialize for error logging - try: - # Make single API call - logger.debug("Calling LLM (%s) to analyze trajectory for placeholders", model) - response = messages_api.create_message( - messages=[MessageParam(role="user", content=user_message)], - model=model, - system=PLACEHOLDER_IDENTIFIER_SYSTEM_PROMPT, - max_tokens=4096, - temperature=0.0, # Deterministic for analysis - ) - logger.debug("Received response from LLM") - - # Extract text from response - if isinstance(response.content, list): - response_text = next( - (block.text for block in response.content if hasattr(block, "text")), - "", - ) - else: - response_text = str(response.content) - - # Parse the JSON response - logger.debug("Parsing LLM response to extract placeholder definitions") - # Handle markdown code blocks if present - if "```json" in response_text: - logger.debug("Removing JSON markdown code block wrapper from response") - response_text = response_text.split("```json")[1].split("```")[0].strip() - elif "```" in response_text: - logger.debug("Removing code block wrapper from response") - response_text = response_text.split("```")[1].split("```")[0].strip() - - placeholder_data = json.loads(response_text) - logger.debug( - "Successfully parsed JSON response with %s placeholders", - len(placeholder_data.get("placeholders", [])), - ) - - # Convert to our data structures - placeholder_definitions = [ - PlaceholderDefinition( - name=p["name"], value=p["value"], description=p["description"] - ) - for p in placeholder_data.get("placeholders", []) - ] - - placeholder_dict = {p.name: p.description for p in placeholder_definitions} - - if placeholder_definitions: - logger.info( - "Successfully identified %s placeholders in trajectory", - len(placeholder_definitions), - ) - for p in placeholder_definitions: - logger.debug(" - %s: %s (%s)", p.name, p.value, p.description) - else: - logger.info( - "No placeholders identified in trajectory " - "(this is normal for trajectories with only static values)" - ) - - except json.JSONDecodeError as e: - logger.warning( - "Failed to parse LLM response as JSON: %s. " - "Falling back to empty placeholder list.", - e, - extra={"response_text": response_text[:500]}, # Log first 500 chars - ) - return {}, [] - except Exception as e: # noqa: BLE001 - logger.warning( - "Failed to identify placeholders with LLM: %s. " - "Falling back to empty placeholder list.", - e, - exc_info=True, - ) - return {}, [] - else: - return placeholder_dict, placeholder_definitions diff --git a/src/askui/utils/trajectory_executor.py b/src/askui/utils/trajectory_executor.py index d953ff09..cdbb87ef 100644 --- a/src/askui/utils/trajectory_executor.py +++ b/src/askui/utils/trajectory_executor.py @@ -17,7 +17,7 @@ ToolUseBlockParam, ) from askui.models.shared.tools import ToolCollection -from askui.utils.placeholder_handler import PlaceholderHandler +from askui.utils.cache_parameter_handler import CacheParameterHandler logger = logging.getLogger(__name__) @@ -47,7 +47,7 @@ class ExecutionResult(BaseModel): class TrajectoryExecutor: """Executes cached trajectories step-by-step with control flow. - Supports pausing at non-cacheable steps, placeholder substitution, + Supports pausing at non-cacheable steps, cache_parameter substitution, and collecting execution results for the agent to review. """ @@ -55,7 +55,7 @@ def __init__( self, trajectory: list[ToolUseBlockParam], toolbox: ToolCollection, - placeholder_values: dict[str, str] | None = None, + parameter_values: dict[str, str] | None = None, delay_time: float = 0.5, visual_validation_enabled: bool = False, ): @@ -64,13 +64,13 @@ def __init__( Args: trajectory: List of tool use blocks to execute toolbox: ToolCollection for executing tools - placeholder_values: Dict of placeholder names to values + parameter_values: Dict of parameter names to values delay_time: Seconds to wait between step executions visual_validation_enabled: Enable visual validation (future feature) """ self.trajectory = trajectory self.toolbox = toolbox - self.placeholder_values = placeholder_values or {} + self.parameter_values = parameter_values or {} self.delay_time = delay_time self.visual_validation_enabled = visual_validation_enabled self.current_step_index = 0 @@ -86,7 +86,7 @@ def execute_next_step(self) -> ExecutionResult: 1. Check if there are more steps to execute 2. Check if the step should be skipped (screenshots, retrieval tools) 3. Check if the step is non-cacheable (needs agent) - 4. Substitute placeholders + 4. Substitute parameters 5. Execute the tool and build messages with proper data types 6. Return result with updated message history @@ -146,9 +146,9 @@ def execute_next_step(self) -> ExecutionResult: message_history=self.message_history.copy(), ) - # Substitute placeholders - substituted_step = PlaceholderHandler.substitute_placeholders( - step, self.placeholder_values + # Substitute parameters + substituted_step = CacheParameterHandler.substitute_parameters( + step, self.parameter_values ) # Execute the tool diff --git a/tests/unit/tools/test_caching_tools.py b/tests/unit/tools/test_caching_tools.py index 605eabed..253a1854 100644 --- a/tests/unit/tools/test_caching_tools.py +++ b/tests/unit/tools/test_caching_tools.py @@ -39,7 +39,7 @@ def test_retrieve_cached_test_executions_lists_json_files() -> None: "invalidation_reason": None, }, "trajectory": [], - "placeholders": {}, + "cache_parameters": {}, } (cache_dir / "cache1.json").write_text(json.dumps(cache_data), encoding="utf-8") (cache_dir / "cache2.json").write_text(json.dumps(cache_data), encoding="utf-8") @@ -90,7 +90,7 @@ def test_retrieve_cached_test_executions_respects_custom_format() -> None: "invalidation_reason": None, }, "trajectory": [], - "placeholders": {}, + "cache_parameters": {}, } (cache_dir / "cache1.json").write_text(json.dumps(cache_data), encoding="utf-8") (cache_dir / "cache2.traj").write_text(json.dumps(cache_data), encoding="utf-8") @@ -129,7 +129,7 @@ def test_retrieve_caches_filters_invalid_by_default(tmp_path: Path) -> None: "invalidation_reason": None, }, "trajectory": [], - "placeholders": {}, + "cache_parameters": {}, } with valid_cache.open("w") as f: json.dump(valid_data, f) @@ -146,7 +146,7 @@ def test_retrieve_caches_filters_invalid_by_default(tmp_path: Path) -> None: "invalidation_reason": "Too many failures", }, "trajectory": [], - "placeholders": {}, + "cache_parameters": {}, } with invalid_cache.open("w") as f: json.dump(invalid_data, f) @@ -177,7 +177,7 @@ def test_retrieve_caches_includes_invalid_when_requested(tmp_path: Path) -> None "invalidation_reason": None, }, "trajectory": [], - "placeholders": {}, + "cache_parameters": {}, } with valid_cache.open("w") as f: json.dump(valid_data, f) @@ -194,7 +194,7 @@ def test_retrieve_caches_includes_invalid_when_requested(tmp_path: Path) -> None "invalidation_reason": "Too many failures", }, "trajectory": [], - "placeholders": {}, + "cache_parameters": {}, } with invalid_cache.open("w") as f: json.dump(invalid_data, f) @@ -235,7 +235,7 @@ def test_execute_cached_execution_raises_error_without_cache_manager() -> None: "invalidation_reason": None, }, "trajectory": [], - "placeholders": {}, + "cache_parameters": {}, } cache_file.write_text(json.dumps(cache_data), encoding="utf-8") @@ -290,7 +290,7 @@ def test_execute_cached_execution_activates_cache_mode() -> None: "type": "tool_use", }, ], - "placeholders": {}, + "cache_parameters": {}, } with cache_file.open("w", encoding="utf-8") as f: @@ -343,7 +343,7 @@ def test_execute_cached_execution_works_with_toolbox() -> None: "type": "tool_use", } ], - "placeholders": {}, + "cache_parameters": {}, } with cache_file.open("w", encoding="utf-8") as f: @@ -398,12 +398,12 @@ def test_execute_cached_execution_initializes_with_custom_settings() -> None: assert tool._settings.delay_time_between_action == 1.0 # noqa: SLF001 -def test_execute_cached_execution_with_placeholders() -> None: - """Test that ExecuteCachedTrajectory validates placeholders.""" +def test_execute_cached_execution_with_parameters() -> None: + """Test that ExecuteCachedTrajectory validates parameters.""" with tempfile.TemporaryDirectory() as temp_dir: cache_file = Path(temp_dir) / "test_trajectory.json" - # Create a v0.1 cache file with placeholders + # Create a v0.1 cache file with parameters cache_data = { "metadata": { "version": "0.1", @@ -422,7 +422,7 @@ def test_execute_cached_execution_with_placeholders() -> None: "type": "tool_use", }, ], - "placeholders": { + "cache_parameters": { "current_date": "Current date", }, } @@ -440,21 +440,21 @@ def test_execute_cached_execution_with_placeholders() -> None: result = tool( trajectory_file=str(cache_file), - placeholder_values={"current_date": "2025-12-11"}, + parameter_values={"current_date": "2025-12-11"}, ) # Verify success assert isinstance(result, str) assert "✓ Cache execution mode activated" in result - assert "1 placeholder value" in result + assert "1 parameter value" in result -def test_execute_cached_execution_missing_placeholders() -> None: - """Test that ExecuteCachedTrajectory returns error for missing placeholders.""" +def test_execute_cached_execution_missing_parameters() -> None: + """Test that ExecuteCachedTrajectory returns error for missing parameters.""" with tempfile.TemporaryDirectory() as temp_dir: cache_file = Path(temp_dir) / "test_trajectory.json" - # Create a v0.1 cache file with placeholders + # Create a v0.1 cache file with parameters cache_data = { "metadata": { "version": "0.1", @@ -473,7 +473,7 @@ def test_execute_cached_execution_missing_placeholders() -> None: "type": "tool_use", } ], - "placeholders": { + "cache_parameters": { "current_date": "Current date", "user_name": "User name", }, @@ -493,17 +493,17 @@ def test_execute_cached_execution_missing_placeholders() -> None: # Verify error message assert isinstance(result, str) - assert "Missing required placeholder values" in result + assert "Missing required parameter values" in result assert "current_date" in result assert "user_name" in result -def test_execute_cached_execution_no_placeholders_backward_compat() -> None: - """Test backward compatibility: trajectories without placeholders work fine.""" +def test_execute_cached_execution_no_parameters_backward_compat() -> None: + """Test backward compatibility: trajectories without parameters work fine.""" with tempfile.TemporaryDirectory() as temp_dir: cache_file = Path(temp_dir) / "test_trajectory.json" - # Create a v0.0 cache file (old format, no placeholders) + # Create a v0.0 cache file (old format, no parameters) trajectory: list[dict[str, Any]] = [ { "id": "tool1", @@ -554,7 +554,7 @@ def test_continue_cached_trajectory_from_middle() -> None: {"id": "4", "name": "tool4", "input": {}, "type": "tool_use"}, {"id": "5", "name": "tool5", "input": {}, "type": "tool_use"}, ], - "placeholders": {}, + "cache_parameters": {}, } with cache_file.open("w", encoding="utf-8") as f: @@ -595,7 +595,7 @@ def test_continue_cached_trajectory_invalid_step_index_negative() -> None: "trajectory": [ {"id": "1", "name": "tool1", "input": {}, "type": "tool_use"}, ], - "placeholders": {}, + "cache_parameters": {}, } with cache_file.open("w", encoding="utf-8") as f: @@ -634,7 +634,7 @@ def test_continue_cached_trajectory_invalid_step_index_too_large() -> None: {"id": "1", "name": "tool1", "input": {}, "type": "tool_use"}, {"id": "2", "name": "tool2", "input": {}, "type": "tool_use"}, ], - "placeholders": {}, + "cache_parameters": {}, } with cache_file.open("w", encoding="utf-8") as f: @@ -655,12 +655,12 @@ def test_continue_cached_trajectory_invalid_step_index_too_large() -> None: assert "valid indices: 0-1" in result -def test_continue_cached_trajectory_with_placeholders() -> None: - """Test continuing execution with placeholder substitution.""" +def test_continue_cached_trajectory_with_parameters() -> None: + """Test continuing execution with parameter substitution.""" with tempfile.TemporaryDirectory() as temp_dir: cache_file = Path(temp_dir) / "test_trajectory.json" - # Create a v0.1 cache file with placeholders + # Create a v0.1 cache file with parameters cache_data = { "metadata": { "version": "0.1", @@ -691,7 +691,7 @@ def test_continue_cached_trajectory_with_placeholders() -> None: "type": "tool_use", }, ], - "placeholders": { + "cache_parameters": { "current_date": "Current date", "user_name": "User name", }, @@ -711,7 +711,7 @@ def test_continue_cached_trajectory_with_placeholders() -> None: result = tool( trajectory_file=str(cache_file), start_from_step_index=1, - placeholder_values={"current_date": "2025-12-11", "user_name": "Alice"}, + parameter_values={"current_date": "2025-12-11", "user_name": "Alice"}, ) # Verify success @@ -742,7 +742,7 @@ def test_execute_cached_trajectory_warns_if_invalid( "trajectory": [ {"id": "1", "name": "click", "input": {"x": 100}, "type": "tool_use"}, ], - "placeholders": {}, + "cache_parameters": {}, } with cache_file.open("w") as f: json.dump(cache_data, f) @@ -790,7 +790,7 @@ def test_inspect_cache_metadata_shows_basic_info(tmp_path: Path) -> None: {"id": "1", "name": "click", "input": {"x": 100}, "type": "tool_use"}, {"id": "2", "name": "type", "input": {"text": "test"}, "type": "tool_use"}, ], - "placeholders": {"current_date": "{{current_date}}"}, + "cache_parameters": {"current_date": "{{current_date}}"}, } with cache_file.open("w") as f: json.dump(cache_data, f) @@ -804,7 +804,7 @@ def test_inspect_cache_metadata_shows_basic_info(tmp_path: Path) -> None: assert "Total Execution Attempts: 5" in result assert "Is Valid: True" in result assert "Total Steps: 2" in result - assert "Placeholders: 1" in result + assert "Parameters: 1" in result assert "current_date" in result @@ -839,7 +839,7 @@ def test_inspect_cache_metadata_shows_failures(tmp_path: Path) -> None: "trajectory": [ {"id": "1", "name": "click", "input": {"x": 100}, "type": "tool_use"}, ], - "placeholders": {}, + "cache_parameters": {}, } with cache_file.open("w") as f: json.dump(cache_data, f) @@ -897,7 +897,7 @@ def test_revalidate_cache_marks_invalid_as_valid(tmp_path: Path) -> None: "trajectory": [ {"id": "1", "name": "click", "input": {"x": 100}, "type": "tool_use"}, ], - "placeholders": {}, + "cache_parameters": {}, } with cache_file.open("w") as f: json.dump(cache_data, f) @@ -938,7 +938,7 @@ def test_revalidate_cache_already_valid(tmp_path: Path) -> None: "trajectory": [ {"id": "1", "name": "click", "input": {"x": 100}, "type": "tool_use"}, ], - "placeholders": {}, + "cache_parameters": {}, } with cache_file.open("w") as f: json.dump(cache_data, f) @@ -984,7 +984,7 @@ def test_invalidate_cache_marks_valid_as_invalid(tmp_path: Path) -> None: "trajectory": [ {"id": "1", "name": "click", "input": {"x": 100}, "type": "tool_use"}, ], - "placeholders": {}, + "cache_parameters": {}, } with cache_file.open("w") as f: json.dump(cache_data, f) @@ -1027,7 +1027,7 @@ def test_invalidate_cache_updates_reason_if_already_invalid(tmp_path: Path) -> N "trajectory": [ {"id": "1", "name": "click", "input": {"x": 100}, "type": "tool_use"}, ], - "placeholders": {}, + "cache_parameters": {}, } with cache_file.open("w") as f: json.dump(cache_data, f) diff --git a/tests/unit/utils/test_cache_manager.py b/tests/unit/utils/test_cache_manager.py index f7c41ed9..9570a04d 100644 --- a/tests/unit/utils/test_cache_manager.py +++ b/tests/unit/utils/test_cache_manager.py @@ -31,7 +31,7 @@ def sample_cache_file() -> CacheFile: id="2", name="type", input={"text": "test"}, type="tool_use" ), ], - placeholders={}, + cache_parameters={}, ) diff --git a/tests/unit/utils/test_placeholder_handler.py b/tests/unit/utils/test_cache_parameter_handler.py similarity index 58% rename from tests/unit/utils/test_placeholder_handler.py rename to tests/unit/utils/test_cache_parameter_handler.py index d022f5bc..e03e2ec6 100644 --- a/tests/unit/utils/test_placeholder_handler.py +++ b/tests/unit/utils/test_cache_parameter_handler.py @@ -1,16 +1,19 @@ -"""Unit tests for PlaceholderHandler.""" +"""Unit tests for CacheParameterHandler.""" import pytest from askui.models.shared.agent_message_param import ToolUseBlockParam -from askui.utils.placeholder_handler import PLACEHOLDER_PATTERN, PlaceholderHandler +from askui.utils.cache_parameter_handler import ( + CACHE_PARAMETER_PATTERN, + CacheParameterHandler, +) -def test_placeholder_pattern_matches_valid_placeholders() -> None: - """Test that the regex pattern matches valid placeholder syntax.""" +def test_parameter_pattern_matches_valid_parameters() -> None: + """Test that the regex pattern matches valid parameter syntax.""" import re - valid_placeholders = [ + valid_parameters = [ "{{variable}}", "{{current_date}}", "{{user_name}}", @@ -18,16 +21,16 @@ def test_placeholder_pattern_matches_valid_placeholders() -> None: "{{VAR123}}", ] - for placeholder in valid_placeholders: - match = re.search(PLACEHOLDER_PATTERN, placeholder) - assert match is not None, f"Should match valid placeholder: {placeholder}" + for parameter in valid_parameters: + match = re.search(CACHE_PARAMETER_PATTERN, parameter) + assert match is not None, f"Should match valid parameter: {parameter}" -def test_placeholder_pattern_does_not_match_invalid() -> None: - """Test that the regex pattern rejects invalid placeholder syntax.""" +def test_parameter_pattern_does_not_match_invalid() -> None: + """Test that the regex pattern rejects invalid parameter syntax.""" import re - invalid_placeholders = [ + invalid_parameters = [ "{{123invalid}}", # Starts with number "{{var-name}}", # Contains hyphen "{{var name}}", # Contains space @@ -35,14 +38,14 @@ def test_placeholder_pattern_does_not_match_invalid() -> None: "{{}}", # Empty ] - for placeholder in invalid_placeholders: - match = re.search(PLACEHOLDER_PATTERN, placeholder) - if match and match.group(0) == placeholder: - pytest.fail(f"Should not match invalid placeholder: {placeholder}") + for parameter in invalid_parameters: + match = re.search(CACHE_PARAMETER_PATTERN, parameter) + if match and match.group(0) == parameter: + pytest.fail(f"Should not match invalid parameter: {parameter}") -def test_extract_placeholders_from_simple_string() -> None: - """Test extracting placeholders from a simple string input.""" +def test_extract_parameters_from_simple_string() -> None: + """Test extracting parameters from a simple string input.""" trajectory = [ ToolUseBlockParam( id="1", @@ -52,12 +55,12 @@ def test_extract_placeholders_from_simple_string() -> None: ) ] - placeholders = PlaceholderHandler.extract_placeholders(trajectory) - assert placeholders == {"current_date"} + parameters = CacheParameterHandler.extract_parameters(trajectory) + assert parameters == {"current_date"} -def test_extract_placeholders_multiple_in_one_string() -> None: - """Test extracting multiple placeholders from one string.""" +def test_extract_parameters_multiple_in_one_string() -> None: + """Test extracting multiple parameters from one string.""" trajectory = [ ToolUseBlockParam( id="1", @@ -70,12 +73,12 @@ def test_extract_placeholders_multiple_in_one_string() -> None: ) ] - placeholders = PlaceholderHandler.extract_placeholders(trajectory) - assert placeholders == {"user_name", "current_date"} + parameters = CacheParameterHandler.extract_parameters(trajectory) + assert parameters == {"user_name", "current_date"} -def test_extract_placeholders_from_nested_dict() -> None: - """Test extracting placeholders from nested dictionary structures.""" +def test_extract_parameters_from_nested_dict() -> None: + """Test extracting parameters from nested dictionary structures.""" trajectory = [ ToolUseBlockParam( id="1", @@ -88,12 +91,12 @@ def test_extract_placeholders_from_nested_dict() -> None: ) ] - placeholders = PlaceholderHandler.extract_placeholders(trajectory) - assert placeholders == {"nested_var", "another_var"} + parameters = CacheParameterHandler.extract_parameters(trajectory) + assert parameters == {"nested_var", "another_var"} -def test_extract_placeholders_from_list() -> None: - """Test extracting placeholders from lists in input.""" +def test_extract_parameters_from_list() -> None: + """Test extracting parameters from lists in input.""" trajectory = [ ToolUseBlockParam( id="1", @@ -109,12 +112,12 @@ def test_extract_placeholders_from_list() -> None: ) ] - placeholders = PlaceholderHandler.extract_placeholders(trajectory) - assert placeholders == {"item1", "item2", "item3"} + parameters = CacheParameterHandler.extract_parameters(trajectory) + assert parameters == {"item1", "item2", "item3"} -def test_extract_placeholders_no_placeholders() -> None: - """Test that extracting from trajectory without placeholders returns empty set.""" +def test_extract_parameters_no_parameters() -> None: + """Test that extracting from trajectory without parameters returns empty set.""" trajectory = [ ToolUseBlockParam( id="1", @@ -124,12 +127,12 @@ def test_extract_placeholders_no_placeholders() -> None: ) ] - placeholders = PlaceholderHandler.extract_placeholders(trajectory) - assert placeholders == set() + parameters = CacheParameterHandler.extract_parameters(trajectory) + assert parameters == set() -def test_extract_placeholders_from_multiple_steps() -> None: - """Test extracting placeholders from multiple trajectory steps.""" +def test_extract_parameters_from_multiple_steps() -> None: + """Test extracting parameters from multiple trajectory steps.""" trajectory = [ ToolUseBlockParam( id="1", @@ -151,12 +154,12 @@ def test_extract_placeholders_from_multiple_steps() -> None: ), ] - placeholders = PlaceholderHandler.extract_placeholders(trajectory) - assert placeholders == {"var1", "var2"} # No duplicates + parameters = CacheParameterHandler.extract_parameters(trajectory) + assert parameters == {"var1", "var2"} # No duplicates -def test_validate_placeholders_all_provided() -> None: - """Test validation passes when all placeholders have values.""" +def test_validate_parameters_all_provided() -> None: + """Test validation passes when all parameters have values.""" trajectory = [ ToolUseBlockParam( id="1", @@ -166,7 +169,7 @@ def test_validate_placeholders_all_provided() -> None: ) ] - is_valid, missing = PlaceholderHandler.validate_placeholders( + is_valid, missing = CacheParameterHandler.validate_parameters( trajectory, {"var1": "value1", "var2": "value2"} ) @@ -174,8 +177,8 @@ def test_validate_placeholders_all_provided() -> None: assert missing == [] -def test_validate_placeholders_missing_some() -> None: - """Test validation fails when some placeholders are missing.""" +def test_validate_parameters_missing_some() -> None: + """Test validation fails when some parameters are missing.""" trajectory = [ ToolUseBlockParam( id="1", @@ -185,7 +188,7 @@ def test_validate_placeholders_missing_some() -> None: ) ] - is_valid, missing = PlaceholderHandler.validate_placeholders( + is_valid, missing = CacheParameterHandler.validate_parameters( trajectory, {"var1": "value1"} ) @@ -193,7 +196,7 @@ def test_validate_placeholders_missing_some() -> None: assert set(missing) == {"var2", "var3"} -def test_validate_placeholders_extra_values_ok() -> None: +def test_validate_parameters_extra_values_ok() -> None: """Test validation passes when extra values are provided (they're ignored).""" trajectory = [ ToolUseBlockParam( @@ -204,7 +207,7 @@ def test_validate_placeholders_extra_values_ok() -> None: ) ] - is_valid, missing = PlaceholderHandler.validate_placeholders( + is_valid, missing = CacheParameterHandler.validate_parameters( trajectory, {"var1": "value1", "extra_var": "extra_value"} ) @@ -212,25 +215,25 @@ def test_validate_placeholders_extra_values_ok() -> None: assert missing == [] -def test_validate_placeholders_no_placeholders() -> None: - """Test validation passes when trajectory has no placeholders.""" +def test_validate_parameters_no_parameters() -> None: + """Test validation passes when trajectory has no parameters.""" trajectory = [ ToolUseBlockParam( id="1", name="tool", - input={"text": "No placeholders here"}, + input={"text": "No parameters here"}, type="tool_use", ) ] - is_valid, missing = PlaceholderHandler.validate_placeholders(trajectory, {}) + is_valid, missing = CacheParameterHandler.validate_parameters(trajectory, {}) assert is_valid is True assert missing == [] -def test_substitute_placeholders_simple_string() -> None: - """Test substituting placeholders in a simple string.""" +def test_substitute_parameters_simple_string() -> None: + """Test substituting parameters in a simple string.""" tool_block = ToolUseBlockParam( id="1", name="computer", @@ -238,7 +241,7 @@ def test_substitute_placeholders_simple_string() -> None: type="tool_use", ) - result = PlaceholderHandler.substitute_placeholders( + result = CacheParameterHandler.substitute_parameters( tool_block, {"current_date": "2025-12-11"} ) @@ -247,8 +250,8 @@ def test_substitute_placeholders_simple_string() -> None: assert result.name == tool_block.name -def test_substitute_placeholders_multiple() -> None: - """Test substituting multiple placeholders in one string.""" +def test_substitute_parameters_multiple() -> None: + """Test substituting multiple parameters in one string.""" tool_block = ToolUseBlockParam( id="1", name="computer", @@ -259,15 +262,15 @@ def test_substitute_placeholders_multiple() -> None: type="tool_use", ) - result = PlaceholderHandler.substitute_placeholders( + result = CacheParameterHandler.substitute_parameters( tool_block, {"user_name": "Alice", "current_date": "2025-12-11"} ) assert result.input["text"] == "Hello Alice, date is 2025-12-11" # type: ignore[index] -def test_substitute_placeholders_nested_dict() -> None: - """Test substituting placeholders in nested dictionaries.""" +def test_substitute_parameters_nested_dict() -> None: + """Test substituting parameters in nested dictionaries.""" tool_block = ToolUseBlockParam( id="1", name="tool", @@ -278,7 +281,7 @@ def test_substitute_placeholders_nested_dict() -> None: type="tool_use", ) - result = PlaceholderHandler.substitute_placeholders( + result = CacheParameterHandler.substitute_parameters( tool_block, {"var1": "value1", "var2": "value2"} ) @@ -286,8 +289,8 @@ def test_substitute_placeholders_nested_dict() -> None: assert result.input["another"] == "value2" # type: ignore[index] -def test_substitute_placeholders_in_list() -> None: - """Test substituting placeholders in lists.""" +def test_substitute_parameters_in_list() -> None: + """Test substituting parameters in lists.""" tool_block = ToolUseBlockParam( id="1", name="tool", @@ -295,7 +298,7 @@ def test_substitute_placeholders_in_list() -> None: type="tool_use", ) - result = PlaceholderHandler.substitute_placeholders( + result = CacheParameterHandler.substitute_parameters( tool_block, {"item1": "value1", "item2": "value2"} ) @@ -304,8 +307,8 @@ def test_substitute_placeholders_in_list() -> None: assert result.input["items"][2]["nested"] == "value2" # type: ignore[index] -def test_substitute_placeholders_no_change_if_no_placeholders() -> None: - """Test that substitution doesn't change input without placeholders.""" +def test_substitute_parameters_no_change_if_no_parameters() -> None: + """Test that substitution doesn't change input without parameters.""" tool_block = ToolUseBlockParam( id="1", name="computer", @@ -313,13 +316,13 @@ def test_substitute_placeholders_no_change_if_no_placeholders() -> None: type="tool_use", ) - result = PlaceholderHandler.substitute_placeholders(tool_block, {}) + result = CacheParameterHandler.substitute_parameters(tool_block, {}) assert result.input == tool_block.input -def test_substitute_placeholders_partial_substitution() -> None: - """Test that only provided placeholders are substituted.""" +def test_substitute_parameters_partial_substitution() -> None: + """Test that only provided parameters are substituted.""" tool_block = ToolUseBlockParam( id="1", name="tool", @@ -327,12 +330,12 @@ def test_substitute_placeholders_partial_substitution() -> None: type="tool_use", ) - result = PlaceholderHandler.substitute_placeholders(tool_block, {"var1": "value1"}) + result = CacheParameterHandler.substitute_parameters(tool_block, {"var1": "value1"}) assert result.input["text"] == "value1 and {{var2}}" # type: ignore[index] -def test_substitute_placeholders_preserves_original() -> None: +def test_substitute_parameters_preserves_original() -> None: """Test that substitution creates a new object, doesn't modify original.""" tool_block = ToolUseBlockParam( id="1", @@ -342,13 +345,13 @@ def test_substitute_placeholders_preserves_original() -> None: ) original_input = tool_block.input.copy() # type: ignore[attr-defined] - PlaceholderHandler.substitute_placeholders(tool_block, {"var1": "value1"}) + CacheParameterHandler.substitute_parameters(tool_block, {"var1": "value1"}) # Original should be unchanged assert tool_block.input == original_input -def test_substitute_placeholders_with_special_characters() -> None: +def test_substitute_parameters_with_special_characters() -> None: """Test substitution with values containing special regex characters.""" tool_block = ToolUseBlockParam( id="1", @@ -358,15 +361,15 @@ def test_substitute_placeholders_with_special_characters() -> None: ) # Value contains regex special characters - result = PlaceholderHandler.substitute_placeholders( + result = CacheParameterHandler.substitute_parameters( tool_block, {"pattern": r".*[test]$"} ) assert result.input["text"] == r"Pattern: .*[test]$" # type: ignore[index] -def test_substitute_placeholders_same_placeholder_multiple_times() -> None: - """Test substituting the same placeholder appearing multiple times.""" +def test_substitute_parameters_same_parameter_multiple_times() -> None: + """Test substituting the same parameter appearing multiple times.""" tool_block = ToolUseBlockParam( id="1", name="tool", @@ -374,6 +377,6 @@ def test_substitute_placeholders_same_placeholder_multiple_times() -> None: type="tool_use", ) - result = PlaceholderHandler.substitute_placeholders(tool_block, {"var": "value"}) + result = CacheParameterHandler.substitute_parameters(tool_block, {"var": "value"}) assert result.input["text"] == "value is value is value" # type: ignore[index] diff --git a/tests/unit/utils/test_cache_validator.py b/tests/unit/utils/test_cache_validator.py index c7af8b12..926ec251 100644 --- a/tests/unit/utils/test_cache_validator.py +++ b/tests/unit/utils/test_cache_validator.py @@ -31,7 +31,7 @@ def sample_cache_file() -> CacheFile: id="2", name="type", input={"text": "test"}, type="tool_use" ), ], - placeholders={}, + cache_parameters={}, ) @@ -325,7 +325,7 @@ def test_composite_validator_empty() -> None: is_valid=True, ), trajectory=[], - placeholders={}, + cache_parameters={}, ) should_inv, reason = validator.should_invalidate(cache_file) diff --git a/tests/unit/utils/test_cache_writer.py b/tests/unit/utils/test_cache_writer.py index a6e1238f..c0f04cb2 100644 --- a/tests/unit/utils/test_cache_writer.py +++ b/tests/unit/utils/test_cache_writer.py @@ -146,7 +146,7 @@ def test_cache_writer_generate_writes_file() -> None: cache_dir=str(cache_dir), file_name="output.json", cache_writer_settings=CacheWriterSettings( - placeholder_identification_strategy="preset" + parameter_identification_strategy="preset" ), ) @@ -178,7 +178,7 @@ def test_cache_writer_generate_writes_file() -> None: # Check v0.1 structure assert "metadata" in data assert "trajectory" in data - assert "placeholders" in data + assert "cache_parameters" in data # Check metadata assert data["metadata"]["version"] == "0.1" @@ -202,7 +202,7 @@ def test_cache_writer_generate_auto_names_file() -> None: cache_dir=str(cache_dir), file_name="", cache_writer_settings=CacheWriterSettings( - placeholder_identification_strategy="preset" + parameter_identification_strategy="preset" ), ) @@ -337,7 +337,7 @@ def test_cache_writer_read_cache_file_v2() -> None: "type": "tool_use", }, ], - "placeholders": {"current_date": "Current date in YYYY-MM-DD format"}, + "cache_parameters": {"current_date": "Current date in YYYY-MM-DD format"}, } with cache_file_path.open("w", encoding="utf-8") as f: @@ -353,7 +353,7 @@ def test_cache_writer_read_cache_file_v2() -> None: assert len(result.trajectory) == 2 assert result.trajectory[0].id == "id1" assert result.trajectory[1].id == "id2" - assert "current_date" in result.placeholders + assert "current_date" in result.cache_parameters def test_cache_writer_set_file_name() -> None: @@ -376,7 +376,7 @@ def test_cache_writer_generate_resets_after_writing() -> None: cache_dir=str(cache_dir), file_name="test.json", cache_writer_settings=CacheWriterSettings( - placeholder_identification_strategy="preset" + parameter_identification_strategy="preset" ), ) @@ -395,19 +395,19 @@ def test_cache_writer_generate_resets_after_writing() -> None: assert cache_writer.messages == [] -def test_cache_writer_detects_and_stores_placeholders() -> None: - """Test that CacheWriter detects placeholders and stores them in metadata.""" +def test_cache_writer_detects_and_stores_parameters() -> None: + """Test that CacheWriter detects parameters and stores them in metadata.""" with tempfile.TemporaryDirectory() as temp_dir: cache_dir = Path(temp_dir) cache_writer = CacheWriter( cache_dir=str(cache_dir), file_name="test.json", cache_writer_settings=CacheWriterSettings( - placeholder_identification_strategy="preset" + parameter_identification_strategy="preset" ), ) - # Add tool use blocks with placeholders + # Add tool use blocks with parameters cache_writer.messages = [ ToolUseBlockParam( id="id1", @@ -430,26 +430,26 @@ def test_cache_writer_detects_and_stores_placeholders() -> None: with cache_file.open("r", encoding="utf-8") as f: data = json.load(f) - # Verify placeholders were detected and stored - assert "placeholders" in data - assert "current_date" in data["placeholders"] - assert "user_name" in data["placeholders"] - assert len(data["placeholders"]) == 2 + # Verify parameters were detected and stored + assert "cache_parameters" in data + assert "current_date" in data["cache_parameters"] + assert "user_name" in data["cache_parameters"] + assert len(data["cache_parameters"]) == 2 -def test_cache_writer_empty_placeholders_when_none_found() -> None: - """Test that placeholders dict is empty when no placeholders exist.""" +def test_cache_writer_empty_parameters_when_none_found() -> None: + """Test that parameters dict is empty when no parameters exist.""" with tempfile.TemporaryDirectory() as temp_dir: cache_dir = Path(temp_dir) cache_writer = CacheWriter( cache_dir=str(cache_dir), file_name="test.json", cache_writer_settings=CacheWriterSettings( - placeholder_identification_strategy="preset" + parameter_identification_strategy="preset" ), ) - # Add tool use blocks without placeholders + # Add tool use blocks without parameters cache_writer.messages = [ ToolUseBlockParam( id="id1", @@ -466,6 +466,6 @@ def test_cache_writer_empty_placeholders_when_none_found() -> None: with cache_file.open("r", encoding="utf-8") as f: data = json.load(f) - # Verify placeholders dict is empty - assert "placeholders" in data - assert data["placeholders"] == {} + # Verify parameters dict is empty + assert "cache_parameters" in data + assert data["cache_parameters"] == {} diff --git a/tests/unit/utils/test_trajectory_executor.py b/tests/unit/utils/test_trajectory_executor.py index d7f2aa6b..cb57ce98 100644 --- a/tests/unit/utils/test_trajectory_executor.py +++ b/tests/unit/utils/test_trajectory_executor.py @@ -22,13 +22,13 @@ def test_trajectory_executor_initialization() -> None: executor = TrajectoryExecutor( trajectory=trajectory, toolbox=toolbox, - placeholder_values={"var": "value"}, + parameter_values={"var": "value"}, delay_time=0.1, ) assert executor.trajectory == trajectory assert executor.toolbox == toolbox - assert executor.placeholder_values == {"var": "value"} + assert executor.parameter_values == {"var": "value"} assert executor.delay_time == 0.1 assert executor.current_step_index == 0 @@ -228,8 +228,8 @@ def test_trajectory_executor_handles_tool_error() -> None: assert "Tool execution failed" in (result.error_message or "") -def test_trajectory_executor_substitutes_placeholders() -> None: - """Test that executor substitutes placeholders before execution.""" +def test_trajectory_executor_substitutes_cache_parameters() -> None: + """Test that executor substitutes cache_parameters before execution.""" captured_steps = [] def capture_run(steps): # type: ignore @@ -257,7 +257,7 @@ def capture_run(steps): # type: ignore executor = TrajectoryExecutor( trajectory=trajectory, toolbox=mock_toolbox, - placeholder_values={"name": "Alice"}, + parameter_values={"name": "Alice"}, delay_time=0, ) From 5a2cd2de38f51abf577177738f947c9e1f2f1fc6 Mon Sep 17 00:00:00 2001 From: philipph-askui Date: Thu, 18 Dec 2025 14:27:30 +0100 Subject: [PATCH 23/30] fix(caching): fix parameter identification when using non-askui messagesAPI --- src/askui/agent_base.py | 5 +- src/askui/models/shared/settings.py | 2 - src/askui/utils/cache_parameter_handler.py | 72 +++++++++++----------- src/askui/utils/caching/cache_writer.py | 40 +++++++++++- 4 files changed, 79 insertions(+), 40 deletions(-) diff --git a/src/askui/agent_base.py b/src/askui/agent_base.py index 3e2d7f4d..f9540112 100644 --- a/src/askui/agent_base.py +++ b/src/askui/agent_base.py @@ -310,7 +310,7 @@ def act( if _caching_settings.strategy != "no": on_message = self._patch_act_with_cache( - _caching_settings, _settings, _tools, on_message, goal_str + _caching_settings, _settings, _tools, on_message, goal_str, _model ) logger.info( "Starting agent act with caching enabled (strategy=%s)", @@ -342,6 +342,7 @@ def _patch_act_with_cache( toolbox: ToolCollection, on_message: OnMessageCb | None, goal: str | None = None, + model: str | None = None, ) -> OnMessageCb | None: """Patch act settings and toolbox with caching functionality. @@ -399,6 +400,8 @@ def _patch_act_with_cache( cache_writer_settings=caching_settings.cache_writer_settings, toolbox=toolbox, goal=goal, + model_router=self._model_router, + model=model, ) if on_message is None: on_message = cache_writer.add_message_cb diff --git a/src/askui/models/shared/settings.py b/src/askui/models/shared/settings.py index e2f1480a..76f033b6 100644 --- a/src/askui/models/shared/settings.py +++ b/src/askui/models/shared/settings.py @@ -11,7 +11,6 @@ from pydantic import BaseModel, ConfigDict, Field from typing_extensions import Literal -from askui.models.anthropic.factory import AnthropicApiProvider from askui.models.shared.agent_message_param import ToolUseBlockParam, UsageParam COMPUTER_USE_20250124_BETA_FLAG = "computer-use-2025-01-24" @@ -44,7 +43,6 @@ class CachedExecutionToolSettings(BaseModel): class CacheWriterSettings(BaseModel): parameter_identification_strategy: CACHE_PARAMETER_IDENTIFICATION_STRATEGY = "llm" - llm_parameter_id_api_provider: AnthropicApiProvider = "askui" class CachingSettings(BaseModel): diff --git a/src/askui/utils/cache_parameter_handler.py b/src/askui/utils/cache_parameter_handler.py index f036757f..5154c133 100644 --- a/src/askui/utils/cache_parameter_handler.py +++ b/src/askui/utils/cache_parameter_handler.py @@ -13,10 +13,6 @@ import re from typing import Any -from askui.locators.serializers import VlmLocatorSerializer -from askui.models.anthropic.factory import AnthropicApiProvider -from askui.models.anthropic.messages_api import AnthropicMessagesApi -from askui.models.model_router import create_api_client from askui.models.shared.agent_message_param import MessageParam, ToolUseBlockParam from askui.models.shared.messages_api import MessagesApi from askui.prompts.caching import CACHING_PARAMETER_IDENTIFIER_SYSTEM_PROMPT @@ -52,8 +48,8 @@ def identify_and_parameterize( trajectory: list[ToolUseBlockParam], goal: str | None, identification_strategy: str, - api_provider: AnthropicApiProvider = "askui", - model: str = "claude-sonnet-4-5-20250929", + messages_api: MessagesApi | None = None, + model: str | None = None, ) -> tuple[str | None, list[ToolUseBlockParam], dict[str, str]]: """Identify parameters and return parameterized trajectory + goal. @@ -73,46 +69,52 @@ def identify_and_parameterize( - Parameterized trajectory (with {{param}} syntax) - Dict mapping parameter names to descriptions """ - if identification_strategy == "llm" and trajectory: - # Create messages_api for LLM-based identification - messages_api = AnthropicMessagesApi( - client=create_api_client(api_provider), - locator_serializer=VlmLocatorSerializer(), - ) - + if identification_strategy == "llm" and trajectory and messages_api and model: # Use LLM to identify parameters - parameters_dict, parameter_definitions = ( - CacheParameterHandler._identify_parameters_with_llm( - trajectory, messages_api, model + try: + logger.info("Trying to extract parameters using the strategy 'llm'") + parameters_dict, parameter_definitions = ( + CacheParameterHandler._identify_parameters_with_llm( + trajectory, messages_api, model + ) ) - ) - if parameter_definitions: - # Replace values with {{parameter}} syntax in trajectory - parameterized_trajectory = ( - CacheParameterHandler._replace_values_with_parameters( - trajectory, parameter_definitions + if parameter_definitions: + # Replace values with {{parameter}} syntax in trajectory + parameterized_trajectory = ( + CacheParameterHandler._replace_values_with_parameters( + trajectory, parameter_definitions + ) ) - ) - # Apply same replacement to goal text - parameterized_goal = goal - if goal: - parameterized_goal = ( - CacheParameterHandler._apply_parameters_to_text( - goal, parameter_definitions + # Apply same replacement to goal text + parameterized_goal = goal + if goal: + parameterized_goal = ( + CacheParameterHandler._apply_parameters_to_text( + goal, parameter_definitions + ) ) + + n_parameters = len(parameter_definitions) + logger.info( + "Replaced %s parameter values in trajectory", n_parameters ) + return parameterized_goal, parameterized_trajectory, parameters_dict - n_parameters = len(parameter_definitions) - logger.info("Replaced %s parameter values in trajectory", n_parameters) - return parameterized_goal, parameterized_trajectory, parameters_dict + else: # noqa: RET505 + # No parameters identified + logger.info("No parameters identified in trajectory") + return goal, trajectory, {} - # No parameters identified - logger.info("No parameters identified in trajectory") - return goal, trajectory, {} + except Exception: + logger.exception( + "An error occurred while extracting parameters using the strategy" + "'llm'. Will use 'preset' strategy instead" + ) # Manual extraction (preset strategy) + logger.info("Extracting parameters using the strategy 'preset'") parameter_names = CacheParameterHandler.extract_parameters(trajectory) parameters_dict = { name: f"Parameter for {name}" diff --git a/src/askui/utils/caching/cache_writer.py b/src/askui/utils/caching/cache_writer.py index 673f27ed..24d3d12f 100644 --- a/src/askui/utils/caching/cache_writer.py +++ b/src/askui/utils/caching/cache_writer.py @@ -2,13 +2,16 @@ import logging from datetime import datetime, timezone from pathlib import Path +from typing import TYPE_CHECKING +from askui.models.model_router import ModelRouter from askui.models.shared.agent_message_param import ( MessageParam, ToolUseBlockParam, UsageParam, ) from askui.models.shared.agent_on_message_cb import OnMessageCbParam +from askui.models.shared.facade import ModelFacade from askui.models.shared.settings import ( CacheFile, CacheMetadata, @@ -17,6 +20,9 @@ from askui.models.shared.tools import ToolCollection from askui.utils.cache_parameter_handler import CacheParameterHandler +if TYPE_CHECKING: + from askui.models.models import ActModel + logger = logging.getLogger(__name__) @@ -28,6 +34,8 @@ def __init__( cache_writer_settings: CacheWriterSettings | None = None, toolbox: ToolCollection | None = None, goal: str | None = None, + model_router: ModelRouter | None = None, + model: str | None = None, ) -> None: self.cache_dir = Path(cache_dir) self.cache_dir.mkdir(exist_ok=True) @@ -38,6 +46,8 @@ def __init__( self.was_cached_execution = False self._cache_writer_settings = cache_writer_settings or CacheWriterSettings() self._goal = goal + self._model_router = model_router + self._model = model self._toolbox: ToolCollection | None = None self._accumulated_usage = UsageParam() @@ -109,11 +119,37 @@ def _parameterize_trajectory( self, ) -> tuple[str | None, list[ToolUseBlockParam], dict[str, str]]: """Identify parameters and return parameterized trajectory + goal.""" + identification_strategy = "preset" + messages_api = None + model = None + + if self._cache_writer_settings.parameter_identification_strategy == "llm": + if self._model_router and self._model: + try: + _get_model: tuple[ActModel, str] = self._model_router._get_model( # noqa: SLF001 + self._model, "act" + ) + if isinstance(_get_model[0], ModelFacade): + act_model: ActModel = _get_model[0]._act_model # noqa: SLF001 + else: + act_model = _get_model[0] + model_name: str = _get_model[1] + if hasattr(act_model, "_messages_api"): + messages_api = act_model._messages_api # noqa: SLF001 + identification_strategy = "llm" + model = model_name + except Exception: + logger.exception( + "Using 'llm' for parameter identification caused an exception." + "Will use 'preset' strategy instead" + ) + return CacheParameterHandler.identify_and_parameterize( trajectory=self.messages, goal=self._goal, - identification_strategy=self._cache_writer_settings.parameter_identification_strategy, - api_provider=self._cache_writer_settings.llm_parameter_id_api_provider, + identification_strategy=identification_strategy, + messages_api=messages_api, + model=model, ) def _blank_non_cacheable_tool_inputs( From 10335e6393e589f20080c5000f9a2dcc4381b2dc Mon Sep 17 00:00:00 2001 From: philipph-askui Date: Mon, 22 Dec 2025 14:01:22 +0100 Subject: [PATCH 24/30] fix(caching): refine CACHE_USE_PROMPT to prevent interference of other (custom) tools --- src/askui/prompts/caching.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/askui/prompts/caching.py b/src/askui/prompts/caching.py index 7be63b5b..12a6f1dc 100644 --- a/src/askui/prompts/caching.py +++ b/src/askui/prompts/caching.py @@ -4,6 +4,10 @@ "task more robust and faster!\n" " To do so, first use the RetrieveCachedTestExecutions tool to check " "which trajectories are available for you.\n" + " It is very important, that you use the RetrieveCachedTestExecutions and not " + "another tool for finding precompted trajectories." + "Hence, please use the RetrieveCachedTestExecutions tool in this step, even in " + "cases where another comparable tool (e.g. list_files tool) might be available.\n" " The details what each trajectory that is available for you does are " "at the end of this prompt.\n" " A trajectory contains all necessary mouse movements, clicks, and " From 1581c3eb2d14eb653dc227a81b28b18501c1db20 Mon Sep 17 00:00:00 2001 From: philipph-askui Date: Wed, 24 Dec 2025 08:43:51 +0100 Subject: [PATCH 25/30] fix(caching): refine verfication_request message from cacheExecutionManager to Agent, to fix FalseNegative cache invalidation --- .../utils/caching/cache_execution_manager.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/askui/utils/caching/cache_execution_manager.py b/src/askui/utils/caching/cache_execution_manager.py index b9bedde4..765bd2a2 100644 --- a/src/askui/utils/caching/cache_execution_manager.py +++ b/src/askui/utils/caching/cache_execution_manager.py @@ -241,16 +241,22 @@ def _handle_cache_completed(self, truncation_strategy: TruncationStrategy) -> bo TextBlockParam( type="text", text=( - "The cached trajectory execution has completed. " - "Please verify if the execution correctly achieved " - "the target system state. " - "Use the verify_cache_execution tool to report " - "your verification result." + "[CACHE EXECUTION COMPLETED]\n\n" + "The CacheExecutor has automatically executed" + f" all steps from the cached trajectory" + f" '{self._cache_file_path}'. All previous tool calls in this" + f" conversation were replayed from cache, not performed by the" + f" agent.\n\n Please verify if the cached execution correctly" + " achieved the target system state using the" + " verify_cache_execution tool." ), ) ], ) truncation_strategy.append_message(verification_request) + self._reporter.add_message( + self.__class__.__name__, verification_request.model_dump(mode="json") + ) logger.debug("Injected cache verification request message") return False # Fall through to let agent verify execution From 617ee38b4e1d9816eb07a35d6574b331945acf82 Mon Sep 17 00:00:00 2001 From: philipph-askui Date: Tue, 30 Dec 2025 15:32:36 +0100 Subject: [PATCH 26/30] feat(caching)!: add caching v0.2 features, including visual validation BREAKING CHANGE: CachingSettings follow a new (simpler) format now --- docs/caching.md | 462 +++++++++++++++--- docs/visual_validation.md | 160 ++++++ src/askui/agent_base.py | 17 +- src/askui/models/anthropic/messages_api.py | 8 +- src/askui/models/shared/agent.py | 10 +- .../models/shared/agent_message_param.py | 22 +- src/askui/models/shared/settings.py | 33 +- src/askui/models/shared/token_counter.py | 46 +- src/askui/tools/caching_tools.py | 43 +- src/askui/utils/cache_parameter_handler.py | 1 + src/askui/utils/caching/cache_writer.py | 203 +++++++- src/askui/utils/trajectory_executor.py | 146 ++++-- src/askui/utils/visual_validation.py | 312 ++++++++++++ tests/unit/utils/test_trajectory_executor.py | 34 +- tests/unit/utils/test_visual_validation.py | 247 ++++++++++ 15 files changed, 1569 insertions(+), 175 deletions(-) create mode 100644 docs/visual_validation.md create mode 100644 src/askui/utils/visual_validation.py create mode 100644 tests/unit/utils/test_visual_validation.py diff --git a/docs/caching.md b/docs/caching.md index 025c639a..4f1260b3 100644 --- a/docs/caching.md +++ b/docs/caching.md @@ -10,57 +10,91 @@ The caching system works by recording all tool use actions (mouse movements, cli **New in v0.1:** The caching system now includes advanced features like parameter support for dynamic values, smart handling of non-cacheable tools that require agent intervention, comprehensive message history tracking, and automatic failure detection with recovery capabilities. +**New in v0.2:** Visual validation using perceptual hashing ensures cached trajectories execute only when the UI state matches expectations. The settings structure has been refactored for better clarity, separating writing settings from execution settings. + ## Caching Strategies +**Updated in v0.2:** Strategy names have been renamed for clarity. + The caching mechanism supports four strategies, configured via the `caching_settings` parameter in the `act()` method: -- **`"no"`** (default): No caching is used. The agent executes normally without recording or replaying actions. -- **`"write"`**: Records all agent actions to a cache file for future replay. -- **`"read"`**: Provides tools to the agent to list and execute previously cached trajectories. -- **`"both"`**: Combines read and write modes - the agent can use existing cached trajectories and will also record new ones. +- **`None`** (default): No caching is used. The agent executes normally without recording or replaying actions. +- **`"record"`**: Records all agent actions to a cache file for future replay. +- **`"execute"`**: Provides tools to the agent to list and execute previously cached trajectories. +- **`"both"`**: Combines execute and record modes - the agent can use existing cached trajectories and will also record new ones. ## Configuration +**Updated in v0.2:** Caching settings have been refactored for better clarity, separating writing-related settings from execution-related settings. + Caching is configured using the `CachingSettings` class: ```python -from askui.models.shared.settings import CachingSettings, CachedExecutionToolSettings, CacheWriterSettings +from askui.models.shared.settings import ( + CachingSettings, + CacheWritingSettings, + CacheExecutionSettings, +) caching_settings = CachingSettings( - strategy="write", # One of: "read", "write", "both", "no" + strategy="both", # One of: "execute", "record", "both", or None cache_dir=".cache", # Directory to store cache files - filename="my_test.json", # Filename for the cache file (optional for write mode) - cache_writer_settings=CacheWriterSettings( - parameter_identification_strategy="llm", - ) # Auto-detect dynamic values (default: "llm") - execute_cached_trajectory_tool_settings=CachedExecutionToolSettings( - delay_time_between_action=0.5 # Delay in seconds between each cached action - ) + writing_settings=CacheWritingSettings( + filename="my_test.json", # Cache file name + parameter_identification_strategy="llm", # Auto-detect dynamic values + visual_verification_method="phash", # Visual validation method + visual_validation_region_size=100, # Size of validation region (pixels) + visual_validation_threshold=10, # Hamming distance threshold + ), + execution_settings=CacheExecutionSettings( + delay_time_between_action=0.5 # Delay in seconds between each action + ), ) ``` ### Parameters -- **`strategy`**: The caching strategy to use (`"read"`, `"write"`, `"both"`, or `"no"`). +- **`strategy`**: The caching strategy to use (`"execute"`, `"record"`, `"both"`, or `None`). **Updated in v0.2:** Renamed from "read"/"write"/"no" to "execute"/"record"/None for clarity. - **`cache_dir`**: Directory where cache files are stored. Defaults to `".cache"`. -- **`filename`**: Name of the cache file to write to or read from. If not specified in write mode, a timestamped filename will be generated automatically (format: `cached_trajectory_YYYYMMDDHHMMSSffffff.json`). -- **`CacheWriterSettings`**: **New in v0.1!** Configuration for the Cache Writer See [CacheWriter Settings](#cachewriter-settings) below. -- **`execute_cached_trajectory_tool_settings`**: Configuration for the trajectory execution tool (optional). See [Execution Settings](#execution-settings) below. +- **`writing_settings`**: **New in v0.2!** Configuration for cache recording. See [Writing Settings](#writing-settings) below. Can be `None` if only executing caches. +- **`execution_settings`**: **New in v0.2!** Configuration for cache execution. See [Execution Settings](#execution-settings) below. Can be `None` if only recording caches. + +### Writing Settings -### CacheWriter Settings +**New in v0.2!** The `CacheWritingSettings` class configures how cache files are recorded: + +```python +from askui.models.shared.settings import CacheWritingSettings + +writing_settings = CacheWritingSettings( + filename="my_test.json", # Name of cache file to create + parameter_identification_strategy="llm", # "llm" or "preset" + visual_verification_method="phash", # "phash", "ahash", or "none" + visual_validation_region_size=100, # Size of region to validate (pixels) + visual_validation_threshold=10, # Hamming distance threshold (0-64) +) +``` + +#### Parameters -- `parameter_identification_strategy`: When `llm` (default), uses AI to automatically identify and parameterize dynamic values like dates, usernames, and IDs during cache recording. When `preset`, only manually specified cache_parameters (using `{{...}}` syntax) are detected. See [Automatic Cache Parameter Identification](#automatic-parameter-identification). -- `llm_parameter_id_api_provider`: The provider of that will be used for for the llm in the parameter identification (will only be used if `parameter_identification_strategy`is set to `llm`). Defaults to `askui`. +- **`filename`**: Name of the cache file to write. Defaults to `""` (auto-generates timestamped filename: `cached_trajectory_YYYYMMDDHHMMSSffffff.json`). +- **`parameter_identification_strategy`**: When `"llm"` (default), uses AI to automatically identify and parameterize dynamic values like dates, usernames, and IDs. When `"preset"`, only manually specified parameters using `{{...}}` syntax are detected. See [Automatic Parameter Identification](#automatic-parameter-identification). +- **`visual_verification_method`**: **New in v0.2!** Visual validation method to use: + - `"phash"` (default): Perceptual hash using DCT - robust to minor changes like compression and lighting + - `"ahash"`: Average hash - simpler and faster, less robust to transformations + - `"none"`: Disable visual validation +- **`visual_validation_region_size`**: **New in v0.2!** Size of the square region (in pixels) to extract around interaction coordinates for visual validation. Defaults to `100` (100×100 pixel region). +- **`visual_validation_threshold`**: **New in v0.2!** Maximum Hamming distance (0-64) between stored and current visual hashes to consider a match. Lower values require closer matches. Defaults to `10`. ### Execution Settings -The `CachedExecutionToolSettings` class allows you to configure how cached trajectories are executed: +The `CacheExecutionSettings` class configures how cached trajectories are executed: ```python -from askui.models.shared.settings import CachedExecutionToolSettings +from askui.models.shared.settings import CacheExecutionSettings -execution_settings = CachedExecutionToolSettings( - delay_time_between_action=0.5 # Delay in seconds between each action (default: 0.5) +execution_settings = CacheExecutionSettings( + delay_time_between_action=0.5 # Delay in seconds between each action ) ``` @@ -74,28 +108,35 @@ You can adjust this value based on your application's responsiveness: ## Usage Examples -### Writing a Cache (Recording) +### Recording a Cache Record agent actions to a cache file for later replay: ```python from askui import VisionAgent -from askui.models.shared.settings import CachingSettings +from askui.models.shared.settings import CachingSettings, CacheWritingSettings with VisionAgent() as agent: agent.act( goal="Fill out the login form with username 'admin' and password 'secret123'", caching_settings=CachingSettings( - strategy="write", + strategy="record", cache_dir=".cache", - filename="login_test.json" + writing_settings=CacheWritingSettings( + filename="login_test.json", + visual_verification_method="phash", # Enable visual validation + ), ) ) ``` -After execution, a cache file will be created at `.cache/login_test.json` containing all the tool use actions performed by the agent, along with metadata about the execution. +After execution, a cache file will be created at `.cache/login_test.json` containing: +- All tool use actions performed by the agent +- Metadata about the execution +- **New in v0.2:** Visual validation hashes for click and type actions +- Automatically detected cache parameters (if any) -### Reading from Cache (Replaying) +### Executing from Cache Provide the agent with access to previously recorded trajectories: @@ -107,13 +148,13 @@ with VisionAgent() as agent: agent.act( goal="Fill out the login form", caching_settings=CachingSettings( - strategy="read", + strategy="execute", cache_dir=".cache" ) ) ``` -When using `strategy="read"`, the agent receives two tools: +When using `strategy="execute"`, the agent receives two tools: 1. **`RetrieveCachedTestExecutions`**: Lists all available cache files in the cache directory 2. **`ExecuteCachedTrajectory`**: Executes a cached trajectory. Can start from the beginning (default) or continue from a specific step index using the optional `start_from_step_index` parameter (useful after handling non-cacheable steps) @@ -134,19 +175,21 @@ with VisionAgent() as agent: agent.act( goal="Create a new task for today with the title 'Review PR'", caching_settings=CachingSettings( - strategy="write", + strategy="record", cache_dir=".cache", - filename="create_task.json" + writing_settings=CacheWritingSettings( + filename="create_task.json" + ) ) ) -# Later, when replaying, the agent can provide parameter values +# Later, when executing, the agent can provide parameter values # If the cache file contains {{current_date}} or {{task_title}}, provide them: with VisionAgent() as agent: agent.act( goal="Create a task using the cached flow", caching_settings=CachingSettings( - strategy="read", + strategy="execute", cache_dir=".cache" ) ) @@ -168,7 +211,7 @@ with VisionAgent() as agent: agent.act( goal="Debug the login form by checking element states", caching_settings=CachingSettings( - strategy="read", + strategy="execute", cache_dir=".cache" ) ) @@ -201,7 +244,7 @@ This is particularly useful for: ### Referencing Cache Files in Goal Prompts -When using `strategy="read"` or `strategy="both"`, you need to inform the agent about which cache files are available and when to use them. This is done by including cache file information directly in your goal prompt. +When using `strategy="execute"` or `strategy="both"`, you need to inform the agent about which cache files are available and when to use them. This is done by including cache file information directly in your goal prompt. #### Explicit Cache File References @@ -218,7 +261,7 @@ with VisionAgent() as agent: If the cache file "open_website_in_chrome.json" is available, please use it for this execution. It will open a new window in Chrome and navigate to the website.""", caching_settings=CachingSettings( - strategy="read", + strategy="execute", cache_dir=".cache" ) ) @@ -241,7 +284,7 @@ with VisionAgent() as agent: Check if a cache file named "{test_id}.json" exists. If it does, use it to replay the test actions, then verify the results.""", caching_settings=CachingSettings( - strategy="read", + strategy="execute", cache_dir="test_cache" ) ) @@ -263,7 +306,7 @@ with VisionAgent() as agent: Choose the most recent one if multiple are available, as it likely contains the most up-to-date interaction sequence.""", caching_settings=CachingSettings( - strategy="read", + strategy="execute", cache_dir=".cache" ) ) @@ -287,7 +330,7 @@ with VisionAgent() as agent: After each cached execution, verify the step completed successfully before proceeding.""", caching_settings=CachingSettings( - strategy="read", + strategy="execute", cache_dir=".cache" ) ) @@ -305,15 +348,15 @@ You can customize the delay between cached actions to match your application's r ```python from askui import VisionAgent -from askui.models.shared.settings import CachingSettings, CachedExecutionToolSettings +from askui.models.shared.settings import CachingSettings, CacheExecutionSettings with VisionAgent() as agent: agent.act( goal="Fill out the login form", caching_settings=CachingSettings( - strategy="read", + strategy="execute", cache_dir=".cache", - execute_cached_trajectory_tool_settings=CachedExecutionToolSettings( + execute_cached_trajectory_tool_settings=CacheExecutionSettings( delay_time_between_action=1.0 # Wait 1 second between each action ) ) @@ -353,44 +396,67 @@ In this mode: **New in v0.1:** Cache files now use an enhanced format with metadata tracking, parameter support, and execution history. -### v0.1 Format (Current) +**New in v0.2:** Cache files include visual validation metadata and enhanced trajectory steps with visual hashes. + +### v0.2 Format (Current) Cache files are JSON objects with the following structure: ```json { "metadata": { - "version": "0.1", - "created_at": "2025-12-11T10:30:00Z", + "version": "0.2", + "created_at": "2025-12-30T10:30:00Z", "goal": "Greet user {{user_name}} and log them in", - "last_executed_at": "2025-12-11T15:45:00Z", + "last_executed_at": "2025-12-30T15:45:00Z", + "token_usage": { + "input_tokens": 1250, + "output_tokens": 380, + "cache_creation_input_tokens": 0, + "cache_read_input_tokens": 0 + }, "execution_attempts": 3, "failures": [ { - "timestamp": "2025-12-11T14:20:00Z", + "timestamp": "2025-12-30T14:20:00Z", "step_index": 5, - "error_message": "Element not found", + "error_message": "Visual validation failed: UI region changed", "failure_count_at_step": 1 } ], "is_valid": true, - "invalidation_reason": null + "invalidation_reason": null, + "visual_verification_method": "phash", + "visual_validation_region_size": 100, + "visual_validation_threshold": 10 }, "trajectory": [ { "type": "tool_use", "id": "toolu_01AbCdEfGhIjKlMnOpQrStUv", "name": "computer", + "input": { + "action": "left_click", + "coordinate": [450, 320] + }, + "visual_representation": "80c0e3f3e3e7e381c7c78f1f3f3f7f7e" + }, + { + "type": "tool_use", + "id": "toolu_02XyZaBcDeFgHiJkLmNoPqRs", + "name": "computer", "input": { "action": "type", "text": "Hello {{user_name}}!" - } + }, + "visual_representation": "91d1f4e4d4c6c282c6c79e2e4e4e6e6d" }, { "type": "tool_use", - "id": "toolu_02XyZaBcDeFgHiJkLmNoPqRs", + "id": "toolu_03StUvWxYzAbCdEfGhIjKlMn", "name": "print_debug_info", - "input": {} + "input": {}, + "visual_representation": null } ], "cache_parameters": { @@ -403,14 +469,18 @@ Cache files are JSON objects with the following structure: #### Metadata Fields -- **`version`**: Cache file format version (currently "0.1") +- **`version`**: Cache file format version (currently "0.2") - **`created_at`**: ISO 8601 timestamp when the cache was created -- **`goal`**: **New!** The original goal/instruction given to the agent when recording this trajectory. Cache Parameters are applied to the goal text just like in the trajectory, making it easy to understand what the cache was designed to accomplish. +- **`goal`**: The original goal/instruction given to the agent when recording this trajectory. Cache Parameters are applied to the goal text just like in the trajectory, making it easy to understand what the cache was designed to accomplish. - **`last_executed_at`**: ISO 8601 timestamp of the last execution (null if never executed) +- **`token_usage`**: **New in v0.1!** Token usage statistics from the recording execution - **`execution_attempts`**: Number of times this trajectory has been executed - **`failures`**: List of failures encountered during execution (see [Failure Tracking](#failure-tracking)) - **`is_valid`**: Boolean indicating if the cache is still considered valid - **`invalidation_reason`**: Optional string explaining why the cache was invalidated +- **`visual_verification_method`**: **New in v0.2!** Visual validation method used when recording (`"phash"`, `"ahash"`, or `null`) +- **`visual_validation_region_size`**: **New in v0.2!** Size of the validation region in pixels (e.g., `100` for 100×100 pixels) +- **`visual_validation_threshold`**: **New in v0.2!** Hamming distance threshold for visual validation (0-64) #### Cache Parameters @@ -452,7 +522,7 @@ The old format was a simple JSON array: The caching system consists of several key components: -- **`CacheWriter`**: Handles recording trajectories in write mode +- **`CacheWriter`**: Handles recording trajectories in record mode - **`CacheExecutionManager`**: Manages cache execution state, flow control, and metadata updates during trajectory replay - **`TrajectoryExecutor`**: Executes individual steps from cached trajectories - **Agent**: Orchestrates the conversation flow and delegates cache execution to `CacheExecutionManager` @@ -465,14 +535,19 @@ When executing a cached trajectory, the `Agent` class delegates all cache-relate This separation of concerns keeps the Agent focused on conversation orchestration while CacheExecutionManager handles all caching complexity. -### Write Mode +### Record Mode -In write mode, the `CacheWriter` class: +In record mode, the `CacheWriter` class: 1. Intercepts all assistant messages via a callback function 2. Extracts tool use blocks from the messages -3. Stores tool blocks in memory during execution -4. When agent finishes (on `stop_reason="end_turn"`): +3. **Enhances with visual validation** (New in v0.2): + - For click and type actions, captures screenshot before execution + - Extracts region around interaction coordinate + - Computes perceptual hash using selected method (pHash/aHash) + - Attaches hash and validation settings to tool block +4. Stores enhanced tool blocks in memory during execution +5. When agent finishes (on `stop_reason="end_turn"`): - **Automatically identifies cache_parameters** using AI (if `parameter_identification_strategy=llm`) - Analyzes trajectory to find dynamic values (dates, usernames, IDs, etc.) - Generates descriptive parameter definitions @@ -480,14 +555,14 @@ In write mode, the `CacheWriter` class: - Applies same replacements to the goal text - **Blanks non-cacheable tool inputs** by setting `input: {}` for tools with `is_cacheable=False` (saves space and privacy) - **Writes to JSON file** with: - - v0.1 metadata (version, timestamps, goal with cache_parameters) - - Trajectory of tool use blocks (with cache_parameters and blanked inputs) + - v0.2 metadata (version, timestamps, goal, token usage, visual validation settings) + - Trajectory of tool use blocks (with cache_parameters, visual hashes, and blanked inputs) - Parameter definitions with descriptions -5. Automatically skips writing if a cached execution was used (to avoid recording replays) +6. Automatically skips writing if a cached execution was used (to avoid recording replays) -### Read Mode +### Execute Mode -In read mode: +In execute mode: 1. Two caching tools are added to the agent's toolbox: - `RetrieveCachedTestExecutions`: Lists available trajectories @@ -499,6 +574,7 @@ In read mode: - Failure recovery strategies 3. The agent can list available cache files and choose appropriate ones 4. During execution via `TrajectoryExecutor`: + - **Visual validation** (New in v0.2): Before each validated step, captures current UI and compares hash to stored hash - Each step is executed sequentially with configurable delays - All tools in the trajectory are executed, including screenshots and retrieval tools - Non-cacheable tools trigger a pause with `NEEDS_AGENT` status @@ -506,6 +582,7 @@ In read mode: - Message history is built with assistant (tool use) and user (tool result) messages - Agent sees all screenshots and results in the message history 5. Execution can pause for agent intervention: + - When visual validation fails (New in v0.2) - When reaching non-cacheable tools - When errors occur (with failure details) 6. Agent can resume execution: @@ -699,14 +776,14 @@ If you prefer manual parameter control: ```python caching_settings = CachingSettings( - strategy="write", - cache_writer_settings = CacheWriterSettings( - parameter_identification_strategy="default" # Only detect {{...}} syntax + strategy="record", + writing_settings=CacheWritingSettings( + parameter_identification_strategy="preset" # Only detect {{...}} syntax ) ) ``` -With `parameter_identification_strategy=default`, only manually specified cache_parameters using the `{{...}}` syntax will be detected. +With `parameter_identification_strategy="preset"`, only manually specified cache_parameters using the `{{...}}` syntax will be detected. #### Logging @@ -785,12 +862,245 @@ Example: } ``` +## Visual Validation + +**New in v0.2!** Visual validation ensures cached trajectories execute only when the UI state matches the recorded state, preventing actions from being executed on incorrect UI elements. + +### How It Works + +During cache recording (record mode), the system: +1. **Captures screenshots** before each interaction (clicks, typing, key presses) +2. **Extracts a region** (e.g., 100×100 pixels) around the interaction coordinate +3. **Computes a perceptual hash** of that region using the selected method +4. **Stores the hash** in the trajectory step along with validation settings + +During cache execution (execute mode), the system: +1. **Captures the current UI state** before each step +2. **Extracts the same region** around the interaction coordinate +3. **Computes the hash** of the current region +4. **Compares hashes** using Hamming distance +5. **Validates the match** against the threshold +6. **Executes the step** only if validation passes, otherwise returns control to the agent + +### Visual Validation Methods + +#### pHash (Perceptual Hash) + +Default method using Discrete Cosine Transform (DCT): + +```python +writing_settings=CacheWritingSettings( + visual_verification_method="phash", # Default + visual_validation_region_size=100, + visual_validation_threshold=10, +) +``` + +**Characteristics:** +- ✅ Robust to minor changes (compression, scaling, lighting adjustments) +- ✅ Sensitive to structural changes (moved buttons, different layouts) +- ✅ Best for most use cases +- ⚠️ Slightly slower than aHash + +**When to use:** +- Production environments where UI may have subtle variations +- Cross-platform testing (different rendering engines) +- Long-lived caches that may encounter minor UI updates + +#### aHash (Average Hash) + +Simpler method using mean pixel values: + +```python +writing_settings=CacheWritingSettings( + visual_verification_method="ahash", + visual_validation_region_size=100, + visual_validation_threshold=10, +) +``` + +**Characteristics:** +- ✅ Fast computation +- ✅ Simple and predictable +- ⚠️ Less robust to transformations +- ⚠️ More sensitive to color/brightness changes + +**When to use:** +- Development/testing environments with controlled conditions +- Performance-critical scenarios +- UI that rarely changes + +#### Disabled + +Disable visual validation entirely: + +```python +writing_settings=CacheWritingSettings( + visual_verification_method="none", +) +``` + +**When to use:** +- UI that never changes +- Testing the caching system itself +- Debugging trajectory execution + +### Configuration Options + +#### Region Size + +The `visual_validation_region_size` parameter controls the size of the square region extracted around each interaction coordinate: + +```python +writing_settings=CacheWritingSettings( + visual_validation_region_size=50, # 50×50 pixel region (smaller, faster) + # visual_validation_region_size=100, # 100×100 pixel region (default, balanced) + # visual_validation_region_size=200, # 200×200 pixel region (larger, more context) +) +``` + +**Smaller regions (50-75 pixels):** +- ✅ Faster processing +- ✅ More focused validation (just the element) +- ⚠️ May miss context changes + +**Larger regions (150-200 pixels):** +- ✅ Captures more UI context +- ✅ Detects broader layout changes +- ⚠️ Slower processing +- ⚠️ More sensitive to unrelated UI changes + +**Default (100 pixels):** +- Balanced between speed and context +- Suitable for most use cases + +#### Validation Threshold + +The `visual_validation_threshold` parameter controls how similar the UI must be (Hamming distance, 0-64): + +```python +writing_settings=CacheWritingSettings( + visual_validation_threshold=5, # Strict: requires very close match + # visual_validation_threshold=10, # Default: balanced + # visual_validation_threshold=20, # Lenient: allows more variation +) +``` + +**Lower thresholds (0-5):** +- Very strict matching +- Fails on minor UI changes +- Best for pixel-perfect UIs + +**Medium thresholds (8-15):** +- Balanced sensitivity +- Tolerates minor variations +- **Default: 10** + +**Higher thresholds (20-30):** +- Lenient matching +- May allow too much variation +- Risk of false positives + +### Validated Actions + +Visual validation is applied to actions that interact with specific UI coordinates: + +**Validated automatically:** +- `left_click` +- `right_click` +- `double_click` +- `middle_click` +- `type` (validates input field location) +- `key` (validates focus location) + +**NOT validated:** +- `mouse_move` (movement doesn't require validation) +- `screenshot` (no UI interaction) +- Non-computer tools +- Tools marked as `is_cacheable=False` + +### Handling Validation Failures + +When visual validation fails during cache execution: + +1. **Execution stops** at the failed step +2. **Agent receives notification** with details: + - Which step failed + - The validation error message + - Current message history and screenshots +3. **Agent can decide**: + - Take a screenshot to assess current UI state + - Execute the step manually if safe + - Skip the step and continue + - Invalidate the cache and request re-recording + +Example agent recovery flow: +``` +Step 5 validation fails: "Visual validation failed: UI region changed (distance: 15 > threshold: 10)" +↓ +Agent takes screenshot to see current state +↓ +Agent sees button is present but slightly moved +↓ +Agent clicks button manually at new location +↓ +Agent continues execution from step 6 +``` + +### Best Practices + +1. **Choose the right method:** + - Use `phash` (default) for most cases + - Use `ahash` only for controlled environments + - Never use `none` in production + +2. **Tune the threshold:** + - Start with default (10) + - Increase if getting too many false failures + - Decrease if allowing incorrect executions + +3. **Adjust region size:** + - Use default (100) initially + - Increase for complex layouts + - Decrease for simple, isolated elements + +4. **Monitor validation logs:** + - Enable INFO logging to see validation results + - Track failure patterns + - Adjust settings based on failure analysis + +5. **Re-record when needed:** + - After significant UI changes + - When validation consistently fails + - After threshold/region adjustments + +### Logging + +Enable INFO-level logging to see visual validation activity: + +```python +import logging +logging.basicConfig(level=logging.INFO) +``` + +During **recording**, you'll see: +``` +INFO: ✓ Visual validation added to computer action=left_click at coordinate (450, 320) (hash=80c0e3f3e3e7e381...) +INFO: ✓ Visual validation added to computer action=type at coordinate (450, 380) (hash=91d1f4e4d4c6c282...) +``` + +During **execution**, validation happens silently on success. On **failure**, you'll see: +``` +WARNING: Visual validation failed at step 5: Visual validation failed: UI region changed significantly (Hamming distance: 15 > threshold: 10) +WARNING: Handing execution back to agent. +``` + ## Limitations and Considerations ### Current Limitations - **UI State Sensitivity**: Cached trajectories assume the UI is in the same state as when they were recorded. If the UI has changed significantly, replay may fail. -- **No on_message Callback**: When using `strategy="write"` or `strategy="both"`, you cannot provide a custom `on_message` callback, as the caching system uses this callback to record actions. +- **No on_message Callback**: When using `strategy="record"` or `strategy="both"`, you cannot provide a custom `on_message` callback, as the caching system uses this callback to record actions. - **Verification Required**: After executing a cached trajectory, the agent should verify that the results are correct, as UI changes may cause partial failures. ### Best Practices @@ -916,9 +1226,9 @@ with VisionAgent() as agent: the login sequence. It contains the steps to navigate to the login page and authenticate with the test credentials.""", caching_settings=CachingSettings( - strategy="read", + strategy="execute", cache_dir="test_cache", - execute_cached_trajectory_tool_settings=CachedExecutionToolSettings( + execute_cached_trajectory_tool_settings=CacheExecutionSettings( delay_time_between_action=0.75 ) ) @@ -955,7 +1265,7 @@ if __name__ == "__main__": Planned features for future versions: -- **Visual Validation**: Screenshot comparison using perceptual hashing (aHash) to detect UI changes +- **✅ Visual Validation** (Implemented in v0.2): Screenshot comparison using perceptual hashing (pHash/aHash) to detect UI changes - **Cache Invalidation Strategies**: Configurable validators for automatic cache invalidation - **Cache Management Tools**: Tools for listing, validating, and invalidating caches - **Smart Retry**: Automatic retry with adjustments when specific failure patterns are detected @@ -980,7 +1290,7 @@ Planned features for future versions: **Issue**: Actions execute too quickly, causing failures - **Cause**: `delay_time_between_action` is too short for your application -- **Solution**: Increase delay in `CachedExecutionToolSettings` (e.g., from 0.5 to 1.0 seconds) +- **Solution**: Increase delay in `CacheExecutionSettings` (e.g., from 0.5 to 1.0 seconds) **Issue**: "Tool not found in toolbox" error - **Cause**: Cached trajectory uses a tool that's no longer available diff --git a/docs/visual_validation.md b/docs/visual_validation.md new file mode 100644 index 00000000..46d0927d --- /dev/null +++ b/docs/visual_validation.md @@ -0,0 +1,160 @@ +# Visual Validation for Caching + +> **Status**: ✅ Implemented +> **Version**: v0.2 +> **Last Updated**: 2025-12-30 + +## Overview + +Visual validation verifies that the UI state matches expectations before executing cached trajectory steps. This significantly improves cache reliability by detecting UI changes that would cause cached actions to fail. + +The system stores visual representations (perceptual hashes) of UI regions where actions like clicks are executed. During cache execution, these hashes are compared with the current UI state to detect changes. + +## How are the visual Representations computed + +We can think of multiple methods, e.g. aHash, pHash, ... + +**Key Necessary Properties:** +- Fast computation (~1-2ms per hash) +- Small storage footprint (64 bits = 8 bytes) +- Robust to minor changes (compression, scaling, lighting) +- Sensitive to structural changes (moved buttons, different layouts) + +Which method was used will be added to the metadata field of the cached trajectory. + + +## How It Works + +### 1. Representation Storage + +When a trajectory is recorded and cached, visual representations will be captured for critical steps: + +```json +{ + "type": "tool_use", + "name": "computer", + "input": { + "action": "left_click", + "coordinate": [450, 300] + }, + "visual_representation": "a8f3c9e14b7d2056" +} +``` + +**Which steps should be validated?** +- Mouse clicks (left_click, right_click, double_click, middle_click) +- Type actions (verify input field hasn't moved) +- Key presses targeting specific UI elements + +**Hash region selection:** +- For clicks: Capture region around click coordinate (e.g., 100x100px centered on target) +- For type actions: Capture region around text input field (e.g., 100x100px centered on target) + +### 2. Hash Verification (During Cache Execution) + +Before executing each step that has a `visual_representation`: + +1. **Capture current screen region** at the same coordinates used during recording +2. **Compute visual Representation, e.g. aHash** of the current region +3. **Compare with stored hash** using Hamming distance +4. **Make decision** based on threshold: + +```python +def should_validate_step(stored_hash: str, current_screen: Image, threshold: int = 10) -> bool: + """ + Check if visual validation passes. + + Args: + stored_hash: The aHash stored in the cache + current_screen: Current screenshot region + threshold: Maximum Hamming distance (0-64) + - 0-5: Nearly identical (recommended for strict validation) + - 6-10: Very similar (default - allows minor changes) + - 11-15: Similar (more lenient) + - 16+: Different (validation should fail) + + Returns: + True if validation passes, False if UI has changed significantly + """ + current_hash = compute_ahash(current_screen) + distance = hamming_distance(stored_hash, current_hash) + return distance <= threshold +``` + +### 3. Validation Results + +**If validation passes** (distance ≤ threshold): +- ✅ Execute the cached step normally +- Continue with trajectory execution + +**If validation fails** (distance > threshold): +- ⚠️ Pause trajectory execution +- Return control to agent with detailed information: + ``` + Visual validation failed at step 5 (left_click at [450, 300]) as the UI region has changed significantly as compared to during recording time. + Please Inspect the current UI state and perform the necessary step. + ``` + +## Configuration + +Visual validation is configured in the Cache Settings: + +```python +# In settings +class CachingSettings: + visual_verification_method: CACHING_VISUAL_VERIFICATION_METHOD = "phash" # or "ahash", "none" + +class CachedExecutionToolSettings: + visual_validation_threshold: int = 10 # Hamming distance threshold (0-64) +``` + +**Configuration Options:** +- `visual_verification_method`: Hash method to use + - `"phash"` (default): Perceptual hash - robust to minor changes, sensitive to structural changes + - `"ahash"`: Average hash - faster but less robust + - `"none"`: Disable visual validation +- `visual_validation_threshold`: Maximum allowed Hamming distance (0-64) + - `0-5`: Nearly identical (strict validation) + - `6-10`: Very similar (default - recommended) + - `11-15`: Similar (lenient) + - `16+`: Different (likely to fail validation) + + +## Benefits + +### 1. Improved Reliability +- Detect UI changes before execution fails +- Reduce cache invalidation due to false negatives +- Provide early warning of UI state mismatches + +### 2. Better User Experience +- Agent can make informed decisions about cache validity +- Clear feedback when UI has changed +- Opportunity to adapt instead of failing + +### 3. Intelligent Cache Management +- Automatically identify outdated caches +- Track which UI regions are stable vs. volatile +- Optimize cache usage patterns + +## Limitations and Considerations + +### 1. Performance Impact +- Each validation requires a screenshot + hash computation (~5-10ms) +- May slow down trajectory execution +- Mitigation: Only validate critical steps, not every action + +### 2. False Positives +- Minor UI changes (animations, hover states) may trigger validation failures +- Threshold tuning required for different applications +- Mitigation: Adaptive thresholds, ignore transient changes + +### 3. False Negatives +- Subtle but critical changes might not be detected +- Text content changes may not affect visual hash significantly +- Mitigation: Combine with other validation methods (OCR, element detection) + +### 4. Storage Overhead +- Each validated step adds 8 bytes (visual_hash) + 1 byte (flag) +- A 100-step trajectory adds ~900 bytes +- Mitigation: Acceptable overhead for improved reliability diff --git a/src/askui/agent_base.py b/src/askui/agent_base.py index f9540112..d07e60d4 100644 --- a/src/askui/agent_base.py +++ b/src/askui/agent_base.py @@ -203,7 +203,7 @@ def act( be used for achieving the `goal`. on_message (OnMessageCb | None, optional): Callback for new messages. If it returns `None`, stops and does not add the message. Cannot be used - with caching_settings strategy "write" or "both". + with caching_settings strategy "record" or "both". tools (list[Tool] | ToolCollection | None, optional): The tools for the agent. Defaults to default tools depending on the selected model. settings (AgentSettings | None, optional): The settings for the agent. @@ -308,7 +308,7 @@ def act( _tools = self._build_tools(tools, _model) - if _caching_settings.strategy != "no": + if _caching_settings.strategy is not None: on_message = self._patch_act_with_cache( _caching_settings, _settings, _tools, on_message, goal_str, _model ) @@ -359,8 +359,8 @@ def _patch_act_with_cache( logger.debug("Setting up caching") caching_tools: list[Tool] = [] - # Setup read mode: add caching tools and modify system prompt - if caching_settings.strategy in ["read", "both"]: + # Setup execute mode: add caching tools and modify system prompt + if caching_settings.strategy in ["execute", "both"]: from askui.tools.caching_tools import VerifyCacheExecution caching_tools.extend( @@ -368,7 +368,7 @@ def _patch_act_with_cache( RetrieveCachedTestExecutions(caching_settings.cache_dir), ExecuteCachedTrajectory( toolbox=toolbox, - settings=caching_settings.execute_cached_trajectory_tool_settings, + settings=caching_settings.execution_settings, ), VerifyCacheExecution(), ] @@ -391,13 +391,12 @@ def _patch_act_with_cache( if caching_tools: toolbox.append_tool(*caching_tools) - # Setup write mode: create cache writer and set message callback + # Setup record mode: create cache writer and set message callback cache_writer = None - if caching_settings.strategy in ["write", "both"]: + if caching_settings.strategy in ["record", "both"]: cache_writer = CacheWriter( cache_dir=caching_settings.cache_dir, - file_name=caching_settings.filename, - cache_writer_settings=caching_settings.cache_writer_settings, + cache_writing_settings=caching_settings.writing_settings, toolbox=toolbox, goal=goal, model_router=self._model_router, diff --git a/src/askui/models/anthropic/messages_api.py b/src/askui/models/anthropic/messages_api.py index 01814164..dab29c7a 100644 --- a/src/askui/models/anthropic/messages_api.py +++ b/src/askui/models/anthropic/messages_api.py @@ -66,12 +66,18 @@ def create_message( tool_choice: BetaToolChoiceParam | Omit = omit, temperature: float | Omit = omit, ) -> MessageParam: + # Convert messages to dicts with API context to exclude internal fields _messages = [ cast( - "BetaMessageParam", message.model_dump(exclude={"stop_reason", "usage"}) + "BetaMessageParam", + message.model_dump( + exclude={"stop_reason", "usage"}, + context={"for_api": True} # Triggers exclusion of internal fields + ), ) for message in messages ] + response = self._client.beta.messages.create( # type: ignore[misc] messages=_messages, max_tokens=max_tokens or 4096, diff --git a/src/askui/models/shared/agent.py b/src/askui/models/shared/agent.py index adde725a..7fa4d298 100644 --- a/src/askui/models/shared/agent.py +++ b/src/askui/models/shared/agent.py @@ -96,8 +96,10 @@ def _get_agent_response( self._accumulate_usage(message_by_assistant.usage) # type: ignore - message_by_assistant_dict = message_by_assistant.model_dump(mode="json") - logger.debug(message_by_assistant_dict) + message_by_assistant_dict = message_by_assistant.model_dump( + mode="json", context={"for_api": True} + ) + # logger.debug(message_by_assistant_dict) truncation_strategy.append_message(message_by_assistant) self._reporter.add_message(self.__class__.__name__, message_by_assistant_dict) @@ -132,7 +134,9 @@ def _process_tool_execution( if not tool_result_message: return False - tool_result_message_dict = tool_result_message.model_dump(mode="json") + tool_result_message_dict = tool_result_message.model_dump( + mode="json", context={"for_api": True} + ) logger.debug(tool_result_message_dict) truncation_strategy.append_message(tool_result_message) diff --git a/src/askui/models/shared/agent_message_param.py b/src/askui/models/shared/agent_message_param.py index b5b82d9e..cca9f890 100644 --- a/src/askui/models/shared/agent_message_param.py +++ b/src/askui/models/shared/agent_message_param.py @@ -1,4 +1,6 @@ -from pydantic import BaseModel +from typing import Any + +from pydantic import BaseModel, model_serializer from typing_extensions import Literal @@ -78,6 +80,24 @@ class ToolUseBlockParam(BaseModel): name: str type: Literal["tool_use"] = "tool_use" cache_control: CacheControlEphemeralParam | None = None + # Visual validation field - internal use only, not sent to Anthropic API + visual_representation: str | None = None + + @model_serializer(mode="wrap") + def _serialize_model(self, serializer, info) -> dict[str, Any]: + """Custom serializer to exclude internal fields when serializing for API. + + When context={'for_api': True}, visual validation fields are excluded. + Otherwise, all fields are included (for cache storage, internal use). + """ + # Use default serialization + data = serializer(self) + + # If serializing for API, remove internal fields + if info.context and info.context.get("for_api"): + data.pop("visual_representation", None) + + return data class BetaThinkingBlock(BaseModel): diff --git a/src/askui/models/shared/settings.py b/src/askui/models/shared/settings.py index 76f033b6..5b0da29a 100644 --- a/src/askui/models/shared/settings.py +++ b/src/askui/models/shared/settings.py @@ -16,8 +16,9 @@ COMPUTER_USE_20250124_BETA_FLAG = "computer-use-2025-01-24" COMPUTER_USE_20251124_BETA_FLAG = "computer-use-2025-11-24" -CACHING_STRATEGY = Literal["read", "write", "both", "no"] +CACHING_STRATEGY = Literal["execute", "record", "both"] CACHE_PARAMETER_IDENTIFICATION_STRATEGY = Literal["llm", "preset"] +CACHING_VISUAL_VERIFICATION_METHOD = Literal["phash", "ahash", "none"] class MessageSettings(BaseModel): @@ -37,22 +38,27 @@ class ActSettings(BaseModel): messages: MessageSettings = Field(default_factory=MessageSettings) -class CachedExecutionToolSettings(BaseModel): - delay_time_between_action: float = 0.5 - +class CacheWritingSettings(BaseModel): + """Settings for writing/recording cache files.""" -class CacheWriterSettings(BaseModel): + filename: str = "" parameter_identification_strategy: CACHE_PARAMETER_IDENTIFICATION_STRATEGY = "llm" + visual_verification_method: CACHING_VISUAL_VERIFICATION_METHOD = "phash" + visual_validation_region_size: int = 100 + visual_validation_threshold: int = 10 + + +class CacheExecutionSettings(BaseModel): + """Settings for executing/replaying cache files.""" + + delay_time_between_action: float = 0.5 class CachingSettings(BaseModel): - strategy: CACHING_STRATEGY = "no" - cache_dir: str = ".cache" - filename: str = "" - execute_cached_trajectory_tool_settings: CachedExecutionToolSettings = ( - CachedExecutionToolSettings() - ) - cache_writer_settings: CacheWriterSettings = CacheWriterSettings() + strategy: CACHING_STRATEGY | None = None + cache_dir: str = ".askui_cache" + writing_settings: CacheWritingSettings | None = None + execution_settings: CacheExecutionSettings | None = None class CacheFailure(BaseModel): @@ -72,6 +78,9 @@ class CacheMetadata(BaseModel): failures: list[CacheFailure] = Field(default_factory=list) is_valid: bool = True invalidation_reason: Optional[str] = None + visual_verification_method: Optional[CACHING_VISUAL_VERIFICATION_METHOD] = None + visual_validation_region_size: Optional[int] = None + visual_validation_threshold: Optional[int] = None class CacheFile(BaseModel): diff --git a/src/askui/models/shared/token_counter.py b/src/askui/models/shared/token_counter.py index 592e9af9..5f162437 100644 --- a/src/askui/models/shared/token_counter.py +++ b/src/askui/models/shared/token_counter.py @@ -165,6 +165,8 @@ def _count_tokens_for_message(self, message: MessageParam) -> int: For image blocks, uses the formula: tokens = (width * height) / 750 (see https://docs.anthropic.com/en/docs/build-with-claude/vision) For other content types, uses the standard character-based estimation. + Uses for_api context to exclude internal fields from token counting. + Args: message (MessageParam): The message to count tokens for. @@ -175,9 +177,11 @@ def _count_tokens_for_message(self, message: MessageParam) -> int: # Simple string content - use standard estimation return int(len(message.content) / self._chars_per_token) - # base tokens for rest of message - total_tokens = 10 - # Content is a list of blocks - process each individually + # Process content blocks individually to handle images properly + # Base tokens for the message structure (role, etc.) + base_tokens = 20 + + total_tokens = base_tokens for block in message.content: total_tokens += self._count_tokens_for_content_block(block) @@ -186,6 +190,8 @@ def _count_tokens_for_message(self, message: MessageParam) -> int: def _count_tokens_for_content_block(self, block: ContentBlockParam) -> int: """Count tokens for a single content block. + Uses for_api context to exclude internal fields like visual validation. + Args: block (ContentBlockParam): The content block to count tokens for. @@ -207,8 +213,25 @@ def _count_tokens_for_content_block(self, block: ContentBlockParam) -> int: total_tokens += self._count_tokens_for_content_block(nested_block) return total_tokens - # For other block types, use string representation - return int(len(self._stringify_object(block)) / self._chars_per_token) + # For other block types (ToolUseBlockParam, TextBlockParam, etc.), + # use string representation with API context to exclude internal fields + stringified = self._stringify_object(block) + token_count = int(len(stringified) / self._chars_per_token) + + # Debug: Log if this is a ToolUseBlockParam with visual validation fields + if hasattr(block, 'visual_representation') and block.visual_representation: + import logging + logger = logging.getLogger(__name__) + logger.debug( + "Token counting for %s: stringified_length=%d, tokens=%d, " + "has_visual_fields=%s", + getattr(block, 'name', 'unknown'), + len(stringified), + token_count, + 'visual_representation' in stringified + ) + + return token_count def _count_tokens_for_image_block(self, block: ImageBlockParam) -> int: """Count tokens for an image block using Anthropic's formula. @@ -248,6 +271,9 @@ def _stringify_object(self, obj: object) -> str: Not whitespace in dumped jsons between object keys and values and among array elements. + For Pydantic models, uses API serialization context to exclude internal fields + that won't be sent to the API (e.g., visual validation fields). + Args: obj (object): The object to stringify. @@ -256,6 +282,16 @@ def _stringify_object(self, obj: object) -> str: """ if isinstance(obj, str): return obj + + # Check if object is a Pydantic model with model_dump method + if hasattr(obj, "model_dump") and callable(getattr(obj, "model_dump")): + try: + # Use for_api context to exclude internal fields from token counting + serialized = obj.model_dump(context={"for_api": True}) # type: ignore[attr-defined] + return json.dumps(serialized, separators=(",", ":")) + except (TypeError, ValueError, AttributeError): + pass # Fall through to default handling + try: return json.dumps(obj, separators=(",", ":")) except (TypeError, ValueError): diff --git a/src/askui/tools/caching_tools.py b/src/askui/tools/caching_tools.py index 8705f203..e5e3bf80 100644 --- a/src/askui/tools/caching_tools.py +++ b/src/askui/tools/caching_tools.py @@ -6,7 +6,7 @@ from pydantic import validate_call from typing_extensions import override -from ..models.shared.settings import CachedExecutionToolSettings +from ..models.shared.settings import CacheExecutionSettings from ..models.shared.tools import Tool, ToolCollection from ..utils.cache_parameter_handler import CacheParameterHandler from ..utils.caching.cache_execution_manager import CacheExecutionManager @@ -136,7 +136,7 @@ class ExecuteCachedTrajectory(Tool): def __init__( self, toolbox: ToolCollection, - settings: CachedExecutionToolSettings | None = None, + settings: CacheExecutionSettings | None = None, ) -> None: super().__init__( name="execute_cached_executions_tool", @@ -204,7 +204,7 @@ def __init__( }, ) if not settings: - settings = CachedExecutionToolSettings() + settings = CacheExecutionSettings() self._settings = settings self._cache_execution_manager: CacheExecutionManager | None = None self._toolbox = toolbox @@ -313,9 +313,40 @@ def _create_executor( Returns: Configured TrajectoryExecutor instance """ + # Read visual validation settings ONLY from cache metadata + # Visual validation is only enabled if the cache was recorded with it + visual_validation_enabled = False + visual_hash_method = "phash" # Default (unused if validation disabled) + visual_validation_threshold = 10 # Default (unused if validation disabled) + visual_validation_region_size = 100 # Default (unused if validation disabled) + + if cache_file.metadata.visual_verification_method: + # Cache has visual validation metadata - use those exact settings + visual_validation_enabled = cache_file.metadata.visual_verification_method != "none" + visual_hash_method = cache_file.metadata.visual_verification_method + + if cache_file.metadata.visual_validation_threshold is not None: + visual_validation_threshold = cache_file.metadata.visual_validation_threshold + + if cache_file.metadata.visual_validation_region_size is not None: + visual_validation_region_size = cache_file.metadata.visual_validation_region_size + + logger.debug( + "Visual validation enabled from cache metadata: method=%s, threshold=%d, region_size=%d", + visual_hash_method, + visual_validation_threshold, + visual_validation_region_size, + ) + else: + # Cache doesn't have visual validation metadata - don't validate + logger.debug( + "Visual validation disabled: cache file has no visual validation metadata" + ) + logger.debug( - "Creating TrajectoryExecutor with delay=%ss", + "Creating TrajectoryExecutor with delay=%ss, visual_validation=%s", self._settings.delay_time_between_action, + visual_validation_enabled, ) # Import here to avoid circular dependency @@ -326,6 +357,10 @@ def _create_executor( toolbox=self._toolbox, parameter_values=parameter_values, delay_time=self._settings.delay_time_between_action, + visual_validation_enabled=visual_validation_enabled, + visual_validation_threshold=visual_validation_threshold, + visual_hash_method=visual_hash_method, + visual_validation_region_size=visual_validation_region_size, ) # Set the starting position if continuing diff --git a/src/askui/utils/cache_parameter_handler.py b/src/askui/utils/cache_parameter_handler.py index 5154c133..953eb36b 100644 --- a/src/askui/utils/cache_parameter_handler.py +++ b/src/askui/utils/cache_parameter_handler.py @@ -287,6 +287,7 @@ def _replace_values_with_parameters( input=parameterized_input, type=tool_block.type, cache_control=tool_block.cache_control, + visual_representation=tool_block.visual_representation, ) ) diff --git a/src/askui/utils/caching/cache_writer.py b/src/askui/utils/caching/cache_writer.py index 24d3d12f..521edc83 100644 --- a/src/askui/utils/caching/cache_writer.py +++ b/src/askui/utils/caching/cache_writer.py @@ -4,6 +4,8 @@ from pathlib import Path from typing import TYPE_CHECKING +from PIL import Image + from askui.models.model_router import ModelRouter from askui.models.shared.agent_message_param import ( MessageParam, @@ -15,10 +17,17 @@ from askui.models.shared.settings import ( CacheFile, CacheMetadata, - CacheWriterSettings, + CacheWritingSettings, ) from askui.models.shared.tools import ToolCollection from askui.utils.cache_parameter_handler import CacheParameterHandler +from askui.utils.visual_validation import ( + compute_ahash, + compute_phash, + extract_region, + get_validation_coordinate, + should_validate_step, +) if TYPE_CHECKING: from askui.models.models import ActModel @@ -30,8 +39,7 @@ class CacheWriter: def __init__( self, cache_dir: str = ".cache", - file_name: str = "", - cache_writer_settings: CacheWriterSettings | None = None, + cache_writing_settings: CacheWritingSettings | None = None, toolbox: ToolCollection | None = None, goal: str | None = None, model_router: ModelRouter | None = None, @@ -40,17 +48,28 @@ def __init__( self.cache_dir = Path(cache_dir) self.cache_dir.mkdir(exist_ok=True) self.messages: list[ToolUseBlockParam] = [] + + # Use default settings if not provided + self._cache_writing_settings = cache_writing_settings or CacheWritingSettings() + + # Extract file_name from settings + file_name = self._cache_writing_settings.filename if file_name and not file_name.endswith(".json"): file_name += ".json" self.file_name = file_name + self.was_cached_execution = False - self._cache_writer_settings = cache_writer_settings or CacheWriterSettings() self._goal = goal self._model_router = model_router self._model = model self._toolbox: ToolCollection | None = None self._accumulated_usage = UsageParam() + # Extract visual verification settings from cache_writing_settings + self._visual_verification_method = self._cache_writing_settings.visual_verification_method + self._visual_validation_region_size = self._cache_writing_settings.visual_validation_region_size + self._visual_validation_threshold = self._cache_writing_settings.visual_validation_threshold + # Set toolbox for cache writer so it can check which tools are cacheable self._toolbox = toolbox @@ -61,10 +80,14 @@ def add_message_cb(self, param: OnMessageCbParam) -> MessageParam: if isinstance(contents, list): for content in contents: if isinstance(content, ToolUseBlockParam): - self.messages.append(content) + # Detect if we're starting a cached execution if content.name == "execute_cached_executions_tool": self.was_cached_execution = True + # Enhance with visual validation if applicable (skip during cached execution) + enhanced_content = self._enhance_with_visual_validation(content) + self.messages.append(enhanced_content) + # Accumulate usage from assistant messages if param.message.usage: self._accumulate_usage(param.message.usage) @@ -123,7 +146,7 @@ def _parameterize_trajectory( messages_api = None model = None - if self._cache_writer_settings.parameter_identification_strategy == "llm": + if self._cache_writing_settings.parameter_identification_strategy == "llm": if self._model_router and self._model: try: _get_model: tuple[ActModel, str] = self._model_router._get_model( # noqa: SLF001 @@ -215,6 +238,9 @@ def _generate_cache_file( created_at=datetime.now(tz=timezone.utc), goal=goal_to_save, token_usage=self._accumulated_usage, + visual_verification_method=self._visual_verification_method, + visual_validation_region_size=self._visual_validation_region_size, + visual_validation_threshold=self._visual_validation_threshold, ), trajectory=trajectory_to_save, cache_parameters=parameters_dict, @@ -224,6 +250,171 @@ def _generate_cache_file( json.dump(cache_file.model_dump(mode="json"), f, indent=4) logger.info("Cache file successfully written: %s ", cache_file_path) + def _enhance_with_visual_validation( + self, tool_block: ToolUseBlockParam + ) -> ToolUseBlockParam: + """Enhance ToolUseBlockParam with visual validation data if applicable. + + Args: + tool_block: The tool use block to potentially enhance + + Returns: + Enhanced ToolUseBlockParam with visual validation data, or original if N/A + """ + # Skip if we're in a cached execution (recording disabled during replay) + if self.was_cached_execution: + return tool_block + + # Skip if visual verification is disabled + if self._visual_verification_method == "none": + logger.debug( + "Visual validation skipped for %s: method='none'", tool_block.name + ) + return tool_block + + # Skip if no toolbox available + if self._toolbox is None: + logger.warning( + "Visual validation skipped for %s: no toolbox available", + tool_block.name, + ) + return tool_block + + # Check if this tool input should be validated + action = None + if isinstance(tool_block.input, dict): + action = tool_block.input.get("action") + + if not should_validate_step(tool_block.name, action): + logger.debug( + "Visual validation skipped for %s action=%s: not a validatable action", + tool_block.name, + action, + ) + return tool_block + + # Get validation coordinate + if not isinstance(tool_block.input, dict): + logger.debug("Visual validation skipped: input is not a dict") + return tool_block + + coordinate = get_validation_coordinate(tool_block.input) + if coordinate is None: + logger.debug( + "Visual validation skipped for %s action=%s: no coordinate found", + tool_block.name, + action, + ) + return tool_block + + # Capture current screenshot and compute hash + try: + screenshot = self._capture_screenshot() + if screenshot is None: + logger.warning( + "Visual validation skipped for %s action=%s: screenshot capture failed", + tool_block.name, + action, + ) + return tool_block + + # Extract region around coordinate + region = extract_region( + screenshot, coordinate, size=self._visual_validation_region_size + ) + + # Compute hash based on method + if self._visual_verification_method == "phash": + visual_hash = compute_phash(region) + elif self._visual_verification_method == "ahash": + visual_hash = compute_ahash(region) + else: + logger.warning( + "Unknown visual verification method: %s", + self._visual_verification_method, + ) + return tool_block + + # Create enhanced ToolUseBlockParam with visual validation data + enhanced = ToolUseBlockParam( + id=tool_block.id, + name=tool_block.name, + input=tool_block.input, + type=tool_block.type, + cache_control=tool_block.cache_control, + visual_representation=visual_hash, + ) + + logger.info( + "✓ Visual validation added to %s action=%s at coordinate %s (hash=%s...)", + tool_block.name, + action, + coordinate, + visual_hash[:16], + ) + + return enhanced + + except Exception as e: + logger.warning( + "Visual validation skipped for %s action=%s: error during enhancement: %s", + tool_block.name, + action, + str(e), + ) + return tool_block + + def _capture_screenshot(self) -> Image.Image | None: + """Capture current screenshot using the computer tool. + + Returns: + PIL Image or None if screenshot capture fails + """ + if self._toolbox is None: + logger.warning("Cannot capture screenshot: toolbox is None") + return None + + # Get the computer tool from the toolbox + tools = self._toolbox.get_tools() + computer_tool = tools.get("computer") + + if computer_tool is None: + logger.warning( + "Cannot capture screenshot: computer tool not found in toolbox. " + "Available tools: %s", + list(tools.keys()), + ) + return None + + # Call the screenshot action + try: + # Try to call _screenshot() method directly if available + if hasattr(computer_tool, "_screenshot"): + result = computer_tool._screenshot() # type: ignore[attr-defined] + if isinstance(result, Image.Image): + logger.debug("Screenshot captured successfully via _screenshot()") + return result + + # Fallback to calling via __call__ with action parameter + result = computer_tool(action="screenshot") # type: ignore[call-arg] + if isinstance(result, Image.Image): + logger.debug("Screenshot captured successfully via __call__") + return result + + logger.warning( + "Screenshot action did not return an Image. Type: %s, Value: %s", + type(result).__name__, + str(result)[:100], + ) + return None + except Exception as e: + logger.warning( + "Error capturing screenshot for visual validation: %s: %s", + type(e).__name__, + str(e), + ) + return None + def _accumulate_usage(self, step_usage: UsageParam) -> None: """Accumulate usage statistics from a single API call. diff --git a/src/askui/utils/trajectory_executor.py b/src/askui/utils/trajectory_executor.py index cdbb87ef..59824bfe 100644 --- a/src/askui/utils/trajectory_executor.py +++ b/src/askui/utils/trajectory_executor.py @@ -9,6 +9,7 @@ import time from typing import Any, Optional +from PIL import Image from pydantic import BaseModel, Field from typing_extensions import Literal @@ -18,6 +19,11 @@ ) from askui.models.shared.tools import ToolCollection from askui.utils.cache_parameter_handler import CacheParameterHandler +from askui.utils.visual_validation import ( + extract_region, + get_validation_coordinate, + validate_visual_hash, +) logger = logging.getLogger(__name__) @@ -58,6 +64,9 @@ def __init__( parameter_values: dict[str, str] | None = None, delay_time: float = 0.5, visual_validation_enabled: bool = False, + visual_validation_threshold: int = 10, + visual_hash_method: str = "phash", + visual_validation_region_size: int = 100, ): """Initialize the trajectory executor. @@ -66,13 +75,19 @@ def __init__( toolbox: ToolCollection for executing tools parameter_values: Dict of parameter names to values delay_time: Seconds to wait between step executions - visual_validation_enabled: Enable visual validation (future feature) + visual_validation_enabled: Enable visual validation + visual_validation_threshold: Hamming distance threshold (0-64) + visual_hash_method: Hash method to use ('phash' or 'ahash') + visual_validation_region_size: Size of square region to extract (in pixels) """ self.trajectory = trajectory self.toolbox = toolbox self.parameter_values = parameter_values or {} self.delay_time = delay_time self.visual_validation_enabled = visual_validation_enabled + self.visual_validation_threshold = visual_validation_threshold + self.visual_hash_method = visual_hash_method + self.visual_validation_region_size = visual_validation_region_size self.current_step_index = 0 self.message_history: list[MessageParam] = [] @@ -129,13 +144,14 @@ def execute_next_step(self) -> ExecutionResult: tool_result=step, # Pass the tool use block for reference ) - # Visual validation (future feature - currently always passes) - # Extension point for aHash-based UI validation + # Visual validation: verify UI state matches cached expectations + # Compares stored visual hash with current screen region if self.visual_validation_enabled: is_valid, error_msg = self.validate_step_visually(step) if not is_valid: logger.warning( - "Visual validation failed at step %d: %s", + "Visual validation failed at step %d: %s. " + "Handing execution back to agent.", step_index, error_msg, ) @@ -284,53 +300,113 @@ def _should_skip_step(self, _step: ToolUseBlockParam) -> bool: return False def validate_step_visually( - self, _step: ToolUseBlockParam, _current_screenshot: Any = None + self, step: ToolUseBlockParam, current_screenshot: Any = None ) -> tuple[bool, str | None]: - """Hook for visual validation of cached steps using aHash comparison. - - This is an extension point for future visual validation implementation. - Currently returns (True, None) - no validation performed. + """Validate cached steps using visual hash comparison. - Future implementation will: - 1. Check if step has visual_validation_required=True - 2. Compute aHash of current screen region - 3. Compare with stored visual_hash - 4. Return validation result based on Hamming distance threshold + Compares the current UI state against the stored visual hash to detect + if the UI has changed significantly since the trajectory was recorded. Args: step: The trajectory step to validate - current_screenshot: Optional current screen capture (future use) + current_screenshot: Optional current screen capture (will capture if None) Returns: Tuple of (is_valid: bool, error_message: str | None) - (True, None) if validation passes or is disabled - (False, error_msg) if validation fails + """ + # Skip validation if disabled + if not self.visual_validation_enabled: + return True, None - Example future implementation: - if not self.visual_validation_enabled: - return True, None + # Skip if no visual representation stored (implies no validation needed) + if step.visual_representation is None: + return True, None + + # Get coordinate for validation + if not isinstance(step.input, dict): + return True, None + + coordinate = get_validation_coordinate(step.input) + if coordinate is None: + logger.debug( + "Could not extract coordinate from step %d for visual validation", + self.current_step_index, + ) + return True, None - if not step.visual_validation_required: + # Capture current screenshot if not provided + if current_screenshot is None: + current_screenshot = self._capture_screenshot() + if current_screenshot is None: + logger.warning( + "Could not capture screenshot for visual validation at step %d", + self.current_step_index, + ) + # Unable to validate, but don't fail execution return True, None - if step.visual_hash is None: - return True, None # No hash stored, skip validation + # Extract region around coordinate + try: + region = extract_region( + current_screenshot, coordinate, size=self.visual_validation_region_size + ) + except Exception as e: + logger.warning( + "Error extracting region for visual validation at step %d: %s", + self.current_step_index, + e, + ) + return True, None - # Capture current screen region - current_hash = compute_ahash(current_screenshot) + # Validate hash + is_valid, error_msg, _distance = validate_visual_hash( + stored_hash=step.visual_representation, + current_image=region, + threshold=self.visual_validation_threshold, + hash_method=self.visual_hash_method, + ) - # Compare hashes - distance = hamming_distance(step.visual_hash, current_hash) - threshold = 10 # Configurable + # Only log if validation fails + if not is_valid: + logger.warning( + "Visual validation failed at step %d: %s", + self.current_step_index, + error_msg, + ) - if distance > threshold: - return False, ( - f"Visual validation failed: UI changed significantly " - f"(distance: {distance} > threshold: {threshold})" - ) + return is_valid, error_msg - return True, None + def _capture_screenshot(self) -> Image.Image | None: + """Capture current screenshot using the computer tool. + + Returns: + PIL Image or None if screenshot capture fails """ - # Future: Implement aHash comparison - # For now, always return True (no validation) - return True, None + # Get the computer tool from toolbox + tools = self.toolbox.get_tools() + computer_tool = tools.get("computer") + + if computer_tool is None: + logger.debug("Computer tool not found in toolbox") + return None + + # Call the screenshot action + try: + # Try to call _screenshot() method directly if available + if hasattr(computer_tool, "_screenshot"): + result = computer_tool._screenshot() # type: ignore[attr-defined] + if isinstance(result, Image.Image): + return result + + # Fallback to calling via __call__ with action parameter + result = computer_tool(action="screenshot") # type: ignore[call-arg] + if isinstance(result, Image.Image): + return result + + logger.warning("Screenshot action did not return an Image: %s", type(result)) + return None + except Exception: + logger.exception("Error capturing screenshot") + return None diff --git a/src/askui/utils/visual_validation.py b/src/askui/utils/visual_validation.py new file mode 100644 index 00000000..f1e0591e --- /dev/null +++ b/src/askui/utils/visual_validation.py @@ -0,0 +1,312 @@ +"""Visual validation utilities for cache execution. + +This module provides utilities for computing and comparing visual hashes +to validate UI state before executing cached trajectory steps. +""" + +import logging +from typing import Any + +import numpy as np +from numpy.typing import NDArray +from PIL import Image + +logger = logging.getLogger(__name__) + + +def compute_phash(image: Image.Image, hash_size: int = 8) -> str: + """Compute perceptual hash (pHash) of an image. + + pHash is robust to minor changes (compression, scaling, lighting) while + being sensitive to structural changes (moved buttons, different layouts). + + Args: + image: PIL Image to hash + hash_size: Size of the hash (default 8 = 64-bit hash) + + Returns: + Hexadecimal string representation of the hash + + The algorithm: + 1. Resize image to hash_size x hash_size + 2. Convert to grayscale + 3. Compute DCT (Discrete Cosine Transform) + 4. Extract top-left 8x8 DCT coefficients + 5. Compute median of coefficients + 6. Create binary hash: 1 if coeff > median, else 0 + """ + # Resize to hash_size x hash_size + resized = image.resize((hash_size, hash_size), Image.Resampling.LANCZOS) + + # Convert to grayscale + gray = resized.convert("L") + + # Convert to numpy array + pixels = np.array(gray, dtype=np.float32) + + # Compute DCT (using a simple 2D DCT approximation) + # For production, consider using scipy.fftpack.dct + dct = _dct_2d(pixels) + + # Extract top-left coefficients (excluding DC component at [0,0]) + dct_low = dct[:hash_size, :hash_size] + + # Compute median + median = np.median(dct_low) + + # Create binary hash + diff = dct_low > median + + # Convert to hexadecimal string + hash_bytes = _binary_array_to_bytes(diff.flatten()) + return hash_bytes.hex() + + +def compute_ahash(image: Image.Image, hash_size: int = 8) -> str: + """Compute average hash (aHash) of an image. + + aHash is a simpler but faster alternative to pHash. It's less robust + to transformations but still useful for basic visual validation. + + Args: + image: PIL Image to hash + hash_size: Size of the hash (default 8 = 64-bit hash) + + Returns: + Hexadecimal string representation of the hash + + The algorithm: + 1. Resize image to hash_size x hash_size + 2. Convert to grayscale + 3. Compute mean pixel value + 4. Create binary hash: 1 if pixel > mean, else 0 + """ + # Resize to hash_size x hash_size + resized = image.resize((hash_size, hash_size), Image.Resampling.LANCZOS) + + # Convert to grayscale + gray = resized.convert("L") + + # Convert to numpy array + pixels = np.array(gray, dtype=np.float32) + + # Compute mean + mean = pixels.mean() + + # Create binary hash + diff = pixels > mean + + # Convert to hexadecimal string + hash_bytes = _binary_array_to_bytes(diff.flatten()) + return hash_bytes.hex() + + +def hamming_distance(hash1: str, hash2: str) -> int: + """Compute Hamming distance between two hash strings. + + The Hamming distance is the number of bit positions where the two + hashes differ. A distance of 0 means identical hashes, while 64 + means completely different (for 64-bit hashes). + + Args: + hash1: First hash (hexadecimal string) + hash2: Second hash (hexadecimal string) + + Returns: + Number of differing bits (0-64 for 64-bit hashes) + + Raises: + ValueError: If hashes have different lengths + """ + if len(hash1) != len(hash2): + msg = f"Hash lengths differ: {len(hash1)} vs {len(hash2)}" + raise ValueError(msg) + + # Convert hex strings to integers and XOR them + # XOR will have 1s where bits differ + xor_result = int(hash1, 16) ^ int(hash2, 16) + + # Count number of 1s (differing bits) + return bin(xor_result).count("1") + + +def extract_region( + image: Image.Image, center: tuple[int, int], size: int = 100 +) -> Image.Image: + """Extract a square region from an image centered at given coordinates. + + Args: + image: Source image + center: (x, y) coordinates of region center + size: Size of the square region in pixels + + Returns: + Cropped image region + + The region is clipped to image boundaries if necessary. + """ + x, y = center + half_size = size // 2 + + # Calculate bounds, clipping to image boundaries + left = max(0, x - half_size) + top = max(0, y - half_size) + right = min(image.width, x + half_size) + bottom = min(image.height, y + half_size) + + # Crop and return + return image.crop((left, top, right, bottom)) + + +def validate_visual_hash( + stored_hash: str, + current_image: Image.Image, + threshold: int = 10, + hash_method: str = "phash", +) -> tuple[bool, str | None, int]: + """Validate that current image matches stored visual hash. + + Args: + stored_hash: The hash stored in the cache + current_image: Current screenshot region + threshold: Maximum Hamming distance to accept (0-64) + - 0-5: Nearly identical (strict validation) + - 6-10: Very similar (recommended default) + - 11-15: Similar (lenient) + - 16+: Different (validation should fail) + hash_method: Hash method to use ('phash' or 'ahash') + + Returns: + Tuple of (is_valid, error_message, distance) + - is_valid: True if validation passes + - error_message: None if valid, error description if invalid + - distance: Hamming distance between hashes + """ + # Compute current hash + if hash_method == "phash": + current_hash = compute_phash(current_image) + elif hash_method == "ahash": + current_hash = compute_ahash(current_image) + else: + return False, f"Unknown hash method: {hash_method}", -1 + + # Compare hashes + try: + distance = hamming_distance(stored_hash, current_hash) + except ValueError as e: + return False, f"Hash comparison failed: {e}", -1 + + # Validate against threshold + if distance <= threshold: + return True, None, distance + + error_msg = ( + f"Visual validation failed: UI region changed significantly " + f"(Hamming distance: {distance} > threshold: {threshold})" + ) + return False, error_msg, distance + + +def should_validate_step(tool_name: str, action: str | None = None) -> bool: + """Determine if a tool step should have visual validation. + + Args: + tool_name: Name of the tool + action: Action type (for computer tool) + + Returns: + True if step should be validated + + Steps that should be validated: + - Mouse clicks (left_click, right_click, double_click, middle_click) + - Type actions (verify input field hasn't moved) + - Key presses targeting specific UI elements + """ + # Computer tool with click or type actions + if tool_name == "computer": + if action in [ + "left_click", + "right_click", + "double_click", + "middle_click", + "type", + "key", + ]: + return True + + # Other tools that interact with specific UI regions + # Add more as needed + return False + + +def get_validation_coordinate(tool_input: dict[str, Any]) -> tuple[int, int] | None: + """Extract the coordinate for visual validation from tool input. + + Args: + tool_input: Tool input dictionary + + Returns: + (x, y) coordinate tuple or None if not applicable + + For click actions, returns the click coordinate. + For type actions, returns the coordinate of the text input field. + """ + # Computer tool coordinates + if "coordinate" in tool_input: + coord = tool_input["coordinate"] + if isinstance(coord, list) and len(coord) == 2: + return (int(coord[0]), int(coord[1])) + + return None + + +# Private helper functions + + +def _dct_2d(image_array: NDArray[np.float32]) -> NDArray[np.complex128]: + """Compute 2D Discrete Cosine Transform. + + This is a simplified implementation. For production use, consider + using scipy.fftpack.dct for better performance and accuracy. + + Args: + image_array: 2D numpy array + + Returns: + 2D DCT coefficients + """ + # Using a simple DCT approximation via FFT + # For production, use: from scipy.fftpack import dct + # return dct(dct(image_array.T, norm='ortho').T, norm='ortho') + + # Simplified approach: use numpy's FFT and take real part + fft = np.fft.fft2(image_array) + # Take absolute value and use as approximation + # Note: This is not a true DCT, but works for hash purposes + return np.abs(fft) + + +def _binary_array_to_bytes(binary_array: NDArray[np.bool_]) -> bytes: + """Convert binary numpy array to bytes. + + Args: + binary_array: 1D array of boolean values + + Returns: + Bytes representation + """ + # Convert to string of 0s and 1s + binary_string = "".join("1" if b else "0" for b in binary_array) + + # Pad to multiple of 8 + padding = 8 - (len(binary_string) % 8) + if padding != 8: + binary_string += "0" * padding + + # Convert to bytes + byte_array = bytearray() + for i in range(0, len(binary_string), 8): + byte = binary_string[i : i + 8] + byte_array.append(int(byte, 2)) + + return bytes(byte_array) diff --git a/tests/unit/utils/test_trajectory_executor.py b/tests/unit/utils/test_trajectory_executor.py index cb57ce98..76e0f010 100644 --- a/tests/unit/utils/test_trajectory_executor.py +++ b/tests/unit/utils/test_trajectory_executor.py @@ -2,8 +2,6 @@ from unittest.mock import MagicMock -import pytest - from askui.models.shared.agent_message_param import ( MessageParam, TextBlockParam, @@ -672,8 +670,7 @@ def test_visual_validation_disabled_by_default() -> None: name="click", input={"x": 100}, type="tool_use", - visual_hash="abc123", - visual_validation_required=True, + visual_representation="abc123", ), ] @@ -759,8 +756,7 @@ def test_validate_step_visually_always_passes_when_disabled() -> None: name="click", input={"x": 100}, type="tool_use", - visual_hash="abc123", - visual_validation_required=True, + visual_representation="abc123", ), ] @@ -794,8 +790,7 @@ def test_validate_step_visually_hook_called_when_enabled() -> None: name="click", input={"x": 100}, type="tool_use", - visual_hash="abc123", - visual_validation_required=True, + visual_representation="abc123", ), ] @@ -824,34 +819,27 @@ def mock_validate(step, screenshot=None) -> tuple[bool, str | None]: # type: ig assert results[0].status == "SUCCESS" -@pytest.mark.skip( - reason="Visual validation fields not yet implemented - future feature" -) def test_visual_validation_fields_on_tool_use_block() -> None: """Test that ToolUseBlockParam supports visual validation fields. - Note: This test is for future functionality. Visual validation fields - (visual_hash, visual_validation_required) are planned but not yet - implemented in the ToolUseBlockParam model. + The visual_representation field stores perceptual hashes (pHash/aHash) for + visual validation during cache execution. """ - # Create step with visual validation fields + # Create step with visual representation field step = ToolUseBlockParam( id="1", name="click", input={"x": 100, "y": 200}, type="tool_use", - visual_hash="a8f3c9e14b7d2056", - visual_validation_required=True, + visual_representation="a8f3c9e14b7d2056", ) - # Fields should be accessible - assert step.visual_hash == "a8f3c9e14b7d2056" # type: ignore[attr-defined] - assert step.visual_validation_required is True # type: ignore[attr-defined] + # Field should be accessible + assert step.visual_representation == "a8f3c9e14b7d2056" - # Default values should work + # Default value should be None step_default = ToolUseBlockParam( id="2", name="type", input={"text": "hello"}, type="tool_use" ) - assert step_default.visual_hash is None # type: ignore[attr-defined] - assert step_default.visual_validation_required is False # type: ignore[attr-defined] + assert step_default.visual_representation is None diff --git a/tests/unit/utils/test_visual_validation.py b/tests/unit/utils/test_visual_validation.py new file mode 100644 index 00000000..161f3dcb --- /dev/null +++ b/tests/unit/utils/test_visual_validation.py @@ -0,0 +1,247 @@ +"""Tests for visual validation utilities.""" + +import pytest +from PIL import Image, ImageDraw + +from askui.utils.visual_validation import ( + compute_ahash, + compute_phash, + extract_region, + get_validation_coordinate, + hamming_distance, + should_validate_step, + validate_visual_hash, +) + + +class TestHashComputation: + """Test hash computation functions.""" + + def test_compute_phash_returns_hex_string(self): + """Test that compute_phash returns a hexadecimal string.""" + # Create a simple test image + img = Image.new("RGB", (100, 100), color="red") + hash_result = compute_phash(img) + + # Should be a hex string + assert isinstance(hash_result, str) + assert len(hash_result) > 0 + # Should be valid hex + int(hash_result, 16) # Will raise if not valid hex + + def test_compute_ahash_returns_hex_string(self): + """Test that compute_ahash returns a hexadecimal string.""" + # Create a simple test image + img = Image.new("RGB", (100, 100), color="blue") + hash_result = compute_ahash(img) + + # Should be a hex string + assert isinstance(hash_result, str) + assert len(hash_result) > 0 + # Should be valid hex + int(hash_result, 16) # Will raise if not valid hex + + def test_identical_images_produce_same_phash(self): + """Test that identical images produce identical phashes.""" + img1 = Image.new("RGB", (100, 100), color="green") + img2 = Image.new("RGB", (100, 100), color="green") + + hash1 = compute_phash(img1) + hash2 = compute_phash(img2) + + assert hash1 == hash2 + + def test_different_images_produce_different_phash(self): + """Test that different images produce different phashes.""" + # Create images with patterns, not solid colors + img1 = Image.new("RGB", (100, 100), color="white") + draw1 = ImageDraw.Draw(img1) + draw1.rectangle([10, 10, 50, 50], fill="red") + + img2 = Image.new("RGB", (100, 100), color="white") + draw2 = ImageDraw.Draw(img2) + draw2.rectangle([60, 60, 90, 90], fill="blue") + + hash1 = compute_phash(img1) + hash2 = compute_phash(img2) + + assert hash1 != hash2 + + +class TestHammingDistance: + """Test Hamming distance calculation.""" + + def test_identical_hashes_have_zero_distance(self): + """Test that identical hashes have Hamming distance of 0.""" + hash1 = "a1b2c3d4" + hash2 = "a1b2c3d4" + + distance = hamming_distance(hash1, hash2) + assert distance == 0 + + def test_different_hashes_have_nonzero_distance(self): + """Test that different hashes have non-zero Hamming distance.""" + hash1 = "ffffffff" # All 1s in binary + hash2 = "00000000" # All 0s in binary + + distance = hamming_distance(hash1, hash2) + assert distance > 0 + + def test_hamming_distance_raises_on_different_lengths(self): + """Test that hamming_distance raises ValueError for different lengths.""" + hash1 = "a1b2" + hash2 = "a1b2c3" + + with pytest.raises(ValueError, match="Hash lengths differ"): + hamming_distance(hash1, hash2) + + +class TestExtractRegion: + """Test region extraction from images.""" + + def test_extract_region_returns_image(self): + """Test that extract_region returns a PIL Image.""" + img = Image.new("RGB", (200, 200), color="red") + center = (100, 100) + + region = extract_region(img, center, size=50) + + assert isinstance(region, Image.Image) + + def test_extract_region_has_correct_size(self): + """Test that extracted region has correct size.""" + img = Image.new("RGB", (200, 200), color="red") + center = (100, 100) + size = 50 + + region = extract_region(img, center, size=size) + + # Region should be approximately size x size + assert region.width <= size + assert region.height <= size + + def test_extract_region_at_edge(self): + """Test that extract_region handles edge cases.""" + img = Image.new("RGB", (100, 100), color="red") + center = (10, 10) # Near edge + + # Should not raise an error + region = extract_region(img, center, size=50) + assert isinstance(region, Image.Image) + + +class TestValidateVisualHash: + """Test visual hash validation.""" + + def test_validate_visual_hash_passes_for_identical_images(self): + """Test validation passes for identical images.""" + img = Image.new("RGB", (100, 100), color="red") + stored_hash = compute_phash(img) + + is_valid, error_msg, distance = validate_visual_hash( + stored_hash, img, threshold=10, hash_method="phash" + ) + + assert is_valid is True + assert error_msg is None + assert distance == 0 + + def test_validate_visual_hash_fails_for_different_images(self): + """Test validation fails for very different images.""" + # Create images with different patterns + img1 = Image.new("RGB", (100, 100), color="white") + draw1 = ImageDraw.Draw(img1) + draw1.rectangle([10, 10, 50, 50], fill="red") + + img2 = Image.new("RGB", (100, 100), color="white") + draw2 = ImageDraw.Draw(img2) + draw2.rectangle([60, 60, 90, 90], fill="blue") + + stored_hash = compute_phash(img1) + + is_valid, error_msg, distance = validate_visual_hash( + stored_hash, img2, threshold=5, hash_method="phash" + ) + + # Should fail due to high distance + assert is_valid is False + assert error_msg is not None + assert "Visual validation failed" in error_msg + + def test_validate_visual_hash_with_ahash_method(self): + """Test validation works with ahash method.""" + img = Image.new("RGB", (100, 100), color="green") + stored_hash = compute_ahash(img) + + is_valid, error_msg, distance = validate_visual_hash( + stored_hash, img, threshold=10, hash_method="ahash" + ) + + assert is_valid is True + assert error_msg is None + assert distance == 0 + + def test_validate_visual_hash_unknown_method(self): + """Test validation fails gracefully with unknown hash method.""" + img = Image.new("RGB", (100, 100), color="red") + stored_hash = "abcdef" + + is_valid, error_msg, distance = validate_visual_hash( + stored_hash, img, threshold=10, hash_method="unknown_method" + ) + + assert is_valid is False + assert error_msg is not None + assert "Unknown hash method" in error_msg + + +class TestShouldValidateStep: + """Test step validation logic.""" + + def test_should_validate_left_click(self): + """Test that left_click actions should be validated.""" + assert should_validate_step("computer", "left_click") is True + + def test_should_validate_right_click(self): + """Test that right_click actions should be validated.""" + assert should_validate_step("computer", "right_click") is True + + def test_should_validate_type_action(self): + """Test that type actions should be validated.""" + assert should_validate_step("computer", "type") is True + + def test_should_not_validate_screenshot(self): + """Test that screenshot actions should not be validated.""" + assert should_validate_step("computer", "screenshot") is False + + def test_should_not_validate_unknown_tool(self): + """Test that unknown tools should not be validated.""" + assert should_validate_step("unknown_tool", None) is False + + +class TestGetValidationCoordinate: + """Test coordinate extraction for validation.""" + + def test_get_validation_coordinate_from_computer_tool(self): + """Test extracting coordinate from computer tool input.""" + tool_input = {"action": "left_click", "coordinate": [450, 300]} + + coord = get_validation_coordinate(tool_input) + + assert coord == (450, 300) + + def test_get_validation_coordinate_returns_none_without_coordinate(self): + """Test returns None when no coordinate in input.""" + tool_input = {"action": "screenshot"} + + coord = get_validation_coordinate(tool_input) + + assert coord is None + + def test_get_validation_coordinate_handles_invalid_format(self): + """Test handles invalid coordinate format gracefully.""" + tool_input = {"coordinate": "invalid"} + + coord = get_validation_coordinate(tool_input) + + assert coord is None From ad9ae225dd5b4fa6852b4977212c2fa06621b087 Mon Sep 17 00:00:00 2001 From: philipph-askui Date: Wed, 31 Dec 2025 10:02:57 +0100 Subject: [PATCH 27/30] feat(caching): add examples --- docs/caching.md | 228 +++++++++++++----- examples/README.md | 186 +++++++++++++++ examples/basic_caching_example.py | 190 +++++++++++++++ examples/visual_validation_example.py | 324 ++++++++++++++++++++++++++ 4 files changed, 866 insertions(+), 62 deletions(-) create mode 100644 examples/README.md create mode 100644 examples/basic_caching_example.py create mode 100644 examples/visual_validation_example.py diff --git a/docs/caching.md b/docs/caching.md index 4f1260b3..df7fae7b 100644 --- a/docs/caching.md +++ b/docs/caching.md @@ -1,6 +1,6 @@ # Caching (Experimental) -**CAUTION: The Caching feature is still in alpha state and subject to change! Use it at your own risk. In case you run into issues, you can disable caching by removing the caching_settings parameter or by explicitly setting the caching_strategy to `no`.** +**CAUTION: The Caching feature is still in alpha state and subject to change! Use it at your own risk. In case you run into issues, you can disable caching by removing the caching_settings parameter or by explicitly setting the strategy to `None`.** The caching mechanism allows you to record and replay agent action sequences (trajectories) for faster and more robust test execution. This feature is particularly useful for regression testing, where you want to replay known-good interaction sequences to verify that your application still behaves correctly. @@ -356,7 +356,7 @@ with VisionAgent() as agent: caching_settings=CachingSettings( strategy="execute", cache_dir=".cache", - execute_cached_trajectory_tool_settings=CacheExecutionSettings( + execution_settings=CacheExecutionSettings( delay_time_between_action=1.0 # Wait 1 second between each action ) ) @@ -496,26 +496,6 @@ Each failure record contains: This information helps with cache invalidation decisions and debugging. -### v0.0 Format (Legacy) - -The old format was a simple JSON array: - -```json -[ - { - "type": "tool_use", - "id": "toolu_01AbCdEfGhIjKlMnOpQrStUv", - "name": "computer", - "input": { - "action": "mouse_move", - "coordinate": [150, 200] - } - } -] -``` - -**Backward Compatibility:** v0.0 cache files are automatically migrated to v0.1 format when read. The system adds default metadata and wraps the trajectory array in the new structure. This migration is transparent and requires no user intervention. - ## How It Works ### Internal Architecture @@ -1123,61 +1103,185 @@ Consider re-recording a cached trajectory when: - Execution takes significantly longer than expected - The cache has been marked invalid due to failure patterns -## Migration from v0.0 to v0.1 +## Migration from v0.1 to v0.2 -**Automatic Migration:** All v0.0 cache files are automatically migrated when read by the v0.1 system. No manual intervention is required. +v0.2 introduces visual validation and refactored settings structure. Here's what you need to know to migrate from v0.1. -### What Happens During Migration +### What Changed in v0.2 -When a v0.0 cache file (simple JSON array) is read: +**Functional Changes:** +- **Visual Validation**: Cache recording now captures visual hashes (pHash/aHash) of UI regions around each click/type action. During execution, these hashes are validated to ensure the UI state matches expectations before executing cached actions. +- **Smarter Cache Execution**: Visual validation helps detect when the UI has changed, preventing cached actions from executing on wrong elements. -1. System detects v0.0 format (array instead of object with metadata) -2. Wraps trajectory in v0.1 structure -3. Adds default metadata: - ```json - { - "version": "0.1", - "created_at": "", - "last_executed_at": null, - "execution_attempts": 0, - "failures": [], - "is_valid": true, - "invalidation_reason": null - } - ``` -4. Extracts any cache_parameters found in trajectory -5. Returns fully-formed `CacheFile` object +**API Changes:** +- **Strategy names renamed** for clarity: + - `"read"` → `"execute"` + - `"write"` → `"record"` + - `"no"` → `None` +- **Settings refactored** into separate writing and execution settings: + - `CacheWritingSettings` for recording-related configuration + - `CacheExecutionSettings` for playback-related configuration +- **Default cache directory** changed from `".cache"` to `".askui_cache"` + +### Step 1: Update Your Code + +**Old v0.1 code:** +```python +from askui.models.shared.settings import CachingSettings + +caching_settings = CachingSettings( + caching_strategy="read", # Old naming + cache_dir=".cache", + file_name="my_test.json", + delay_time_between_action=0.5, +) +``` + +**New v0.2 code:** +```python +from askui.models.shared.settings import ( + CachingSettings, + CacheWritingSettings, + CacheExecutionSettings, +) + +caching_settings = CachingSettings( + strategy="execute", # New naming (was "read") + cache_dir=".askui_cache", # New default directory + writing_settings=CacheWritingSettings( + filename="my_test.json", + visual_verification_method="phash", # New in v0.2 + visual_validation_region_size=100, # New in v0.2 + visual_validation_threshold=10, # New in v0.2 + ), + execution_settings=CacheExecutionSettings( + delay_time_between_action=0.5, + ), +) +``` -### Compatibility Guarantees +**Migration checklist:** +- [ ] Replace `caching_strategy` with `strategy` +- [ ] Rename `"read"` → `"execute"`, `"write"` → `"record"`, `"no"` → `None` +- [ ] Move `file_name` → `writing_settings.filename` +- [ ] Move `delay_time_between_action` → `execution_settings.delay_time_between_action` +- [ ] Add `writing_settings=CacheWritingSettings(...)` if using record mode +- [ ] Add `execution_settings=CacheExecutionSettings(...)` if using execute mode +- [ ] Update `cache_dir` from `".cache"` to `".askui_cache"` (optional but recommended) -- All v0.0 cache files continue to work without modification -- Migration is performed on-the-fly during read -- Original files are not modified on disk (unless re-written) -- v0.1 system can read both formats seamlessly +### Step 2: Handle Existing Cache Files -### Programmatic Migration (Optional) +**Important:** v0.1 cache files do NOT work with v0.2 due to visual validation changes. -If you prefer to upgrade v0.0 cache files to v0.1 format on disk (rather than letting the system migrate them on-the-fly during read), you can do so programmatically: +You have two options: + +#### Option A: Delete Old Cache Files (Recommended) + +The simplest approach is to delete all v0.1 cache files and re-record them with v0.2: + +```bash +# Delete all old cache files +rm -rf .cache/*.json + +# Or if you updated to new directory: +rm -rf .askui_cache/*.json +``` + +Then re-run your workflows in `record` mode to create new v0.2 cache files with visual validation. + +#### Option B: Disable Visual Validation (Not Recommended) + +If you must use old cache files temporarily, you can disable visual validation: ```python -from pathlib import Path -from askui.utils.cache_writer import CacheWriter -import json +writing_settings=CacheWritingSettings( + filename="old_cache.json", + visual_verification_method="none", # Disable visual validation +) +``` + +**Warning:** Without visual validation, cached actions may execute on wrong UI elements if the interface has changed. This defeats the primary benefit of v0.2. + +### Step 3: Verify Migration -# Read v0.0 file (auto-migrates to v0.1 in memory) -cache_path = Path(".cache/old_cache.json") -cached_trajectory = CacheWriter.read_cache_file(cache_path) +After updating your code and cache files: + +1. **Test record mode**: Verify new cache files are created with visual validation + ```bash + # Check for visual_representation fields in cache file + cat .askui_cache/my_test.json | grep -A2 visual_representation + ``` -# Write back to disk in v0.1 format -with cache_path.open("w", encoding="utf-8") as f: - json.dump(cached_trajectory.model_dump(mode="json"), f, indent=2, default=str) +2. **Test execute mode**: Verify cached trajectories execute with visual validation + - You should see log messages about visual validation during execution + - If UI has changed, execution should fail with visual validation errors + +3. **Check metadata**: Verify cache files contain v0.2 metadata + ```bash + cat .askui_cache/my_test.json | grep -A5 '"metadata"' + ``` + + Should include: + ```json + "visual_verification_method": "phash", + "visual_validation_region_size": 100, + "visual_validation_threshold": 10 + ``` + +### Example: Complete Migration + +**Before (v0.1):** +```python +caching_settings = CachingSettings( + caching_strategy="both", + cache_dir=".cache", + file_name="login_test.json", + delay_time_between_action=0.3, +) +``` + +**After (v0.2):** +```python +caching_settings = CachingSettings( + strategy="both", # Renamed from caching_strategy + cache_dir=".askui_cache", # New default + writing_settings=CacheWritingSettings( + filename="login_test.json", # Moved from file_name + visual_verification_method="phash", # New: visual validation + visual_validation_region_size=100, # New: validation region + visual_validation_threshold=10, # New: strictness level + parameter_identification_strategy="llm", + ), + execution_settings=CacheExecutionSettings( + delay_time_between_action=0.3, # Moved to execution_settings + ), +) ``` -**Note:** Programmatic migration is optional - all v0.0 caches are automatically migrated during read operations. You only need to manually upgrade cache files if you want them in v0.1 format on disk immediately. +### Troubleshooting + +**Issue**: Cache execution fails with "Visual validation failed" +- **Cause**: UI has changed since cache was recorded +- **Solution**: Re-record the cache or adjust `visual_validation_threshold` (higher = more tolerant) + +**Issue**: Import error for `CacheWritingSettings` +- **Cause**: Old import statement +- **Solution**: Update imports: + ```python + from askui.models.shared.settings import ( + CachingSettings, + CacheWritingSettings, + CacheExecutionSettings, + ) + ``` + +**Issue**: Old cache files don't work +- **Cause**: v0.1 cache files lack visual validation data +- **Solution**: Delete old cache files and re-record with v0.2 -## Example: Complete Test Workflow with v0.1 Features +## Example: Complete Test Workflow -Here's a complete example showing advanced v0.1 features: +Here's a complete example showing the caching system: ```python import logging @@ -1228,7 +1332,7 @@ with VisionAgent() as agent: caching_settings=CachingSettings( strategy="execute", cache_dir="test_cache", - execute_cached_trajectory_tool_settings=CacheExecutionSettings( + execution_settings=CacheExecutionSettings( delay_time_between_action=0.75 ) ) diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 00000000..39e27570 --- /dev/null +++ b/examples/README.md @@ -0,0 +1,186 @@ +# AskUI Caching Examples + +This directory contains example scripts demonstrating the capabilities of the AskUI caching system (v0.2). + +## Examples Overview + +### 1. `basic_caching_example.py` +**Introduction to cache recording and execution** + +Demonstrates: +- ✅ **Record mode**: Save a trajectory to a cache file +- ✅ **Execute mode**: Replay a cached trajectory +- ✅ **Both mode**: Try execute, fall back to record +- ✅ **Cache parameters**: Dynamic value substitution with `{{parameter}}` syntax +- ✅ **AI-based parameter detection**: Automatic identification of dynamic values + +**Best for**: Getting started with caching, understanding the basic workflow + +### 2. `visual_validation_example.py` +**Visual UI state validation with perceptual hashing** + +Demonstrates: +- ✅ **pHash validation**: Perceptual hashing (recommended, robust) +- ✅ **aHash validation**: Average hashing (faster, simpler) +- ✅ **Threshold tuning**: Adjusting strictness (0-64 range) +- ✅ **Region size**: Controlling validation area (50-200 pixels) +- ✅ **Disabling validation**: When to skip visual validation + +**Best for**: Understanding visual validation, tuning validation parameters for your use case + +## Quick Start + +1. **Install dependencies**: + ```bash + pdm install + ``` + +2. **Run an example**: + ```bash + pdm run python examples/basic_caching_example.py + ``` + +3. **Explore the cache files**: + ```bash + cat .askui_cache/basic_example.json + ``` + +## Understanding the Examples + +### Basic Workflow + +```python +# 1. Record a trajectory +caching_settings = CachingSettings( + strategy="record", # Save to cache + cache_dir=".askui_cache", + writing_settings=CacheWritingSettings( + filename="my_cache.json", + visual_verification_method="phash", + ), +) + +# 2. Execute from cache +caching_settings = CachingSettings( + strategy="execute", # Replay from cache + cache_dir=".askui_cache", + execution_settings=CacheExecutionSettings( + delay_time_between_action=0.5, + ), +) + +# 3. Both (recommended for development) +caching_settings = CachingSettings( + strategy="both", # Try execute, fall back to record + cache_dir=".askui_cache", + writing_settings=CacheWritingSettings(filename="my_cache.json"), + execution_settings=CacheExecutionSettings(), +) +``` + +### Visual Validation Settings + +```python +writing_settings=CacheWritingSettings( + visual_verification_method="phash", # or "ahash" or "none" + visual_validation_region_size=100, # 100x100 pixel region + visual_validation_threshold=10, # Hamming distance (0-64) +) +``` + +**Threshold Guidelines**: +- `0-5`: Very strict (detects tiny changes) +- `6-10`: Strict (recommended for stable UIs) ✅ +- `11-15`: Moderate (tolerates minor changes) +- `16+`: Lenient (may miss significant changes) + +**Region Size Guidelines**: +- `50`: Small, precise validation +- `100`: Balanced (recommended default) ✅ +- `150-200`: Large, more context + +## Customizing Examples + +Each example can be customized by modifying: + +1. **The goal**: Change the task description +2. **Cache settings**: Adjust validation parameters +3. **Tools**: Add custom tools to the agent +4. **Model**: Change the AI model (e.g., `model="askui/claude-sonnet-4-5-20250929"`) + +## Cache File Structure (v0.2) + +```json +{ + "metadata": { + "version": "0.1", + "created_at": "2025-01-15T10:30:00Z", + "goal": "Task description", + "visual_verification_method": "phash", + "visual_validation_region_size": 100, + "visual_validation_threshold": 10 + }, + "trajectory": [ + { + "type": "tool_use", + "name": "computer", + "input": {"action": "left_click", "coordinate": [450, 320]}, + "visual_representation": "80c0e3f3e3e7e381..." // pHash/aHash + } + ], + "cache_parameters": { + "search_term": "Description of the parameter" + } +} +``` + +## Tips and Best Practices + +### When to Use Caching + +✅ **Good use cases**: +- Repetitive UI automation tasks +- Testing workflows that require setup +- Demos and presentations +- Regression testing of UI workflows + +❌ **Not recommended**: +- Highly dynamic UIs that change frequently +- Tasks requiring real-time decision making +- One-off tasks that won't be repeated + +### Choosing Validation Settings + +**For stable UIs** (e.g., desktop applications): +- Method: `phash` +- Threshold: `5-10` +- Region: `100` + +**For dynamic UIs** (e.g., websites with ads): +- Method: `phash` +- Threshold: `15-20` +- Region: `150` + +**For maximum performance** (trusted cache): +- Method: `none` +- (Visual validation disabled) + +### Debugging Cache Execution + +If cache execution fails: + +1. **Check visual validation**: Lower threshold or disable temporarily +2. **Verify UI state**: Ensure UI hasn't changed since recording +3. **Check cache file**: Look for `visual_representation` fields +4. **Review logs**: Look for "Visual validation failed" messages +5. **Re-record**: Delete cache file and record fresh trajectory + +## Additional Resources + +- **Documentation**: See `docs/caching.md` for complete documentation +- **Visual Validation**: See `docs/visual_validation.md` for technical details +- **Playground**: See `playground/caching_demo.py` for more examples + +## Questions? + +For issues or questions, please refer to the main documentation or open an issue in the repository. diff --git a/examples/basic_caching_example.py b/examples/basic_caching_example.py new file mode 100644 index 00000000..382d482f --- /dev/null +++ b/examples/basic_caching_example.py @@ -0,0 +1,190 @@ +"""Basic Caching Example - Introduction to Cache Recording and Execution. + +This example demonstrates: +- Recording a trajectory to a cache file (record mode) +- Executing a cached trajectory (execute mode) +- Using both modes together (both mode) +- Cache parameters for dynamic value substitution +""" + +import logging + +from askui import VisionAgent +from askui.models.shared.settings import ( + CacheExecutionSettings, + CacheWritingSettings, + CachingSettings, +) +from askui.reporting import SimpleHtmlReporter + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger() + + +def record_trajectory_example() -> None: + """Record a new trajectory to a cache file. + + This example records a simple task (opening Calculator and performing + a calculation) to a cache file. The first time you run this, the agent + will perform the task normally. The trajectory will be saved to + 'basic_example.json' for later reuse. + """ + goal = """Please open the Calculator application and calculate 15 + 27. + Then close the Calculator application. + """ + + caching_settings = CachingSettings( + strategy="record", # Record mode: save trajectory to cache file + cache_dir=".askui_cache", + writing_settings=CacheWritingSettings( + filename="basic_example.json", + visual_verification_method="phash", # Use perceptual hashing + visual_validation_region_size=100, # 100x100 pixel region around clicks + visual_validation_threshold=10, # Hamming distance threshold + parameter_identification_strategy="llm", # AI-based parameter detection + ), + ) + + with VisionAgent( + display=1, + reporters=[SimpleHtmlReporter()], + ) as agent: + agent.act(goal, caching_settings=caching_settings) + + logger.info("✓ Trajectory recorded to .askui_cache/basic_example.json") + logger.info("Run execute_trajectory_example() to replay this trajectory") + + +def execute_trajectory_example() -> None: + """Execute a previously recorded trajectory from cache. + + This example executes the trajectory recorded in record_trajectory_example(). + The agent will replay the exact sequence of actions from the cache file, + validating the UI state at each step using visual hashing. + + Prerequisites: + - Run record_trajectory_example() first to create the cache file + """ + goal = """Please execute the cached trajectory from basic_example.json + """ + + caching_settings = CachingSettings( + strategy="execute", # Execute mode: replay from cache file + cache_dir=".askui_cache", + execution_settings=CacheExecutionSettings( + delay_time_between_action=0.5, # Wait 0.5s between actions + ), + ) + + with VisionAgent( + display=1, + reporters=[SimpleHtmlReporter()], + ) as agent: + agent.act(goal, caching_settings=caching_settings) + + logger.info("✓ Trajectory executed from cache") + + +def both_mode_example() -> None: + """Use both record and execute modes together. + + This example demonstrates the "both" strategy, which: + 1. First tries to execute from cache if available + 2. Falls back to normal execution if cache doesn't exist + 3. Records the trajectory if it wasn't cached + + This is the most flexible mode for development and testing. + """ + goal = """Please open the TextEdit application, type "Hello from cache!", + and close the application without saving. + If available, use cache file at both_mode_example.json + """ + + caching_settings = CachingSettings( + strategy="both", # Both: try execute, fall back to record + cache_dir=".askui_cache", + writing_settings=CacheWritingSettings( + filename="both_mode_example.json", + visual_verification_method="phash", + parameter_identification_strategy="llm", + ), + execution_settings=CacheExecutionSettings( + delay_time_between_action=0.3, + ), + ) + + with VisionAgent( + display=1, + reporters=[SimpleHtmlReporter()], + ) as agent: + agent.act(goal, caching_settings=caching_settings) + + logger.info("✓ Task completed using both mode") + + +def cache_parameters_example() -> None: + """Demonstrate cache parameters for dynamic value substitution. + + Cache parameters allow you to record a trajectory once and replay it + with different values. The AI identifies dynamic values (like search + terms, numbers, dates) and replaces them with {{parameter_name}} syntax. + + When executing the cache, you can provide different values for these + parameters. + """ + # First run: Record with original value + goal_record = """Please open Safari, navigate to www.google.com, + search for "Python programming", and close Safari. + """ + + caching_settings = CachingSettings( + strategy="record", + cache_dir=".askui_cache", + writing_settings=CacheWritingSettings( + filename="parameterized_example.json", + visual_verification_method="phash", + parameter_identification_strategy="llm", # AI detects "Python programming" as parameter + ), + ) + + logger.info("Recording parameterized trajectory...") + with VisionAgent(display=1, reporters=[SimpleHtmlReporter()]) as agent: + agent.act(goal_record, caching_settings=caching_settings) + + logger.info("✓ Parameterized trajectory recorded") + logger.info("The AI identified dynamic values and created parameters") + logger.info( + "Check .askui_cache/parameterized_example.json to see {{parameter}} syntax" + ) + + +if __name__ == "__main__": + # Run examples in sequence + print("\n" + "=" * 70) + print("BASIC CACHING EXAMPLES") + print("=" * 70 + "\n") + + print("\n1. Recording a trajectory to cache...") + print("-" * 70) + record_trajectory_example() + + print("\n2. Executing from cache...") + print("-" * 70) + # Uncomment to execute the cached trajectory: + execute_trajectory_example() + + print("\n3. Using 'both' mode (execute or record)...") + print("-" * 70) + # Uncomment to try both mode: + both_mode_example() + + print("\n4. Cache parameters for dynamic values...") + print("-" * 70) + # Uncomment to try parameterized caching: + cache_parameters_example() + + print("\n" + "=" * 70) + print("Examples completed!") + print("=" * 70 + "\n") diff --git a/examples/visual_validation_example.py b/examples/visual_validation_example.py new file mode 100644 index 00000000..91cb00fa --- /dev/null +++ b/examples/visual_validation_example.py @@ -0,0 +1,324 @@ +"""Visual Validation Example - Demonstrating Visual UI State Validation. + +This example demonstrates: +- Visual validation using perceptual hashing (pHash) +- Visual validation using average hashing (aHash) +- Adjusting validation thresholds for strictness +- Handling visual validation failures gracefully +- Understanding when visual validation helps detect UI changes +""" + +import logging + +from askui import VisionAgent +from askui.models.shared.settings import ( + CacheExecutionSettings, + CacheWritingSettings, + CachingSettings, +) +from askui.reporting import SimpleHtmlReporter + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger() + + +def phash_validation_example() -> None: + """Record and execute with pHash visual validation. + + Perceptual hashing (pHash) is the default and recommended method for + visual validation. It uses Discrete Cosine Transform (DCT) to create + a fingerprint of the UI region around each click/action. + + pHash is robust to: + - Minor image compression artifacts + - Small lighting changes + - Slight color variations + + But will detect: + - UI element position changes + - Different UI states (buttons, menus, etc.) + - Major layout changes + """ + goal = """Please open the System Settings application, click on "General", + then close the System Settings application. + If available, use cache file at phash_example.json + """ + + caching_settings = CachingSettings( + strategy="both", + cache_dir=".askui_cache", + writing_settings=CacheWritingSettings( + filename="phash_example.json", + visual_verification_method="phash", # Perceptual hashing (recommended) + visual_validation_region_size=100, # 100x100 pixel region + visual_validation_threshold=10, # Lower = stricter (0-64 range) + ), + execution_settings=CacheExecutionSettings( + delay_time_between_action=0.5, + ), + ) + + with VisionAgent( + display=1, + reporters=[SimpleHtmlReporter()], + ) as agent: + agent.act(goal, caching_settings=caching_settings) + + logger.info("✓ pHash validation example completed") + logger.info("Visual hashes were computed for each click/action") + logger.info("Check the cache file to see 'visual_representation' fields") + + +def ahash_validation_example() -> None: + """Record and execute with aHash visual validation. + + Average hashing (aHash) is an alternative method that computes + the average pixel values in a region. It's simpler and faster than + pHash but slightly less robust. + + aHash is good for: + - Very fast validation + - Simple UI elements + - High-contrast interfaces + + Use pHash for better accuracy in most cases. + """ + goal = """Please open the Finder application, navigate to Documents folder, + then close the Finder window. + If available, use cache file at ahash_example.json + """ + + caching_settings = CachingSettings( + strategy="both", + cache_dir=".askui_cache", + writing_settings=CacheWritingSettings( + filename="ahash_example.json", + visual_verification_method="ahash", # Average hashing (alternative) + visual_validation_region_size=100, + visual_validation_threshold=10, + ), + execution_settings=CacheExecutionSettings( + delay_time_between_action=0.5, + ), + ) + + with VisionAgent( + display=1, + reporters=[SimpleHtmlReporter()], + ) as agent: + agent.act(goal, caching_settings=caching_settings) + + logger.info("✓ aHash validation example completed") + + +def strict_validation_example() -> None: + """Use strict visual validation (low threshold). + + A lower threshold means stricter validation. The threshold represents + the maximum Hamming distance (number of differing bits) allowed between + the cached hash and the current UI state. + + Threshold guidelines: + - 0-5: Very strict (detects tiny changes) + - 6-10: Strict (recommended for stable UIs) + - 11-15: Moderate (tolerates minor changes) + - 16+: Lenient (may miss significant changes) + """ + goal = """Please open the Calendar application and click on today's date. + Then close the Calendar application. + If available, use cache file at strict_validation.json + """ + + caching_settings = CachingSettings( + strategy="both", + cache_dir=".askui_cache", + writing_settings=CacheWritingSettings( + filename="strict_validation.json", + visual_verification_method="phash", + visual_validation_region_size=100, + visual_validation_threshold=5, # Very strict validation + ), + execution_settings=CacheExecutionSettings( + delay_time_between_action=0.5, + ), + ) + + with VisionAgent( + display=1, + reporters=[SimpleHtmlReporter()], + ) as agent: + agent.act(goal, caching_settings=caching_settings) + + logger.info("✓ Strict validation example completed") + logger.info("Threshold=5 will catch even minor UI changes") + + +def lenient_validation_example() -> None: + """Use lenient visual validation (high threshold). + + A higher threshold makes validation more tolerant of UI changes. + Use this when: + - UI has dynamic elements (ads, recommendations, etc.) + - Layout changes slightly between executions + - You want cache to work across minor UI updates + + Be careful: Too lenient validation may miss important UI changes! + """ + goal = """Please open Safari, navigate to www.apple.com, + and close Safari. + If available, use cache file at lenient_validation.json + """ + + caching_settings = CachingSettings( + strategy="both", + cache_dir=".askui_cache", + writing_settings=CacheWritingSettings( + filename="lenient_validation.json", + visual_verification_method="phash", + visual_validation_region_size=100, + visual_validation_threshold=20, # Lenient validation + ), + execution_settings=CacheExecutionSettings( + delay_time_between_action=0.5, + ), + ) + + with VisionAgent( + display=1, + reporters=[SimpleHtmlReporter()], + ) as agent: + agent.act(goal, caching_settings=caching_settings) + + logger.info("✓ Lenient validation example completed") + logger.info("Threshold=20 tolerates more UI variation") + + +def region_size_example() -> None: + """Demonstrate different validation region sizes. + + The region size determines how large an area around each click point + is captured and validated. Larger regions capture more context but + may be less precise. + + Region size guidelines: + - 50x50: Small, precise validation of exact click target + - 100x100: Balanced (recommended default) + - 150x150: Large, includes more surrounding UI context + - 200x200: Very large, captures entire UI section + """ + goal = """Please open the Notes application, click on "New Note" button, + and close Notes without saving. + If available, use cache file at region_size_example.json + """ + + caching_settings = CachingSettings( + strategy="both", + cache_dir=".askui_cache", + writing_settings=CacheWritingSettings( + filename="region_size_example.json", + visual_verification_method="phash", + visual_validation_region_size=150, # Larger region for more context + visual_validation_threshold=10, + ), + execution_settings=CacheExecutionSettings( + delay_time_between_action=0.5, + ), + ) + + with VisionAgent( + display=1, + reporters=[SimpleHtmlReporter()], + ) as agent: + agent.act(goal, caching_settings=caching_settings) + + logger.info("✓ Region size example completed") + logger.info("Used 150x150 pixel region for validation") + + +def no_validation_example() -> None: + """Disable visual validation entirely. + + Set visual_verification_method="none" to disable visual validation. + The cache will still work, but won't verify UI state during execution. + + Use this when: + - You trust the cache completely + - UI changes frequently and validation would fail + - Performance is critical (validation adds small overhead) + + Warning: Without validation, cached actions may execute on wrong UI! + """ + goal = """Please open the Music application and close it. + If available, use cache file at no_validation_example.json + """ + + caching_settings = CachingSettings( + strategy="both", + cache_dir=".askui_cache", + writing_settings=CacheWritingSettings( + filename="no_validation_example.json", + visual_verification_method="none", # Disable visual validation + ), + execution_settings=CacheExecutionSettings( + delay_time_between_action=0.5, + ), + ) + + with VisionAgent( + display=1, + reporters=[SimpleHtmlReporter()], + ) as agent: + agent.act(goal, caching_settings=caching_settings) + + logger.info("✓ No validation example completed") + logger.info("Cache executed without visual validation") + + +if __name__ == "__main__": + # Run examples to demonstrate different validation modes + print("\n" + "=" * 70) + print("VISUAL VALIDATION EXAMPLES") + print("=" * 70 + "\n") + + print("\n1. pHash validation (recommended)...") + print("-" * 70) + phash_validation_example() + + print("\n2. aHash validation (alternative)...") + print("-" * 70) + # Uncomment to try aHash: + ahash_validation_example() + + print("\n3. Strict validation (threshold=5)...") + print("-" * 70) + # Uncomment to try strict validation: + strict_validation_example() + + print("\n4. Lenient validation (threshold=20)...") + print("-" * 70) + # Uncomment to try lenient validation: + lenient_validation_example() + + print("\n5. Large region size (150x150)...") + print("-" * 70) + # Uncomment to try larger region: + region_size_example() + + print("\n6. No validation (disabled)...") + print("-" * 70) + # Uncomment to try without validation: + no_validation_example() + + print("\n" + "=" * 70) + print("Visual validation examples completed!") + print("=" * 70 + "\n") + + print("\nKey Takeaways:") + print("- pHash is recommended for most use cases (robust and accurate)") + print("- Threshold 5-10 is good for stable UIs") + print("- Threshold 15-20 is better for dynamic UIs") + print("- Region size 100x100 is a good default") + print("- Visual validation helps detect unexpected UI changes") + print() From 29900318043dd8811a5cc9aa54f2bace83c9dabb Mon Sep 17 00:00:00 2001 From: philipph-askui Date: Wed, 31 Dec 2025 10:33:29 +0100 Subject: [PATCH 28/30] chore(caching): refactor tests to use new caching api --- tests/e2e/agent/test_act_caching.py | 2 +- .../test_messages_api_field_stripping.py | 99 +++++++++++++++++++ .../test_token_counting_visual_validation.py | 96 ++++++++++++++++++ tests/unit/tools/test_caching_tools.py | 4 +- tests/unit/utils/test_cache_writer.py | 74 ++++++++++---- 5 files changed, 251 insertions(+), 24 deletions(-) create mode 100644 tests/unit/models/test_messages_api_field_stripping.py create mode 100644 tests/unit/models/test_token_counting_visual_validation.py diff --git a/tests/e2e/agent/test_act_caching.py b/tests/e2e/agent/test_act_caching.py index 70652b43..52910f92 100644 --- a/tests/e2e/agent/test_act_caching.py +++ b/tests/e2e/agent/test_act_caching.py @@ -9,7 +9,7 @@ from askui.agent import VisionAgent from askui.models.shared.agent_message_param import MessageParam from askui.models.shared.agent_on_message_cb import OnMessageCbParam -from askui.models.shared.settings import CachedExecutionToolSettings, CachingSettings +from askui.models.shared.settings import CacheExecutionSettings, CachingSettings def test_act_with_caching_strategy_read(vision_agent: VisionAgent) -> None: diff --git a/tests/unit/models/test_messages_api_field_stripping.py b/tests/unit/models/test_messages_api_field_stripping.py new file mode 100644 index 00000000..3cf46c64 --- /dev/null +++ b/tests/unit/models/test_messages_api_field_stripping.py @@ -0,0 +1,99 @@ +"""Tests for Pydantic-based context-aware serialization of internal fields.""" + +from askui.models.shared.agent_message_param import MessageParam, ToolUseBlockParam + + +def test_tool_use_block_includes_fields_by_default(): + """Test that visual validation fields are included in normal serialization.""" + tool_block = ToolUseBlockParam( + id="test_id", + name="computer", + input={"action": "left_click", "coordinate": [100, 200]}, + type="tool_use", + visual_representation="abc123", + ) + + # Default serialization includes all fields + serialized = tool_block.model_dump() + + assert serialized["visual_representation"] == "abc123" + assert serialized["id"] == "test_id" + assert serialized["name"] == "computer" + + +def test_tool_use_block_excludes_fields_for_api(): + """Test that visual validation fields are excluded when for_api context is set.""" + tool_block = ToolUseBlockParam( + id="test_id", + name="computer", + input={"action": "left_click", "coordinate": [100, 200]}, + type="tool_use", + visual_representation="abc123", + ) + + # Serialization with for_api context excludes internal fields + serialized = tool_block.model_dump(context={"for_api": True}) + + # Internal fields should be excluded + assert "visual_representation" not in serialized + + # Other fields should remain + assert serialized["id"] == "test_id" + assert serialized["name"] == "computer" + assert serialized["input"] == {"action": "left_click", "coordinate": [100, 200]} + + +def test_tool_use_block_without_visual_validation(): + """Test serialization of tool block without visual validation fields.""" + tool_block = ToolUseBlockParam( + id="test_id", + name="computer", + input={"action": "screenshot"}, + type="tool_use", + ) + + # Both modes should work fine + normal = tool_block.model_dump() + for_api = tool_block.model_dump(context={"for_api": True}) + + # Should not have visual representation field in either case (or it should be None) + assert "visual_representation" not in normal or normal["visual_representation"] is None + assert "visual_representation" not in for_api + + +def test_message_with_tool_use_context_propagation(): + """Test that context propagates through nested MessageParam serialization.""" + tool_block = ToolUseBlockParam( + id="test_id", + name="computer", + input={"action": "left_click", "coordinate": [100, 200]}, + type="tool_use", + visual_representation="abc123", + ) + + message = MessageParam(role="assistant", content=[tool_block]) + + # Normal dump includes fields + normal = message.model_dump() + assert normal["content"][0]["visual_representation"] == "abc123" + + # API dump excludes fields + for_api = message.model_dump(context={"for_api": True}) + assert "visual_representation" not in for_api["content"][0] + + +def test_cache_storage_includes_all_fields(): + """Test that cache storage (mode='json') includes all fields.""" + tool_block = ToolUseBlockParam( + id="test_id", + name="computer", + input={"action": "left_click", "coordinate": [100, 200]}, + type="tool_use", + visual_representation="abc123", + ) + + # Cache storage uses mode='json' without for_api context + cache_dump = tool_block.model_dump(mode="json") + + # Should include all fields for cache storage + assert cache_dump["visual_representation"] == "abc123" diff --git a/tests/unit/models/test_token_counting_visual_validation.py b/tests/unit/models/test_token_counting_visual_validation.py new file mode 100644 index 00000000..68d798d0 --- /dev/null +++ b/tests/unit/models/test_token_counting_visual_validation.py @@ -0,0 +1,96 @@ +"""Test that token counting excludes visual validation fields.""" + +from askui.models.shared.agent_message_param import MessageParam, ToolUseBlockParam +from askui.models.shared.token_counter import SimpleTokenCounter + + +def test_token_counting_excludes_visual_validation_fields(): + """Verify that visual validation fields don't inflate token counts.""" + # Create two identical tool blocks, one with and one without visual validation + tool_block_without = ToolUseBlockParam( + id="test_id", + name="computer", + input={"action": "left_click", "coordinate": [100, 200]}, + type="tool_use", + ) + + tool_block_with = ToolUseBlockParam( + id="test_id", + name="computer", + input={"action": "left_click", "coordinate": [100, 200]}, + type="tool_use", + visual_representation="a" * 1000, # 1000 character hash + ) + + # Create messages + msg_without = MessageParam(role="assistant", content=[tool_block_without]) + msg_with = MessageParam(role="assistant", content=[tool_block_with]) + + # Count tokens + counter = SimpleTokenCounter() + counts_without = counter.count_tokens(messages=[msg_without]) + counts_with = counter.count_tokens(messages=[msg_with]) + + # Token counts should be identical (visual fields excluded from counting) + assert counts_without.total == counts_with.total, ( + f"Token counts differ: {counts_without.total} vs {counts_with.total}. " + "Visual validation fields should be excluded from token counting." + ) + + +def test_token_counter_uses_api_context(): + """Verify that token counter uses for_api context when stringifying objects.""" + tool_block = ToolUseBlockParam( + id="test_id", + name="computer", + input={"action": "left_click", "coordinate": [100, 200]}, + type="tool_use", + visual_representation="hash123", + ) + + counter = SimpleTokenCounter() + + # Stringify the object (as token counter does internally) + stringified = counter._stringify_object(tool_block) + + # Should not contain visual validation fields + assert "visual_representation" not in stringified + assert "hash123" not in stringified + + # Should contain regular fields + assert "test_id" in stringified + assert "computer" in stringified + + +def test_token_counting_with_multiple_tool_blocks(): + """Test token counting with multiple tool blocks in one message.""" + blocks = [ + ToolUseBlockParam( + id=f"id_{i}", + name="computer", + input={"action": "left_click", "coordinate": [i * 100, i * 100]}, + type="tool_use", + visual_representation="x" * 500, # Large hash + ) + for i in range(5) + ] + + blocks_without_validation = [ + ToolUseBlockParam( + id=f"id_{i}", + name="computer", + input={"action": "left_click", "coordinate": [i * 100, i * 100]}, + type="tool_use", + ) + for i in range(5) + ] + + msg_with = MessageParam(role="assistant", content=blocks) + msg_without = MessageParam(role="assistant", content=blocks_without_validation) + + counter = SimpleTokenCounter() + counts_with = counter.count_tokens(messages=[msg_with]) + counts_without = counter.count_tokens(messages=[msg_without]) + + # Should have same token count + assert counts_with.total == counts_without.total diff --git a/tests/unit/tools/test_caching_tools.py b/tests/unit/tools/test_caching_tools.py index 253a1854..7c9e69eb 100644 --- a/tests/unit/tools/test_caching_tools.py +++ b/tests/unit/tools/test_caching_tools.py @@ -9,7 +9,7 @@ import pytest -from askui.models.shared.settings import CachedExecutionToolSettings +from askui.models.shared.settings import CacheExecutionSettings from askui.models.shared.tools import ToolCollection from askui.tools.caching_tools import ( ExecuteCachedTrajectory, @@ -390,7 +390,7 @@ def test_execute_cached_execution_initializes_with_default_settings() -> None: def test_execute_cached_execution_initializes_with_custom_settings() -> None: """Test that ExecuteCachedTrajectory accepts custom settings.""" mock_toolbox = MagicMock(spec=ToolCollection) - custom_settings = CachedExecutionToolSettings(delay_time_between_action=1.0) + custom_settings = CacheExecutionSettings(delay_time_between_action=1.0) tool = ExecuteCachedTrajectory(toolbox=mock_toolbox, settings=custom_settings) # Should have custom settings initialized diff --git a/tests/unit/utils/test_cache_writer.py b/tests/unit/utils/test_cache_writer.py index c0f04cb2..c9ce4a5d 100644 --- a/tests/unit/utils/test_cache_writer.py +++ b/tests/unit/utils/test_cache_writer.py @@ -7,14 +7,19 @@ from askui.models.shared.agent_message_param import MessageParam, ToolUseBlockParam from askui.models.shared.agent_on_message_cb import OnMessageCbParam -from askui.models.shared.settings import CacheFile, CacheWriterSettings +from askui.models.shared.settings import CacheFile, CacheWritingSettings + +# Note: CacheWritingSettings was renamed to CacheWritingSettings in v0.2 from askui.utils.caching.cache_writer import CacheWriter def test_cache_writer_initialization() -> None: """Test CacheWriter initialization.""" with tempfile.TemporaryDirectory() as temp_dir: - cache_writer = CacheWriter(cache_dir=temp_dir, file_name="test.json") + cache_writer = CacheWriter( + cache_dir=temp_dir, + cache_writing_settings=CacheWritingSettings(filename="test.json"), + ) assert cache_writer.cache_dir == Path(temp_dir) assert cache_writer.file_name == "test.json" assert cache_writer.messages == [] @@ -35,17 +40,26 @@ def test_cache_writer_creates_cache_directory() -> None: def test_cache_writer_adds_json_extension() -> None: """Test that CacheWriter adds .json extension if not present.""" with tempfile.TemporaryDirectory() as temp_dir: - cache_writer = CacheWriter(cache_dir=temp_dir, file_name="test") + cache_writer = CacheWriter( + cache_dir=temp_dir, + cache_writing_settings=CacheWritingSettings(filename="test"), + ) assert cache_writer.file_name == "test.json" - cache_writer2 = CacheWriter(cache_dir=temp_dir, file_name="test.json") + cache_writer2 = CacheWriter( + cache_dir=temp_dir, + cache_writing_settings=CacheWritingSettings(filename="test.json"), + ) assert cache_writer2.file_name == "test.json" def test_cache_writer_add_message_cb_stores_tool_use_blocks() -> None: """Test that add_message_cb stores ToolUseBlockParam from assistant messages.""" with tempfile.TemporaryDirectory() as temp_dir: - cache_writer = CacheWriter(cache_dir=temp_dir, file_name="test.json") + cache_writer = CacheWriter( + cache_dir=temp_dir, + cache_writing_settings=CacheWritingSettings(filename="test.json"), + ) tool_use_block = ToolUseBlockParam( id="test_id", @@ -74,7 +88,10 @@ def test_cache_writer_add_message_cb_stores_tool_use_blocks() -> None: def test_cache_writer_add_message_cb_ignores_non_tool_use_content() -> None: """Test that add_message_cb ignores non-ToolUseBlockParam content.""" with tempfile.TemporaryDirectory() as temp_dir: - cache_writer = CacheWriter(cache_dir=temp_dir, file_name="test.json") + cache_writer = CacheWriter( + cache_dir=temp_dir, + cache_writing_settings=CacheWritingSettings(filename="test.json"), + ) message = MessageParam( role="assistant", @@ -94,7 +111,10 @@ def test_cache_writer_add_message_cb_ignores_non_tool_use_content() -> None: def test_cache_writer_add_message_cb_ignores_user_messages() -> None: """Test that add_message_cb ignores user messages.""" with tempfile.TemporaryDirectory() as temp_dir: - cache_writer = CacheWriter(cache_dir=temp_dir, file_name="test.json") + cache_writer = CacheWriter( + cache_dir=temp_dir, + cache_writing_settings=CacheWritingSettings(filename="test.json"), + ) message = MessageParam( role="user", @@ -114,7 +134,10 @@ def test_cache_writer_add_message_cb_ignores_user_messages() -> None: def test_cache_writer_detects_cached_execution() -> None: """Test that CacheWriter detects when execute_cached_executions_tool is used.""" with tempfile.TemporaryDirectory() as temp_dir: - cache_writer = CacheWriter(cache_dir=temp_dir, file_name="test.json") + cache_writer = CacheWriter( + cache_dir=temp_dir, + cache_writing_settings=CacheWritingSettings(filename="test.json"), + ) tool_use_block = ToolUseBlockParam( id="cached_exec_id", @@ -144,8 +167,8 @@ def test_cache_writer_generate_writes_file() -> None: cache_dir = Path(temp_dir) cache_writer = CacheWriter( cache_dir=str(cache_dir), - file_name="output.json", - cache_writer_settings=CacheWriterSettings( + cache_writing_settings=CacheWritingSettings( + filename="output.json", parameter_identification_strategy="preset" ), ) @@ -200,8 +223,8 @@ def test_cache_writer_generate_auto_names_file() -> None: cache_dir = Path(temp_dir) cache_writer = CacheWriter( cache_dir=str(cache_dir), - file_name="", - cache_writer_settings=CacheWriterSettings( + cache_writing_settings=CacheWritingSettings( + filename="", parameter_identification_strategy="preset" ), ) @@ -225,7 +248,10 @@ def test_cache_writer_generate_skips_cached_execution() -> None: """Test that generate() doesn't write file for cached executions.""" with tempfile.TemporaryDirectory() as temp_dir: cache_dir = Path(temp_dir) - cache_writer = CacheWriter(cache_dir=str(cache_dir), file_name="test.json") + cache_writer = CacheWriter( + cache_dir=str(cache_dir), + cache_writing_settings=CacheWritingSettings(filename="test.json"), + ) cache_writer.was_cached_execution = True cache_writer.messages = [ @@ -247,7 +273,10 @@ def test_cache_writer_generate_skips_cached_execution() -> None: def test_cache_writer_reset() -> None: """Test that reset() clears messages and filename.""" with tempfile.TemporaryDirectory() as temp_dir: - cache_writer = CacheWriter(cache_dir=temp_dir, file_name="original.json") + cache_writer = CacheWriter( + cache_dir=temp_dir, + cache_writing_settings=CacheWritingSettings(filename="original.json"), + ) # Add some data cache_writer.messages = [ @@ -359,7 +388,10 @@ def test_cache_writer_read_cache_file_v2() -> None: def test_cache_writer_set_file_name() -> None: """Test that set_file_name() updates the filename.""" with tempfile.TemporaryDirectory() as temp_dir: - cache_writer = CacheWriter(cache_dir=temp_dir, file_name="original.json") + cache_writer = CacheWriter( + cache_dir=temp_dir, + cache_writing_settings=CacheWritingSettings(filename="original.json"), + ) cache_writer.set_file_name("new_name") assert cache_writer.file_name == "new_name.json" @@ -374,8 +406,8 @@ def test_cache_writer_generate_resets_after_writing() -> None: cache_dir = Path(temp_dir) cache_writer = CacheWriter( cache_dir=str(cache_dir), - file_name="test.json", - cache_writer_settings=CacheWriterSettings( + cache_writing_settings=CacheWritingSettings( + filename="test.json", parameter_identification_strategy="preset" ), ) @@ -401,8 +433,8 @@ def test_cache_writer_detects_and_stores_parameters() -> None: cache_dir = Path(temp_dir) cache_writer = CacheWriter( cache_dir=str(cache_dir), - file_name="test.json", - cache_writer_settings=CacheWriterSettings( + cache_writing_settings=CacheWritingSettings( + filename="test.json", parameter_identification_strategy="preset" ), ) @@ -443,8 +475,8 @@ def test_cache_writer_empty_parameters_when_none_found() -> None: cache_dir = Path(temp_dir) cache_writer = CacheWriter( cache_dir=str(cache_dir), - file_name="test.json", - cache_writer_settings=CacheWriterSettings( + cache_writing_settings=CacheWritingSettings( + filename="test.json", parameter_identification_strategy="preset" ), ) From f32f661d367d5266528052382f67007a1d9e5899 Mon Sep 17 00:00:00 2001 From: philipph-askui Date: Wed, 31 Dec 2025 11:16:26 +0100 Subject: [PATCH 29/30] chore(caching): fix formatting and linting issues --- .gitignore | 2 + src/askui/models/anthropic/messages_api.py | 2 +- .../models/shared/agent_message_param.py | 10 +++- src/askui/models/shared/token_counter.py | 11 ++-- src/askui/tools/caching_tools.py | 28 +++++++--- src/askui/utils/caching/cache_writer.py | 53 ++++++++++--------- src/askui/utils/trajectory_executor.py | 12 +++-- src/askui/utils/visual_validation.py | 2 +- tests/e2e/agent/test_act_caching.py | 42 +++++++++------ .../test_messages_api_field_stripping.py | 14 ++--- .../test_token_counting_visual_validation.py | 6 +-- tests/unit/utils/test_cache_writer.py | 15 ++---- tests/unit/utils/test_visual_validation.py | 52 +++++++++--------- 13 files changed, 143 insertions(+), 106 deletions(-) diff --git a/.gitignore b/.gitignore index 246d7df9..600c9ba0 100644 --- a/.gitignore +++ b/.gitignore @@ -166,4 +166,6 @@ reports/ /chat /askui_chat.db .cache/ +.askui_cache/ +playground/* diff --git a/src/askui/models/anthropic/messages_api.py b/src/askui/models/anthropic/messages_api.py index dab29c7a..9f16db97 100644 --- a/src/askui/models/anthropic/messages_api.py +++ b/src/askui/models/anthropic/messages_api.py @@ -72,7 +72,7 @@ def create_message( "BetaMessageParam", message.model_dump( exclude={"stop_reason", "usage"}, - context={"for_api": True} # Triggers exclusion of internal fields + context={"for_api": True}, # Triggers exclusion of internal fields ), ) for message in messages diff --git a/src/askui/models/shared/agent_message_param.py b/src/askui/models/shared/agent_message_param.py index cca9f890..1503bb5b 100644 --- a/src/askui/models/shared/agent_message_param.py +++ b/src/askui/models/shared/agent_message_param.py @@ -1,6 +1,8 @@ from typing import Any from pydantic import BaseModel, model_serializer +from pydantic.functional_serializers import SerializerFunctionWrapHandler +from pydantic_core import core_schema from typing_extensions import Literal @@ -84,14 +86,18 @@ class ToolUseBlockParam(BaseModel): visual_representation: str | None = None @model_serializer(mode="wrap") - def _serialize_model(self, serializer, info) -> dict[str, Any]: + def _serialize_model( + self, + serializer: SerializerFunctionWrapHandler, + info: core_schema.SerializationInfo, + ) -> dict[str, Any]: """Custom serializer to exclude internal fields when serializing for API. When context={'for_api': True}, visual validation fields are excluded. Otherwise, all fields are included (for cache storage, internal use). """ # Use default serialization - data = serializer(self) + data: dict[str, Any] = serializer(self) # If serializing for API, remove internal fields if info.context and info.context.get("for_api"): diff --git a/src/askui/models/shared/token_counter.py b/src/askui/models/shared/token_counter.py index 5f162437..378f7033 100644 --- a/src/askui/models/shared/token_counter.py +++ b/src/askui/models/shared/token_counter.py @@ -219,16 +219,17 @@ def _count_tokens_for_content_block(self, block: ContentBlockParam) -> int: token_count = int(len(stringified) / self._chars_per_token) # Debug: Log if this is a ToolUseBlockParam with visual validation fields - if hasattr(block, 'visual_representation') and block.visual_representation: + if hasattr(block, "visual_representation") and block.visual_representation: import logging + logger = logging.getLogger(__name__) logger.debug( "Token counting for %s: stringified_length=%d, tokens=%d, " "has_visual_fields=%s", - getattr(block, 'name', 'unknown'), + getattr(block, "name", "unknown"), len(stringified), token_count, - 'visual_representation' in stringified + "visual_representation" in stringified, ) return token_count @@ -284,10 +285,10 @@ def _stringify_object(self, obj: object) -> str: return obj # Check if object is a Pydantic model with model_dump method - if hasattr(obj, "model_dump") and callable(getattr(obj, "model_dump")): + if hasattr(obj, "model_dump") and callable(obj.model_dump): try: # Use for_api context to exclude internal fields from token counting - serialized = obj.model_dump(context={"for_api": True}) # type: ignore[attr-defined] + serialized = obj.model_dump(context={"for_api": True}) return json.dumps(serialized, separators=(",", ":")) except (TypeError, ValueError, AttributeError): pass # Fall through to default handling diff --git a/src/askui/tools/caching_tools.py b/src/askui/tools/caching_tools.py index e5e3bf80..bcba4ef6 100644 --- a/src/askui/tools/caching_tools.py +++ b/src/askui/tools/caching_tools.py @@ -322,25 +322,37 @@ def _create_executor( if cache_file.metadata.visual_verification_method: # Cache has visual validation metadata - use those exact settings - visual_validation_enabled = cache_file.metadata.visual_verification_method != "none" + visual_validation_enabled = ( + cache_file.metadata.visual_verification_method != "none" + ) visual_hash_method = cache_file.metadata.visual_verification_method if cache_file.metadata.visual_validation_threshold is not None: - visual_validation_threshold = cache_file.metadata.visual_validation_threshold + visual_validation_threshold = ( + cache_file.metadata.visual_validation_threshold + ) if cache_file.metadata.visual_validation_region_size is not None: - visual_validation_region_size = cache_file.metadata.visual_validation_region_size + visual_validation_region_size = ( + cache_file.metadata.visual_validation_region_size + ) logger.debug( - "Visual validation enabled from cache metadata: method=%s, threshold=%d, region_size=%d", - visual_hash_method, - visual_validation_threshold, - visual_validation_region_size, + ( + "Visual validation enabled from cache metadata: " + "method=%s, threshold=%d, region_size=%d", + visual_hash_method, + visual_validation_threshold, + visual_validation_region_size, + ) ) else: # Cache doesn't have visual validation metadata - don't validate logger.debug( - "Visual validation disabled: cache file has no visual validation metadata" + ( + "Visual validation disabled: cache file has no visual " + "validation metadata" + ) ) logger.debug( diff --git a/src/askui/utils/caching/cache_writer.py b/src/askui/utils/caching/cache_writer.py index 521edc83..40fa47ef 100644 --- a/src/askui/utils/caching/cache_writer.py +++ b/src/askui/utils/caching/cache_writer.py @@ -66,9 +66,15 @@ def __init__( self._accumulated_usage = UsageParam() # Extract visual verification settings from cache_writing_settings - self._visual_verification_method = self._cache_writing_settings.visual_verification_method - self._visual_validation_region_size = self._cache_writing_settings.visual_validation_region_size - self._visual_validation_threshold = self._cache_writing_settings.visual_validation_threshold + self._visual_verification_method = ( + self._cache_writing_settings.visual_verification_method + ) + self._visual_validation_region_size = ( + self._cache_writing_settings.visual_validation_region_size + ) + self._visual_validation_threshold = ( + self._cache_writing_settings.visual_validation_threshold + ) # Set toolbox for cache writer so it can check which tools are cacheable self._toolbox = toolbox @@ -84,7 +90,8 @@ def add_message_cb(self, param: OnMessageCbParam) -> MessageParam: if content.name == "execute_cached_executions_tool": self.was_cached_execution = True - # Enhance with visual validation if applicable (skip during cached execution) + # Enhance with visual validation if applicable + # (skip during cached execution) enhanced_content = self._enhance_with_visual_validation(content) self.messages.append(enhanced_content) @@ -250,7 +257,7 @@ def _generate_cache_file( json.dump(cache_file.model_dump(mode="json"), f, indent=4) logger.info("Cache file successfully written: %s ", cache_file_path) - def _enhance_with_visual_validation( + def _enhance_with_visual_validation( # noqa: C901 self, tool_block: ToolUseBlockParam ) -> ToolUseBlockParam: """Enhance ToolUseBlockParam with visual validation data if applicable. @@ -312,7 +319,8 @@ def _enhance_with_visual_validation( screenshot = self._capture_screenshot() if screenshot is None: logger.warning( - "Visual validation skipped for %s action=%s: screenshot capture failed", + "Visual validation skipped for %s action=%s: " + "screenshot capture failed", tool_block.name, action, ) @@ -326,14 +334,8 @@ def _enhance_with_visual_validation( # Compute hash based on method if self._visual_verification_method == "phash": visual_hash = compute_phash(region) - elif self._visual_verification_method == "ahash": + else: # ahash (none was already handled earlier) visual_hash = compute_ahash(region) - else: - logger.warning( - "Unknown visual verification method: %s", - self._visual_verification_method, - ) - return tool_block # Create enhanced ToolUseBlockParam with visual validation data enhanced = ToolUseBlockParam( @@ -346,23 +348,26 @@ def _enhance_with_visual_validation( ) logger.info( - "✓ Visual validation added to %s action=%s at coordinate %s (hash=%s...)", + "✓ Visual validation added to %s action=%s at coordinate %s " + "(hash=%s...)", tool_block.name, action, coordinate, visual_hash[:16], ) + return enhanced # noqa: TRY300 - return enhanced - - except Exception as e: + except Exception as e: # noqa: BLE001 logger.warning( - "Visual validation skipped for %s action=%s: error during enhancement: %s", + "Visual validation skipped for %s action=%s: " + "error during enhancement: %s", tool_block.name, action, str(e), ) - return tool_block + # Fall through to return original tool_block + + return tool_block def _capture_screenshot(self) -> Image.Image | None: """Capture current screenshot using the computer tool. @@ -390,13 +395,13 @@ def _capture_screenshot(self) -> Image.Image | None: try: # Try to call _screenshot() method directly if available if hasattr(computer_tool, "_screenshot"): - result = computer_tool._screenshot() # type: ignore[attr-defined] + result = computer_tool._screenshot() # noqa: SLF001 if isinstance(result, Image.Image): logger.debug("Screenshot captured successfully via _screenshot()") return result # Fallback to calling via __call__ with action parameter - result = computer_tool(action="screenshot") # type: ignore[call-arg] + result = computer_tool(action="screenshot") if isinstance(result, Image.Image): logger.debug("Screenshot captured successfully via __call__") return result @@ -406,14 +411,14 @@ def _capture_screenshot(self) -> Image.Image | None: type(result).__name__, str(result)[:100], ) - return None - except Exception as e: + except Exception as e: # noqa: BLE001 logger.warning( "Error capturing screenshot for visual validation: %s: %s", type(e).__name__, str(e), ) - return None + + return None def _accumulate_usage(self, step_usage: UsageParam) -> None: """Accumulate usage statistics from a single API call. diff --git a/src/askui/utils/trajectory_executor.py b/src/askui/utils/trajectory_executor.py index 59824bfe..a0400b9c 100644 --- a/src/askui/utils/trajectory_executor.py +++ b/src/askui/utils/trajectory_executor.py @@ -352,7 +352,7 @@ def validate_step_visually( region = extract_region( current_screenshot, coordinate, size=self.visual_validation_region_size ) - except Exception as e: + except Exception as e: # noqa: BLE001 logger.warning( "Error extracting region for visual validation at step %d: %s", self.current_step_index, @@ -396,17 +396,19 @@ def _capture_screenshot(self) -> Image.Image | None: try: # Try to call _screenshot() method directly if available if hasattr(computer_tool, "_screenshot"): - result = computer_tool._screenshot() # type: ignore[attr-defined] + result = computer_tool._screenshot() # noqa: SLF001 if isinstance(result, Image.Image): return result # Fallback to calling via __call__ with action parameter - result = computer_tool(action="screenshot") # type: ignore[call-arg] + result = computer_tool(action="screenshot") if isinstance(result, Image.Image): return result - logger.warning("Screenshot action did not return an Image: %s", type(result)) - return None + logger.warning( + "Screenshot action did not return an Image: %s", type(result) + ) + return None # noqa: TRY300 except Exception: logger.exception("Error capturing screenshot") return None diff --git a/src/askui/utils/visual_validation.py b/src/askui/utils/visual_validation.py index f1e0591e..041807d5 100644 --- a/src/askui/utils/visual_validation.py +++ b/src/askui/utils/visual_validation.py @@ -127,7 +127,7 @@ def hamming_distance(hash1: str, hash2: str) -> int: xor_result = int(hash1, 16) ^ int(hash2, 16) # Count number of 1s (differing bits) - return bin(xor_result).count("1") + return (xor_result).bit_count() def extract_region( diff --git a/tests/e2e/agent/test_act_caching.py b/tests/e2e/agent/test_act_caching.py index 52910f92..568ef5a5 100644 --- a/tests/e2e/agent/test_act_caching.py +++ b/tests/e2e/agent/test_act_caching.py @@ -9,7 +9,11 @@ from askui.agent import VisionAgent from askui.models.shared.agent_message_param import MessageParam from askui.models.shared.agent_on_message_cb import OnMessageCbParam -from askui.models.shared.settings import CacheExecutionSettings, CachingSettings +from askui.models.shared.settings import ( + CacheExecutionSettings, + CacheWritingSettings, + CachingSettings, +) def test_act_with_caching_strategy_read(vision_agent: VisionAgent) -> None: @@ -24,7 +28,7 @@ def test_act_with_caching_strategy_read(vision_agent: VisionAgent) -> None: vision_agent.act( goal="Tell me a joke", caching_settings=CachingSettings( - strategy="read", + strategy="execute", cache_dir=str(cache_dir), ), ) @@ -41,9 +45,11 @@ def test_act_with_caching_strategy_write(vision_agent: VisionAgent) -> None: vision_agent.act( goal="Tell me a joke", caching_settings=CachingSettings( - strategy="write", + strategy="record", cache_dir=str(cache_dir), - filename=cache_filename, + writing_settings=CacheWritingSettings( + filename=cache_filename, + ), ), ) @@ -68,7 +74,9 @@ def test_act_with_caching_strategy_both(vision_agent: VisionAgent) -> None: caching_settings=CachingSettings( strategy="both", cache_dir=str(cache_dir), - filename=cache_filename, + writing_settings=CacheWritingSettings( + filename=cache_filename, + ), ), ) @@ -86,8 +94,7 @@ def test_act_with_caching_strategy_no(vision_agent: VisionAgent) -> None: vision_agent.act( goal="Tell me a joke", caching_settings=CachingSettings( - strategy="no", - cache_dir=str(cache_dir), + strategy=None, ), ) @@ -106,9 +113,11 @@ def test_act_with_custom_cache_dir_and_filename(vision_agent: VisionAgent) -> No vision_agent.act( goal="Tell me a joke", caching_settings=CachingSettings( - strategy="write", + strategy="execute", cache_dir=str(custom_cache_dir), - filename=custom_filename, + writing_settings=CacheWritingSettings( + filename=custom_filename, + ), ), ) @@ -132,7 +141,7 @@ def dummy_callback(param: OnMessageCbParam) -> MessageParam: vision_agent.act( goal="Tell me a joke", caching_settings=CachingSettings( - strategy="write", + strategy="record", cache_dir=str(temp_dir), ), on_message=dummy_callback, @@ -170,9 +179,11 @@ def test_cache_file_contains_tool_use_blocks(vision_agent: VisionAgent) -> None: vision_agent.act( goal="Tell me a joke", caching_settings=CachingSettings( - strategy="write", + strategy="record", cache_dir=str(cache_dir), - filename=cache_filename, + writing_settings=CacheWritingSettings( + filename=cache_filename, + ), ), ) @@ -205,13 +216,14 @@ def test_act_with_custom_cached_execution_tool_settings( cache_file.write_text("[]", encoding="utf-8") # Act with custom execution tool settings - custom_settings = CachedExecutionToolSettings(delay_time_between_action=2.0) vision_agent.act( goal="Tell me a joke", caching_settings=CachingSettings( - strategy="read", + strategy="record", cache_dir=str(cache_dir), - execute_cached_trajectory_tool_settings=custom_settings, + execution_settings=CacheExecutionSettings( + delay_time_between_action=2.0, + ), ), ) diff --git a/tests/unit/models/test_messages_api_field_stripping.py b/tests/unit/models/test_messages_api_field_stripping.py index 3cf46c64..464305ab 100644 --- a/tests/unit/models/test_messages_api_field_stripping.py +++ b/tests/unit/models/test_messages_api_field_stripping.py @@ -3,7 +3,7 @@ from askui.models.shared.agent_message_param import MessageParam, ToolUseBlockParam -def test_tool_use_block_includes_fields_by_default(): +def test_tool_use_block_includes_fields_by_default() -> None: """Test that visual validation fields are included in normal serialization.""" tool_block = ToolUseBlockParam( id="test_id", @@ -21,7 +21,7 @@ def test_tool_use_block_includes_fields_by_default(): assert serialized["name"] == "computer" -def test_tool_use_block_excludes_fields_for_api(): +def test_tool_use_block_excludes_fields_for_api() -> None: """Test that visual validation fields are excluded when for_api context is set.""" tool_block = ToolUseBlockParam( id="test_id", @@ -43,7 +43,7 @@ def test_tool_use_block_excludes_fields_for_api(): assert serialized["input"] == {"action": "left_click", "coordinate": [100, 200]} -def test_tool_use_block_without_visual_validation(): +def test_tool_use_block_without_visual_validation() -> None: """Test serialization of tool block without visual validation fields.""" tool_block = ToolUseBlockParam( id="test_id", @@ -57,11 +57,13 @@ def test_tool_use_block_without_visual_validation(): for_api = tool_block.model_dump(context={"for_api": True}) # Should not have visual representation field in either case (or it should be None) - assert "visual_representation" not in normal or normal["visual_representation"] is None + assert ( + "visual_representation" not in normal or normal["visual_representation"] is None + ) assert "visual_representation" not in for_api -def test_message_with_tool_use_context_propagation(): +def test_message_with_tool_use_context_propagation() -> None: """Test that context propagates through nested MessageParam serialization.""" tool_block = ToolUseBlockParam( id="test_id", @@ -82,7 +84,7 @@ def test_message_with_tool_use_context_propagation(): assert "visual_representation" not in for_api["content"][0] -def test_cache_storage_includes_all_fields(): +def test_cache_storage_includes_all_fields() -> None: """Test that cache storage (mode='json') includes all fields.""" tool_block = ToolUseBlockParam( id="test_id", diff --git a/tests/unit/models/test_token_counting_visual_validation.py b/tests/unit/models/test_token_counting_visual_validation.py index 68d798d0..f97de19d 100644 --- a/tests/unit/models/test_token_counting_visual_validation.py +++ b/tests/unit/models/test_token_counting_visual_validation.py @@ -4,7 +4,7 @@ from askui.models.shared.token_counter import SimpleTokenCounter -def test_token_counting_excludes_visual_validation_fields(): +def test_token_counting_excludes_visual_validation_fields() -> None: """Verify that visual validation fields don't inflate token counts.""" # Create two identical tool blocks, one with and one without visual validation tool_block_without = ToolUseBlockParam( @@ -38,7 +38,7 @@ def test_token_counting_excludes_visual_validation_fields(): ) -def test_token_counter_uses_api_context(): +def test_token_counter_uses_api_context() -> None: """Verify that token counter uses for_api context when stringifying objects.""" tool_block = ToolUseBlockParam( id="test_id", @@ -62,7 +62,7 @@ def test_token_counter_uses_api_context(): assert "computer" in stringified -def test_token_counting_with_multiple_tool_blocks(): +def test_token_counting_with_multiple_tool_blocks() -> None: """Test token counting with multiple tool blocks in one message.""" blocks = [ ToolUseBlockParam( diff --git a/tests/unit/utils/test_cache_writer.py b/tests/unit/utils/test_cache_writer.py index c9ce4a5d..6f6e84d7 100644 --- a/tests/unit/utils/test_cache_writer.py +++ b/tests/unit/utils/test_cache_writer.py @@ -168,8 +168,7 @@ def test_cache_writer_generate_writes_file() -> None: cache_writer = CacheWriter( cache_dir=str(cache_dir), cache_writing_settings=CacheWritingSettings( - filename="output.json", - parameter_identification_strategy="preset" + filename="output.json", parameter_identification_strategy="preset" ), ) @@ -224,8 +223,7 @@ def test_cache_writer_generate_auto_names_file() -> None: cache_writer = CacheWriter( cache_dir=str(cache_dir), cache_writing_settings=CacheWritingSettings( - filename="", - parameter_identification_strategy="preset" + filename="", parameter_identification_strategy="preset" ), ) @@ -407,8 +405,7 @@ def test_cache_writer_generate_resets_after_writing() -> None: cache_writer = CacheWriter( cache_dir=str(cache_dir), cache_writing_settings=CacheWritingSettings( - filename="test.json", - parameter_identification_strategy="preset" + filename="test.json", parameter_identification_strategy="preset" ), ) @@ -434,8 +431,7 @@ def test_cache_writer_detects_and_stores_parameters() -> None: cache_writer = CacheWriter( cache_dir=str(cache_dir), cache_writing_settings=CacheWritingSettings( - filename="test.json", - parameter_identification_strategy="preset" + filename="test.json", parameter_identification_strategy="preset" ), ) @@ -476,8 +472,7 @@ def test_cache_writer_empty_parameters_when_none_found() -> None: cache_writer = CacheWriter( cache_dir=str(cache_dir), cache_writing_settings=CacheWritingSettings( - filename="test.json", - parameter_identification_strategy="preset" + filename="test.json", parameter_identification_strategy="preset" ), ) diff --git a/tests/unit/utils/test_visual_validation.py b/tests/unit/utils/test_visual_validation.py index 161f3dcb..acc50a6c 100644 --- a/tests/unit/utils/test_visual_validation.py +++ b/tests/unit/utils/test_visual_validation.py @@ -17,7 +17,7 @@ class TestHashComputation: """Test hash computation functions.""" - def test_compute_phash_returns_hex_string(self): + def test_compute_phash_returns_hex_string(self) -> None: """Test that compute_phash returns a hexadecimal string.""" # Create a simple test image img = Image.new("RGB", (100, 100), color="red") @@ -29,7 +29,7 @@ def test_compute_phash_returns_hex_string(self): # Should be valid hex int(hash_result, 16) # Will raise if not valid hex - def test_compute_ahash_returns_hex_string(self): + def test_compute_ahash_returns_hex_string(self) -> None: """Test that compute_ahash returns a hexadecimal string.""" # Create a simple test image img = Image.new("RGB", (100, 100), color="blue") @@ -41,7 +41,7 @@ def test_compute_ahash_returns_hex_string(self): # Should be valid hex int(hash_result, 16) # Will raise if not valid hex - def test_identical_images_produce_same_phash(self): + def test_identical_images_produce_same_phash(self) -> None: """Test that identical images produce identical phashes.""" img1 = Image.new("RGB", (100, 100), color="green") img2 = Image.new("RGB", (100, 100), color="green") @@ -51,16 +51,16 @@ def test_identical_images_produce_same_phash(self): assert hash1 == hash2 - def test_different_images_produce_different_phash(self): + def test_different_images_produce_different_phash(self) -> None: """Test that different images produce different phashes.""" # Create images with patterns, not solid colors img1 = Image.new("RGB", (100, 100), color="white") draw1 = ImageDraw.Draw(img1) - draw1.rectangle([10, 10, 50, 50], fill="red") + draw1.rectangle((10, 10, 50, 50), fill="red") img2 = Image.new("RGB", (100, 100), color="white") draw2 = ImageDraw.Draw(img2) - draw2.rectangle([60, 60, 90, 90], fill="blue") + draw2.rectangle((60, 60, 90, 90), fill="blue") hash1 = compute_phash(img1) hash2 = compute_phash(img2) @@ -71,7 +71,7 @@ def test_different_images_produce_different_phash(self): class TestHammingDistance: """Test Hamming distance calculation.""" - def test_identical_hashes_have_zero_distance(self): + def test_identical_hashes_have_zero_distance(self) -> None: """Test that identical hashes have Hamming distance of 0.""" hash1 = "a1b2c3d4" hash2 = "a1b2c3d4" @@ -79,7 +79,7 @@ def test_identical_hashes_have_zero_distance(self): distance = hamming_distance(hash1, hash2) assert distance == 0 - def test_different_hashes_have_nonzero_distance(self): + def test_different_hashes_have_nonzero_distance(self) -> None: """Test that different hashes have non-zero Hamming distance.""" hash1 = "ffffffff" # All 1s in binary hash2 = "00000000" # All 0s in binary @@ -87,7 +87,7 @@ def test_different_hashes_have_nonzero_distance(self): distance = hamming_distance(hash1, hash2) assert distance > 0 - def test_hamming_distance_raises_on_different_lengths(self): + def test_hamming_distance_raises_on_different_lengths(self) -> None: """Test that hamming_distance raises ValueError for different lengths.""" hash1 = "a1b2" hash2 = "a1b2c3" @@ -99,7 +99,7 @@ def test_hamming_distance_raises_on_different_lengths(self): class TestExtractRegion: """Test region extraction from images.""" - def test_extract_region_returns_image(self): + def test_extract_region_returns_image(self) -> None: """Test that extract_region returns a PIL Image.""" img = Image.new("RGB", (200, 200), color="red") center = (100, 100) @@ -108,7 +108,7 @@ def test_extract_region_returns_image(self): assert isinstance(region, Image.Image) - def test_extract_region_has_correct_size(self): + def test_extract_region_has_correct_size(self) -> None: """Test that extracted region has correct size.""" img = Image.new("RGB", (200, 200), color="red") center = (100, 100) @@ -120,7 +120,7 @@ def test_extract_region_has_correct_size(self): assert region.width <= size assert region.height <= size - def test_extract_region_at_edge(self): + def test_extract_region_at_edge(self) -> None: """Test that extract_region handles edge cases.""" img = Image.new("RGB", (100, 100), color="red") center = (10, 10) # Near edge @@ -133,7 +133,7 @@ def test_extract_region_at_edge(self): class TestValidateVisualHash: """Test visual hash validation.""" - def test_validate_visual_hash_passes_for_identical_images(self): + def test_validate_visual_hash_passes_for_identical_images(self) -> None: """Test validation passes for identical images.""" img = Image.new("RGB", (100, 100), color="red") stored_hash = compute_phash(img) @@ -146,16 +146,16 @@ def test_validate_visual_hash_passes_for_identical_images(self): assert error_msg is None assert distance == 0 - def test_validate_visual_hash_fails_for_different_images(self): + def test_validate_visual_hash_fails_for_different_images(self) -> None: """Test validation fails for very different images.""" # Create images with different patterns img1 = Image.new("RGB", (100, 100), color="white") draw1 = ImageDraw.Draw(img1) - draw1.rectangle([10, 10, 50, 50], fill="red") + draw1.rectangle((10, 10, 50, 50), fill="red") img2 = Image.new("RGB", (100, 100), color="white") draw2 = ImageDraw.Draw(img2) - draw2.rectangle([60, 60, 90, 90], fill="blue") + draw2.rectangle((60, 60, 90, 90), fill="blue") stored_hash = compute_phash(img1) @@ -168,7 +168,7 @@ def test_validate_visual_hash_fails_for_different_images(self): assert error_msg is not None assert "Visual validation failed" in error_msg - def test_validate_visual_hash_with_ahash_method(self): + def test_validate_visual_hash_with_ahash_method(self) -> None: """Test validation works with ahash method.""" img = Image.new("RGB", (100, 100), color="green") stored_hash = compute_ahash(img) @@ -181,7 +181,7 @@ def test_validate_visual_hash_with_ahash_method(self): assert error_msg is None assert distance == 0 - def test_validate_visual_hash_unknown_method(self): + def test_validate_visual_hash_unknown_method(self) -> None: """Test validation fails gracefully with unknown hash method.""" img = Image.new("RGB", (100, 100), color="red") stored_hash = "abcdef" @@ -198,23 +198,23 @@ def test_validate_visual_hash_unknown_method(self): class TestShouldValidateStep: """Test step validation logic.""" - def test_should_validate_left_click(self): + def test_should_validate_left_click(self) -> None: """Test that left_click actions should be validated.""" assert should_validate_step("computer", "left_click") is True - def test_should_validate_right_click(self): + def test_should_validate_right_click(self) -> None: """Test that right_click actions should be validated.""" assert should_validate_step("computer", "right_click") is True - def test_should_validate_type_action(self): + def test_should_validate_type_action(self) -> None: """Test that type actions should be validated.""" assert should_validate_step("computer", "type") is True - def test_should_not_validate_screenshot(self): + def test_should_not_validate_screenshot(self) -> None: """Test that screenshot actions should not be validated.""" assert should_validate_step("computer", "screenshot") is False - def test_should_not_validate_unknown_tool(self): + def test_should_not_validate_unknown_tool(self) -> None: """Test that unknown tools should not be validated.""" assert should_validate_step("unknown_tool", None) is False @@ -222,7 +222,7 @@ def test_should_not_validate_unknown_tool(self): class TestGetValidationCoordinate: """Test coordinate extraction for validation.""" - def test_get_validation_coordinate_from_computer_tool(self): + def test_get_validation_coordinate_from_computer_tool(self) -> None: """Test extracting coordinate from computer tool input.""" tool_input = {"action": "left_click", "coordinate": [450, 300]} @@ -230,7 +230,7 @@ def test_get_validation_coordinate_from_computer_tool(self): assert coord == (450, 300) - def test_get_validation_coordinate_returns_none_without_coordinate(self): + def test_get_validation_coordinate_returns_none_without_coordinate(self) -> None: """Test returns None when no coordinate in input.""" tool_input = {"action": "screenshot"} @@ -238,7 +238,7 @@ def test_get_validation_coordinate_returns_none_without_coordinate(self): assert coord is None - def test_get_validation_coordinate_handles_invalid_format(self): + def test_get_validation_coordinate_handles_invalid_format(self) -> None: """Test handles invalid coordinate format gracefully.""" tool_input = {"coordinate": "invalid"} From 9099c1462b7be64172abe88ec1a7d325092b9daa Mon Sep 17 00:00:00 2001 From: philipph-askui Date: Thu, 8 Jan 2026 15:00:27 +0100 Subject: [PATCH 30/30] fix(caching): fix bug that caused a crash when parameters contained characters forbidden in regex --- src/askui/utils/cache_parameter_handler.py | 7 +- tests/e2e/agent/test_act_caching.py | 2 +- .../utils/test_cache_parameter_handler.py | 71 +++++++++++++++++++ 3 files changed, 76 insertions(+), 4 deletions(-) diff --git a/src/askui/utils/cache_parameter_handler.py b/src/askui/utils/cache_parameter_handler.py index 953eb36b..b0eaaf27 100644 --- a/src/askui/utils/cache_parameter_handler.py +++ b/src/askui/utils/cache_parameter_handler.py @@ -470,11 +470,12 @@ def _substitute_in_value(value: Any, parameter_values: dict[str, str]) -> Any: New value with parameters substituted """ if isinstance(value, str): - # Replace all parameters in the string + # Replace all parameters in the string using literal string replacement + # This avoids regex interpretation issues with backslashes in values result = value for name, replacement in parameter_values.items(): - pattern = r"\{\{" + re.escape(name) + r"\}\}" - result = re.sub(pattern, replacement, result) + placeholder = f"{{{{{name}}}}}" # Creates "{{parameter_name}}" + result = result.replace(placeholder, replacement) return result if isinstance(value, dict): # Recursively substitute in dict values diff --git a/tests/e2e/agent/test_act_caching.py b/tests/e2e/agent/test_act_caching.py index 568ef5a5..1166826a 100644 --- a/tests/e2e/agent/test_act_caching.py +++ b/tests/e2e/agent/test_act_caching.py @@ -113,7 +113,7 @@ def test_act_with_custom_cache_dir_and_filename(vision_agent: VisionAgent) -> No vision_agent.act( goal="Tell me a joke", caching_settings=CachingSettings( - strategy="execute", + strategy="record", cache_dir=str(custom_cache_dir), writing_settings=CacheWritingSettings( filename=custom_filename, diff --git a/tests/unit/utils/test_cache_parameter_handler.py b/tests/unit/utils/test_cache_parameter_handler.py index e03e2ec6..287409e7 100644 --- a/tests/unit/utils/test_cache_parameter_handler.py +++ b/tests/unit/utils/test_cache_parameter_handler.py @@ -368,6 +368,77 @@ def test_substitute_parameters_with_special_characters() -> None: assert result.input["text"] == r"Pattern: .*[test]$" # type: ignore[index] +def test_substitute_parameters_with_backslashes() -> None: + """Test substitution with values containing backslashes (e.g., Windows paths). + + This test reveals the bug where backslashes in parameter values cause + re.PatternError because re.sub() treats the replacement string as a + regex replacement pattern, not a literal string. + """ + tool_block = ToolUseBlockParam( + id="1", + name="tool", + input={"text": "Open file at {{file_path}}"}, + type="tool_use", + ) + + # Value contains backslashes (like Windows paths) + # Previously raised: re.PatternError: bad escape \A at position 2 + result = CacheParameterHandler.substitute_parameters( + tool_block, {"file_path": "C:\\AskUI\\test"} + ) + + assert result.input["text"] == "Open file at C:\\AskUI\\test" # type: ignore[index] + + +def test_substitute_parameters_with_various_backslash_sequences() -> None: + """Test substitution with various backslash escape sequences. + + Tests multiple scenarios where backslashes could cause issues: + - Windows UNC paths + - Regex patterns as values + - Various escape sequences that could be misinterpreted + """ + # Test case 1: Windows path with \D + tool_block = ToolUseBlockParam( + id="1", + name="tool", + input={"text": "Save to {{path}}"}, + type="tool_use", + ) + + result = CacheParameterHandler.substitute_parameters( + tool_block, {"path": "D:\\Data\\file.txt"} + ) + assert result.input["text"] == "Save to D:\\Data\\file.txt" # type: ignore[index] + + # Test case 2: UNC path + tool_block2 = ToolUseBlockParam( + id="2", + name="tool", + input={"text": "Network path: {{unc_path}}"}, + type="tool_use", + ) + + result2 = CacheParameterHandler.substitute_parameters( + tool_block2, {"unc_path": "\\\\server\\share\\folder"} + ) + assert result2.input["text"] == "Network path: \\\\server\\share\\folder" # type: ignore[index] + + # Test case 3: Regex pattern as value (with backslashes) + tool_block3 = ToolUseBlockParam( + id="3", + name="tool", + input={"text": "Match pattern {{regex}}"}, + type="tool_use", + ) + + result3 = CacheParameterHandler.substitute_parameters( + tool_block3, {"regex": "\\d+\\s+\\w+"} + ) + assert result3.input["text"] == "Match pattern \\d+\\s+\\w+" # type: ignore[index] + + def test_substitute_parameters_same_parameter_multiple_times() -> None: """Test substituting the same parameter appearing multiple times.""" tool_block = ToolUseBlockParam(