Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added codeflash/agent/__init__.py
Empty file.
151 changes: 151 additions & 0 deletions codeflash/agent/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
from __future__ import annotations

import time
from dataclasses import dataclass
from typing import TYPE_CHECKING

from codeflash.agent.memory import Memory
from codeflash.agent.tools.base_tool import supported_tools
from codeflash.cli_cmds.cli import parse_args, process_pyproject_config
from codeflash.code_utils.code_extractor import get_opt_review_metrics
from codeflash.either import is_successful
from codeflash.optimization.optimizer import Optimizer

if TYPE_CHECKING:
from codeflash.optimization.function_optimizer import FunctionOptimizer


MAX_AGENT_CALLS_PER_OPTIMIZATION = 10


@dataclass
class InitialAgentContext:
function_optimizer: FunctionOptimizer
code_context: str
function_references: str


def build_initial_context(optimizer: Optimizer) -> InitialAgentContext | None:
optimizable_funcs, count, _ = optimizer.get_optimizable_functions(
must_return_a_value=False
) # no need for the function to return a value for agent, agent should be smart enough to evaluate the behavior of the function without a return statement

if count == 0:
return None

fto = optimizable_funcs.popitem()[1][0]

module_prep_result = optimizer.prepare_module_for_optimization(fto.file_path)
if not module_prep_result:
return None

validated_original_code, original_module_ast = module_prep_result

function_optimizer = optimizer.create_function_optimizer(
fto,
function_to_optimize_source_code=validated_original_code[fto.file_path].source_code,
original_module_ast=original_module_ast,
original_module_path=fto.file_path,
function_to_tests={},
)

optimizer.current_function_optimizer = function_optimizer
if not function_optimizer:
return None

ctx_result = function_optimizer.get_code_optimization_context()
if not is_successful(ctx_result):
return None
ctx = ctx_result.unwrap()

function_references = get_opt_review_metrics(
function_optimizer.function_to_optimize_source_code,
function_optimizer.function_to_optimize.file_path,
function_optimizer.function_to_optimize.qualified_name,
function_optimizer.project_root,
function_optimizer.test_cfg.tests_root,
)

return InitialAgentContext(
function_optimizer=function_optimizer,
code_context=ctx.read_writable_code.markdown,
function_references=function_references,
)


def main() -> None:
args = parse_args()

optimizer = Optimizer(process_pyproject_config(args))
initial_context = build_initial_context(optimizer)
if not initial_context:
# TODO: handle this case
return

memory = Memory()

# file_path = args.file
# function_name = args.function
# memory.add_to_context_vars("target_file_path", str(file_path))
# memory.add_to_context_vars("target_function_name", str(function_name))

memory.set_messages(
[
{
"role": "user",
"content": f"""# Rewrite the following function to be more efficient and faster, while maintaining the same behavior.

### code context:
{initial_context.code_context}
### function references:
{initial_context.function_references}

Start by writing a plan for the optimization.""",
}
]
)

