diff --git a/codeflash/agent/__init__.py b/codeflash/agent/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/codeflash/agent/main.py b/codeflash/agent/main.py
new file mode 100644
index 000000000..f696d7655
--- /dev/null
+++ b/codeflash/agent/main.py
@@ -0,0 +1,151 @@
+from __future__ import annotations
+
+import time
+from dataclasses import dataclass
+from typing import TYPE_CHECKING
+
+from codeflash.agent.memory import Memory
+from codeflash.agent.tools.base_tool import supported_tools
+from codeflash.cli_cmds.cli import parse_args, process_pyproject_config
+from codeflash.code_utils.code_extractor import get_opt_review_metrics
+from codeflash.either import is_successful
+from codeflash.optimization.optimizer import Optimizer
+
+if TYPE_CHECKING:
+ from codeflash.optimization.function_optimizer import FunctionOptimizer
+
+
+MAX_AGENT_CALLS_PER_OPTIMIZATION = 10
+
+
+@dataclass
+class InitialAgentContext:
+ function_optimizer: FunctionOptimizer
+ code_context: str
+ function_references: str
+
+
+def build_initial_context(optimizer: Optimizer) -> InitialAgentContext | None:
+ optimizable_funcs, count, _ = optimizer.get_optimizable_functions(
+ must_return_a_value=False
+ ) # no need for the function to return a value for agent, agent should be smart enough to evaluate the behavior of the function without a return statement
+
+ if count == 0:
+ return None
+
+ fto = optimizable_funcs.popitem()[1][0]
+
+ module_prep_result = optimizer.prepare_module_for_optimization(fto.file_path)
+ if not module_prep_result:
+ return None
+
+ validated_original_code, original_module_ast = module_prep_result
+
+ function_optimizer = optimizer.create_function_optimizer(
+ fto,
+ function_to_optimize_source_code=validated_original_code[fto.file_path].source_code,
+ original_module_ast=original_module_ast,
+ original_module_path=fto.file_path,
+ function_to_tests={},
+ )
+
+ optimizer.current_function_optimizer = function_optimizer
+ if not function_optimizer:
+ return None
+
+ ctx_result = function_optimizer.get_code_optimization_context()
+ if not is_successful(ctx_result):
+ return None
+ ctx = ctx_result.unwrap()
+
+ function_references = get_opt_review_metrics(
+ function_optimizer.function_to_optimize_source_code,
+ function_optimizer.function_to_optimize.file_path,
+ function_optimizer.function_to_optimize.qualified_name,
+ function_optimizer.project_root,
+ function_optimizer.test_cfg.tests_root,
+ )
+
+ return InitialAgentContext(
+ function_optimizer=function_optimizer,
+ code_context=ctx.read_writable_code.markdown,
+ function_references=function_references,
+ )
+
+
+def main() -> None:
+ args = parse_args()
+
+ optimizer = Optimizer(process_pyproject_config(args))
+ initial_context = build_initial_context(optimizer)
+ if not initial_context:
+ # TODO: handle this case
+ return
+
+ memory = Memory()
+
+ # file_path = args.file
+ # function_name = args.function
+ # memory.add_to_context_vars("target_file_path", str(file_path))
+ # memory.add_to_context_vars("target_function_name", str(function_name))
+
+ memory.set_messages(
+ [
+ {
+ "role": "user",
+ "content": f"""# Rewrite the following function to be more efficient and faster, while maintaining the same behavior.
+
+### code context:
+{initial_context.code_context}
+### function references:
+{initial_context.function_references}
+
+Start by writing a plan for the optimization.""",
+ }
+ ]
+ )
+
+ memory.set_messages(
+ [
+ {
+ "role": "user",
+ "content": '# Rewrite the following function to be more efficient and faster, while maintaining the same behavior.\n\n### code context:\n```python:code_to_optimize/bubble_sort.py\ndef sorter(arr):\n print("codeflash stdout: Sorting list")\n for i in range(len(arr)):\n for j in range(len(arr) - 1):\n if arr[j] > arr[j + 1]:\n temp = arr[j]\n arr[j] = arr[j + 1]\n arr[j + 1] = temp\n print(f"result: {arr}")\n```\n### function references:\n```python:code_to_optimize/bubble_sort_from_another_file.py\ndef sort_from_another_file(arr):\n sorted_arr = sorter(arr)\n return sorted_arr\n```\n```python:temp_test.py\ndef run_and_capture(arr):\n saved_stdout = sys.stdout\n try:\n out = StringIO()\n sys.stdout = out\n data = arr.copy()\n sorter(data)\n return out.getvalue(), data\n finally:\n sys.stdout = saved_stdout\n```\n\n\nStart by writing a plan for the optimization.',
+ },
+ {
+ "role": "assistant",
+ "content": '## Optimization Plan\n\n### 1- Understand Code Behavior\n\n- **Current Functionality:** \n The `sorter` function implements the bubble sort algorithm for sorting a list `arr` in-place, with debug output before and after sorting.\n - Before sorting: prints `"codeflash stdout: Sorting list"`\n - After sorting: prints `"result: ..."`, showing the sorted list.\n - Sorts `arr` in place and does not return a value.\n- **Usage:** \n - Called by `sort_from_another_file`, although its return value is not used (likely functions for side-effects and in-place mutation).\n - Used in `run_and_capture` with output redirection for test validation.\n- **Performance Bottlenecks:**\n - The classic bubble sort is O(n²) and always loops over the entire unsorted list for every pass.\n - `range(len(arr) - 1)` does not shrink on each outer iteration—does not exploit the fact that the end section is already sorted.\n - No early exit when the list is already sorted or no swaps occur—wastes unnecessary passes.\n - Swapping via a temporary variable instead of Python\'s tuple swap (minor impact).\n - Performs redundant `len(arr)` calculations in loop bounds.\n - Prints entire list after sorting, which may be slow for very large lists (this is part of specified external behavior and can\'t be changed).\n- **External Behavior Constraints:**\n - Must print exactly the same messages (`stdout` is inspected in tests).\n - Sorts in-place.\n - No return value expected (returning would change externally visible behavior).\n\n---\n\n### 2- Begin Optimization\n\n- **Algorithm Improvement:** \n - Optimize bubble sort by:\n - Shortening the range of the inner loop as the sorted suffix grows.\n - Adding an early exit—if no swaps occur in a pass, terminate early (list is sorted).\n- **Pythonic Improvements:** \n - Use tuple swap for readability and slight speedup.\n - Hoist `len(arr)` to a local variable to avoid repeated computation.\n- **Other Considerations:** \n - Ensure print statements (before and after) produce byte-for-byte identical output.\n - Preserve in-place modification and no return value.\n\n**Concrete Steps:**\n 1. Store `len(arr)` once in a variable.\n 2. Change the inner loop\'s range to shrink with each pass: `range(0, n - i - 1)`.\n 3. Add a `swapped` flag; if no swaps occur on a pass, break early.\n 4. Use tuple swap: `arr[j], arr[j+1] = arr[j+1], arr[j]`.\n 5. Keep print statements unchanged in position and format.\n\n---\n\n### 3- Execute & Verify\n\n- **Testing:**\n - Use existing test (`run_and_capture`) and other targeted tests, including edge cases (empty list, already-sorted list, reversed list, single-element list, and lists with duplicates).\n - Compare outputs and final sorted array against original implementation to confirm behavioral equivalence.\n - Check that no exceptions are raised and that the outputs and side effects are identical.\n\n---\n\n### 4- Benchmarking\n\n- **Performance Measurement:** \n - Benchmark sorting on a range of input sizes (especially medium and large).\n - Compare runtime of original and optimized functions using timeit or a similar timing tool.\n - Assess and report:\n - Average/median runtime before and after optimization.\n - Absolute and relative (percentage) speedup achieved.\n\n---',
+ },
+ {
+ "role": "assistant",
+ "content": '\ncode_to_optimize/bubble_sort.py\n\n<<<<<<< SEARCH\ndef sorter(arr):\n print("codeflash stdout: Sorting list")\n for i in range(len(arr)):\n for j in range(len(arr) - 1):\n if arr[j] > arr[j + 1]:\n temp = arr[j]\n arr[j] = arr[j + 1]\n arr[j + 1] = temp\n print(f"result: {arr}")\n=======\ndef sorter(arr):\n print("codeflash stdout: Sorting list")\n n = len(arr)\n for i in range(n):\n swapped = False\n for j in range(0, n - i - 1):\n if arr[j] > arr[j + 1]:\n arr[j], arr[j + 1] = arr[j + 1], arr[j]\n swapped = True\n if not swapped:\n break\n print(f"result: {arr}")\n>>>>>>> REPLACE\n\n',
+ },
+ {"role": "user", "content": 'replace_in_file: {"status": "OK"}'},
+ {
+ "role": "assistant",
+ "content": '\npython\n\nimport sys\nfrom io import StringIO\nfrom code_to_optimize.bubble_sort import sorter\n\ndef run_and_capture(arr):\n saved_stdout = sys.stdout\n try:\n out = StringIO()\n sys.stdout = out\n data = arr.copy()\n sorter(data)\n return out.getvalue(), data\n finally:\n sys.stdout = saved_stdout\n\n# Test cases for behavioral equivalence\ntest_vectors = [\n [],\n [1],\n [2, 1],\n [1, 2, 3],\n [3, 2, 1],\n [5, 3, 8, 4, 2],\n [5, 5, 5, 5],\n [3, 1, 2, 3, 2, 1],\n list(range(100, 0, -1)),\n]\n\nfor vec in test_vectors:\n out, sorted_vec = run_and_capture(vec)\n assert "codeflash stdout: Sorting list" in out\n assert "result: " in out\n # result: [sorted version]\n result_line = out.splitlines()[-1]\n # safely eval the list (since printed via str)\n import ast\n printed_list = ast.literal_eval(result_line.split(\': \', 1)[1])\n assert printed_list == sorted(vec)\n assert sorted_vec == sorted(vec)\nprint("All behavioral tests passed for optimized bubble sort.")\n\n',
+ },
+ {
+ "role": "user",
+ "content": 'execute_code: {"status": "OK", "stdout": "All behavioral tests passed for optimized bubble sort."}',
+ },
+ ]
+ )
+ while memory.api_calls_counter < MAX_AGENT_CALLS_PER_OPTIMIZATION:
+ ai_service_client = initial_context.function_optimizer.aiservice_client
+ time_start = time.perf_counter()
+ response = ai_service_client.call_agent(
+ trace_id=memory.trace_id,
+ context_vars=memory.get_context_vars(),
+ available_tools=supported_tools,
+ messages=memory.get_messages(),
+ )
+ memory.api_calls_counter += 1
+ new_messages = response.get("messages", [])
+ memory.set_messages(new_messages)
+
+ time_end = time.perf_counter()
+ print(f"Time taken: {time_end - time_start:.2f} seconds")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/codeflash/agent/memory.py b/codeflash/agent/memory.py
new file mode 100644
index 000000000..84d0a4933
--- /dev/null
+++ b/codeflash/agent/memory.py
@@ -0,0 +1,47 @@
+import uuid
+from pathlib import Path
+from typing import Any
+
+from codeflash.code_utils.code_utils import encoded_tokens_len
+
+json_primitive_types = (str, float, int, bool)
+
+
+class Memory:
+ def __init__(self) -> None:
+ self._context_vars: dict[str, str] = {}
+ self._messages: list[dict[str, str]] = []
+ self.api_calls_counter = 0
+ self.trace_id = str(uuid.uuid4())
+ self.max_tokens = 16000
+
+ def _serialize(self, obj: Any) -> Any: # noqa: ANN401
+ if isinstance(obj, list):
+ return [self._serialize(i) for i in obj]
+ if isinstance(obj, dict):
+ return {k: self._serialize(v) for k, v in obj.items()}
+ if isinstance(obj, json_primitive_types) or obj is None:
+ return obj
+ if isinstance(obj, Path):
+ return obj.as_posix()
+ return str(obj)
+
+ def add_to_context_vars(self, key: str, value: any) -> dict[str, str]:
+ self._context_vars[key] = self._serialize(value)
+ return self._context_vars
+
+ def get_context_vars(self) -> dict[str, str]:
+ return self._context_vars
+
+ def set_messages(self, messages: list[dict[str, str]]) -> list[dict[str, str]]:
+ self._messages = messages
+ if self.get_total_tokens() > self.max_tokens:
+ # TODO: summarize messages
+ pass
+ return self._messages
+
+ def get_messages(self) -> list[dict[str, str]]:
+ return self._messages
+
+ def get_total_tokens(self) -> int:
+ return sum(encoded_tokens_len(message["content"]) for message in self._messages)
diff --git a/codeflash/agent/tools/__init__.py b/codeflash/agent/tools/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/codeflash/agent/tools/base_tool.py b/codeflash/agent/tools/base_tool.py
new file mode 100644
index 000000000..99981bcfd
--- /dev/null
+++ b/codeflash/agent/tools/base_tool.py
@@ -0,0 +1,32 @@
+from __future__ import annotations
+
+from dataclasses import dataclass
+from enum import Enum
+from typing import Any
+
+
+class Tool(str, Enum):
+ REPLACE_IN_FILE = "replace_in_file"
+ ADD_TO_CONTEXT_VARS = "add_to_context_vars"
+ EXECUTE_CODE = "execute_code" # currently only supports python
+ SEARCH_FUNCTION_REFRENCES = "search_function_references"
+ GET_NAME_DEFINITION = "get_name_definition"
+ TERMINATE = "terminate" # terminates either with success or failure
+
+
+# TODO: use this as a type for the api response
+@dataclass(frozen=True)
+class ToolCall:
+ tool_name: str
+ args: dict[str, Any]
+ needs_context_vars: bool = False
+
+
+supported_tools: list[str] = [
+ Tool.REPLACE_IN_FILE,
+ Tool.ADD_TO_CONTEXT_VARS,
+ Tool.EXECUTE_CODE,
+ # Tool.SEARCH_FUNCTION_REFRENCES,
+ # Tool.GET_NAME_DEFINITION,
+ # Tool.TERMINATE,
+]
diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py
index 94fbb44ad..c0ec82f75 100644
--- a/codeflash/api/aiservice.py
+++ b/codeflash/api/aiservice.py
@@ -388,6 +388,38 @@ def code_repair(self, request: AIServiceCodeRepairRequest) -> OptimizedCandidate
console.rule()
return None
+ def call_agent(
+ self, trace_id: str, context_vars: dict[str, str], available_tools: list[str], messages: list[dict[str, str]]
+ ) -> dict[str, str]:
+ try:
+ payload = {
+ "trace_id": trace_id,
+ "context_vars": context_vars,
+ "available_tools": available_tools,
+ "messages": messages,
+ }
+ response = self.make_ai_service_request("/agent", payload=payload, timeout=self.timeout)
+ except (requests.exceptions.RequestException, TypeError) as e:
+ logger.exception(f"Error calling codeflash agent: {e}")
+ return None
+
+ if response.status_code == 200:
+ res = response.json()
+ console.rule()
+
+ if res["status"] == "success":
+ return {"messages": res["messages"]}
+
+ return None
+
+ try:
+ error = response.json()["error"]
+ except Exception:
+ error = response.text
+ logger.error(f"Error generating optimized candidates: {response.status_code} - {error}")
+ ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error})
+ return None
+
def get_new_explanation( # noqa: D417
self,
source_code: str,
diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py
index 0f9f76a75..36be6375a 100644
--- a/codeflash/discovery/functions_to_optimize.py
+++ b/codeflash/discovery/functions_to_optimize.py
@@ -95,21 +95,22 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
class FunctionWithReturnStatement(ast.NodeVisitor):
- def __init__(self, file_path: Path) -> None:
+ def __init__(self, file_path: Path, must_return_a_value: bool = True) -> None: # noqa: FBT001, FBT002
self.functions: list[FunctionToOptimize] = []
self.ast_path: list[FunctionParent] = []
self.file_path: Path = file_path
+ self.must_return_a_value: bool = must_return_a_value
def visit_FunctionDef(self, node: FunctionDef) -> None:
# Check if the function has a return statement and add it to the list
- if function_has_return_statement(node) and not function_is_a_property(node):
+ if not self.must_return_a_value or (function_has_return_statement(node) and not function_is_a_property(node)):
self.functions.append(
FunctionToOptimize(function_name=node.name, file_path=self.file_path, parents=self.ast_path[:])
)
def visit_AsyncFunctionDef(self, node: AsyncFunctionDef) -> None:
# Check if the async function has a return statement and add it to the list
- if function_has_return_statement(node) and not function_is_a_property(node):
+ if not self.must_return_a_value or (function_has_return_statement(node) and not function_is_a_property(node)):
self.functions.append(
FunctionToOptimize(
function_name=node.name, file_path=self.file_path, parents=self.ast_path[:], is_async=True
@@ -182,6 +183,7 @@ def get_functions_to_optimize(
project_root: Path,
module_root: Path,
previous_checkpoint_functions: dict[str, dict[str, str]] | None = None,
+ must_return_a_value: bool = True, # noqa: FBT001, FBT002
) -> tuple[dict[Path, list[FunctionToOptimize]], int, Path | None]:
assert sum([bool(optimize_all), bool(replay_test), bool(file)]) <= 1, (
"Only one of optimize_all, replay_test, or file should be provided"
@@ -203,7 +205,9 @@ def get_functions_to_optimize(
logger.info("!lsp|Finding all functions in the file '%s'…", file)
console.rule()
file = Path(file) if isinstance(file, str) else file
- functions: dict[Path, list[FunctionToOptimize]] = find_all_functions_in_file(file)
+ functions: dict[Path, list[FunctionToOptimize]] = find_all_functions_in_file(
+ file, must_return_a_value=must_return_a_value
+ )
if only_get_this_function is not None:
split_function = only_get_this_function.split(".")
if len(split_function) > 2:
@@ -368,7 +372,10 @@ def get_all_files_and_functions(module_root_path: Path) -> dict[str, list[Functi
return dict(files_list)
-def find_all_functions_in_file(file_path: Path) -> dict[Path, list[FunctionToOptimize]]:
+def find_all_functions_in_file(
+ file_path: Path,
+ must_return_a_value: bool = True, # noqa: FBT001, FBT002
+) -> dict[Path, list[FunctionToOptimize]]:
functions: dict[Path, list[FunctionToOptimize]] = {}
with file_path.open(encoding="utf8") as f:
try:
@@ -377,7 +384,7 @@ def find_all_functions_in_file(file_path: Path) -> dict[Path, list[FunctionToOpt
if DEBUG_MODE:
logger.exception(e)
return functions
- function_name_visitor = FunctionWithReturnStatement(file_path)
+ function_name_visitor = FunctionWithReturnStatement(file_path, must_return_a_value=must_return_a_value)
function_name_visitor.visit(ast_module)
functions[file_path] = function_name_visitor.functions
return functions
diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py
index 1e1ddefcf..32dd34add 100644
--- a/codeflash/optimization/optimizer.py
+++ b/codeflash/optimization/optimizer.py
@@ -124,7 +124,10 @@ def run_benchmarks(
console.rule()
return function_benchmark_timings, total_benchmark_timings
- def get_optimizable_functions(self) -> tuple[dict[Path, list[FunctionToOptimize]], int, Path | None]:
+ def get_optimizable_functions(
+ self,
+ must_return_a_value: bool = True, # noqa: FBT001, FBT002
+ ) -> tuple[dict[Path, list[FunctionToOptimize]], int, Path | None]:
"""Discover functions to optimize."""
from codeflash.discovery.functions_to_optimize import get_functions_to_optimize
@@ -137,7 +140,8 @@ def get_optimizable_functions(self) -> tuple[dict[Path, list[FunctionToOptimize]
ignore_paths=self.args.ignore_paths,
project_root=self.args.project_root,
module_root=self.args.module_root,
- previous_checkpoint_functions=self.args.previous_checkpoint_functions,
+ previous_checkpoint_functions=getattr(self.args, "previous_checkpoint_functions", None),
+ must_return_a_value=must_return_a_value,
)
def create_function_optimizer(