diff --git a/codeflash/agent/__init__.py b/codeflash/agent/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/codeflash/agent/main.py b/codeflash/agent/main.py new file mode 100644 index 000000000..f696d7655 --- /dev/null +++ b/codeflash/agent/main.py @@ -0,0 +1,151 @@ +from __future__ import annotations + +import time +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from codeflash.agent.memory import Memory +from codeflash.agent.tools.base_tool import supported_tools +from codeflash.cli_cmds.cli import parse_args, process_pyproject_config +from codeflash.code_utils.code_extractor import get_opt_review_metrics +from codeflash.either import is_successful +from codeflash.optimization.optimizer import Optimizer + +if TYPE_CHECKING: + from codeflash.optimization.function_optimizer import FunctionOptimizer + + +MAX_AGENT_CALLS_PER_OPTIMIZATION = 10 + + +@dataclass +class InitialAgentContext: + function_optimizer: FunctionOptimizer + code_context: str + function_references: str + + +def build_initial_context(optimizer: Optimizer) -> InitialAgentContext | None: + optimizable_funcs, count, _ = optimizer.get_optimizable_functions( + must_return_a_value=False + ) # no need for the function to return a value for agent, agent should be smart enough to evaluate the behavior of the function without a return statement + + if count == 0: + return None + + fto = optimizable_funcs.popitem()[1][0] + + module_prep_result = optimizer.prepare_module_for_optimization(fto.file_path) + if not module_prep_result: + return None + + validated_original_code, original_module_ast = module_prep_result + + function_optimizer = optimizer.create_function_optimizer( + fto, + function_to_optimize_source_code=validated_original_code[fto.file_path].source_code, + original_module_ast=original_module_ast, + original_module_path=fto.file_path, + function_to_tests={}, + ) + + optimizer.current_function_optimizer = function_optimizer + if not function_optimizer: + return None + + ctx_result = function_optimizer.get_code_optimization_context() + if not is_successful(ctx_result): + return None + ctx = ctx_result.unwrap() + + function_references = get_opt_review_metrics( + function_optimizer.function_to_optimize_source_code, + function_optimizer.function_to_optimize.file_path, + function_optimizer.function_to_optimize.qualified_name, + function_optimizer.project_root, + function_optimizer.test_cfg.tests_root, + ) + + return InitialAgentContext( + function_optimizer=function_optimizer, + code_context=ctx.read_writable_code.markdown, + function_references=function_references, + ) + + +def main() -> None: + args = parse_args() + + optimizer = Optimizer(process_pyproject_config(args)) + initial_context = build_initial_context(optimizer) + if not initial_context: + # TODO: handle this case + return + + memory = Memory() + + # file_path = args.file + # function_name = args.function + # memory.add_to_context_vars("target_file_path", str(file_path)) + # memory.add_to_context_vars("target_function_name", str(function_name)) + + memory.set_messages( + [ + { + "role": "user", + "content": f"""# Rewrite the following function to be more efficient and faster, while maintaining the same behavior. + +### code context: +{initial_context.code_context} +### function references: +{initial_context.function_references} + +Start by writing a plan for the optimization.""", + } + ] + ) + + memory.set_messages( + [ + { + "role": "user", + "content": '# Rewrite the following function to be more efficient and faster, while maintaining the same behavior.\n\n### code context:\n```python:code_to_optimize/bubble_sort.py\ndef sorter(arr):\n print("codeflash stdout: Sorting list")\n for i in range(len(arr)):\n for j in range(len(arr) - 1):\n if arr[j] > arr[j + 1]:\n temp = arr[j]\n arr[j] = arr[j + 1]\n arr[j + 1] = temp\n print(f"result: {arr}")\n```\n### function references:\n```python:code_to_optimize/bubble_sort_from_another_file.py\ndef sort_from_another_file(arr):\n sorted_arr = sorter(arr)\n return sorted_arr\n```\n```python:temp_test.py\ndef run_and_capture(arr):\n saved_stdout = sys.stdout\n try:\n out = StringIO()\n sys.stdout = out\n data = arr.copy()\n sorter(data)\n return out.getvalue(), data\n finally:\n sys.stdout = saved_stdout\n```\n\n\nStart by writing a plan for the optimization.', + }, + { + "role": "assistant", + "content": '## Optimization Plan\n\n### 1- Understand Code Behavior\n\n- **Current Functionality:** \n The `sorter` function implements the bubble sort algorithm for sorting a list `arr` in-place, with debug output before and after sorting.\n - Before sorting: prints `"codeflash stdout: Sorting list"`\n - After sorting: prints `"result: ..."`, showing the sorted list.\n - Sorts `arr` in place and does not return a value.\n- **Usage:** \n - Called by `sort_from_another_file`, although its return value is not used (likely functions for side-effects and in-place mutation).\n - Used in `run_and_capture` with output redirection for test validation.\n- **Performance Bottlenecks:**\n - The classic bubble sort is O(n²) and always loops over the entire unsorted list for every pass.\n - `range(len(arr) - 1)` does not shrink on each outer iteration—does not exploit the fact that the end section is already sorted.\n - No early exit when the list is already sorted or no swaps occur—wastes unnecessary passes.\n - Swapping via a temporary variable instead of Python\'s tuple swap (minor impact).\n - Performs redundant `len(arr)` calculations in loop bounds.\n - Prints entire list after sorting, which may be slow for very large lists (this is part of specified external behavior and can\'t be changed).\n- **External Behavior Constraints:**\n - Must print exactly the same messages (`stdout` is inspected in tests).\n - Sorts in-place.\n - No return value expected (returning would change externally visible behavior).\n\n---\n\n### 2- Begin Optimization\n\n- **Algorithm Improvement:** \n - Optimize bubble sort by:\n - Shortening the range of the inner loop as the sorted suffix grows.\n - Adding an early exit—if no swaps occur in a pass, terminate early (list is sorted).\n- **Pythonic Improvements:** \n - Use tuple swap for readability and slight speedup.\n - Hoist `len(arr)` to a local variable to avoid repeated computation.\n- **Other Considerations:** \n - Ensure print statements (before and after) produce byte-for-byte identical output.\n - Preserve in-place modification and no return value.\n\n**Concrete Steps:**\n 1. Store `len(arr)` once in a variable.\n 2. Change the inner loop\'s range to shrink with each pass: `range(0, n - i - 1)`.\n 3. Add a `swapped` flag; if no swaps occur on a pass, break early.\n 4. Use tuple swap: `arr[j], arr[j+1] = arr[j+1], arr[j]`.\n 5. Keep print statements unchanged in position and format.\n\n---\n\n### 3- Execute & Verify\n\n- **Testing:**\n - Use existing test (`run_and_capture`) and other targeted tests, including edge cases (empty list, already-sorted list, reversed list, single-element list, and lists with duplicates).\n - Compare outputs and final sorted array against original implementation to confirm behavioral equivalence.\n - Check that no exceptions are raised and that the outputs and side effects are identical.\n\n---\n\n### 4- Benchmarking\n\n- **Performance Measurement:** \n - Benchmark sorting on a range of input sizes (especially medium and large).\n - Compare runtime of original and optimized functions using timeit or a similar timing tool.\n - Assess and report:\n - Average/median runtime before and after optimization.\n - Absolute and relative (percentage) speedup achieved.\n\n---', + }, + { + "role": "assistant", + "content": '\ncode_to_optimize/bubble_sort.py\n\n<<<<<<< SEARCH\ndef sorter(arr):\n print("codeflash stdout: Sorting list")\n for i in range(len(arr)):\n for j in range(len(arr) - 1):\n if arr[j] > arr[j + 1]:\n temp = arr[j]\n arr[j] = arr[j + 1]\n arr[j + 1] = temp\n print(f"result: {arr}")\n=======\ndef sorter(arr):\n print("codeflash stdout: Sorting list")\n n = len(arr)\n for i in range(n):\n swapped = False\n for j in range(0, n - i - 1):\n if arr[j] > arr[j + 1]:\n arr[j], arr[j + 1] = arr[j + 1], arr[j]\n swapped = True\n if not swapped:\n break\n print(f"result: {arr}")\n>>>>>>> REPLACE\n\n', + }, + {"role": "user", "content": 'replace_in_file: {"status": "OK"}'}, + { + "role": "assistant", + "content": '\npython\n\nimport sys\nfrom io import StringIO\nfrom code_to_optimize.bubble_sort import sorter\n\ndef run_and_capture(arr):\n saved_stdout = sys.stdout\n try:\n out = StringIO()\n sys.stdout = out\n data = arr.copy()\n sorter(data)\n return out.getvalue(), data\n finally:\n sys.stdout = saved_stdout\n\n# Test cases for behavioral equivalence\ntest_vectors = [\n [],\n [1],\n [2, 1],\n [1, 2, 3],\n [3, 2, 1],\n [5, 3, 8, 4, 2],\n [5, 5, 5, 5],\n [3, 1, 2, 3, 2, 1],\n list(range(100, 0, -1)),\n]\n\nfor vec in test_vectors:\n out, sorted_vec = run_and_capture(vec)\n assert "codeflash stdout: Sorting list" in out\n assert "result: " in out\n # result: [sorted version]\n result_line = out.splitlines()[-1]\n # safely eval the list (since printed via str)\n import ast\n printed_list = ast.literal_eval(result_line.split(\': \', 1)[1])\n assert printed_list == sorted(vec)\n assert sorted_vec == sorted(vec)\nprint("All behavioral tests passed for optimized bubble sort.")\n\n', + }, + { + "role": "user", + "content": 'execute_code: {"status": "OK", "stdout": "All behavioral tests passed for optimized bubble sort."}', + }, + ] + ) + while memory.api_calls_counter < MAX_AGENT_CALLS_PER_OPTIMIZATION: + ai_service_client = initial_context.function_optimizer.aiservice_client + time_start = time.perf_counter() + response = ai_service_client.call_agent( + trace_id=memory.trace_id, + context_vars=memory.get_context_vars(), + available_tools=supported_tools, + messages=memory.get_messages(), + ) + memory.api_calls_counter += 1 + new_messages = response.get("messages", []) + memory.set_messages(new_messages) + + time_end = time.perf_counter() + print(f"Time taken: {time_end - time_start:.2f} seconds") + + +if __name__ == "__main__": + main() diff --git a/codeflash/agent/memory.py b/codeflash/agent/memory.py new file mode 100644 index 000000000..84d0a4933 --- /dev/null +++ b/codeflash/agent/memory.py @@ -0,0 +1,47 @@ +import uuid +from pathlib import Path +from typing import Any + +from codeflash.code_utils.code_utils import encoded_tokens_len + +json_primitive_types = (str, float, int, bool) + + +class Memory: + def __init__(self) -> None: + self._context_vars: dict[str, str] = {} + self._messages: list[dict[str, str]] = [] + self.api_calls_counter = 0 + self.trace_id = str(uuid.uuid4()) + self.max_tokens = 16000 + + def _serialize(self, obj: Any) -> Any: # noqa: ANN401 + if isinstance(obj, list): + return [self._serialize(i) for i in obj] + if isinstance(obj, dict): + return {k: self._serialize(v) for k, v in obj.items()} + if isinstance(obj, json_primitive_types) or obj is None: + return obj + if isinstance(obj, Path): + return obj.as_posix() + return str(obj) + + def add_to_context_vars(self, key: str, value: any) -> dict[str, str]: + self._context_vars[key] = self._serialize(value) + return self._context_vars + + def get_context_vars(self) -> dict[str, str]: + return self._context_vars + + def set_messages(self, messages: list[dict[str, str]]) -> list[dict[str, str]]: + self._messages = messages + if self.get_total_tokens() > self.max_tokens: + # TODO: summarize messages + pass + return self._messages + + def get_messages(self) -> list[dict[str, str]]: + return self._messages + + def get_total_tokens(self) -> int: + return sum(encoded_tokens_len(message["content"]) for message in self._messages) diff --git a/codeflash/agent/tools/__init__.py b/codeflash/agent/tools/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/codeflash/agent/tools/base_tool.py b/codeflash/agent/tools/base_tool.py new file mode 100644 index 000000000..99981bcfd --- /dev/null +++ b/codeflash/agent/tools/base_tool.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum +from typing import Any + + +class Tool(str, Enum): + REPLACE_IN_FILE = "replace_in_file" + ADD_TO_CONTEXT_VARS = "add_to_context_vars" + EXECUTE_CODE = "execute_code" # currently only supports python + SEARCH_FUNCTION_REFRENCES = "search_function_references" + GET_NAME_DEFINITION = "get_name_definition" + TERMINATE = "terminate" # terminates either with success or failure + + +# TODO: use this as a type for the api response +@dataclass(frozen=True) +class ToolCall: + tool_name: str + args: dict[str, Any] + needs_context_vars: bool = False + + +supported_tools: list[str] = [ + Tool.REPLACE_IN_FILE, + Tool.ADD_TO_CONTEXT_VARS, + Tool.EXECUTE_CODE, + # Tool.SEARCH_FUNCTION_REFRENCES, + # Tool.GET_NAME_DEFINITION, + # Tool.TERMINATE, +] diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index 94fbb44ad..c0ec82f75 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -388,6 +388,38 @@ def code_repair(self, request: AIServiceCodeRepairRequest) -> OptimizedCandidate console.rule() return None + def call_agent( + self, trace_id: str, context_vars: dict[str, str], available_tools: list[str], messages: list[dict[str, str]] + ) -> dict[str, str]: + try: + payload = { + "trace_id": trace_id, + "context_vars": context_vars, + "available_tools": available_tools, + "messages": messages, + } + response = self.make_ai_service_request("/agent", payload=payload, timeout=self.timeout) + except (requests.exceptions.RequestException, TypeError) as e: + logger.exception(f"Error calling codeflash agent: {e}") + return None + + if response.status_code == 200: + res = response.json() + console.rule() + + if res["status"] == "success": + return {"messages": res["messages"]} + + return None + + try: + error = response.json()["error"] + except Exception: + error = response.text + logger.error(f"Error generating optimized candidates: {response.status_code} - {error}") + ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error}) + return None + def get_new_explanation( # noqa: D417 self, source_code: str, diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index 0f9f76a75..36be6375a 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -95,21 +95,22 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> None: class FunctionWithReturnStatement(ast.NodeVisitor): - def __init__(self, file_path: Path) -> None: + def __init__(self, file_path: Path, must_return_a_value: bool = True) -> None: # noqa: FBT001, FBT002 self.functions: list[FunctionToOptimize] = [] self.ast_path: list[FunctionParent] = [] self.file_path: Path = file_path + self.must_return_a_value: bool = must_return_a_value def visit_FunctionDef(self, node: FunctionDef) -> None: # Check if the function has a return statement and add it to the list - if function_has_return_statement(node) and not function_is_a_property(node): + if not self.must_return_a_value or (function_has_return_statement(node) and not function_is_a_property(node)): self.functions.append( FunctionToOptimize(function_name=node.name, file_path=self.file_path, parents=self.ast_path[:]) ) def visit_AsyncFunctionDef(self, node: AsyncFunctionDef) -> None: # Check if the async function has a return statement and add it to the list - if function_has_return_statement(node) and not function_is_a_property(node): + if not self.must_return_a_value or (function_has_return_statement(node) and not function_is_a_property(node)): self.functions.append( FunctionToOptimize( function_name=node.name, file_path=self.file_path, parents=self.ast_path[:], is_async=True @@ -182,6 +183,7 @@ def get_functions_to_optimize( project_root: Path, module_root: Path, previous_checkpoint_functions: dict[str, dict[str, str]] | None = None, + must_return_a_value: bool = True, # noqa: FBT001, FBT002 ) -> tuple[dict[Path, list[FunctionToOptimize]], int, Path | None]: assert sum([bool(optimize_all), bool(replay_test), bool(file)]) <= 1, ( "Only one of optimize_all, replay_test, or file should be provided" @@ -203,7 +205,9 @@ def get_functions_to_optimize( logger.info("!lsp|Finding all functions in the file '%s'…", file) console.rule() file = Path(file) if isinstance(file, str) else file - functions: dict[Path, list[FunctionToOptimize]] = find_all_functions_in_file(file) + functions: dict[Path, list[FunctionToOptimize]] = find_all_functions_in_file( + file, must_return_a_value=must_return_a_value + ) if only_get_this_function is not None: split_function = only_get_this_function.split(".") if len(split_function) > 2: @@ -368,7 +372,10 @@ def get_all_files_and_functions(module_root_path: Path) -> dict[str, list[Functi return dict(files_list) -def find_all_functions_in_file(file_path: Path) -> dict[Path, list[FunctionToOptimize]]: +def find_all_functions_in_file( + file_path: Path, + must_return_a_value: bool = True, # noqa: FBT001, FBT002 +) -> dict[Path, list[FunctionToOptimize]]: functions: dict[Path, list[FunctionToOptimize]] = {} with file_path.open(encoding="utf8") as f: try: @@ -377,7 +384,7 @@ def find_all_functions_in_file(file_path: Path) -> dict[Path, list[FunctionToOpt if DEBUG_MODE: logger.exception(e) return functions - function_name_visitor = FunctionWithReturnStatement(file_path) + function_name_visitor = FunctionWithReturnStatement(file_path, must_return_a_value=must_return_a_value) function_name_visitor.visit(ast_module) functions[file_path] = function_name_visitor.functions return functions diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index 1e1ddefcf..32dd34add 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -124,7 +124,10 @@ def run_benchmarks( console.rule() return function_benchmark_timings, total_benchmark_timings - def get_optimizable_functions(self) -> tuple[dict[Path, list[FunctionToOptimize]], int, Path | None]: + def get_optimizable_functions( + self, + must_return_a_value: bool = True, # noqa: FBT001, FBT002 + ) -> tuple[dict[Path, list[FunctionToOptimize]], int, Path | None]: """Discover functions to optimize.""" from codeflash.discovery.functions_to_optimize import get_functions_to_optimize @@ -137,7 +140,8 @@ def get_optimizable_functions(self) -> tuple[dict[Path, list[FunctionToOptimize] ignore_paths=self.args.ignore_paths, project_root=self.args.project_root, module_root=self.args.module_root, - previous_checkpoint_functions=self.args.previous_checkpoint_functions, + previous_checkpoint_functions=getattr(self.args, "previous_checkpoint_functions", None), + must_return_a_value=must_return_a_value, ) def create_function_optimizer(