Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 144 additions & 0 deletions codeflash/benchmarking/cprofile_stage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
from __future__ import annotations

import json
import os
import pstats
import sqlite3
import subprocess
import sys
import tempfile
from pathlib import Path

from codeflash.cli_cmds.console import logger


def run_cprofile_stage(test_root: Path, project_root: Path, output_trace_file: Path) -> bool:
with tempfile.NamedTemporaryFile(suffix=".prof", delete=False) as tmp:
prof_file = Path(tmp.name)

try:
env = os.environ.copy()
env["PYTHONPATH"] = str(project_root)

cmd = [
sys.executable,
"-m",
"cProfile",
"-o",
str(prof_file),
"-m",
"pytest",
str(test_root),
"-x",
"-q",
"--tb=short",
"-p",
"no:warnings",
]

logger.debug(f"Running cProfile stage: {' '.join(cmd)}")

result = subprocess.run(cmd, check=False, env=env, cwd=str(project_root))

if result.returncode not in {0, 1}:
logger.warning(f"cProfile stage returned {result.returncode}")

if not prof_file.exists() or prof_file.stat().st_size == 0:
logger.warning("cProfile output file is empty or missing")
return False

stats = pstats.Stats(str(prof_file))
convert_pstats_to_sqlite(stats, output_trace_file, project_root)

logger.info(f"cProfile stage complete: {len(stats.stats)} functions profiled") # type: ignore[attr-defined]

except Exception as e:
logger.warning(f"cProfile stage failed: {e}")
return False
else:
return True
finally:
if prof_file.exists():
prof_file.unlink()


def convert_pstats_to_sqlite(stats: pstats.Stats, output_file: Path, project_root: Path) -> None:
if output_file.exists():
output_file.unlink()

con = sqlite3.connect(output_file)
cur = con.cursor()

cur.execute("""
CREATE TABLE pstats (
filename TEXT,
line_number INTEGER,
function TEXT,
class_name TEXT,
call_count_nonrecursive INTEGER,
num_callers INTEGER,
total_time_ns INTEGER,
cumulative_time_ns INTEGER,
callers TEXT
)
""")
cur.execute("CREATE TABLE total_time (time_ns INTEGER)")

total_time_s = 0.0
project_root_str = str(project_root.resolve())

for func_key, (cc, nc, tt, ct, callers) in stats.stats.items(): # type: ignore[attr-defined]
filename, line_number, func_name = func_key

if not _is_project_file(filename, project_root_str):
continue

total_time_s += tt

class_name = None
base_func_name = func_name
if "." in func_name and not func_name.startswith("<"):
parts = func_name.rsplit(".", 1)
if len(parts) == 2:
class_name, base_func_name = parts

callers_json = json.dumps([{"key": list(k), "value": v} for k, v in callers.items()])
total_time_ns = int(tt * 1e9)
cumulative_time_ns = int(ct * 1e9)

cur.execute(
"INSERT INTO pstats VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
(
filename,
line_number,
base_func_name,
class_name,
cc,
nc,
total_time_ns,
cumulative_time_ns,
callers_json,
),
)

total_time_ns = int(total_time_s * 1e9)
cur.execute("INSERT INTO total_time VALUES (?)", (total_time_ns,))

con.commit()
con.close()


def _is_project_file(filename: str, project_root: str) -> bool:
if not filename or filename.startswith("<"):
return False

try:
abs_filename = str(Path(filename).resolve())
except (OSError, ValueError):
return False

if not abs_filename.startswith(project_root):
return False

exclude_patterns = ("site-packages", ".venv", "venv", "__pycache__", ".pyc", "_pytest", "pluggy")
return not any(pattern in abs_filename for pattern in exclude_patterns)
13 changes: 13 additions & 0 deletions codeflash/benchmarking/function_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,3 +223,16 @@ def rank_functions(self, functions_to_optimize: list[FunctionToOptimize]) -> lis
f"Function ranking order: {[f'{func.function_name} (addressable_time={self.get_function_addressable_time(func):.2f}ns)' for func in ranked]}"
)
return ranked

def get_top_n_functions(self, functions_to_optimize: list[FunctionToOptimize], n: int) -> list[FunctionToOptimize]:
if n <= 0:
return []

if not self._function_stats:
logger.warning("No function stats available, returning first N functions")
return functions_to_optimize[:n]

sorted_funcs = sorted(functions_to_optimize, key=self.get_function_addressable_time, reverse=True)
top_n = sorted_funcs[:n]
logger.info(f"Selected Top {len(top_n)} functions by addressable time")
return top_n
7 changes: 7 additions & 0 deletions codeflash/cli_cmds/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,13 @@ def parse_args() -> Namespace:
const="",
default=SUPPRESS,
)
parser.add_argument(
"--top-n-functions",
type=int,
default=None,
help="Limit optimization to Top N functions by addressable time. Uses fast cProfile profiling to identify "
"high-impact functions before detailed tracing. Best used with --all to focus on highest-impact functions.",
)
parser.add_argument(
"--module-root",
type=str,
Expand Down
3 changes: 3 additions & 0 deletions codeflash/code_utils/config_consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
REPAIR_UNMATCHED_PERCENTAGE_LIMIT = 0.4 # if the percentage of unmatched tests is greater than this, we won't fix it (lowering this value makes the repair more stricted)
MAX_REPAIRS_PER_TRACE = 4 # maximum number of repairs we will do for each function

# 2-stage tracing
DEFAULT_TOP_N_FUNCTIONS = None # None means no limit, optimize all above threshold
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not entirely sure what number this should be


MAX_N_CANDIDATES = 5
MAX_N_CANDIDATES_LP = 6

Expand Down
74 changes: 65 additions & 9 deletions codeflash/optimization/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __init__(self, args: Namespace) -> None:
self.local_aiservice_client = LocalAiServiceClient() if self.experiment_id else None
self.replay_tests_dir = None
self.trace_file: Path | None = None
self.cprofile_trace_file: Path | None = None # For 2-stage tracing
self.functions_checkpoint: CodeflashRunCheckpoint | None = None
self.current_function_being_optimized: FunctionToOptimize | None = None # current only for the LSP
self.current_function_optimizer: FunctionOptimizer | None = None
Expand Down Expand Up @@ -124,6 +125,29 @@ def run_benchmarks(
console.rule()
return function_benchmark_timings, total_benchmark_timings

def run_cprofile_stage1(self) -> Path | None:
from codeflash.benchmarking.cprofile_stage import run_cprofile_stage

console.rule()
logger.info("Running cProfile to identify high-impact functions...")
self.cprofile_trace_file = get_run_tmp_file(Path("cprofile_stage1.trace"))
if self.cprofile_trace_file.exists():
self.cprofile_trace_file.unlink()

success = run_cprofile_stage(
test_root=self.test_cfg.tests_root,
project_root=self.args.project_root,
output_trace_file=self.cprofile_trace_file,
)

if success and self.cprofile_trace_file.exists():
console.rule()
return self.cprofile_trace_file

logger.warning("Stage 1 cProfile profiling failed, falling back to default ranking")
console.rule()
return None

def get_optimizable_functions(self) -> tuple[dict[Path, list[FunctionToOptimize]], int, Path | None]:
"""Discover functions to optimize."""
from codeflash.discovery.functions_to_optimize import get_functions_to_optimize
Expand Down Expand Up @@ -322,7 +346,10 @@ def display_global_ranking(
console.print(f"[dim]... and {len(globally_ranked) - display_count} more functions[/dim]")

def rank_all_functions_globally(
self, file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]], trace_file_path: Path | None
self,
file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]],
trace_file_path: Path | None,
top_n: int | None = None,
) -> list[tuple[Path, FunctionToOptimize]]:
"""Rank all functions globally across all files based on trace data.

Expand All @@ -332,6 +359,7 @@ def rank_all_functions_globally(
Args:
file_to_funcs_to_optimize: Mapping of file paths to functions to optimize
trace_file_path: Path to trace file with performance data
top_n: If set, limit to Top N functions by addressable time (bypasses threshold)

Returns:
List of (file_path, function) tuples in globally ranked order by addressable time.
Expand All @@ -342,9 +370,11 @@ def rank_all_functions_globally(
for file_path, functions in file_to_funcs_to_optimize.items():
all_functions.extend((file_path, func) for func in functions)

# If no trace file, return in original order
# If no trace file, return in original order (limited to top_n if set)
if not trace_file_path or not trace_file_path.exists():
logger.debug("No trace file available, using original function order")
if top_n is not None and top_n > 0:
return all_functions[:top_n]
return all_functions

try:
Expand All @@ -359,8 +389,11 @@ def rank_all_functions_globally(
# Extract just the functions for ranking (without file paths)
functions_only = [func for _, func in all_functions]

# Rank globally
ranked_functions = ranker.rank_functions(functions_only)
# Use Top N selection if specified, otherwise use threshold-based ranking
if top_n is not None and top_n > 0:
ranked_functions = ranker.get_top_n_functions(functions_only, top_n)
else:
ranked_functions = ranker.rank_functions(functions_only)

# Reconstruct with file paths by looking up original file for each ranked function
# Build reverse mapping: function -> file path
Expand All @@ -378,10 +411,16 @@ def rank_all_functions_globally(
globally_ranked.append((file_path, func))

console.rule()
logger.info(
f"Globally ranked {len(ranked_functions)} functions by addressable time "
f"(filtered {len(functions_only) - len(ranked_functions)} low-importance functions)"
)
if top_n is not None and top_n > 0:
logger.info(
f"Selected Top {len(ranked_functions)} functions by addressable time "
f"(from {len(functions_only)} total functions)"
)
else:
logger.info(
f"Globally ranked {len(ranked_functions)} functions by addressable time "
f"(filtered {len(functions_only) - len(ranked_functions)} low-importance functions)"
)

# Display ranking table for user visibility
self.display_global_ranking(globally_ranked, ranker)
Expand Down Expand Up @@ -436,6 +475,16 @@ def run(self) -> None:
self.test_cfg.concolic_test_root_dir = Path(
tempfile.mkdtemp(dir=self.args.tests_root, prefix="codeflash_concolic_")
)

# Get top_n_functions from args (for 2-stage tracing)
top_n_functions = getattr(self.args, "top_n_functions", None)

# Run cProfile Stage 1 if top_n_functions is specified and no existing trace
if top_n_functions is not None and top_n_functions > 0 and not trace_file_path:
cprofile_trace = self.run_cprofile_stage1()
if cprofile_trace:
trace_file_path = cprofile_trace

try:
ph("cli-optimize-functions-to-optimize", {"num_functions": num_optimizable_functions})
if num_optimizable_functions == 0:
Expand All @@ -447,7 +496,9 @@ def run(self) -> None:
self.functions_checkpoint = CodeflashRunCheckpoint(self.args.module_root)

# GLOBAL RANKING: Rank all functions together before optimizing
globally_ranked_functions = self.rank_all_functions_globally(file_to_funcs_to_optimize, trace_file_path)
globally_ranked_functions = self.rank_all_functions_globally(
file_to_funcs_to_optimize, trace_file_path, top_n=top_n_functions
)
# Cache for module preparation (avoid re-parsing same files)
prepared_modules: dict[Path, tuple[dict[Path, ValidCode], ast.Module]] = {}

Expand Down Expand Up @@ -570,6 +621,9 @@ def cleanup_replay_tests(self) -> None:
if self.trace_file and self.trace_file.exists():
logger.debug(f"Cleaning up trace file: {self.trace_file}")
paths_to_cleanup.append(self.trace_file)
if self.cprofile_trace_file and self.cprofile_trace_file.exists():
logger.debug(f"Cleaning up cProfile trace file: {self.cprofile_trace_file}")
paths_to_cleanup.append(self.cprofile_trace_file)
if paths_to_cleanup:
cleanup_paths(paths_to_cleanup)

Expand All @@ -590,6 +644,8 @@ def cleanup_temporary_paths(self) -> None:
paths_to_cleanup = [self.replay_tests_dir]
if self.trace_file:
paths_to_cleanup.append(self.trace_file)
if self.cprofile_trace_file:
paths_to_cleanup.append(self.cprofile_trace_file)
if self.test_cfg.tests_root.exists():
for trace_file in self.test_cfg.tests_root.glob("*.trace"):
if trace_file not in paths_to_cleanup:
Expand Down
Loading