diff --git a/codeflash/benchmarking/instrument_codeflash_trace.py b/codeflash/benchmarking/instrument_codeflash_trace.py index f54674cfa..40a9a0390 100644 --- a/codeflash/benchmarking/instrument_codeflash_trace.py +++ b/codeflash/benchmarking/instrument_codeflash_trace.py @@ -105,6 +105,15 @@ def add_codeflash_decorator_to_code(code: str, functions_to_optimize: list[Funct def instrument_codeflash_trace_decorator(file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]]) -> None: """Instrument codeflash_trace decorator to functions to optimize.""" for file_path, functions_to_optimize in file_to_funcs_to_optimize.items(): + # Skip codeflash's own benchmarking and picklepatch modules to avoid circular imports + # (codeflash_trace.py imports from picklepatch, and instrumenting these would cause + # them to import codeflash_trace back, creating a circular import) + # Use rpartition to find the last "codeflash" in path (handles nested paths) + _, sep, after = file_path.as_posix().rpartition("/codeflash/") + if sep: + submodule = after.partition("/")[0] + if submodule in ("benchmarking", "picklepatch"): + continue original_code = file_path.read_text(encoding="utf-8") new_code = add_codeflash_decorator_to_code(original_code, functions_to_optimize) # Modify the code diff --git a/codeflash/benchmarking/trace_benchmarks.py b/codeflash/benchmarking/trace_benchmarks.py index 6de415a55..8217ac37e 100644 --- a/codeflash/benchmarking/trace_benchmarks.py +++ b/codeflash/benchmarking/trace_benchmarks.py @@ -32,16 +32,22 @@ def trace_benchmarks_pytest( **run_args, ) if result.returncode != 0: - if "ERROR collecting" in result.stdout: + # Combine stdout and stderr for error reporting (errors often go to stderr) + combined_output = result.stdout + if result.stderr: + combined_output = combined_output + "\n" + result.stderr if combined_output else result.stderr + + if "ERROR collecting" in combined_output: # Pattern matches "===== ERRORS =====" (any number of =) and captures everything after error_pattern = r"={3,}\s*ERRORS\s*={3,}\n([\s\S]*?)(?:={3,}|$)" - match = re.search(error_pattern, result.stdout) - error_section = match.group(1) if match else result.stdout - elif "FAILURES" in result.stdout: + match = re.search(error_pattern, combined_output) + error_section = match.group(1) if match else combined_output + elif "FAILURES" in combined_output: # Pattern matches "===== FAILURES =====" (any number of =) and captures everything after error_pattern = r"={3,}\s*FAILURES\s*={3,}\n([\s\S]*?)(?:={3,}|$)" - match = re.search(error_pattern, result.stdout) - error_section = match.group(1) if match else result.stdout + match = re.search(error_pattern, combined_output) + error_section = match.group(1) if match else combined_output else: - error_section = result.stdout + error_section = combined_output logger.warning(f"Error collecting benchmarks - Pytest Exit code: {result.returncode}, {error_section}") + logger.debug(f"Full pytest output:\n{combined_output}") diff --git a/codeflash/code_utils/concolic_utils.py b/codeflash/code_utils/concolic_utils.py index f59cb7aab..9148c6145 100644 --- a/codeflash/code_utils/concolic_utils.py +++ b/codeflash/code_utils/concolic_utils.py @@ -33,40 +33,28 @@ def _transform_assert_line(self, line: str) -> Optional[str]: indent, assert_method, args = unittest_match.groups() if args: - arg_parts = self._split_top_level_args(args) - if arg_parts and arg_parts[0]: - return f"{indent}{arg_parts[0]}" + arg_parts = self._first_top_level_arg(args) + if arg_parts: + return f"{indent}{arg_parts}" return None - def _split_top_level_args(self, args_str: str) -> list[str]: - result = [] - current = [] - depth = 0 - - for char in args_str: - if char in "([{": - depth += 1 - current.append(char) - elif char in ")]}": - depth -= 1 - current.append(char) - elif char == "," and depth == 0: - result.append("".join(current).strip()) - current = [] - else: - current.append(char) - - if current: - result.append("".join(current).strip()) - - return result - def __init__(self) -> None: # Pre-compiling regular expressions for faster execution self.assert_re = re.compile(r"\s*assert\s+(.*?)(?:\s*==\s*.*)?$") self.unittest_re = re.compile(r"(\s*)self\.assert([A-Za-z]+)\((.*)\)$") + def _first_top_level_arg(self, args: str) -> str: + depth = 0 + for i, ch in enumerate(args): + if ch in "([{": + depth += 1 + elif ch in ")]}": + depth -= 1 + elif ch == "," and depth == 0: + return args[:i].strip() + return args.strip() + def clean_concolic_tests(test_suite_code: str) -> str: try: diff --git a/tests/code_utils/test_concolic_utils.py b/tests/code_utils/test_concolic_utils.py new file mode 100644 index 000000000..2117216f2 --- /dev/null +++ b/tests/code_utils/test_concolic_utils.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +import pytest + +from codeflash.code_utils.concolic_utils import AssertCleanup + + +class TestFirstTopLevelArg: + @pytest.fixture + def cleanup(self) -> AssertCleanup: + return AssertCleanup() + + def test_single_argument_no_comma(self, cleanup: AssertCleanup) -> None: + assert cleanup._first_top_level_arg("foo") == "foo" + + def test_single_argument_with_whitespace(self, cleanup: AssertCleanup) -> None: + assert cleanup._first_top_level_arg(" foo ") == "foo" + + def test_two_simple_arguments(self, cleanup: AssertCleanup) -> None: + assert cleanup._first_top_level_arg("a, b") == "a" + + def test_multiple_arguments(self, cleanup: AssertCleanup) -> None: + assert cleanup._first_top_level_arg("x, y, z") == "x" + + def test_nested_parentheses_comma_ignored(self, cleanup: AssertCleanup) -> None: + assert cleanup._first_top_level_arg("func(a, b), c") == "func(a, b)" + + def test_nested_brackets_comma_ignored(self, cleanup: AssertCleanup) -> None: + assert cleanup._first_top_level_arg("[1, 2], x") == "[1, 2]" + + def test_nested_braces_comma_ignored(self, cleanup: AssertCleanup) -> None: + assert cleanup._first_top_level_arg("{a: b, c: d}, e") == "{a: b, c: d}" + + def test_deeply_nested_parentheses(self, cleanup: AssertCleanup) -> None: + assert cleanup._first_top_level_arg("f(g(h(i))), j") == "f(g(h(i)))" + + def test_mixed_bracket_types(self, cleanup: AssertCleanup) -> None: + assert cleanup._first_top_level_arg("func([{a, b}], c), d") == "func([{a, b}], c)" + + def test_empty_string(self, cleanup: AssertCleanup) -> None: + assert cleanup._first_top_level_arg("") == "" + + def test_only_whitespace(self, cleanup: AssertCleanup) -> None: + assert cleanup._first_top_level_arg(" ") == "" + + def test_comma_at_start(self, cleanup: AssertCleanup) -> None: + assert cleanup._first_top_level_arg(", a") == "" + + def test_no_top_level_comma(self, cleanup: AssertCleanup) -> None: + assert cleanup._first_top_level_arg("func(a, b)") == "func(a, b)" + + def test_empty_parentheses(self, cleanup: AssertCleanup) -> None: + assert cleanup._first_top_level_arg("(), b") == "()" + + def test_empty_brackets(self, cleanup: AssertCleanup) -> None: + assert cleanup._first_top_level_arg("[], b") == "[]" + + def test_empty_braces(self, cleanup: AssertCleanup) -> None: + assert cleanup._first_top_level_arg("{}, b") == "{}" + + def test_whitespace_around_comma(self, cleanup: AssertCleanup) -> None: + assert cleanup._first_top_level_arg("a , b") == "a" + + def test_complex_nested_structure(self, cleanup: AssertCleanup) -> None: + assert cleanup._first_top_level_arg("{'key': [1, (2, 3)]}, other") == "{'key': [1, (2, 3)]}" + + def test_string_literal_with_comma(self, cleanup: AssertCleanup) -> None: + # Note: this function doesn't handle string literals specially + # commas inside strings are treated as top-level + assert cleanup._first_top_level_arg('"a,b", c') == '"a' + + def test_unbalanced_opening_bracket(self, cleanup: AssertCleanup) -> None: + # With unbalanced opening, no top-level comma found + assert cleanup._first_top_level_arg("(a, b") == "(a, b" + + def test_multiple_consecutive_commas(self, cleanup: AssertCleanup) -> None: + assert cleanup._first_top_level_arg(",,") == "" + + def test_attribute_access(self, cleanup: AssertCleanup) -> None: + assert cleanup._first_top_level_arg("obj.attr, b") == "obj.attr" + + def test_numeric_arguments(self, cleanup: AssertCleanup) -> None: + assert cleanup._first_top_level_arg("123, 456") == "123" + + def test_negative_number(self, cleanup: AssertCleanup) -> None: + assert cleanup._first_top_level_arg("-42, x") == "-42" + + def test_float_argument(self, cleanup: AssertCleanup) -> None: + assert cleanup._first_top_level_arg("3.14, x") == "3.14" + + def test_newline_in_argument(self, cleanup: AssertCleanup) -> None: + assert cleanup._first_top_level_arg("a\nb, c") == "a\nb" + + def test_tab_whitespace(self, cleanup: AssertCleanup) -> None: + assert cleanup._first_top_level_arg("\ta\t, b") == "a" diff --git a/tests/test_instrument_codeflash_trace.py b/tests/test_instrument_codeflash_trace.py index 38a6381e2..b1a2ca17f 100644 --- a/tests/test_instrument_codeflash_trace.py +++ b/tests/test_instrument_codeflash_trace.py @@ -544,4 +544,144 @@ def target_function(): return "Hello from target function after nested function" """ - assert modified_code.strip() == expected_code.strip() \ No newline at end of file + assert modified_code.strip() == expected_code.strip() + + +def test_instrument_codeflash_trace_skips_benchmarking_module() -> None: + """Test that files in codeflash/benchmarking/ are skipped to avoid circular imports.""" + with tempfile.TemporaryDirectory() as temp_dir: + # Create a directory structure that mimics codeflash/benchmarking/ + benchmarking_dir = Path(temp_dir) / "codeflash" / "benchmarking" + benchmarking_dir.mkdir(parents=True) + + test_file_path = benchmarking_dir / "some_module.py" + original_content = """ +def some_function(): + return "This should not be modified" +""" + test_file_path.write_text(original_content, encoding="utf-8") + + fto = FunctionToOptimize( + function_name="some_function", + file_path=test_file_path, + parents=[] + ) + + instrument_codeflash_trace_decorator({test_file_path: [fto]}) + + # File should remain unchanged + assert test_file_path.read_text(encoding="utf-8") == original_content + + +def test_instrument_codeflash_trace_skips_picklepatch_module() -> None: + """Test that files in codeflash/picklepatch/ are skipped to avoid circular imports.""" + with tempfile.TemporaryDirectory() as temp_dir: + # Create a directory structure that mimics codeflash/picklepatch/ + picklepatch_dir = Path(temp_dir) / "codeflash" / "picklepatch" + picklepatch_dir.mkdir(parents=True) + + test_file_path = picklepatch_dir / "patcher.py" + original_content = """ +def patch_function(): + return "This should not be modified" +""" + test_file_path.write_text(original_content, encoding="utf-8") + + fto = FunctionToOptimize( + function_name="patch_function", + file_path=test_file_path, + parents=[] + ) + + instrument_codeflash_trace_decorator({test_file_path: [fto]}) + + # File should remain unchanged + assert test_file_path.read_text(encoding="utf-8") == original_content + + +def test_instrument_codeflash_trace_nested_codeflash_path_skips_benchmarking() -> None: + """Test that nested codeflash paths like /project/codeflash/codeflash/benchmarking/ are skipped. + + The rpartition logic should find the LAST 'codeflash' in the path. + """ + with tempfile.TemporaryDirectory() as temp_dir: + # Create nested structure: project_codeflash/codeflash/benchmarking/ + nested_dir = Path(temp_dir) / "project_codeflash" / "codeflash" / "benchmarking" + nested_dir.mkdir(parents=True) + + test_file_path = nested_dir / "trace_module.py" + original_content = """ +def trace_func(): + return "Should not be modified" +""" + test_file_path.write_text(original_content, encoding="utf-8") + + fto = FunctionToOptimize( + function_name="trace_func", + file_path=test_file_path, + parents=[] + ) + + instrument_codeflash_trace_decorator({test_file_path: [fto]}) + + # File should remain unchanged because last /codeflash/ is followed by benchmarking + assert test_file_path.read_text(encoding="utf-8") == original_content + + +def test_instrument_codeflash_trace_nested_codeflash_path_instruments_other_modules() -> None: + """Test that nested codeflash paths with non-skipped modules ARE instrumented. + + The rpartition logic should allow instrumentation when the submodule is not benchmarking/picklepatch. + """ + with tempfile.TemporaryDirectory() as temp_dir: + # Create nested structure: project_codeflash/codeflash/other_module/ + nested_dir = Path(temp_dir) / "project_codeflash" / "codeflash" / "other_module" + nested_dir.mkdir(parents=True) + + test_file_path = nested_dir / "utils.py" + original_content = """ +def util_func(): + return "Should be modified" +""" + test_file_path.write_text(original_content, encoding="utf-8") + + fto = FunctionToOptimize( + function_name="util_func", + file_path=test_file_path, + parents=[] + ) + + instrument_codeflash_trace_decorator({test_file_path: [fto]}) + + # File SHOULD be modified because other_module is not in skip list + modified_content = test_file_path.read_text(encoding="utf-8") + assert "codeflash_trace" in modified_content + assert "@codeflash_trace" in modified_content + + +def test_instrument_codeflash_trace_no_codeflash_in_path() -> None: + """Test that paths without 'codeflash' directory are instrumented normally.""" + with tempfile.TemporaryDirectory() as temp_dir: + # Create a path with no 'codeflash' directory + project_dir = Path(temp_dir) / "myproject" / "src" + project_dir.mkdir(parents=True) + + test_file_path = project_dir / "main.py" + original_content = """ +def main_func(): + return "Should be modified" +""" + test_file_path.write_text(original_content, encoding="utf-8") + + fto = FunctionToOptimize( + function_name="main_func", + file_path=test_file_path, + parents=[] + ) + + instrument_codeflash_trace_decorator({test_file_path: [fto]}) + + # File SHOULD be modified + modified_content = test_file_path.read_text(encoding="utf-8") + assert "codeflash_trace" in modified_content + assert "@codeflash_trace" in modified_content \ No newline at end of file