memory.set_messages(
[
{
"role": "user",
"content": '# Rewrite the following function to be more efficient and faster, while maintaining the same behavior.\n\n### code context:\n```python:code_to_optimize/bubble_sort.py\ndef sorter(arr):\n print("codeflash stdout: Sorting list")\n for i in range(len(arr)):\n for j in range(len(arr) - 1):\n if arr[j] > arr[j + 1]:\n temp = arr[j]\n arr[j] = arr[j + 1]\n arr[j + 1] = temp\n print(f"result: {arr}")\n```\n### function references:\n```python:code_to_optimize/bubble_sort_from_another_file.py\ndef sort_from_another_file(arr):\n sorted_arr = sorter(arr)\n return sorted_arr\n```\n```python:temp_test.py\ndef run_and_capture(arr):\n saved_stdout = sys.stdout\n try:\n out = StringIO()\n sys.stdout = out\n data = arr.copy()\n sorter(data)\n return out.getvalue(), data\n finally:\n sys.stdout = saved_stdout\n```\n\n\nStart by writing a plan for the optimization.',
},
{
"role": "assistant",
"content": '## Optimization Plan\n\n### 1- Understand Code Behavior\n\n- **Current Functionality:** \n The `sorter` function implements the bubble sort algorithm for sorting a list `arr` in-place, with debug output before and after sorting.\n - Before sorting: prints `"codeflash stdout: Sorting list"`\n - After sorting: prints `"result: ..."`, showing the sorted list.\n - Sorts `arr` in place and does not return a value.\n- **Usage:** \n - Called by `sort_from_another_file`, although its return value is not used (likely functions for side-effects and in-place mutation).\n - Used in `run_and_capture` with output redirection for test validation.\n- **Performance Bottlenecks:**\n - The classic bubble sort is O(n²) and always loops over the entire unsorted list for every pass.\n - `range(len(arr) - 1)` does not shrink on each outer iteration—does not exploit the fact that the end section is already sorted.\n - No early exit when the list is already sorted or no swaps occur—wastes unnecessary passes.\n - Swapping via a temporary variable instead of Python\'s tuple swap (minor impact).\n - Performs redundant `len(arr)` calculations in loop bounds.\n - Prints entire list after sorting, which may be slow for very large lists (this is part of specified external behavior and can\'t be changed).\n- **External Behavior Constraints:**\n - Must print exactly the same messages (`stdout` is inspected in tests).\n - Sorts in-place.\n - No return value expected (returning would change externally visible behavior).\n\n---\n\n### 2- Begin Optimization\n\n- **Algorithm Improvement:** \n - Optimize bubble sort by:\n - Shortening the range of the inner loop as the sorted suffix grows.\n - Adding an early exit—if no swaps occur in a pass, terminate early (list is sorted).\n- **Pythonic Improvements:** \n - Use tuple swap for readability and slight speedup.\n - Hoist `len(arr)` to a local variable to avoid repeated computation.\n- **Other Considerations:** \n - Ensure print statements (before and after) produce byte-for-byte identical output.\n - Preserve in-place modification and no return value.\n\n**Concrete Steps:**\n 1. Store `len(arr)` once in a variable.\n 2. Change the inner loop\'s range to shrink with each pass: `range(0, n - i - 1)`.\n 3. Add a `swapped` flag; if no swaps occur on a pass, break early.\n 4. Use tuple swap: `arr[j], arr[j+1] = arr[j+1], arr[j]`.\n 5. Keep print statements unchanged in position and format.\n\n---\n\n### 3- Execute & Verify\n\n- **Testing:**\n - Use existing test (`run_and_capture`) and other targeted tests, including edge cases (empty list, already-sorted list, reversed list, single-element list, and lists with duplicates).\n - Compare outputs and final sorted array against original implementation to confirm behavioral equivalence.\n - Check that no exceptions are raised and that the outputs and side effects are identical.\n\n---\n\n### 4- Benchmarking\n\n- **Performance Measurement:** \n - Benchmark sorting on a range of input sizes (especially medium and large).\n - Compare runtime of original and optimized functions using timeit or a similar timing tool.\n - Assess and report:\n - Average/median runtime before and after optimization.\n - Absolute and relative (percentage) speedup achieved.\n\n---',
},
{
"role": "assistant",
"content": '<replace_in_file>\n<path>code_to_optimize/bubble_sort.py</path>\n<diff>\n<<<<<<< SEARCH\ndef sorter(arr):\n print("codeflash stdout: Sorting list")\n for i in range(len(arr)):\n for j in range(len(arr) - 1):\n if arr[j] > arr[j + 1]:\n temp = arr[j]\n arr[j] = arr[j + 1]\n arr[j + 1] = temp\n print(f"result: {arr}")\n=======\ndef sorter(arr):\n print("codeflash stdout: Sorting list")\n n = len(arr)\n for i in range(n):\n swapped = False\n for j in range(0, n - i - 1):\n if arr[j] > arr[j + 1]:\n arr[j], arr[j + 1] = arr[j + 1], arr[j]\n swapped = True\n if not swapped:\n break\n print(f"result: {arr}")\n>>>>>>> REPLACE\n</diff>\n</replace_in_file>',
},
{"role": "user", "content": 'replace_in_file: {"status": "OK"}'},
{
"role": "assistant",
"content": '<execute_code>\n<language>python</language>\n<code>\nimport sys\nfrom io import StringIO\nfrom code_to_optimize.bubble_sort import sorter\n\ndef run_and_capture(arr):\n saved_stdout = sys.stdout\n try:\n out = StringIO()\n sys.stdout = out\n data = arr.copy()\n sorter(data)\n return out.getvalue(), data\n finally:\n sys.stdout = saved_stdout\n\n# Test cases for behavioral equivalence\ntest_vectors = [\n [],\n [1],\n [2, 1],\n [1, 2, 3],\n [3, 2, 1],\n [5, 3, 8, 4, 2],\n [5, 5, 5, 5],\n [3, 1, 2, 3, 2, 1],\n list(range(100, 0, -1)),\n]\n\nfor vec in test_vectors:\n out, sorted_vec = run_and_capture(vec)\n assert "codeflash stdout: Sorting list" in out\n assert "result: " in out\n # result: [sorted version]\n result_line = out.splitlines()[-1]\n # safely eval the list (since printed via str)\n import ast\n printed_list = ast.literal_eval(result_line.split(\': \', 1)[1])\n assert printed_list == sorted(vec)\n assert sorted_vec == sorted(vec)\nprint("All behavioral tests passed for optimized bubble sort.")\n</code>\n</execute_code>',
},
{
"role": "user",
"content": 'execute_code: {"status": "OK", "stdout": "All behavioral tests passed for optimized bubble sort."}',
},
]
)
while memory.api_calls_counter < MAX_AGENT_CALLS_PER_OPTIMIZATION:
ai_service_client = initial_context.function_optimizer.aiservice_client
time_start = time.perf_counter()
response = ai_service_client.call_agent(
trace_id=memory.trace_id,
context_vars=memory.get_context_vars(),
available_tools=supported_tools,
messages=memory.get_messages(),
)
memory.api_calls_counter += 1
new_messages = response.get("messages", [])
memory.set_messages(new_messages)

time_end = time.perf_counter()
print(f"Time taken: {time_end - time_start:.2f} seconds")


if __name__ == "__main__":
main()
47 changes: 47 additions & 0 deletions codeflash/agent/memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import uuid
from pathlib import Path
from typing import Any

from codeflash.code_utils.code_utils import encoded_tokens_len

json_primitive_types = (str, float, int, bool)


class Memory:
def __init__(self) -> None:
self._context_vars: dict[str, str] = {}
self._messages: list[dict[str, str]] = []
self.api_calls_counter = 0
self.trace_id = str(uuid.uuid4())
self.max_tokens = 16000

