From d0e176d54c7e066fc113678c21915b0c52a272fe Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Thu, 1 Jan 2026 16:14:42 -0500 Subject: [PATCH] feat: add 2-stage tracing with cProfile for Top N function selection Use cProfile (Stage 1) for fast profiling to identify high-impact functions, then run detailed tracer (Stage 2) only on Top N functions by addressable time. --- codeflash/benchmarking/cprofile_stage.py | 144 ++++++++++++++++++++++ codeflash/benchmarking/function_ranker.py | 13 ++ codeflash/cli_cmds/cli.py | 7 ++ codeflash/code_utils/config_consts.py | 3 + codeflash/optimization/optimizer.py | 74 +++++++++-- 5 files changed, 232 insertions(+), 9 deletions(-) create mode 100644 codeflash/benchmarking/cprofile_stage.py diff --git a/codeflash/benchmarking/cprofile_stage.py b/codeflash/benchmarking/cprofile_stage.py new file mode 100644 index 000000000..fe5539e99 --- /dev/null +++ b/codeflash/benchmarking/cprofile_stage.py @@ -0,0 +1,144 @@ +from __future__ import annotations + +import json +import os +import pstats +import sqlite3 +import subprocess +import sys +import tempfile +from pathlib import Path + +from codeflash.cli_cmds.console import logger + + +def run_cprofile_stage(test_root: Path, project_root: Path, output_trace_file: Path) -> bool: + with tempfile.NamedTemporaryFile(suffix=".prof", delete=False) as tmp: + prof_file = Path(tmp.name) + + try: + env = os.environ.copy() + env["PYTHONPATH"] = str(project_root) + + cmd = [ + sys.executable, + "-m", + "cProfile", + "-o", + str(prof_file), + "-m", + "pytest", + str(test_root), + "-x", + "-q", + "--tb=short", + "-p", + "no:warnings", + ] + + logger.debug(f"Running cProfile stage: {' '.join(cmd)}") + + result = subprocess.run(cmd, check=False, env=env, cwd=str(project_root)) + + if result.returncode not in {0, 1}: + logger.warning(f"cProfile stage returned {result.returncode}") + + if not prof_file.exists() or prof_file.stat().st_size == 0: + logger.warning("cProfile output file is empty or missing") + return False + + stats = pstats.Stats(str(prof_file)) + convert_pstats_to_sqlite(stats, output_trace_file, project_root) + + logger.info(f"cProfile stage complete: {len(stats.stats)} functions profiled") # type: ignore[attr-defined] + + except Exception as e: + logger.warning(f"cProfile stage failed: {e}") + return False + else: + return True + finally: + if prof_file.exists(): + prof_file.unlink() + + +def convert_pstats_to_sqlite(stats: pstats.Stats, output_file: Path, project_root: Path) -> None: + if output_file.exists(): + output_file.unlink() + + con = sqlite3.connect(output_file) + cur = con.cursor() + + cur.execute(""" + CREATE TABLE pstats ( + filename TEXT, + line_number INTEGER, + function TEXT, + class_name TEXT, + call_count_nonrecursive INTEGER, + num_callers INTEGER, + total_time_ns INTEGER, + cumulative_time_ns INTEGER, + callers TEXT + ) + """) + cur.execute("CREATE TABLE total_time (time_ns INTEGER)") + + total_time_s = 0.0 + project_root_str = str(project_root.resolve()) + + for func_key, (cc, nc, tt, ct, callers) in stats.stats.items(): # type: ignore[attr-defined] + filename, line_number, func_name = func_key + + if not _is_project_file(filename, project_root_str): + continue + + total_time_s += tt + + class_name = None + base_func_name = func_name + if "." in func_name and not func_name.startswith("<"): + parts = func_name.rsplit(".", 1) + if len(parts) == 2: + class_name, base_func_name = parts + + callers_json = json.dumps([{"key": list(k), "value": v} for k, v in callers.items()]) + total_time_ns = int(tt * 1e9) + cumulative_time_ns = int(ct * 1e9) + + cur.execute( + "INSERT INTO pstats VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", + ( + filename, + line_number, + base_func_name, + class_name, + cc, + nc, + total_time_ns, + cumulative_time_ns, + callers_json, + ), + ) + + total_time_ns = int(total_time_s * 1e9) + cur.execute("INSERT INTO total_time VALUES (?)", (total_time_ns,)) + + con.commit() + con.close() + + +def _is_project_file(filename: str, project_root: str) -> bool: + if not filename or filename.startswith("<"): + return False + + try: + abs_filename = str(Path(filename).resolve()) + except (OSError, ValueError): + return False + + if not abs_filename.startswith(project_root): + return False + + exclude_patterns = ("site-packages", ".venv", "venv", "__pycache__", ".pyc", "_pytest", "pluggy") + return not any(pattern in abs_filename for pattern in exclude_patterns) diff --git a/codeflash/benchmarking/function_ranker.py b/codeflash/benchmarking/function_ranker.py index 20c45f443..b6e183699 100644 --- a/codeflash/benchmarking/function_ranker.py +++ b/codeflash/benchmarking/function_ranker.py @@ -223,3 +223,16 @@ def rank_functions(self, functions_to_optimize: list[FunctionToOptimize]) -> lis f"Function ranking order: {[f'{func.function_name} (addressable_time={self.get_function_addressable_time(func):.2f}ns)' for func in ranked]}" ) return ranked + + def get_top_n_functions(self, functions_to_optimize: list[FunctionToOptimize], n: int) -> list[FunctionToOptimize]: + if n <= 0: + return [] + + if not self._function_stats: + logger.warning("No function stats available, returning first N functions") + return functions_to_optimize[:n] + + sorted_funcs = sorted(functions_to_optimize, key=self.get_function_addressable_time, reverse=True) + top_n = sorted_funcs[:n] + logger.info(f"Selected Top {len(top_n)} functions by addressable time") + return top_n diff --git a/codeflash/cli_cmds/cli.py b/codeflash/cli_cmds/cli.py index 271eb06bc..8f9c7f77f 100644 --- a/codeflash/cli_cmds/cli.py +++ b/codeflash/cli_cmds/cli.py @@ -66,6 +66,13 @@ def parse_args() -> Namespace: const="", default=SUPPRESS, ) + parser.add_argument( + "--top-n-functions", + type=int, + default=None, + help="Limit optimization to Top N functions by addressable time. Uses fast cProfile profiling to identify " + "high-impact functions before detailed tracing. Best used with --all to focus on highest-impact functions.", + ) parser.add_argument( "--module-root", type=str, diff --git a/codeflash/code_utils/config_consts.py b/codeflash/code_utils/config_consts.py index 6abefe3e7..b6ca77fa1 100644 --- a/codeflash/code_utils/config_consts.py +++ b/codeflash/code_utils/config_consts.py @@ -35,6 +35,9 @@ REPAIR_UNMATCHED_PERCENTAGE_LIMIT = 0.4 # if the percentage of unmatched tests is greater than this, we won't fix it (lowering this value makes the repair more stricted) MAX_REPAIRS_PER_TRACE = 4 # maximum number of repairs we will do for each function +# 2-stage tracing +DEFAULT_TOP_N_FUNCTIONS = None # None means no limit, optimize all above threshold + MAX_N_CANDIDATES = 5 MAX_N_CANDIDATES_LP = 6 diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index c98d37042..4fa58c880 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -55,6 +55,7 @@ def __init__(self, args: Namespace) -> None: self.local_aiservice_client = LocalAiServiceClient() if self.experiment_id else None self.replay_tests_dir = None self.trace_file: Path | None = None + self.cprofile_trace_file: Path | None = None # For 2-stage tracing self.functions_checkpoint: CodeflashRunCheckpoint | None = None self.current_function_being_optimized: FunctionToOptimize | None = None # current only for the LSP self.current_function_optimizer: FunctionOptimizer | None = None @@ -124,6 +125,29 @@ def run_benchmarks( console.rule() return function_benchmark_timings, total_benchmark_timings + def run_cprofile_stage1(self) -> Path | None: + from codeflash.benchmarking.cprofile_stage import run_cprofile_stage + + console.rule() + logger.info("Running cProfile to identify high-impact functions...") + self.cprofile_trace_file = get_run_tmp_file(Path("cprofile_stage1.trace")) + if self.cprofile_trace_file.exists(): + self.cprofile_trace_file.unlink() + + success = run_cprofile_stage( + test_root=self.test_cfg.tests_root, + project_root=self.args.project_root, + output_trace_file=self.cprofile_trace_file, + ) + + if success and self.cprofile_trace_file.exists(): + console.rule() + return self.cprofile_trace_file + + logger.warning("Stage 1 cProfile profiling failed, falling back to default ranking") + console.rule() + return None + def get_optimizable_functions(self) -> tuple[dict[Path, list[FunctionToOptimize]], int, Path | None]: """Discover functions to optimize.""" from codeflash.discovery.functions_to_optimize import get_functions_to_optimize @@ -322,7 +346,10 @@ def display_global_ranking( console.print(f"[dim]... and {len(globally_ranked) - display_count} more functions[/dim]") def rank_all_functions_globally( - self, file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]], trace_file_path: Path | None + self, + file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]], + trace_file_path: Path | None, + top_n: int | None = None, ) -> list[tuple[Path, FunctionToOptimize]]: """Rank all functions globally across all files based on trace data. @@ -332,6 +359,7 @@ def rank_all_functions_globally( Args: file_to_funcs_to_optimize: Mapping of file paths to functions to optimize trace_file_path: Path to trace file with performance data + top_n: If set, limit to Top N functions by addressable time (bypasses threshold) Returns: List of (file_path, function) tuples in globally ranked order by addressable time. @@ -342,9 +370,11 @@ def rank_all_functions_globally( for file_path, functions in file_to_funcs_to_optimize.items(): all_functions.extend((file_path, func) for func in functions) - # If no trace file, return in original order + # If no trace file, return in original order (limited to top_n if set) if not trace_file_path or not trace_file_path.exists(): logger.debug("No trace file available, using original function order") + if top_n is not None and top_n > 0: + return all_functions[:top_n] return all_functions try: @@ -359,8 +389,11 @@ def rank_all_functions_globally( # Extract just the functions for ranking (without file paths) functions_only = [func for _, func in all_functions] - # Rank globally - ranked_functions = ranker.rank_functions(functions_only) + # Use Top N selection if specified, otherwise use threshold-based ranking + if top_n is not None and top_n > 0: + ranked_functions = ranker.get_top_n_functions(functions_only, top_n) + else: + ranked_functions = ranker.rank_functions(functions_only) # Reconstruct with file paths by looking up original file for each ranked function # Build reverse mapping: function -> file path @@ -378,10 +411,16 @@ def rank_all_functions_globally( globally_ranked.append((file_path, func)) console.rule() - logger.info( - f"Globally ranked {len(ranked_functions)} functions by addressable time " - f"(filtered {len(functions_only) - len(ranked_functions)} low-importance functions)" - ) + if top_n is not None and top_n > 0: + logger.info( + f"Selected Top {len(ranked_functions)} functions by addressable time " + f"(from {len(functions_only)} total functions)" + ) + else: + logger.info( + f"Globally ranked {len(ranked_functions)} functions by addressable time " + f"(filtered {len(functions_only) - len(ranked_functions)} low-importance functions)" + ) # Display ranking table for user visibility self.display_global_ranking(globally_ranked, ranker) @@ -436,6 +475,16 @@ def run(self) -> None: self.test_cfg.concolic_test_root_dir = Path( tempfile.mkdtemp(dir=self.args.tests_root, prefix="codeflash_concolic_") ) + + # Get top_n_functions from args (for 2-stage tracing) + top_n_functions = getattr(self.args, "top_n_functions", None) + + # Run cProfile Stage 1 if top_n_functions is specified and no existing trace + if top_n_functions is not None and top_n_functions > 0 and not trace_file_path: + cprofile_trace = self.run_cprofile_stage1() + if cprofile_trace: + trace_file_path = cprofile_trace + try: ph("cli-optimize-functions-to-optimize", {"num_functions": num_optimizable_functions}) if num_optimizable_functions == 0: @@ -447,7 +496,9 @@ def run(self) -> None: self.functions_checkpoint = CodeflashRunCheckpoint(self.args.module_root) # GLOBAL RANKING: Rank all functions together before optimizing - globally_ranked_functions = self.rank_all_functions_globally(file_to_funcs_to_optimize, trace_file_path) + globally_ranked_functions = self.rank_all_functions_globally( + file_to_funcs_to_optimize, trace_file_path, top_n=top_n_functions + ) # Cache for module preparation (avoid re-parsing same files) prepared_modules: dict[Path, tuple[dict[Path, ValidCode], ast.Module]] = {} @@ -570,6 +621,9 @@ def cleanup_replay_tests(self) -> None: if self.trace_file and self.trace_file.exists(): logger.debug(f"Cleaning up trace file: {self.trace_file}") paths_to_cleanup.append(self.trace_file) + if self.cprofile_trace_file and self.cprofile_trace_file.exists(): + logger.debug(f"Cleaning up cProfile trace file: {self.cprofile_trace_file}") + paths_to_cleanup.append(self.cprofile_trace_file) if paths_to_cleanup: cleanup_paths(paths_to_cleanup) @@ -590,6 +644,8 @@ def cleanup_temporary_paths(self) -> None: paths_to_cleanup = [self.replay_tests_dir] if self.trace_file: paths_to_cleanup.append(self.trace_file) + if self.cprofile_trace_file: + paths_to_cleanup.append(self.cprofile_trace_file) if self.test_cfg.tests_root.exists(): for trace_file in self.test_cfg.tests_root.glob("*.trace"): if trace_file not in paths_to_cleanup: