-
Notifications
You must be signed in to change notification settings - Fork 21
[WIP][FEAT] Codeflash Agent #1059
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
mohammedahmed18
wants to merge
1
commit into
main
Choose a base branch
from
feat/agentic-codeflash
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,151 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import time | ||
| from dataclasses import dataclass | ||
| from typing import TYPE_CHECKING | ||
|
|
||
| from codeflash.agent.memory import Memory | ||
| from codeflash.agent.tools.base_tool import supported_tools | ||
| from codeflash.cli_cmds.cli import parse_args, process_pyproject_config | ||
| from codeflash.code_utils.code_extractor import get_opt_review_metrics | ||
| from codeflash.either import is_successful | ||
| from codeflash.optimization.optimizer import Optimizer | ||
|
|
||
| if TYPE_CHECKING: | ||
| from codeflash.optimization.function_optimizer import FunctionOptimizer | ||
|
|
||
|
|
||
| MAX_AGENT_CALLS_PER_OPTIMIZATION = 10 | ||
|
|
||
|
|
||
| @dataclass | ||
| class InitialAgentContext: | ||
| function_optimizer: FunctionOptimizer | ||
| code_context: str | ||
| function_references: str | ||
|
|
||
|
|
||
| def build_initial_context(optimizer: Optimizer) -> InitialAgentContext | None: | ||
| optimizable_funcs, count, _ = optimizer.get_optimizable_functions( | ||
| must_return_a_value=False | ||
| ) # no need for the function to return a value for agent, agent should be smart enough to evaluate the behavior of the function without a return statement | ||
|
|
||
| if count == 0: | ||
| return None | ||
|
|
||
| fto = optimizable_funcs.popitem()[1][0] | ||
|
|
||
| module_prep_result = optimizer.prepare_module_for_optimization(fto.file_path) | ||
| if not module_prep_result: | ||
| return None | ||
|
|
||
| validated_original_code, original_module_ast = module_prep_result | ||
|
|
||
| function_optimizer = optimizer.create_function_optimizer( | ||
| fto, | ||
| function_to_optimize_source_code=validated_original_code[fto.file_path].source_code, | ||
| original_module_ast=original_module_ast, | ||
| original_module_path=fto.file_path, | ||
| function_to_tests={}, | ||
| ) | ||
|
|
||
| optimizer.current_function_optimizer = function_optimizer | ||
| if not function_optimizer: | ||
| return None | ||
|
|
||
| ctx_result = function_optimizer.get_code_optimization_context() | ||
| if not is_successful(ctx_result): | ||
| return None | ||
| ctx = ctx_result.unwrap() | ||
|
|
||
| function_references = get_opt_review_metrics( | ||
| function_optimizer.function_to_optimize_source_code, | ||
| function_optimizer.function_to_optimize.file_path, | ||
| function_optimizer.function_to_optimize.qualified_name, | ||
| function_optimizer.project_root, | ||
| function_optimizer.test_cfg.tests_root, | ||
| ) | ||
|
|
||
| return InitialAgentContext( | ||
| function_optimizer=function_optimizer, | ||
| code_context=ctx.read_writable_code.markdown, | ||
| function_references=function_references, | ||
| ) | ||
|
|
||
|
|
||
| def main() -> None: | ||
| args = parse_args() | ||
|
|
||
| optimizer = Optimizer(process_pyproject_config(args)) | ||
| initial_context = build_initial_context(optimizer) | ||
| if not initial_context: | ||
| # TODO: handle this case | ||
| return | ||
|
|
||
| memory = Memory() | ||
|
|
||
| # file_path = args.file | ||
| # function_name = args.function | ||
| # memory.add_to_context_vars("target_file_path", str(file_path)) | ||
| # memory.add_to_context_vars("target_function_name", str(function_name)) | ||
|
|
||
| memory.set_messages( | ||
| [ | ||
| { | ||
| "role": "user", | ||
| "content": f"""# Rewrite the following function to be more efficient and faster, while maintaining the same behavior. | ||
|
|
||
| ### code context: | ||
| {initial_context.code_context} | ||
| ### function references: | ||
| {initial_context.function_references} | ||
|
|
||
| Start by writing a plan for the optimization.""", | ||
| } | ||
| ] | ||
| ) | ||
|
|
||
| memory.set_messages( | ||
| [ | ||
| { | ||
| "role": "user", | ||
| "content": '# Rewrite the following function to be more efficient and faster, while maintaining the same behavior.\n\n### code context:\n```python:code_to_optimize/bubble_sort.py\ndef sorter(arr):\n print("codeflash stdout: Sorting list")\n for i in range(len(arr)):\n for j in range(len(arr) - 1):\n if arr[j] > arr[j + 1]:\n temp = arr[j]\n arr[j] = arr[j + 1]\n arr[j + 1] = temp\n print(f"result: {arr}")\n```\n### function references:\n```python:code_to_optimize/bubble_sort_from_another_file.py\ndef sort_from_another_file(arr):\n sorted_arr = sorter(arr)\n return sorted_arr\n```\n```python:temp_test.py\ndef run_and_capture(arr):\n saved_stdout = sys.stdout\n try:\n out = StringIO()\n sys.stdout = out\n data = arr.copy()\n sorter(data)\n return out.getvalue(), data\n finally:\n sys.stdout = saved_stdout\n```\n\n\nStart by writing a plan for the optimization.', | ||
| }, | ||
| { | ||
| "role": "assistant", | ||
| "content": '## Optimization Plan\n\n### 1- Understand Code Behavior\n\n- **Current Functionality:** \n The `sorter` function implements the bubble sort algorithm for sorting a list `arr` in-place, with debug output before and after sorting.\n - Before sorting: prints `"codeflash stdout: Sorting list"`\n - After sorting: prints `"result: ..."`, showing the sorted list.\n - Sorts `arr` in place and does not return a value.\n- **Usage:** \n - Called by `sort_from_another_file`, although its return value is not used (likely functions for side-effects and in-place mutation).\n - Used in `run_and_capture` with output redirection for test validation.\n- **Performance Bottlenecks:**\n - The classic bubble sort is O(n²) and always loops over the entire unsorted list for every pass.\n - `range(len(arr) - 1)` does not shrink on each outer iteration—does not exploit the fact that the end section is already sorted.\n - No early exit when the list is already sorted or no swaps occur—wastes unnecessary passes.\n - Swapping via a temporary variable instead of Python\'s tuple swap (minor impact).\n - Performs redundant `len(arr)` calculations in loop bounds.\n - Prints entire list after sorting, which may be slow for very large lists (this is part of specified external behavior and can\'t be changed).\n- **External Behavior Constraints:**\n - Must print exactly the same messages (`stdout` is inspected in tests).\n - Sorts in-place.\n - No return value expected (returning would change externally visible behavior).\n\n---\n\n### 2- Begin Optimization\n\n- **Algorithm Improvement:** \n - Optimize bubble sort by:\n - Shortening the range of the inner loop as the sorted suffix grows.\n - Adding an early exit—if no swaps occur in a pass, terminate early (list is sorted).\n- **Pythonic Improvements:** \n - Use tuple swap for readability and slight speedup.\n - Hoist `len(arr)` to a local variable to avoid repeated computation.\n- **Other Considerations:** \n - Ensure print statements (before and after) produce byte-for-byte identical output.\n - Preserve in-place modification and no return value.\n\n**Concrete Steps:**\n 1. Store `len(arr)` once in a variable.\n 2. Change the inner loop\'s range to shrink with each pass: `range(0, n - i - 1)`.\n 3. Add a `swapped` flag; if no swaps occur on a pass, break early.\n 4. Use tuple swap: `arr[j], arr[j+1] = arr[j+1], arr[j]`.\n 5. Keep print statements unchanged in position and format.\n\n---\n\n### 3- Execute & Verify\n\n- **Testing:**\n - Use existing test (`run_and_capture`) and other targeted tests, including edge cases (empty list, already-sorted list, reversed list, single-element list, and lists with duplicates).\n - Compare outputs and final sorted array against original implementation to confirm behavioral equivalence.\n - Check that no exceptions are raised and that the outputs and side effects are identical.\n\n---\n\n### 4- Benchmarking\n\n- **Performance Measurement:** \n - Benchmark sorting on a range of input sizes (especially medium and large).\n - Compare runtime of original and optimized functions using timeit or a similar timing tool.\n - Assess and report:\n - Average/median runtime before and after optimization.\n - Absolute and relative (percentage) speedup achieved.\n\n---', | ||
| }, | ||
| { | ||
| "role": "assistant", | ||
| "content": '<replace_in_file>\n<path>code_to_optimize/bubble_sort.py</path>\n<diff>\n<<<<<<< SEARCH\ndef sorter(arr):\n print("codeflash stdout: Sorting list")\n for i in range(len(arr)):\n for j in range(len(arr) - 1):\n if arr[j] > arr[j + 1]:\n temp = arr[j]\n arr[j] = arr[j + 1]\n arr[j + 1] = temp\n print(f"result: {arr}")\n=======\ndef sorter(arr):\n print("codeflash stdout: Sorting list")\n n = len(arr)\n for i in range(n):\n swapped = False\n for j in range(0, n - i - 1):\n if arr[j] > arr[j + 1]:\n arr[j], arr[j + 1] = arr[j + 1], arr[j]\n swapped = True\n if not swapped:\n break\n print(f"result: {arr}")\n>>>>>>> REPLACE\n</diff>\n</replace_in_file>', | ||
| }, | ||
| {"role": "user", "content": 'replace_in_file: {"status": "OK"}'}, | ||
| { | ||
| "role": "assistant", | ||
| "content": '<execute_code>\n<language>python</language>\n<code>\nimport sys\nfrom io import StringIO\nfrom code_to_optimize.bubble_sort import sorter\n\ndef run_and_capture(arr):\n saved_stdout = sys.stdout\n try:\n out = StringIO()\n sys.stdout = out\n data = arr.copy()\n sorter(data)\n return out.getvalue(), data\n finally:\n sys.stdout = saved_stdout\n\n# Test cases for behavioral equivalence\ntest_vectors = [\n [],\n [1],\n [2, 1],\n [1, 2, 3],\n [3, 2, 1],\n [5, 3, 8, 4, 2],\n [5, 5, 5, 5],\n [3, 1, 2, 3, 2, 1],\n list(range(100, 0, -1)),\n]\n\nfor vec in test_vectors:\n out, sorted_vec = run_and_capture(vec)\n assert "codeflash stdout: Sorting list" in out\n assert "result: " in out\n # result: [sorted version]\n result_line = out.splitlines()[-1]\n # safely eval the list (since printed via str)\n import ast\n printed_list = ast.literal_eval(result_line.split(\': \', 1)[1])\n assert printed_list == sorted(vec)\n assert sorted_vec == sorted(vec)\nprint("All behavioral tests passed for optimized bubble sort.")\n</code>\n</execute_code>', | ||
| }, | ||
| { | ||
| "role": "user", | ||
| "content": 'execute_code: {"status": "OK", "stdout": "All behavioral tests passed for optimized bubble sort."}', | ||
| }, | ||
| ] | ||
| ) | ||
| while memory.api_calls_counter < MAX_AGENT_CALLS_PER_OPTIMIZATION: | ||
| ai_service_client = initial_context.function_optimizer.aiservice_client | ||
| time_start = time.perf_counter() | ||
| response = ai_service_client.call_agent( | ||
| trace_id=memory.trace_id, | ||
| context_vars=memory.get_context_vars(), | ||
| available_tools=supported_tools, | ||
| messages=memory.get_messages(), | ||
| ) | ||
| memory.api_calls_counter += 1 | ||
| new_messages = response.get("messages", []) | ||
| memory.set_messages(new_messages) | ||
|
|
||
| time_end = time.perf_counter() | ||
| print(f"Time taken: {time_end - time_start:.2f} seconds") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,47 @@ | ||
| import uuid | ||
| from pathlib import Path | ||
| from typing import Any | ||
|
|
||
| from codeflash.code_utils.code_utils import encoded_tokens_len | ||
|
|
||
| json_primitive_types = (str, float, int, bool) | ||
|
|
||
|
|
||
| class Memory: | ||
| def __init__(self) -> None: | ||
| self._context_vars: dict[str, str] = {} | ||
| self._messages: list[dict[str, str]] = [] | ||
| self.api_calls_counter = 0 | ||
| self.trace_id = str(uuid.uuid4()) | ||
| self.max_tokens = 16000 | ||
|
|
||
| def _serialize(self, obj: Any) -> Any: # noqa: ANN401 | ||
| if isinstance(obj, list): | ||
| return [self._serialize(i) for i in obj] | ||
| if isinstance(obj, dict): | ||
| return {k: self._serialize(v) for k, v in obj.items()} | ||
| if isinstance(obj, json_primitive_types) or obj is None: | ||
| return obj | ||
| if isinstance(obj, Path): | ||
| return obj.as_posix() | ||
| return str(obj) | ||
|
|
||
| def add_to_context_vars(self, key: str, value: any) -> dict[str, str]: | ||
| self._context_vars[key] = self._serialize(value) | ||
| return self._context_vars | ||
|
|
||
| def get_context_vars(self) -> dict[str, str]: | ||
| return self._context_vars | ||
|
|
||
| def set_messages(self, messages: list[dict[str, str]]) -> list[dict[str, str]]: | ||
| self._messages = messages | ||
| if self.get_total_tokens() > self.max_tokens: | ||
| # TODO: summarize messages | ||
| pass | ||
| return self._messages | ||
|
|
||
| def get_messages(self) -> list[dict[str, str]]: | ||
| return self._messages | ||
|
|
||
| def get_total_tokens(self) -> int: | ||
| return sum(encoded_tokens_len(message["content"]) for message in self._messages) | ||
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,32 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from dataclasses import dataclass | ||
| from enum import Enum | ||
| from typing import Any | ||
|
|
||
|
|
||
| class Tool(str, Enum): | ||
| REPLACE_IN_FILE = "replace_in_file" | ||
| ADD_TO_CONTEXT_VARS = "add_to_context_vars" | ||
| EXECUTE_CODE = "execute_code" # currently only supports python | ||
| SEARCH_FUNCTION_REFRENCES = "search_function_references" | ||
| GET_NAME_DEFINITION = "get_name_definition" | ||
| TERMINATE = "terminate" # terminates either with success or failure | ||
|
|
||
|
|
||
| # TODO: use this as a type for the api response | ||
| @dataclass(frozen=True) | ||
| class ToolCall: | ||
| tool_name: str | ||
| args: dict[str, Any] | ||
| needs_context_vars: bool = False | ||
|
|
||
|
|
||
| supported_tools: list[str] = [ | ||
| Tool.REPLACE_IN_FILE, | ||
| Tool.ADD_TO_CONTEXT_VARS, | ||
| Tool.EXECUTE_CODE, | ||
| # Tool.SEARCH_FUNCTION_REFRENCES, | ||
| # Tool.GET_NAME_DEFINITION, | ||
| # Tool.TERMINATE, | ||
| ] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
⚡️Codeflash found 33% (0.33x) speedup for
Memory.add_to_context_varsincodeflash/agent/memory.py⏱️ Runtime :
1.99 milliseconds→1.49 milliseconds(best of92runs)📝 Explanation and details
The optimized code achieves a 33% speedup by implementing two key strategies: early-exit fast paths and reordering type checks based on frequency.
What changed:
In
_serialize(): Reordered type checks to prioritize primitives (str, float, int, bool, None) first, followed by Path objects. This moves the most common cases to the front, enabling faster early returns before checking less common types like list and dict.In
add_to_context_vars(): Added a fast-path that directly handles primitives and Path objects without calling_serialize()at all, avoiding the function call overhead and recursive descent for simple values.Why this is faster:
Reduced function call overhead: The line profiler shows
add_to_context_varsdropped from 32.4ms to 19.0ms total time (41% faster). For primitive values (which comprise ~98.6% of inputs based on the 2048/2076 fast-path hits), the code now skips the_serialize()function call entirely.Better branch prediction: By checking primitives first in
_serialize(), the most common case (3936/4534 hits = 87%) exits immediately. The original code wasted cycles checking list/dict first, which only matched 289 total times.Fewer isinstance checks: The original
_serialize()performed ~4 isinstance checks per call on average before finding primitives. The optimized version checks primitives first, averaging ~1-2 checks for the common case.Test results show this optimization excels for:
Impact on workloads:
The function appears to be used for storing agent context variables, which typically consist of scalar values like IDs, timestamps, and configuration flags. The optimization dramatically improves this common case while maintaining correct behavior for complex nested structures.
✅ Correctness verification report:
🌀 Click to see Generated Regression Tests
To test or edit this optimization locally
git merge codeflash/optimize-pr1059-2026-01-15T14.12.43Click to see suggested changes