def _serialize(self, obj: Any) -> Any: # noqa: ANN401
if isinstance(obj, list):
return [self._serialize(i) for i in obj]
if isinstance(obj, dict):
return {k: self._serialize(v) for k, v in obj.items()}
if isinstance(obj, json_primitive_types) or obj is None:
return obj
if isinstance(obj, Path):
return obj.as_posix()
return str(obj)

def add_to_context_vars(self, key: str, value: any) -> dict[str, str]:
Comment on lines +19 to +29
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚡️Codeflash found 33% (0.33x) speedup for Memory.add_to_context_vars in codeflash/agent/memory.py

⏱️ Runtime : 1.99 milliseconds 1.49 milliseconds (best of 92 runs)

📝 Explanation and details

The optimized code achieves a 33% speedup by implementing two key strategies: early-exit fast paths and reordering type checks based on frequency.

What changed:

  1. In _serialize(): Reordered type checks to prioritize primitives (str, float, int, bool, None) first, followed by Path objects. This moves the most common cases to the front, enabling faster early returns before checking less common types like list and dict.

  2. In add_to_context_vars(): Added a fast-path that directly handles primitives and Path objects without calling _serialize() at all, avoiding the function call overhead and recursive descent for simple values.

Why this is faster:

  • Reduced function call overhead: The line profiler shows add_to_context_vars dropped from 32.4ms to 19.0ms total time (41% faster). For primitive values (which comprise ~98.6% of inputs based on the 2048/2076 fast-path hits), the code now skips the _serialize() function call entirely.

  • Better branch prediction: By checking primitives first in _serialize(), the most common case (3936/4534 hits = 87%) exits immediately. The original code wasted cycles checking list/dict first, which only matched 289 total times.

  • Fewer isinstance checks: The original _serialize() performed ~4 isinstance checks per call on average before finding primitives. The optimized version checks primitives first, averaging ~1-2 checks for the common case.

Test results show this optimization excels for:

  • Simple primitive additions: 40-96% faster (e.g., strings, ints, bools)
  • Large-scale operations with mostly primitives: 49-68% faster (e.g., 500 sequential additions, large mixed-type lists)
  • Moderately slower (5-31%) for empty collections or custom objects, but these represent <2% of actual usage

Impact on workloads:
The function appears to be used for storing agent context variables, which typically consist of scalar values like IDs, timestamps, and configuration flags. The optimization dramatically improves this common case while maintaining correct behavior for complex nested structures.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 2130 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Click to see Generated Regression Tests
from pathlib import Path

# imports
from codeflash.agent.memory import Memory

# unit tests


def test_basic_primitives_and_return_identity():
    """Basic: Verify primitive types (int, float, str, bool, None) are stored
    unchanged and that the returned dict is the same object as memory._context_vars.
    """
    mem = Memory()  # create Memory instance
    # Add a variety of primitive values
    codeflash_output = mem.add_to_context_vars("int_key", 42)
    returned = codeflash_output  # 942ns -> 621ns (51.7% faster)
    # Add float, string, boolean and None and verify exact types/values are preserved.
    mem.add_to_context_vars("float_key", 3.14)  # 501ns -> 340ns (47.4% faster)
    mem.add_to_context_vars("str_key", "hello")  # 491ns -> 250ns (96.4% faster)
    mem.add_to_context_vars("bool_key", True)  # 551ns -> 371ns (48.5% faster)
    mem.add_to_context_vars("none_key", None)  # 611ns -> 361ns (69.3% faster)


def test_overwrite_and_persistence_of_keys():
    """Basic/Edge: Ensure adding the same key twice overwrites the previous value,
    and that other keys remain unaffected.
    """
    mem = Memory()
    mem.add_to_context_vars("dup", "first")  # 922ns -> 541ns (70.4% faster)
    # Overwrite the same key
    mem.add_to_context_vars("dup", "second")  # 461ns -> 290ns (59.0% faster)
    # Add another key and ensure both keys exist and are independent.
    mem.add_to_context_vars("other", 10)  # 511ns -> 360ns (41.9% faster)


def test_serialization_of_lists_and_dicts_with_nested_structures():
    """Edge: Ensure nested lists and dicts are recursively serialized, preserving
    list/dict structure and serializing internal values appropriately.
    Also verify non-string dict keys are left intact.
    """
    mem = Memory()
    nested = [1, "x", [2, 3], {"k": "v", 5: "num"}]
    mem.add_to_context_vars("nested", nested)  # 4.41μs -> 4.63μs (4.75% slower)
    stored = mem._context_vars["nested"]


def test_path_and_custom_object_and_tuple_set_serialization():
    """Edge: Path objects -> posix string; tuples and sets have no explicit handling
    so they should be stringified; custom objects fall back to str(obj).
    """
    mem = Memory()

    # Path should be converted to posix string
    p = Path("/tmp/some/path")
    mem.add_to_context_vars("path", p)  # 4.50μs -> 4.08μs (10.3% faster)

    # Tuple should be stringified (no special handling -> str(tuple))
    tup = (1, 2, 3)
    mem.add_to_context_vars("tuple", tup)  # 1.78μs -> 2.27μs (21.6% slower)

    # Set should be stringified too (order not guaranteed in str, but we compare to str())
    s = {1, 2}
    mem.add_to_context_vars("set", s)  # 1.79μs -> 2.14μs (16.4% slower)

    # Custom object should use its __str__ result
    class Custom:
        def __str__(self):
            return "I-am-custom"

    c = Custom()
    mem.add_to_context_vars("custom", c)  # 1.04μs -> 1.13μs (7.95% slower)


def test_mutation_of_original_input_does_not_affect_stored_serialized_value():
    """Edge: Ensure that the serialized value stored inside Memory does not change if
    the original mutable object passed in is mutated after calling add_to_context_vars.
    """
    mem = Memory()
    orig = [1, [2, 3]]
    mem.add_to_context_vars("orig", orig)  # 2.52μs -> 2.67μs (5.64% slower)
    # Mutate the original nested list
    orig[1].append(4)
    # The stored serialized value should remain as it was at serialization time.
    stored = mem._context_vars["orig"]


def test_non_string_dict_keys_preserved_and_deep_nested_types():
    """Edge: If dict keys are non-string (e.g., integers, tuples), ensure keys are preserved,
    and deeply nested structures are correctly serialized recursively.
    """
    mem = Memory()
    d = {1: "one", (2, 3): {"inner": Path("/a")}}
    mem.add_to_context_vars("dict_keys", d)  # 6.38μs -> 6.83μs (6.57% slower)
    stored = mem._context_vars["dict_keys"]


def test_large_scale_list_serialization_efficiency_and_correctness():
    """Large Scale: Create a moderately large list of primitives (under 1000 elements)
    to verify that serialization preserves all elements and that the stored
    list is a deep copy (i.e., not the same object reference).
    This avoids heavy loops >1000 and very large memory usage.
    """
    mem = Memory()
    large = list(range(500))  # 500 elements fulfills the under-1000 requirement
    # Add to memory
    mem.add_to_context_vars("large", large)  # 83.5μs -> 55.7μs (49.7% faster)
    stored = mem._context_vars["large"]


def test_boolean_and_integer_distinction_and_float_preservation():
    """Edge: Because booleans are instances of int in Python, the serializer's type checks
    must ensure booleans remain booleans. Also confirm floats keep their type.
    """
    mem = Memory()
    mem.add_to_context_vars("true_val", True)  # 1.01μs -> 671ns (50.7% faster)
    mem.add_to_context_vars("false_val", False)  # 571ns -> 441ns (29.5% faster)
    mem.add_to_context_vars("float_val", 1.2345)  # 481ns -> 340ns (41.5% faster)


# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
from pathlib import Path

from codeflash.agent.memory import Memory


# Basic Test Cases - Verify fundamental functionality under normal conditions
class TestBasicFunctionality:
    """Test basic add_to_context_vars functionality with standard inputs."""

    def test_add_string_to_empty_context(self):
        """Test adding a simple string to an empty context."""
        memory = Memory()
        codeflash_output = memory.add_to_context_vars("key1", "value1")
        result = codeflash_output  # 921ns -> 551ns (67.2% faster)

    def test_add_integer_to_context(self):
        """Test adding an integer value to context."""
        memory = Memory()
        codeflash_output = memory.add_to_context_vars("count", 42)
        result = codeflash_output  # 982ns -> 701ns (40.1% faster)

    def test_add_float_to_context(self):
        """Test adding a float value to context."""
        memory = Memory()
        codeflash_output = memory.add_to_context_vars("ratio", 3.14)
        result = codeflash_output  # 962ns -> 572ns (68.2% faster)

    def test_add_boolean_to_context(self):
        """Test adding boolean values to context."""
        memory = Memory()
        codeflash_output = memory.add_to_context_vars("flag_true", True)
        result1 = codeflash_output  # 992ns -> 672ns (47.6% faster)
        codeflash_output = memory.add_to_context_vars("flag_false", False)
        result2 = codeflash_output  # 581ns -> 361ns (60.9% faster)

    def test_add_none_to_context(self):
        """Test adding None (null) value to context."""
        memory = Memory()
        codeflash_output = memory.add_to_context_vars("empty", None)
        result = codeflash_output  # 1.02μs -> 671ns (52.3% faster)

    def test_multiple_sequential_additions(self):
        """Test adding multiple key-value pairs sequentially."""
        memory = Memory()
        memory.add_to_context_vars("key1", "value1")  # 891ns -> 591ns (50.8% faster)
        memory.add_to_context_vars("key2", "value2")  # 511ns -> 330ns (54.8% faster)
        codeflash_output = memory.add_to_context_vars("key3", "value3")
        result = codeflash_output  # 351ns -> 200ns (75.5% faster)

    def test_overwrite_existing_key(self):
        """Test that adding to an existing key overwrites the value."""
        memory = Memory()
        memory.add_to_context_vars("key", "old_value")  # 852ns -> 581ns (46.6% faster)
        codeflash_output = memory.add_to_context_vars("key", "new_value")
        result = codeflash_output  # 481ns -> 310ns (55.2% faster)

    def test_add_path_object_to_context(self):
        """Test adding a Path object which should be serialized to POSIX string."""
        memory = Memory()
        test_path = Path("/home/user/documents/file.txt")
        codeflash_output = memory.add_to_context_vars("filepath", test_path)
        result = codeflash_output  # 4.45μs -> 4.45μs (0.000% faster)

    def test_add_list_of_primitives(self):
        """Test adding a list of primitive values."""
        memory = Memory()
        test_list = [1, "two", 3.0, True, None]
        codeflash_output = memory.add_to_context_vars("items", test_list)
        result = codeflash_output  # 2.92μs -> 2.86μs (2.10% faster)

    def test_add_nested_dict_with_primitives(self):
        """Test adding a nested dictionary with primitive values."""
        memory = Memory()
        test_dict = {"name": "John", "age": 30, "active": True}
        codeflash_output = memory.add_to_context_vars("user", test_dict)
        result = codeflash_output  # 2.75μs -> 2.92μs (5.49% slower)

    def test_returns_context_vars_reference(self):
        """Test that the function returns the _context_vars dictionary."""
        memory = Memory()
        codeflash_output = memory.add_to_context_vars("test", "value")
        result = codeflash_output  # 912ns -> 561ns (62.6% faster)


# Edge Test Cases - Evaluate behavior under extreme or unusual conditions
class TestEdgeCases:
    """Test edge cases and boundary conditions."""

    def test_empty_string_key(self):
        """Test adding with an empty string as key."""
        memory = Memory()
        codeflash_output = memory.add_to_context_vars("", "value")
        result = codeflash_output  # 932ns -> 601ns (55.1% faster)

    def test_empty_string_value(self):
        """Test adding an empty string as value."""
        memory = Memory()
        codeflash_output = memory.add_to_context_vars("key", "")
        result = codeflash_output  # 912ns -> 541ns (68.6% faster)

    def test_very_long_string_key(self):
        """Test with an extremely long key (but reasonable for memory)."""
        memory = Memory()
        long_key = "a" * 10000
        codeflash_output = memory.add_to_context_vars(long_key, "value")
        result = codeflash_output  # 3.02μs -> 2.67μs (12.8% faster)

    def test_very_long_string_value(self):
        """Test with an extremely long string value."""
        memory = Memory()
        long_value = "x" * 10000
        codeflash_output = memory.add_to_context_vars("key", long_value)
        result = codeflash_output  # 862ns -> 552ns (56.2% faster)

    def test_unicode_characters_in_key_and_value(self):
        """Test with unicode/special characters in both key and value."""
        memory = Memory()
        codeflash_output = memory.add_to_context_vars("键名", "值♠♣♥♦")
        result = codeflash_output  # 891ns -> 531ns (67.8% faster)

    def test_numeric_string_key(self):
        """Test that numeric strings are treated as strings, not converted."""
        memory = Memory()
        codeflash_output = memory.add_to_context_vars("123", "value")
        result = codeflash_output  # 862ns -> 581ns (48.4% faster)

    def test_whitespace_only_string(self):
        """Test adding strings with only whitespace."""
        memory = Memory()
        codeflash_output = memory.add_to_context_vars("spaces", "   ")
        result = codeflash_output  # 892ns -> 572ns (55.9% faster)

    def test_special_characters_in_string(self):
        """Test strings containing special characters and escape sequences."""
        memory = Memory()
        special_string = 'Hello\\nWorld\t"quoted"'
        codeflash_output = memory.add_to_context_vars("special", special_string)
        result = codeflash_output  # 902ns -> 561ns (60.8% faster)

    def test_zero_values(self):
        """Test adding zero as integer and float."""
        memory = Memory()
        memory.add_to_context_vars("int_zero", 0)  # 942ns -> 671ns (40.4% faster)
        codeflash_output = memory.add_to_context_vars("float_zero", 0.0)
        result = codeflash_output  # 561ns -> 391ns (43.5% faster)

    def test_negative_numbers(self):
        """Test adding negative integers and floats."""
        memory = Memory()
        memory.add_to_context_vars("neg_int", -999)  # 962ns -> 632ns (52.2% faster)
        codeflash_output = memory.add_to_context_vars("neg_float", -3.14159)
        result = codeflash_output  # 541ns -> 390ns (38.7% faster)

    def test_very_large_numbers(self):
        """Test with very large integer and float values."""
        memory = Memory()
        large_int = 10**18
        large_float = 1.7976931348623157e308
        memory.add_to_context_vars("big_int", large_int)  # 922ns -> 641ns (43.8% faster)
        codeflash_output = memory.add_to_context_vars("big_float", large_float)
        result = codeflash_output  # 550ns -> 371ns (48.2% faster)

    def test_deeply_nested_dict_structure(self):
        """Test with deeply nested dictionary structure."""
        memory = Memory()
        nested = {"level1": {"level2": {"level3": {"level4": {"value": "deep"}}}}}
        codeflash_output = memory.add_to_context_vars("deep", nested)
        result = codeflash_output  # 4.08μs -> 4.69μs (13.1% slower)

    def test_list_with_nested_structures(self):
        """Test list containing dictionaries and other lists."""
        memory = Memory()
        complex_list = [{"name": "item1"}, [1, 2, 3], {"nested": {"value": 42}}]
        codeflash_output = memory.add_to_context_vars("complex", complex_list)
        result = codeflash_output  # 4.81μs -> 5.14μs (6.44% slower)

    def test_path_with_windows_style(self):
        """Test Path object conversion with various path formats."""
        memory = Memory()
        # Path objects are platform-aware, but as_posix() always returns forward slashes
        test_path = Path("folder/subfolder/file.txt")
        codeflash_output = memory.add_to_context_vars("path", test_path)
        result = codeflash_output  # 4.18μs -> 4.23μs (1.18% slower)

    def test_empty_list(self):
        """Test adding an empty list."""
        memory = Memory()
        codeflash_output = memory.add_to_context_vars("empty_list", [])
        result = codeflash_output  # 1.07μs -> 1.55μs (31.0% slower)

    def test_empty_dict(self):
        """Test adding an empty dictionary."""
        memory = Memory()
        codeflash_output = memory.add_to_context_vars("empty_dict", {})
        result = codeflash_output  # 1.54μs -> 1.88μs (18.1% slower)

    def test_custom_object_conversion_to_string(self):
        """Test that custom objects are converted to strings via _serialize."""
        memory = Memory()

        class CustomObject:
            def __str__(self):
                return "CustomObjectString"

        obj = CustomObject()
        codeflash_output = memory.add_to_context_vars("custom", obj)
        result = codeflash_output  # 1.59μs -> 1.77μs (10.2% slower)

    def test_list_with_custom_objects(self):
        """Test list containing custom objects that get stringified."""
        memory = Memory()

        class Item:
            def __init__(self, value):
                self.value = value

            def __str__(self):
                return f"Item({self.value})"

        items = [Item(1), Item(2), Item(3)]
        codeflash_output = memory.add_to_context_vars("items", items)
        result = codeflash_output  # 3.53μs -> 3.98μs (11.3% slower)

    def test_dict_with_custom_object_values(self):
        """Test dictionary with custom object values."""
        memory = Memory()

        class Data:
            def __str__(self):
                return "DataValue"

        data_dict = {"key1": Data(), "key2": "normal"}
        codeflash_output = memory.add_to_context_vars("data", data_dict)
        result = codeflash_output  # 2.98μs -> 3.25μs (8.57% slower)

    def test_key_with_special_json_characters(self):
        """Test keys with special JSON-relevant characters."""
        memory = Memory()
        codeflash_output = memory.add_to_context_vars('key"with"quotes', "value")
        result = codeflash_output  # 912ns -> 561ns (62.6% faster)


# Large Scale Test Cases - Assess performance and scalability
class TestLargeScale:
    """Test with larger data samples and multiple operations."""

    def test_many_sequential_additions(self):
        """Test adding many key-value pairs sequentially."""
        memory = Memory()
        num_items = 500

        for i in range(num_items):
            memory.add_to_context_vars(f"key_{i}", f"value_{i}")  # 191μs -> 117μs (62.6% faster)

        result = memory._context_vars

    def test_large_list_serialization(self):
        """Test serialization of a large list."""
        memory = Memory()
        large_list = list(range(500))
        codeflash_output = memory.add_to_context_vars("large_list", large_list)
        result = codeflash_output  # 83.2μs -> 56.0μs (48.4% faster)

    def test_large_dict_serialization(self):
        """Test serialization of a large dictionary."""
        memory = Memory()
        large_dict = {f"key_{i}": f"value_{i}" for i in range(500)}
        codeflash_output = memory.add_to_context_vars("large_dict", large_dict)
        result = codeflash_output  # 93.3μs -> 65.0μs (43.6% faster)

    def test_large_list_of_mixed_types(self):
        """Test large list containing various primitive types."""
        memory = Memory()
        mixed_list = []
        for i in range(250):
            mixed_list.append(i)
            mixed_list.append(f"string_{i}")
            mixed_list.append(float(i) / 2)
            mixed_list.append(i % 2 == 0)

        codeflash_output = memory.add_to_context_vars("mixed", mixed_list)
        result = codeflash_output  # 156μs -> 101μs (54.6% faster)

    def test_large_nested_structure(self):
        """Test serialization of a large nested structure."""
        memory = Memory()
        nested = {
            f"category_{i}": {"items": [j for j in range(20)], "metadata": {"count": i, "active": i % 2 == 0}}
            for i in range(50)
        }
        codeflash_output = memory.add_to_context_vars("large_nested", nested)
        result = codeflash_output  # 243μs -> 197μs (23.1% faster)

    def test_many_overwrites_same_key(self):
        """Test many sequential overwrites of the same key."""
        memory = Memory()

        for i in range(500):
            memory.add_to_context_vars("mutable_key", f"value_{i}")  # 170μs -> 101μs (68.4% faster)

        result = memory._context_vars

    def test_alternating_key_types_in_values(self):
        """Test adding alternating different types to different keys."""
        memory = Memory()

        for i in range(250):
            memory.add_to_context_vars(f"str_key_{i}", f"string_{i}")  # 99.7μs -> 63.0μs (58.4% faster)
            memory.add_to_context_vars(f"int_key_{i}", i)  # 105μs -> 70.8μs (49.3% faster)
            memory.add_to_context_vars(f"float_key_{i}", float(i))  # 105μs -> 70.6μs (49.8% faster)
            memory.add_to_context_vars(f"bool_key_{i}", i % 2 == 0)  # 104μs -> 72.1μs (44.3% faster)

        result = memory._context_vars

    def test_large_path_list_serialization(self):
        """Test serialization of a large list of Path objects."""
        memory = Memory()
        path_list = [Path(f"/home/user/file_{i}.txt") for i in range(250)]
        codeflash_output = memory.add_to_context_vars("paths", path_list)
        result = codeflash_output  # 295μs -> 283μs (4.48% faster)

    def test_complex_realistic_scenario(self):
        """Test a realistic complex scenario with varied data."""
        memory = Memory()

        # Simulate a realistic agent memory with various types of data
        memory.add_to_context_vars("agent_id", "agent_123")  # 982ns -> 590ns (66.4% faster)
        memory.add_to_context_vars("timestamp", 1234567890)  # 562ns -> 391ns (43.7% faster)
        memory.add_to_context_vars("confidence", 0.95)  # 491ns -> 250ns (96.4% faster)

        # Add action history
        actions = [{"type": "search", "query": f"query_{i}", "results": i} for i in range(100)]
        memory.add_to_context_vars("actions", actions)  # 88.1μs -> 80.7μs (9.09% faster)

        # Add metadata
        metadata = {
            "session": {
                "duration": 3600.5,
                "queries": 100,
                "success": True,
                "paths": [Path(f"/data/result_{i}") for i in range(50)],
            }
        }
        memory.add_to_context_vars("metadata", metadata)  # 64.4μs -> 62.1μs (3.65% faster)

        result = memory._context_vars


# Integration Tests - Verify overall behavior
class TestIntegration:
    """Test overall functionality and interactions."""

    def test_memory_state_persistence(self):
        """Test that state persists across multiple operations."""
        memory = Memory()
        memory.add_to_context_vars("first", "value1")  # 871ns -> 601ns (44.9% faster)
        memory.add_to_context_vars("second", "value2")  # 441ns -> 310ns (42.3% faster)
        memory.add_to_context_vars("first", "updated1")  # 410ns -> 261ns (57.1% faster)

    def test_multiple_memory_instances_isolated(self):
        """Test that multiple Memory instances don't interfere."""
        memory1 = Memory()
        memory2 = Memory()

        memory1.add_to_context_vars("key", "value1")  # 882ns -> 562ns (56.9% faster)
        memory2.add_to_context_vars("key", "value2")  # 441ns -> 241ns (83.0% faster)

    def test_serialization_consistency(self):
        """Test that same values always serialize to same results."""
        memory1 = Memory()
        memory2 = Memory()

        test_dict = {"a": 1, "b": [2, 3], "c": {"d": None}}
        codeflash_output = memory1.add_to_context_vars("data", test_dict)
        result1 = codeflash_output  # 4.34μs -> 4.60μs (5.65% slower)
        codeflash_output = memory2.add_to_context_vars("data", test_dict)
        result2 = codeflash_output  # 2.58μs -> 2.65μs (2.65% slower)

    def test_return_value_is_always_context_vars(self):
        """Verify return value is always the _context_vars dict."""
        memory = Memory()

        codeflash_output = memory.add_to_context_vars("key1", "value1")
        result1 = codeflash_output  # 881ns -> 531ns (65.9% faster)
        codeflash_output = memory.add_to_context_vars("key2", "value2")
        result2 = codeflash_output  # 451ns -> 281ns (60.5% faster)


# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To test or edit this optimization locally git merge codeflash/optimize-pr1059-2026-01-15T14.12.43

Click to see suggested changes
Suggested change
if isinstance(obj, list):
return [self._serialize(i) for i in obj]
if isinstance(obj, dict):
return {k: self._serialize(v) for k, v in obj.items()}
if isinstance(obj, json_primitive_types) or obj is None:
return obj
if isinstance(obj, Path):
return obj.as_posix()
return str(obj)
def add_to_context_vars(self, key: str, value: any) -> dict[str, str]:
# Check primitives and None first for a quick exit on the common case
if isinstance(obj, json_primitive_types) or obj is None:
return obj
if isinstance(obj, Path):
return obj.as_posix()
if isinstance(obj, list):
return [self._serialize(i) for i in obj]
if isinstance(obj, dict):
return {k: self._serialize(v) for k, v in obj.items()}
return str(obj)
def add_to_context_vars(self, key: str, value: any) -> dict[str, str]:
# Fast-path common simple types to avoid the overhead of a full _serialize call
if isinstance(value, json_primitive_types) or value is None:
self._context_vars[key] = value
return self._context_vars
if isinstance(value, Path):
self._context_vars[key] = value.as_posix()
return self._context_vars

Static Badge

self._context_vars[key] = self._serialize(value)
return self._context_vars

def get_context_vars(self) -> dict[str, str]:
return self._context_vars

def set_messages(self, messages: list[dict[str, str]]) -> list[dict[str, str]]:
self._messages = messages
if self.get_total_tokens() > self.max_tokens:
# TODO: summarize messages
pass
return self._messages

def get_messages(self) -> list[dict[str, str]]:
return self._messages

def get_total_tokens(self) -> int:
return sum(encoded_tokens_len(message["content"]) for message in self._messages)
Empty file.
32 changes: 32 additions & 0 deletions codeflash/agent/tools/base_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from __future__ import annotations

from dataclasses import dataclass
from enum import Enum
from typing import Any


class Tool(str, Enum):
REPLACE_IN_FILE = "replace_in_file"
ADD_TO_CONTEXT_VARS = "add_to_context_vars"
EXECUTE_CODE = "execute_code" # currently only supports python
SEARCH_FUNCTION_REFRENCES = "search_function_references"
GET_NAME_DEFINITION = "get_name_definition"
TERMINATE = "terminate" # terminates either with success or failure


# TODO: use this as a type for the api response
@dataclass(frozen=True)
class ToolCall:
tool_name: str
args: dict[str, Any]
needs_context_vars: bool = False


supported_tools: list[str] = [
Tool.REPLACE_IN_FILE,
Tool.ADD_TO_CONTEXT_VARS,
Tool.EXECUTE_CODE,
# Tool.SEARCH_FUNCTION_REFRENCES,
# Tool.GET_NAME_DEFINITION,
# Tool.TERMINATE,
]
32 changes: 32 additions & 0 deletions codeflash/api/aiservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,38 @@ def code_repair(self, request: AIServiceCodeRepairRequest) -> OptimizedCandidate
console.rule()
return None

def call_agent(
self, trace_id: str, context_vars: dict[str, str], available_tools: list[str], messages: list[dict[str, str]]
) -> dict[str, str]:
try:
payload = {
"trace_id": trace_id,
"context_vars": context_vars,
"available_tools": available_tools,
"messages": messages,
}
response = self.make_ai_service_request("/agent", payload=payload, timeout=self.timeout)
except (requests.exceptions.RequestException, TypeError) as e:
logger.exception(f"Error calling codeflash agent: {e}")
return None

if response.status_code == 200:
res = response.json()
console.rule()

if res["status"] == "success":
return {"messages": res["messages"]}

return None

try:
error = response.json()["error"]
except Exception:
error = response.text
logger.error(f"Error generating optimized candidates: {response.status_code} - {error}")
ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error})
return None

def get_new_explanation( # noqa: D417
self,
source_code: str,
Expand Down
19 changes: 13 additions & 6 deletions codeflash/discovery/functions_to_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,21 +95,22 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> None:


class FunctionWithReturnStatement(ast.NodeVisitor):
def __init__(self, file_path: Path) -> None:
def __init__(self, file_path: Path, must_return_a_value: bool = True) -> None: # noqa: FBT001, FBT002
self.functions: list[FunctionToOptimize] = []
self.ast_path: list[FunctionParent] = []
self.file_path: Path = file_path
self.must_return_a_value: bool = must_return_a_value

def visit_FunctionDef(self, node: FunctionDef) -> None:
# Check if the function has a return statement and add it to the list
if function_has_return_statement(node) and not function_is_a_property(node):
if not self.must_return_a_value or (function_has_return_statement(node) and not function_is_a_property(node)):
self.functions.append(
FunctionToOptimize(function_name=node.name, file_path=self.file_path, parents=self.ast_path[:])
)

def visit_AsyncFunctionDef(self, node: AsyncFunctionDef) -> None:
# Check if the async function has a return statement and add it to the list
if function_has_return_statement(node) and not function_is_a_property(node):
if not self.must_return_a_value or (function_has_return_statement(node) and not function_is_a_property(node)):
self.functions.append(
FunctionToOptimize(
function_name=node.name, file_path=self.file_path, parents=self.ast_path[:], is_async=True
Expand Down Expand Up @@ -182,6 +183,7 @@ def get_functions_to_optimize(
project_root: Path,
module_root: Path,
previous_checkpoint_functions: dict[str, dict[str, str]] | None = None,
must_return_a_value: bool = True, # noqa: FBT001, FBT002
) -> tuple[dict[Path, list[FunctionToOptimize]], int, Path | None]:
assert sum([bool(optimize_all), bool(replay_test), bool(file)]) <= 1, (
"Only one of optimize_all, replay_test, or file should be provided"
Expand All @@ -203,7 +205,9 @@ def get_functions_to_optimize(
logger.info("!lsp|Finding all functions in the file '%s'…", file)
console.rule()
file = Path(file) if isinstance(file, str) else file
functions: dict[Path, list[FunctionToOptimize]] = find_all_functions_in_file(file)
functions: dict[Path, list[FunctionToOptimize]] = find_all_functions_in_file(
file, must_return_a_value=must_return_a_value
)
if only_get_this_function is not None:
split_function = only_get_this_function.split(".")
if len(split_function) > 2:
Expand Down Expand Up @@ -368,7 +372,10 @@ def get_all_files_and_functions(module_root_path: Path) -> dict[str, list[Functi
return dict(files_list)


def find_all_functions_in_file(file_path: Path) -> dict[Path, list[FunctionToOptimize]]:
def find_all_functions_in_file(
file_path: Path,
must_return_a_value: bool = True, # noqa: FBT001, FBT002
) -> dict[Path, list[FunctionToOptimize]]:
functions: dict[Path, list[FunctionToOptimize]] = {}
with file_path.open(encoding="utf8") as f:
try:
Expand All @@ -377,7 +384,7 @@ def find_all_functions_in_file(file_path: Path) -> dict[Path, list[FunctionToOpt
if DEBUG_MODE:
logger.exception(e)
return functions
function_name_visitor = FunctionWithReturnStatement(file_path)
function_name_visitor = FunctionWithReturnStatement(file_path, must_return_a_value=must_return_a_value)
function_name_visitor.visit(ast_module)
functions[file_path] = function_name_visitor.functions
return functions
Expand Down
8 changes: 6 additions & 2 deletions codeflash/optimization/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,10 @@ def run_benchmarks(
console.rule()
return function_benchmark_timings, total_benchmark_timings

def get_optimizable_functions(self) -> tuple[dict[Path, list[FunctionToOptimize]], int, Path | None]:
def get_optimizable_functions(
self,
must_return_a_value: bool = True, # noqa: FBT001, FBT002
) -> tuple[dict[Path, list[FunctionToOptimize]], int, Path | None]:
"""Discover functions to optimize."""
from codeflash.discovery.functions_to_optimize import get_functions_to_optimize

Expand All @@ -137,7 +140,8 @@ def get_optimizable_functions(self) -> tuple[dict[Path, list[FunctionToOptimize]
ignore_paths=self.args.ignore_paths,
project_root=self.args.project_root,
module_root=self.args.module_root,
previous_checkpoint_functions=self.args.previous_checkpoint_functions,
previous_checkpoint_functions=getattr(self.args, "previous_checkpoint_functions", None),
must_return_a_value=must_return_a_value,
)

def create_function_optimizer(
Expand Down
Loading