From 2e0c5788caccfa897ced8bc58cbfc0fc90fa6d38 Mon Sep 17 00:00:00 2001 From: Codeflash Bot Date: Tue, 30 Dec 2025 09:44:29 -0800 Subject: [PATCH 01/27] need to test now --- .../code_utils/instrument_existing_tests.py | 315 +++++++++++++++++- 1 file changed, 313 insertions(+), 2 deletions(-) diff --git a/codeflash/code_utils/instrument_existing_tests.py b/codeflash/code_utils/instrument_existing_tests.py index 6184782b3..43d195ece 100644 --- a/codeflash/code_utils/instrument_existing_tests.py +++ b/codeflash/code_utils/instrument_existing_tests.py @@ -671,6 +671,8 @@ def inject_profiling_into_existing_test( function_to_optimize: FunctionToOptimize, tests_project_root: Path, mode: TestingMode = TestingMode.BEHAVIOR, + *, + jit_warmup: bool = False, ) -> tuple[bool, str | None]: if function_to_optimize.is_async: return inject_async_profiling_into_existing_test( @@ -704,13 +706,277 @@ def inject_profiling_into_existing_test( ast.Import(names=[ast.alias(name="dill", asname="pickle")]), ] ) - additional_functions = [create_wrapper_function(mode)] + additional_functions = [create_wrapper_function(mode, jit_warmup=jit_warmup)] + if jit_warmup: + additional_functions.insert(0, create_jit_sync_helper()) tree.body = [*new_imports, *additional_functions, *tree.body] return True, sort_imports(ast.unparse(tree), float_to_top=True) -def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.FunctionDef: +def create_jit_sync_helper() -> ast.FunctionDef: + """Create a helper function that synchronizes JIT-compiled frameworks (PyTorch, TensorFlow, JAX, MLX). + + This function generates AST for: + def _codeflash_jit_sync(): + try: + import torch + if torch.cuda.is_available(): + torch.cuda.synchronize() + if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): + torch.mps.synchronize() + except ImportError: + pass + try: + import jax + # Block until all JAX computations are complete + jax.effects_barrier() + except ImportError: + pass + try: + import mlx.core as mx + mx.synchronize() + except ImportError: + pass + # Note: TensorFlow in eager mode auto-syncs; Numba JIT is CPU-based and doesn't need sync + """ + lineno = 1 + + # PyTorch sync block + pytorch_sync = ast.Try( + body=[ + ast.Import(names=[ast.alias(name="torch")], lineno=lineno), + # if torch.cuda.is_available(): torch.cuda.synchronize() + ast.If( + test=ast.Call( + func=ast.Attribute( + value=ast.Attribute(value=ast.Name(id="torch", ctx=ast.Load()), attr="cuda", ctx=ast.Load()), + attr="is_available", + ctx=ast.Load(), + ), + args=[], + keywords=[], + ), + body=[ + ast.Expr( + value=ast.Call( + func=ast.Attribute( + value=ast.Attribute( + value=ast.Name(id="torch", ctx=ast.Load()), attr="cuda", ctx=ast.Load() + ), + attr="synchronize", + ctx=ast.Load(), + ), + args=[], + keywords=[], + ) + ) + ], + orelse=[], + lineno=lineno, + ), + # if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): torch.mps.synchronize() + ast.If( + test=ast.BoolOp( + op=ast.And(), + values=[ + ast.Call( + func=ast.Name(id="hasattr", ctx=ast.Load()), + args=[ + ast.Attribute( + value=ast.Name(id="torch", ctx=ast.Load()), attr="backends", ctx=ast.Load() + ), + ast.Constant(value="mps"), + ], + keywords=[], + ), + ast.Call( + func=ast.Attribute( + value=ast.Attribute( + value=ast.Attribute( + value=ast.Name(id="torch", ctx=ast.Load()), attr="backends", ctx=ast.Load() + ), + attr="mps", + ctx=ast.Load(), + ), + attr="is_available", + ctx=ast.Load(), + ), + args=[], + keywords=[], + ), + ], + ), + body=[ + ast.Expr( + value=ast.Call( + func=ast.Attribute( + value=ast.Attribute( + value=ast.Name(id="torch", ctx=ast.Load()), attr="mps", ctx=ast.Load() + ), + attr="synchronize", + ctx=ast.Load(), + ), + args=[], + keywords=[], + ) + ) + ], + orelse=[], + lineno=lineno, + ), + ], + handlers=[ + ast.ExceptHandler( + type=ast.Name(id="ImportError", ctx=ast.Load()), + name=None, + body=[ast.Pass(lineno=lineno)], + lineno=lineno, + ) + ], + orelse=[], + finalbody=[], + lineno=lineno, + ) + + # JAX sync block - use effects_barrier() to wait for all computations + jax_sync = ast.Try( + body=[ + ast.Import(names=[ast.alias(name="jax")], lineno=lineno), + ast.Expr( + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id="jax", ctx=ast.Load()), attr="effects_barrier", ctx=ast.Load() + ), + args=[], + keywords=[], + ) + ), + ], + handlers=[ + ast.ExceptHandler( + type=ast.Name(id="ImportError", ctx=ast.Load()), + name=None, + body=[ast.Pass(lineno=lineno)], + lineno=lineno, + ) + ], + orelse=[], + finalbody=[], + lineno=lineno, + ) + + # MLX sync block + mlx_sync = ast.Try( + body=[ + ast.Import(names=[ast.alias(name="mlx.core", asname="mx")], lineno=lineno), + ast.Expr( + value=ast.Call( + func=ast.Attribute(value=ast.Name(id="mx", ctx=ast.Load()), attr="synchronize", ctx=ast.Load()), + args=[], + keywords=[], + ) + ), + ], + handlers=[ + ast.ExceptHandler( + type=ast.Name(id="ImportError", ctx=ast.Load()), + name=None, + body=[ast.Pass(lineno=lineno)], + lineno=lineno, + ) + ], + orelse=[], + finalbody=[], + lineno=lineno, + ) + + # TensorFlow sync block - sync XLA/TPU devices + tensorflow_sync = ast.Try( + body=[ + ast.Import(names=[ast.alias(name="tensorflow", asname="tf")], lineno=lineno), + # For TPU: tf.tpu.experimental.initialize_tpu_system if available + # For GPU: operations complete synchronously in eager mode but we can force sync + ast.If( + test=ast.Call( + func=ast.Name(id="hasattr", ctx=ast.Load()), + args=[ + ast.Attribute(value=ast.Name(id="tf", ctx=ast.Load()), attr="config", ctx=ast.Load()), + ast.Constant(value="experimental"), + ], + keywords=[], + ), + body=[ + # Get all physical devices and sync GPUs + ast.For( + target=ast.Name(id="_device", ctx=ast.Store()), + iter=ast.Call( + func=ast.Attribute( + value=ast.Attribute( + value=ast.Name(id="tf", ctx=ast.Load()), attr="config", ctx=ast.Load() + ), + attr="list_physical_devices", + ctx=ast.Load(), + ), + args=[ast.Constant(value="GPU")], + keywords=[], + ), + body=[ + ast.Expr( + value=ast.Call( + func=ast.Attribute( + value=ast.Attribute( + value=ast.Attribute( + value=ast.Name(id="tf", ctx=ast.Load()), attr="test", ctx=ast.Load() + ), + attr="experimental", + ctx=ast.Load(), + ), + attr="sync_devices", + ctx=ast.Load(), + ), + args=[], + keywords=[], + ) + ) + ], + orelse=[], + lineno=lineno, + ) + ], + orelse=[], + lineno=lineno, + ), + ], + handlers=[ + ast.ExceptHandler( + type=ast.Tuple( + elts=[ast.Name(id="ImportError", ctx=ast.Load()), ast.Name(id="AttributeError", ctx=ast.Load())], + ctx=ast.Load(), + ), + name=None, + body=[ast.Pass(lineno=lineno)], + lineno=lineno, + ) + ], + orelse=[], + finalbody=[], + lineno=lineno, + ) + + return ast.FunctionDef( + name="_codeflash_jit_sync", + args=ast.arguments( + args=[], vararg=None, kwarg=None, posonlyargs=[], kwonlyargs=[], kw_defaults=[], defaults=[] + ), + body=[pytorch_sync, jax_sync, mlx_sync, tensorflow_sync], + decorator_list=[], + returns=None, + lineno=lineno, + ) + + +def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR, *, jit_warmup: bool = False) -> ast.FunctionDef: lineno = 1 wrapper_body: list[ast.stmt] = [ ast.Assign( @@ -871,6 +1137,25 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun ast.Assign( targets=[ast.Name(id="exception", ctx=ast.Store())], value=ast.Constant(value=None), lineno=lineno + 10 ), + # JIT warmup: call function once to trigger JIT compilation before timing + *( + [ + ast.Expr( + value=ast.Call( + func=ast.Name(id="codeflash_wrapped", ctx=ast.Load()), + args=[ast.Starred(value=ast.Name(id="args", ctx=ast.Load()), ctx=ast.Load())], + keywords=[ast.keyword(arg=None, value=ast.Name(id="kwargs", ctx=ast.Load()))], + ), + lineno=lineno + 10, + ), + ast.Expr( + value=ast.Call(func=ast.Name(id="_codeflash_jit_sync", ctx=ast.Load()), args=[], keywords=[]), + lineno=lineno + 10, + ), + ] + if jit_warmup + else [] + ), ast.Expr( value=ast.Call( func=ast.Attribute(value=ast.Name(id="gc", ctx=ast.Load()), attr="disable", ctx=ast.Load()), @@ -881,6 +1166,19 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun ), ast.Try( body=[ + # Sync before starting timer (ensure previous operations are complete) + *( + [ + ast.Expr( + value=ast.Call( + func=ast.Name(id="_codeflash_jit_sync", ctx=ast.Load()), args=[], keywords=[] + ), + lineno=lineno + 11, + ) + ] + if jit_warmup + else [] + ), ast.Assign( targets=[ast.Name(id="counter", ctx=ast.Store())], value=ast.Call( @@ -901,6 +1199,19 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun ), lineno=lineno + 12, ), + # Sync after function call to ensure all GPU/async operations complete before stopping timer + *( + [ + ast.Expr( + value=ast.Call( + func=ast.Name(id="_codeflash_jit_sync", ctx=ast.Load()), args=[], keywords=[] + ), + lineno=lineno + 12, + ) + ] + if jit_warmup + else [] + ), ast.Assign( targets=[ast.Name(id="codeflash_duration", ctx=ast.Store())], value=ast.BinOp( From fed05952bb1ce06de5b480bbebc0d849a74b79ef Mon Sep 17 00:00:00 2001 From: Codeflash Bot Date: Tue, 30 Dec 2025 16:20:26 -0800 Subject: [PATCH 02/27] wip --- .../code_utils/instrument_existing_tests.py | 19 ------------------- codeflash/optimization/function_optimizer.py | 2 ++ tests/test_instrument_tests.py | 2 +- 3 files changed, 3 insertions(+), 20 deletions(-) diff --git a/codeflash/code_utils/instrument_existing_tests.py b/codeflash/code_utils/instrument_existing_tests.py index 43d195ece..d6c5f913b 100644 --- a/codeflash/code_utils/instrument_existing_tests.py +++ b/codeflash/code_utils/instrument_existing_tests.py @@ -1137,25 +1137,6 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR, *, jit_war ast.Assign( targets=[ast.Name(id="exception", ctx=ast.Store())], value=ast.Constant(value=None), lineno=lineno + 10 ), - # JIT warmup: call function once to trigger JIT compilation before timing - *( - [ - ast.Expr( - value=ast.Call( - func=ast.Name(id="codeflash_wrapped", ctx=ast.Load()), - args=[ast.Starred(value=ast.Name(id="args", ctx=ast.Load()), ctx=ast.Load())], - keywords=[ast.keyword(arg=None, value=ast.Name(id="kwargs", ctx=ast.Load()))], - ), - lineno=lineno + 10, - ), - ast.Expr( - value=ast.Call(func=ast.Name(id="_codeflash_jit_sync", ctx=ast.Load()), args=[], keywords=[]), - lineno=lineno + 10, - ), - ] - if jit_warmup - else [] - ), ast.Expr( value=ast.Call( func=ast.Attribute(value=ast.Name(id="gc", ctx=ast.Load()), attr="disable", ctx=ast.Load()), diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 416bdc8df..5cab36eda 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -1230,6 +1230,7 @@ def instrument_existing_tests(self, function_to_all_tests: dict[str, set[Functio call_positions=[test.position for test in tests_in_file_list], function_to_optimize=self.function_to_optimize, tests_project_root=self.test_cfg.tests_project_rootdir, + jit_warmup=True, ) if not success: continue @@ -1239,6 +1240,7 @@ def instrument_existing_tests(self, function_to_all_tests: dict[str, set[Functio call_positions=[test.position for test in tests_in_file_list], function_to_optimize=self.function_to_optimize, tests_project_root=self.test_cfg.tests_project_rootdir, + jit_warmup=True, ) if not success: continue diff --git a/tests/test_instrument_tests.py b/tests/test_instrument_tests.py index a74f41533..d6439550d 100644 --- a/tests/test_instrument_tests.py +++ b/tests/test_instrument_tests.py @@ -193,7 +193,7 @@ def test_sort(self): Path(f.name), [CodePosition(9, 17), CodePosition(13, 17), CodePosition(17, 17)], func, - Path(f.name).parent, + Path(f.name).parent, jit_warmup=True ) os.chdir(original_cwd) assert success From e38df6282a6394f9d8a556b147c1ec30d75b5036 Mon Sep 17 00:00:00 2001 From: Codeflash Bot Date: Tue, 30 Dec 2025 16:25:31 -0800 Subject: [PATCH 03/27] wip --- codeflash/verification/comparator.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/codeflash/verification/comparator.py b/codeflash/verification/comparator.py index 7737900df..d95c45012 100644 --- a/codeflash/verification/comparator.py +++ b/codeflash/verification/comparator.py @@ -24,6 +24,7 @@ HAS_JAX = find_spec("jax") is not None HAS_XARRAY = find_spec("xarray") is not None HAS_TENSORFLOW = find_spec("tensorflow") is not None +HAS_MLX = find_spec("mlx") is not None def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001, ANN401, FBT002, PLR0911 @@ -138,6 +139,17 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001 return False return comparator(orig.to_list(), new.to_list(), superset_obj) + if HAS_MLX: + import mlx.core as mx # type: ignore # noqa: PGH003 + + if isinstance(orig, mx.array): + if orig.dtype != new.dtype: + return False + if orig.shape != new.shape: + return False + # MLX allclose handles NaN comparison via equal_nan parameter + return bool(mx.allclose(orig, new, equal_nan=True)) + if HAS_SQLALCHEMY: import sqlalchemy # type: ignore # noqa: PGH003 From 17042a2040f0c65cbd3b59967b9dd35effd6e72e Mon Sep 17 00:00:00 2001 From: Codeflash Bot Date: Tue, 30 Dec 2025 16:30:32 -0800 Subject: [PATCH 04/27] wip --- code_to_optimize/discrete_riccati.py | 170 +++++++++++ .../tests/pytest/test_gridmake2.py | 216 ++++++++++++++ .../tests/pytest/test_gridmake2_torch.py | 267 ++++++++++++++++++ codeflash/api/aiservice.py | 69 ++++- codeflash/optimization/function_optimizer.py | 20 ++ 5 files changed, 741 insertions(+), 1 deletion(-) create mode 100644 code_to_optimize/discrete_riccati.py create mode 100644 code_to_optimize/tests/pytest/test_gridmake2.py create mode 100644 code_to_optimize/tests/pytest/test_gridmake2_torch.py diff --git a/code_to_optimize/discrete_riccati.py b/code_to_optimize/discrete_riccati.py new file mode 100644 index 000000000..53fe30891 --- /dev/null +++ b/code_to_optimize/discrete_riccati.py @@ -0,0 +1,170 @@ +""" +Utility functions used in CompEcon + +Based routines found in the CompEcon toolbox by Miranda and Fackler. + +References +---------- +Miranda, Mario J, and Paul L Fackler. Applied Computational Economics +and Finance, MIT Press, 2002. + +""" +from functools import reduce +import numpy as np +import torch + + +def ckron(*arrays): + """ + Repeatedly applies the np.kron function to an arbitrary number of + input arrays + + Parameters + ---------- + *arrays : tuple/list of np.ndarray + + Returns + ------- + out : np.ndarray + The result of repeated kronecker products. + + Notes + ----- + Based of original function `ckron` in CompEcon toolbox by Miranda + and Fackler. + + References + ---------- + Miranda, Mario J, and Paul L Fackler. Applied Computational + Economics and Finance, MIT Press, 2002. + + """ + return reduce(np.kron, arrays) + + +def gridmake(*arrays): + """ + Expands one or more vectors (or matrices) into a matrix where rows span the + cartesian product of combinations of the input arrays. Each column of the + input arrays will correspond to one column of the output matrix. + + Parameters + ---------- + *arrays : tuple/list of np.ndarray + Tuple/list of vectors to be expanded. + + Returns + ------- + out : np.ndarray + The cartesian product of combinations of the input arrays. + + Notes + ----- + Based of original function ``gridmake`` in CompEcon toolbox by + Miranda and Fackler + + References + ---------- + Miranda, Mario J, and Paul L Fackler. Applied Computational Economics + and Finance, MIT Press, 2002. + + """ + if all([i.ndim == 1 for i in arrays]): + d = len(arrays) + if d == 2: + out = _gridmake2(*arrays) + else: + out = _gridmake2(arrays[0], arrays[1]) + for arr in arrays[2:]: + out = _gridmake2(out, arr) + + return out + else: + raise NotImplementedError("Come back here") + + +def _gridmake2(x1, x2): + """ + Expands two vectors (or matrices) into a matrix where rows span the + cartesian product of combinations of the input arrays. Each column of the + input arrays will correspond to one column of the output matrix. + + Parameters + ---------- + x1 : np.ndarray + First vector to be expanded. + + x2 : np.ndarray + Second vector to be expanded. + + Returns + ------- + out : np.ndarray + The cartesian product of combinations of the input arrays. + + Notes + ----- + Based of original function ``gridmake2`` in CompEcon toolbox by + Miranda and Fackler. + + References + ---------- + Miranda, Mario J, and Paul L Fackler. Applied Computational Economics + and Finance, MIT Press, 2002. + + """ + if x1.ndim == 1 and x2.ndim == 1: + return np.column_stack([np.tile(x1, x2.shape[0]), + np.repeat(x2, x1.shape[0])]) + elif x1.ndim > 1 and x2.ndim == 1: + first = np.tile(x1, (x2.shape[0], 1)) + second = np.repeat(x2, x1.shape[0]) + return np.column_stack([first, second]) + else: + raise NotImplementedError("Come back here") + + +def _gridmake2_torch(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: + """ + PyTorch version of _gridmake2. + + Expands two tensors into a matrix where rows span the cartesian product + of combinations of the input tensors. Each column of the input tensors + will correspond to one column of the output matrix. + + Parameters + ---------- + x1 : torch.Tensor + First tensor to be expanded. + + x2 : torch.Tensor + Second tensor to be expanded. + + Returns + ------- + out : torch.Tensor + The cartesian product of combinations of the input tensors. + + Notes + ----- + Based on original function ``gridmake2`` in CompEcon toolbox by + Miranda and Fackler. + + References + ---------- + Miranda, Mario J, and Paul L Fackler. Applied Computational Economics + and Finance, MIT Press, 2002. + + """ + if x1.dim() == 1 and x2.dim() == 1: + # tile x1 by x2.shape[0] times, repeat_interleave x2 by x1.shape[0] + first = x1.tile(x2.shape[0]) + second = x2.repeat_interleave(x1.shape[0]) + return torch.column_stack([first, second]) + elif x1.dim() > 1 and x2.dim() == 1: + # tile x1 along first dimension + first = x1.tile(x2.shape[0], 1) + second = x2.repeat_interleave(x1.shape[0]) + return torch.column_stack([first, second]) + else: + raise NotImplementedError("Come back here") diff --git a/code_to_optimize/tests/pytest/test_gridmake2.py b/code_to_optimize/tests/pytest/test_gridmake2.py new file mode 100644 index 000000000..60d7bfe56 --- /dev/null +++ b/code_to_optimize/tests/pytest/test_gridmake2.py @@ -0,0 +1,216 @@ +import numpy as np +import pytest +from numpy.testing import assert_array_equal + +from code_to_optimize.discrete_riccati import _gridmake2 + + +class TestGridmake2With1DArrays: + """Tests for _gridmake2 with two 1D arrays.""" + + def test_basic_two_element_arrays(self): + """Test basic cartesian product of two 2-element arrays.""" + x1 = np.array([1, 2]) + x2 = np.array([3, 4]) + result = _gridmake2(x1, x2) + + # Expected: x1 is tiled len(x2) times, x2 is repeated len(x1) times + expected = np.array([ + [1, 3], + [2, 3], + [1, 4], + [2, 4] + ]) + assert_array_equal(result, expected) + + def test_different_length_arrays(self): + """Test cartesian product with arrays of different lengths.""" + x1 = np.array([1, 2, 3]) + x2 = np.array([10, 20]) + result = _gridmake2(x1, x2) + + # Result should have len(x1) * len(x2) = 6 rows + expected = np.array([ + [1, 10], + [2, 10], + [3, 10], + [1, 20], + [2, 20], + [3, 20] + ]) + assert_array_equal(result, expected) + assert result.shape == (6, 2) + + def test_single_element_arrays(self): + """Test with single-element arrays.""" + x1 = np.array([5]) + x2 = np.array([7]) + result = _gridmake2(x1, x2) + + expected = np.array([[5, 7]]) + assert_array_equal(result, expected) + assert result.shape == (1, 2) + + def test_single_element_with_multi_element(self): + """Test single-element array with multi-element array.""" + x1 = np.array([1]) + x2 = np.array([10, 20, 30]) + result = _gridmake2(x1, x2) + + expected = np.array([ + [1, 10], + [1, 20], + [1, 30] + ]) + assert_array_equal(result, expected) + + def test_float_arrays(self): + """Test with float arrays.""" + x1 = np.array([1.5, 2.5]) + x2 = np.array([0.1, 0.2]) + result = _gridmake2(x1, x2) + + expected = np.array([ + [1.5, 0.1], + [2.5, 0.1], + [1.5, 0.2], + [2.5, 0.2] + ]) + assert_array_equal(result, expected) + + def test_negative_values(self): + """Test with negative values.""" + x1 = np.array([-1, 0, 1]) + x2 = np.array([-10, 10]) + result = _gridmake2(x1, x2) + + expected = np.array([ + [-1, -10], + [0, -10], + [1, -10], + [-1, 10], + [0, 10], + [1, 10] + ]) + assert_array_equal(result, expected) + + def test_result_shape(self): + """Test that result shape is (len(x1)*len(x2), 2).""" + x1 = np.array([1, 2, 3, 4]) + x2 = np.array([5, 6, 7]) + result = _gridmake2(x1, x2) + + assert result.shape == (12, 2) + + def test_larger_arrays(self): + """Test with larger arrays.""" + x1 = np.arange(10) + x2 = np.arange(5) + result = _gridmake2(x1, x2) + + assert result.shape == (50, 2) + # Verify first column is x1 tiled 5 times + assert_array_equal(result[:10, 0], x1) + assert_array_equal(result[10:20, 0], x1) + # Verify second column is x2 repeated 10 times each + assert all(result[:10, 1] == 0) + assert all(result[10:20, 1] == 1) + + +class TestGridmake2With2DFirst: + """Tests for _gridmake2 when x1 is 2D and x2 is 1D.""" + + def test_2d_first_1d_second(self): + """Test with 2D first array and 1D second array.""" + x1 = np.array([[1, 2], [3, 4]]) # 2 rows, 2 cols + x2 = np.array([10, 20]) + result = _gridmake2(x1, x2) + + # x1 is tiled len(x2) times vertically + # x2 is repeated len(x1) times (2 rows) + expected = np.array([ + [1, 2, 10], + [3, 4, 10], + [1, 2, 20], + [3, 4, 20] + ]) + assert_array_equal(result, expected) + + def test_2d_single_column(self): + """Test with 2D array having single column.""" + x1 = np.array([[1], [2], [3]]) # 3 rows, 1 col + x2 = np.array([10, 20]) + result = _gridmake2(x1, x2) + + expected = np.array([ + [1, 10], + [2, 10], + [3, 10], + [1, 20], + [2, 20], + [3, 20] + ]) + assert_array_equal(result, expected) + + def test_2d_multiple_columns(self): + """Test with 2D array having multiple columns.""" + x1 = np.array([[1, 2, 3], [4, 5, 6]]) # 2 rows, 3 cols + x2 = np.array([100]) + result = _gridmake2(x1, x2) + + expected = np.array([ + [1, 2, 3, 100], + [4, 5, 6, 100] + ]) + assert_array_equal(result, expected) + + +class TestGridmake2EdgeCases: + """Edge case tests for _gridmake2.""" + + def test_empty_arrays_raise_or_return_empty(self): + """Test behavior with empty arrays.""" + x1 = np.array([]) + x2 = np.array([1, 2]) + result = _gridmake2(x1, x2) + # Empty x1 should result in empty output + assert result.shape[0] == 0 + + def test_both_empty_arrays(self): + """Test with both empty arrays.""" + x1 = np.array([]) + x2 = np.array([]) + result = _gridmake2(x1, x2) + assert result.shape[0] == 0 + + def test_integer_dtype_preserved(self): + """Test that integer dtype is handled correctly.""" + x1 = np.array([1, 2], dtype=np.int64) + x2 = np.array([3, 4], dtype=np.int64) + result = _gridmake2(x1, x2) + assert result.dtype == np.int64 + + def test_float_dtype_preserved(self): + """Test that float dtype is handled correctly.""" + x1 = np.array([1.0, 2.0], dtype=np.float64) + x2 = np.array([3.0, 4.0], dtype=np.float64) + result = _gridmake2(x1, x2) + assert result.dtype == np.float64 + + +class TestGridmake2NotImplemented: + """Tests for NotImplementedError cases.""" + + def test_both_2d_raises(self): + """Test that two 2D arrays raises NotImplementedError.""" + x1 = np.array([[1, 2], [3, 4]]) + x2 = np.array([[5, 6], [7, 8]]) + with pytest.raises(NotImplementedError): + _gridmake2(x1, x2) + + def test_1d_first_2d_second_raises(self): + """Test that 1D first and 2D second raises NotImplementedError.""" + x1 = np.array([1, 2]) + x2 = np.array([[5, 6], [7, 8]]) + with pytest.raises(NotImplementedError): + _gridmake2(x1, x2) \ No newline at end of file diff --git a/code_to_optimize/tests/pytest/test_gridmake2_torch.py b/code_to_optimize/tests/pytest/test_gridmake2_torch.py new file mode 100644 index 000000000..f2ee737a2 --- /dev/null +++ b/code_to_optimize/tests/pytest/test_gridmake2_torch.py @@ -0,0 +1,267 @@ +import pytest +import torch + +from code_to_optimize.discrete_riccati import _gridmake2_torch + + +class TestGridmake2TorchCPU: + """Tests for _gridmake2_torch with CPU tensors.""" + + def test_both_1d_simple(self): + """Test with two simple 1D tensors.""" + x1 = torch.tensor([1, 2, 3]) + x2 = torch.tensor([10, 20]) + + result = _gridmake2_torch(x1, x2) + + # Expected: x1 tiled x2.shape[0] times, x2 repeat_interleaved x1.shape[0] + # x1 tiled: [1, 2, 3, 1, 2, 3] + # x2 repeated: [10, 10, 10, 20, 20, 20] + expected = torch.tensor([ + [1, 10], + [2, 10], + [3, 10], + [1, 20], + [2, 20], + [3, 20], + ]) + assert torch.equal(result, expected) + + def test_both_1d_single_element(self): + """Test with single element tensors.""" + x1 = torch.tensor([5]) + x2 = torch.tensor([10]) + + result = _gridmake2_torch(x1, x2) + + expected = torch.tensor([[5, 10]]) + assert torch.equal(result, expected) + + def test_both_1d_float_tensors(self): + """Test with float tensors.""" + x1 = torch.tensor([1.5, 2.5]) + x2 = torch.tensor([0.1, 0.2, 0.3]) + + result = _gridmake2_torch(x1, x2) + + assert result.shape == (6, 2) + assert result.dtype == torch.float32 + + def test_2d_and_1d_simple(self): + """Test with 2D x1 and 1D x2.""" + x1 = torch.tensor([[1, 2], [3, 4]]) + x2 = torch.tensor([10, 20]) + + result = _gridmake2_torch(x1, x2) + + # x1 tiled along first dim: [[1, 2], [3, 4], [1, 2], [3, 4]] + # x2 repeated: [10, 10, 20, 20] + # column_stack: [[1, 2, 10], [3, 4, 10], [1, 2, 20], [3, 4, 20]] + expected = torch.tensor([ + [1, 2, 10], + [3, 4, 10], + [1, 2, 20], + [3, 4, 20], + ]) + assert torch.equal(result, expected) + + def test_2d_and_1d_single_column(self): + """Test with 2D x1 having a single column and 1D x2.""" + x1 = torch.tensor([[1], [2], [3]]) + x2 = torch.tensor([10, 20]) + + result = _gridmake2_torch(x1, x2) + + expected = torch.tensor([ + [1, 10], + [2, 10], + [3, 10], + [1, 20], + [2, 20], + [3, 20], + ]) + assert torch.equal(result, expected) + + def test_output_shape_1d_1d(self): + """Test output shape for two 1D tensors.""" + x1 = torch.tensor([1, 2, 3, 4, 5]) + x2 = torch.tensor([10, 20, 30]) + + result = _gridmake2_torch(x1, x2) + + # Shape should be (len(x1) * len(x2), 2) + assert result.shape == (15, 2) + + def test_output_shape_2d_1d(self): + """Test output shape for 2D and 1D tensors.""" + x1 = torch.tensor([[1, 2, 3], [4, 5, 6]]) # Shape (2, 3) + x2 = torch.tensor([10, 20, 30, 40]) # Shape (4,) + + result = _gridmake2_torch(x1, x2) + + # Shape should be (2 * 4, 3 + 1) = (8, 4) + assert result.shape == (8, 4) + + def test_not_implemented_for_2d_2d(self): + """Test that NotImplementedError is raised for two 2D tensors.""" + x1 = torch.tensor([[1, 2], [3, 4]]) + x2 = torch.tensor([[10, 20], [30, 40]]) + + with pytest.raises(NotImplementedError, match="Come back here"): + _gridmake2_torch(x1, x2) + + def test_not_implemented_for_1d_2d(self): + """Test that NotImplementedError is raised for 1D and 2D tensors.""" + x1 = torch.tensor([1, 2, 3]) + x2 = torch.tensor([[10, 20], [30, 40]]) + + with pytest.raises(NotImplementedError, match="Come back here"): + _gridmake2_torch(x1, x2) + + def test_preserves_dtype_int(self): + """Test that integer dtype is preserved.""" + x1 = torch.tensor([1, 2, 3], dtype=torch.int32) + x2 = torch.tensor([10, 20], dtype=torch.int32) + + result = _gridmake2_torch(x1, x2) + + assert result.dtype == torch.int32 + + def test_preserves_dtype_float64(self): + """Test that float64 dtype is preserved.""" + x1 = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float64) + x2 = torch.tensor([10.0, 20.0], dtype=torch.float64) + + result = _gridmake2_torch(x1, x2) + + assert result.dtype == torch.float64 + + def test_large_tensors(self): + """Test with larger tensors.""" + x1 = torch.arange(100) + x2 = torch.arange(50) + + result = _gridmake2_torch(x1, x2) + + assert result.shape == (5000, 2) + # Verify first and last elements + assert result[0, 0] == 0 and result[0, 1] == 0 + assert result[-1, 0] == 99 and result[-1, 1] == 49 + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +class TestGridmake2TorchCUDA: + """Tests for _gridmake2_torch with CUDA tensors.""" + + def test_both_1d_simple_cuda(self): + """Test with two simple 1D CUDA tensors.""" + x1 = torch.tensor([1, 2, 3], device="cuda") + x2 = torch.tensor([10, 20], device="cuda") + + result = _gridmake2_torch(x1, x2) + + expected = torch.tensor([ + [1, 10], + [2, 10], + [3, 10], + [1, 20], + [2, 20], + [3, 20], + ], device="cuda") + assert result.device.type == "cuda" + assert torch.equal(result, expected) + + def test_both_1d_matches_cpu(self): + """Test that CUDA version matches CPU version.""" + x1_cpu = torch.tensor([1.0, 2.0, 3.0, 4.0]) + x2_cpu = torch.tensor([10.0, 20.0, 30.0]) + + x1_cuda = x1_cpu.cuda() + x2_cuda = x2_cpu.cuda() + + result_cpu = _gridmake2_torch(x1_cpu, x2_cpu) + result_cuda = _gridmake2_torch(x1_cuda, x2_cuda) + + assert result_cuda.device.type == "cuda" + torch.testing.assert_close(result_cpu, result_cuda.cpu()) + + def test_2d_and_1d_cuda(self): + """Test with 2D x1 and 1D x2 on CUDA.""" + x1 = torch.tensor([[1, 2], [3, 4]], device="cuda") + x2 = torch.tensor([10, 20], device="cuda") + + result = _gridmake2_torch(x1, x2) + + expected = torch.tensor([ + [1, 2, 10], + [3, 4, 10], + [1, 2, 20], + [3, 4, 20], + ], device="cuda") + assert result.device.type == "cuda" + assert torch.equal(result, expected) + + def test_2d_and_1d_matches_cpu(self): + """Test that CUDA version matches CPU version for 2D, 1D inputs.""" + x1_cpu = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + x2_cpu = torch.tensor([10.0, 20.0]) + + x1_cuda = x1_cpu.cuda() + x2_cuda = x2_cpu.cuda() + + result_cpu = _gridmake2_torch(x1_cpu, x2_cpu) + result_cuda = _gridmake2_torch(x1_cuda, x2_cuda) + + assert result_cuda.device.type == "cuda" + torch.testing.assert_close(result_cpu, result_cuda.cpu()) + + def test_output_stays_on_cuda(self): + """Test that output tensor stays on CUDA device.""" + x1 = torch.tensor([1, 2, 3], device="cuda") + x2 = torch.tensor([10, 20], device="cuda") + + result = _gridmake2_torch(x1, x2) + + assert result.is_cuda + + def test_preserves_dtype_float32_cuda(self): + """Test that float32 dtype is preserved on CUDA.""" + x1 = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32, device="cuda") + x2 = torch.tensor([10.0, 20.0], dtype=torch.float32, device="cuda") + + result = _gridmake2_torch(x1, x2) + + assert result.dtype == torch.float32 + assert result.device.type == "cuda" + + def test_preserves_dtype_float64_cuda(self): + """Test that float64 dtype is preserved on CUDA.""" + x1 = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float64, device="cuda") + x2 = torch.tensor([10.0, 20.0], dtype=torch.float64, device="cuda") + + result = _gridmake2_torch(x1, x2) + + assert result.dtype == torch.float64 + assert result.device.type == "cuda" + + def test_large_tensors_cuda(self): + """Test with larger tensors on CUDA.""" + x1 = torch.arange(100, device="cuda") + x2 = torch.arange(50, device="cuda") + + result = _gridmake2_torch(x1, x2) + + assert result.shape == (5000, 2) + assert result.device.type == "cuda" + # Verify first and last elements + assert result[0, 0].item() == 0 and result[0, 1].item() == 0 + assert result[-1, 0].item() == 99 and result[-1, 1].item() == 49 + + def test_not_implemented_for_2d_2d_cuda(self): + """Test that NotImplementedError is raised for two 2D CUDA tensors.""" + x1 = torch.tensor([[1, 2], [3, 4]], device="cuda") + x2 = torch.tensor([[10, 20], [30, 40]], device="cuda") + + with pytest.raises(NotImplementedError, match="Come back here"): + _gridmake2_torch(x1, x2) + diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index 2eedb9fae..9718f1756 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -46,7 +46,7 @@ def get_aiservice_base_url(self) -> str: logger.info("Using local AI Service at http://localhost:8000") console.rule() return "http://localhost:8000" - return "https://app.codeflash.ai" + return "http://localhost:8000" def make_ai_service_request( self, @@ -177,6 +177,73 @@ def optimize_python_code( # noqa: D417 console.rule() return [] + def get_jit_rewritten_code( # noqa: D417 + self, + source_code: str, + dependency_code: str, + trace_id: str, + num_candidates: int = 1, + experiment_metadata: ExperimentMetadata | None = None, + *, + is_async: bool = False, + ) -> list[OptimizedCandidate]: + """Optimize the given python code for performance by making a request to the Django endpoint. + + Parameters + ---------- + - source_code (str): The python code to optimize. + - dependency_code (str): The dependency code used as read-only context for the optimization + - trace_id (str): Trace id of optimization run + - num_candidates (int): Number of optimization variants to generate. Default is 10. + - experiment_metadata (Optional[ExperimentalMetadata, None]): Any available experiment metadata for this optimization + + Returns + ------- + - List[OptimizationCandidate]: A list of Optimization Candidates. + + """ + start_time = time.perf_counter() + git_repo_owner, git_repo_name = safe_get_repo_owner_and_name() + + payload = { + "source_code": source_code, + "dependency_code": dependency_code, + "num_variants": num_candidates, + "trace_id": trace_id, + "python_version": platform.python_version(), + "experiment_metadata": experiment_metadata, + "codeflash_version": codeflash_version, + "current_username": get_last_commit_author_if_pr_exists(None), + "repo_owner": git_repo_owner, + "repo_name": git_repo_name, + "n_candidates": N_CANDIDATES_EFFECTIVE, + "is_async": is_async, + } + + logger.info("!lsp|Generating optimized candidates…") + console.rule() + try: + response = self.make_ai_service_request("/rewrite_jit", payload=payload, timeout=60) + except requests.exceptions.RequestException as e: + logger.exception(f"Error generating optimized candidates: {e}") + ph("cli-optimize-error-caught", {"error": str(e)}) + return [] + + if response.status_code == 200: + optimizations_json = response.json()["optimizations"] + console.rule() + end_time = time.perf_counter() + logger.debug(f"!lsp|Generating possible optimizations took {end_time - start_time:.2f} seconds.") + return self._get_valid_candidates(optimizations_json, OptimizedCandidateSource.OPTIMIZE) + try: + error = response.json()["error"] + except Exception: + error = response.text + logger.error(f"Error generating optimized candidates: {response.status_code} - {error}") + ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error}) + console.rule() + return [] + def optimize_python_code_line_profiler( # noqa: D417 self, source_code: str, diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 5cab36eda..2bdc0b05c 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -453,6 +453,26 @@ def optimize_function(self) -> Result[BestOptimization, str]: revert_to_print=bool(get_pr_number()), ): console.rule() + # get new opt candidate + + jit_compiled_opt_candidate = self.aiservice_client.get_jit_rewritten_code( + code_context.read_writable_code.markdown, code_context.read_only_context_code, self.function_trace_id + ) + # write files + # Try to replace function with optimized code + self.replace_function_and_helpers_with_optimized_code( + code_context=code_context, + optimized_code=jit_compiled_opt_candidate[0].source_code, + original_helper_code=original_helper_code, + ) + # get codecontext + new_code_context = self.get_code_optimization_context().unwrap() + # unwrite files + self.write_code_and_helpers( + self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path + ) + # Generate tests and optimizations in parallel + future_tests = self.executor.submit(self.generate_and_instrument_tests, new_code_context) # Generate tests and optimizations in parallel future_tests = self.executor.submit(self.generate_and_instrument_tests, code_context) future_optimizations = self.executor.submit( From f945fefaecadfb9511b5a497d8016c105289e589 Mon Sep 17 00:00:00 2001 From: Codeflash Bot Date: Tue, 30 Dec 2025 17:28:19 -0800 Subject: [PATCH 05/27] bugfix --- codeflash/optimization/function_optimizer.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 2bdc0b05c..037460624 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -473,8 +473,6 @@ def optimize_function(self) -> Result[BestOptimization, str]: ) # Generate tests and optimizations in parallel future_tests = self.executor.submit(self.generate_and_instrument_tests, new_code_context) - # Generate tests and optimizations in parallel - future_tests = self.executor.submit(self.generate_and_instrument_tests, code_context) future_optimizations = self.executor.submit( self.generate_optimizations, read_writable_code=code_context.read_writable_code, From d555e8bd164ab2ecc5a4fdec6f0c1263aac4f9aa Mon Sep 17 00:00:00 2001 From: Codeflash Bot Date: Tue, 30 Dec 2025 18:42:46 -0800 Subject: [PATCH 06/27] mlx is problematic --- .../code_utils/instrument_existing_tests.py | 51 ++++++++++--------- 1 file changed, 26 insertions(+), 25 deletions(-) diff --git a/codeflash/code_utils/instrument_existing_tests.py b/codeflash/code_utils/instrument_existing_tests.py index d6c5f913b..e259edb8e 100644 --- a/codeflash/code_utils/instrument_existing_tests.py +++ b/codeflash/code_utils/instrument_existing_tests.py @@ -866,30 +866,30 @@ def _codeflash_jit_sync(): lineno=lineno, ) - # MLX sync block - mlx_sync = ast.Try( - body=[ - ast.Import(names=[ast.alias(name="mlx.core", asname="mx")], lineno=lineno), - ast.Expr( - value=ast.Call( - func=ast.Attribute(value=ast.Name(id="mx", ctx=ast.Load()), attr="synchronize", ctx=ast.Load()), - args=[], - keywords=[], - ) - ), - ], - handlers=[ - ast.ExceptHandler( - type=ast.Name(id="ImportError", ctx=ast.Load()), - name=None, - body=[ast.Pass(lineno=lineno)], - lineno=lineno, - ) - ], - orelse=[], - finalbody=[], - lineno=lineno, - ) + # # MLX sync block + # mlx_sync = ast.Try( + # body=[ + # ast.Import(names=[ast.alias(name="mlx.core", asname="mx")], lineno=lineno), + # ast.Expr( + # value=ast.Call( + # func=ast.Attribute(value=ast.Name(id="mx", ctx=ast.Load()), attr="synchronize", ctx=ast.Load()), + # args=[], + # keywords=[], + # ) + # ), + # ], + # handlers=[ + # ast.ExceptHandler( + # type=ast.Name(id="ImportError", ctx=ast.Load()), + # name=None, + # body=[ast.Pass(lineno=lineno)], + # lineno=lineno, + # ) + # ], + # orelse=[], + # finalbody=[], + # lineno=lineno, + # ) # TensorFlow sync block - sync XLA/TPU devices tensorflow_sync = ast.Try( @@ -969,7 +969,8 @@ def _codeflash_jit_sync(): args=ast.arguments( args=[], vararg=None, kwarg=None, posonlyargs=[], kwonlyargs=[], kw_defaults=[], defaults=[] ), - body=[pytorch_sync, jax_sync, mlx_sync, tensorflow_sync], + # body=[pytorch_sync, jax_sync, mlx_sync, tensorflow_sync], + body=[pytorch_sync, jax_sync, tensorflow_sync], decorator_list=[], returns=None, lineno=lineno, From 7bf6681ce13d372caff855ad33886d4f19f16aaa Mon Sep 17 00:00:00 2001 From: Codeflash Bot Date: Wed, 31 Dec 2025 18:27:49 -0800 Subject: [PATCH 07/27] failsafe --- codeflash/optimization/function_optimizer.py | 29 +++++++++++--------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 037460624..51c0c43bf 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -458,19 +458,22 @@ def optimize_function(self) -> Result[BestOptimization, str]: jit_compiled_opt_candidate = self.aiservice_client.get_jit_rewritten_code( code_context.read_writable_code.markdown, code_context.read_only_context_code, self.function_trace_id ) - # write files - # Try to replace function with optimized code - self.replace_function_and_helpers_with_optimized_code( - code_context=code_context, - optimized_code=jit_compiled_opt_candidate[0].source_code, - original_helper_code=original_helper_code, - ) - # get codecontext - new_code_context = self.get_code_optimization_context().unwrap() - # unwrite files - self.write_code_and_helpers( - self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path - ) + if len(jit_compiled_opt_candidate) > 0: + # write files + # Try to replace function with optimized code + self.replace_function_and_helpers_with_optimized_code( + code_context=code_context, + optimized_code=jit_compiled_opt_candidate[0].source_code, + original_helper_code=original_helper_code, + ) + # get codecontext + new_code_context = self.get_code_optimization_context().unwrap() + # unwrite files + self.write_code_and_helpers( + self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path + ) + else: + new_code_context = code_context # Generate tests and optimizations in parallel future_tests = self.executor.submit(self.generate_and_instrument_tests, new_code_context) future_optimizations = self.executor.submit( From f1e473576197a97b737d52b8df939f375acb2616 Mon Sep 17 00:00:00 2001 From: Codeflash Bot Date: Thu, 1 Jan 2026 14:25:34 -0800 Subject: [PATCH 08/27] comparator fix --- codeflash/verification/comparator.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/codeflash/verification/comparator.py b/codeflash/verification/comparator.py index d95c45012..704d19b3c 100644 --- a/codeflash/verification/comparator.py +++ b/codeflash/verification/comparator.py @@ -24,7 +24,6 @@ HAS_JAX = find_spec("jax") is not None HAS_XARRAY = find_spec("xarray") is not None HAS_TENSORFLOW = find_spec("tensorflow") is not None -HAS_MLX = find_spec("mlx") is not None def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001, ANN401, FBT002, PLR0911 @@ -139,17 +138,6 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001 return False return comparator(orig.to_list(), new.to_list(), superset_obj) - if HAS_MLX: - import mlx.core as mx # type: ignore # noqa: PGH003 - - if isinstance(orig, mx.array): - if orig.dtype != new.dtype: - return False - if orig.shape != new.shape: - return False - # MLX allclose handles NaN comparison via equal_nan parameter - return bool(mx.allclose(orig, new, equal_nan=True)) - if HAS_SQLALCHEMY: import sqlalchemy # type: ignore # noqa: PGH003 @@ -235,6 +223,10 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001 return False return all(comparator(orig[field], new[field], superset_obj) for field in orig.dtype.fields) + # Handle np.dtype instances (including numpy.dtypes.* classes like Float64DType, Int64DType, etc.) + if isinstance(orig, np.dtype): + return orig == new + if HAS_SCIPY and isinstance(orig, scipy.sparse.spmatrix): if orig.dtype != new.dtype: return False From 81d0599cfaddf473f3adde57ed195a94be8fbad0 Mon Sep 17 00:00:00 2001 From: aseembits93 Date: Mon, 12 Jan 2026 17:36:37 -0800 Subject: [PATCH 09/27] new gpu instrumentation --- .../code_utils/instrument_existing_tests.py | 517 ++++++++++-------- 1 file changed, 277 insertions(+), 240 deletions(-) diff --git a/codeflash/code_utils/instrument_existing_tests.py b/codeflash/code_utils/instrument_existing_tests.py index e259edb8e..5a1c32991 100644 --- a/codeflash/code_utils/instrument_existing_tests.py +++ b/codeflash/code_utils/instrument_existing_tests.py @@ -665,14 +665,50 @@ def inject_async_profiling_into_existing_test( return True, sort_imports(ast.unparse(tree), float_to_top=True) +def detect_frameworks_from_code(code: str) -> dict[str, str]: + """Detect GPU/device frameworks (torch, tensorflow, jax) used in the code by analyzing imports. + + Returns: + A dictionary mapping framework names to their import aliases. + For example: {"torch": "th", "tensorflow": "tf", "jax": "jax"} + + """ + frameworks: dict[str, str] = {} + try: + tree = ast.parse(code) + except SyntaxError: + return frameworks + + for node in ast.walk(tree): + if isinstance(node, ast.Import): + for alias in node.names: + module_name = alias.name.split(".")[0] + if module_name == "torch": + # Use asname if available, otherwise use the module name + frameworks["torch"] = alias.asname if alias.asname else module_name + elif module_name == "tensorflow": + frameworks["tensorflow"] = alias.asname if alias.asname else module_name + elif module_name == "jax": + frameworks["jax"] = alias.asname if alias.asname else module_name + elif isinstance(node, ast.ImportFrom): # noqa: SIM102 + if node.module: + module_name = node.module.split(".")[0] + if module_name == "torch" and "torch" not in frameworks: + frameworks["torch"] = module_name + elif module_name == "tensorflow" and "tensorflow" not in frameworks: + frameworks["tensorflow"] = module_name + elif module_name == "jax" and "jax" not in frameworks: + frameworks["jax"] = module_name + + return frameworks + + def inject_profiling_into_existing_test( test_path: Path, call_positions: list[CodePosition], function_to_optimize: FunctionToOptimize, tests_project_root: Path, mode: TestingMode = TestingMode.BEHAVIOR, - *, - jit_warmup: bool = False, ) -> tuple[bool, str | None]: if function_to_optimize.is_async: return inject_async_profiling_into_existing_test( @@ -681,6 +717,8 @@ def inject_profiling_into_existing_test( with test_path.open(encoding="utf8") as f: test_code = f.read() + + used_frameworks = detect_frameworks_from_code(test_code) try: tree = ast.parse(test_code) except SyntaxError: @@ -706,85 +744,91 @@ def inject_profiling_into_existing_test( ast.Import(names=[ast.alias(name="dill", asname="pickle")]), ] ) - additional_functions = [create_wrapper_function(mode, jit_warmup=jit_warmup)] - if jit_warmup: - additional_functions.insert(0, create_jit_sync_helper()) + # Add framework imports for GPU sync code (needed when framework is only imported via submodule) + for framework_name, framework_alias in used_frameworks.items(): + if framework_alias == framework_name: + # Only add import if we're using the framework name directly (not an alias) + # This handles cases like "from torch.nn import Module" where torch needs to be imported + new_imports.append(ast.Import(names=[ast.alias(name=framework_name)])) + else: + # If there's an alias, use it (e.g., "import torch as th") + new_imports.append(ast.Import(names=[ast.alias(name=framework_name, asname=framework_alias)])) + additional_functions = [create_wrapper_function(mode, used_frameworks)] tree.body = [*new_imports, *additional_functions, *tree.body] return True, sort_imports(ast.unparse(tree), float_to_top=True) -def create_jit_sync_helper() -> ast.FunctionDef: - """Create a helper function that synchronizes JIT-compiled frameworks (PyTorch, TensorFlow, JAX, MLX). - - This function generates AST for: - def _codeflash_jit_sync(): - try: - import torch - if torch.cuda.is_available(): - torch.cuda.synchronize() - if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): - torch.mps.synchronize() - except ImportError: - pass - try: - import jax - # Block until all JAX computations are complete - jax.effects_barrier() - except ImportError: - pass - try: - import mlx.core as mx - mx.synchronize() - except ImportError: - pass - # Note: TensorFlow in eager mode auto-syncs; Numba JIT is CPU-based and doesn't need sync - """ - lineno = 1 +def _create_device_sync_precompute_statements(used_frameworks: dict[str, str] | None) -> list[ast.stmt]: + """Create AST statements to pre-compute device sync conditions before profiling. - # PyTorch sync block - pytorch_sync = ast.Try( - body=[ - ast.Import(names=[ast.alias(name="torch")], lineno=lineno), - # if torch.cuda.is_available(): torch.cuda.synchronize() - ast.If( - test=ast.Call( - func=ast.Attribute( - value=ast.Attribute(value=ast.Name(id="torch", ctx=ast.Load()), attr="cuda", ctx=ast.Load()), - attr="is_available", - ctx=ast.Load(), - ), - args=[], - keywords=[], - ), - body=[ - ast.Expr( - value=ast.Call( + This moves the conditional checks (like is_available(), hasattr(), etc.) outside + the timing block to avoid their overhead affecting the measurements. + + Args: + used_frameworks: Dict mapping framework names to their import aliases + + Returns: + List of AST statements that pre-compute sync conditions into boolean variables + + """ + if not used_frameworks: + return [] + + precompute_statements: list[ast.stmt] = [] + + # PyTorch: pre-compute whether to sync CUDA or MPS + if "torch" in used_frameworks: + torch_alias = used_frameworks["torch"] + # _codeflash_should_sync_cuda = torch.cuda.is_available() and torch.cuda.is_initialized() + precompute_statements.append( + ast.Assign( + targets=[ast.Name(id="_codeflash_should_sync_cuda", ctx=ast.Store())], + value=ast.BoolOp( + op=ast.And(), + values=[ + ast.Call( func=ast.Attribute( value=ast.Attribute( - value=ast.Name(id="torch", ctx=ast.Load()), attr="cuda", ctx=ast.Load() + value=ast.Name(id=torch_alias, ctx=ast.Load()), attr="cuda", ctx=ast.Load() ), - attr="synchronize", + attr="is_available", ctx=ast.Load(), ), args=[], keywords=[], - ) - ) - ], - orelse=[], - lineno=lineno, - ), - # if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): torch.mps.synchronize() - ast.If( - test=ast.BoolOp( + ), + ast.Call( + func=ast.Attribute( + value=ast.Attribute( + value=ast.Name(id=torch_alias, ctx=ast.Load()), attr="cuda", ctx=ast.Load() + ), + attr="is_initialized", + ctx=ast.Load(), + ), + args=[], + keywords=[], + ), + ], + ), + lineno=1, + ) + ) + # _codeflash_should_sync_mps = (not _codeflash_should_sync_cuda and + # hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() and + # hasattr(torch.mps, 'synchronize')) + precompute_statements.append( + ast.Assign( + targets=[ast.Name(id="_codeflash_should_sync_mps", ctx=ast.Store())], + value=ast.BoolOp( op=ast.And(), values=[ + ast.UnaryOp(op=ast.Not(), operand=ast.Name(id="_codeflash_should_sync_cuda", ctx=ast.Load())), ast.Call( func=ast.Name(id="hasattr", ctx=ast.Load()), args=[ ast.Attribute( - value=ast.Name(id="torch", ctx=ast.Load()), attr="backends", ctx=ast.Load() + value=ast.Name(id=torch_alias, ctx=ast.Load()), attr="backends", ctx=ast.Load() ), ast.Constant(value="mps"), ], @@ -794,7 +838,7 @@ def _codeflash_jit_sync(): func=ast.Attribute( value=ast.Attribute( value=ast.Attribute( - value=ast.Name(id="torch", ctx=ast.Load()), attr="backends", ctx=ast.Load() + value=ast.Name(id=torch_alias, ctx=ast.Load()), attr="backends", ctx=ast.Load() ), attr="mps", ctx=ast.Load(), @@ -805,179 +849,192 @@ def _codeflash_jit_sync(): args=[], keywords=[], ), - ], - ), - body=[ - ast.Expr( - value=ast.Call( - func=ast.Attribute( - value=ast.Attribute( - value=ast.Name(id="torch", ctx=ast.Load()), attr="mps", ctx=ast.Load() + ast.Call( + func=ast.Name(id="hasattr", ctx=ast.Load()), + args=[ + ast.Attribute( + value=ast.Name(id=torch_alias, ctx=ast.Load()), attr="mps", ctx=ast.Load() ), - attr="synchronize", - ctx=ast.Load(), - ), - args=[], + ast.Constant(value="synchronize"), + ], keywords=[], - ) - ) - ], - orelse=[], - lineno=lineno, - ), - ], - handlers=[ - ast.ExceptHandler( - type=ast.Name(id="ImportError", ctx=ast.Load()), - name=None, - body=[ast.Pass(lineno=lineno)], - lineno=lineno, + ), + ], + ), + lineno=1, ) - ], - orelse=[], - finalbody=[], - lineno=lineno, - ) + ) - # JAX sync block - use effects_barrier() to wait for all computations - jax_sync = ast.Try( - body=[ - ast.Import(names=[ast.alias(name="jax")], lineno=lineno), - ast.Expr( + # JAX: pre-compute whether jax.block_until_ready exists + if "jax" in used_frameworks: + jax_alias = used_frameworks["jax"] + # _codeflash_should_sync_jax = hasattr(jax, 'block_until_ready') + precompute_statements.append( + ast.Assign( + targets=[ast.Name(id="_codeflash_should_sync_jax", ctx=ast.Store())], value=ast.Call( - func=ast.Attribute( - value=ast.Name(id="jax", ctx=ast.Load()), attr="effects_barrier", ctx=ast.Load() - ), - args=[], + func=ast.Name(id="hasattr", ctx=ast.Load()), + args=[ast.Name(id=jax_alias, ctx=ast.Load()), ast.Constant(value="block_until_ready")], keywords=[], - ) - ), - ], - handlers=[ - ast.ExceptHandler( - type=ast.Name(id="ImportError", ctx=ast.Load()), - name=None, - body=[ast.Pass(lineno=lineno)], - lineno=lineno, + ), + lineno=1, ) - ], - orelse=[], - finalbody=[], - lineno=lineno, - ) + ) - # # MLX sync block - # mlx_sync = ast.Try( - # body=[ - # ast.Import(names=[ast.alias(name="mlx.core", asname="mx")], lineno=lineno), - # ast.Expr( - # value=ast.Call( - # func=ast.Attribute(value=ast.Name(id="mx", ctx=ast.Load()), attr="synchronize", ctx=ast.Load()), - # args=[], - # keywords=[], - # ) - # ), - # ], - # handlers=[ - # ast.ExceptHandler( - # type=ast.Name(id="ImportError", ctx=ast.Load()), - # name=None, - # body=[ast.Pass(lineno=lineno)], - # lineno=lineno, - # ) - # ], - # orelse=[], - # finalbody=[], - # lineno=lineno, - # ) - - # TensorFlow sync block - sync XLA/TPU devices - tensorflow_sync = ast.Try( - body=[ - ast.Import(names=[ast.alias(name="tensorflow", asname="tf")], lineno=lineno), - # For TPU: tf.tpu.experimental.initialize_tpu_system if available - # For GPU: operations complete synchronously in eager mode but we can force sync - ast.If( - test=ast.Call( + # TensorFlow: pre-compute whether tf.test.experimental.sync_devices exists + if "tensorflow" in used_frameworks: + tf_alias = used_frameworks["tensorflow"] + # _codeflash_should_sync_tf = hasattr(tf.test.experimental, 'sync_devices') + precompute_statements.append( + ast.Assign( + targets=[ast.Name(id="_codeflash_should_sync_tf", ctx=ast.Store())], + value=ast.Call( func=ast.Name(id="hasattr", ctx=ast.Load()), args=[ - ast.Attribute(value=ast.Name(id="tf", ctx=ast.Load()), attr="config", ctx=ast.Load()), - ast.Constant(value="experimental"), + ast.Attribute( + value=ast.Attribute( + value=ast.Name(id=tf_alias, ctx=ast.Load()), attr="test", ctx=ast.Load() + ), + attr="experimental", + ctx=ast.Load(), + ), + ast.Constant(value="sync_devices"), ], keywords=[], ), - body=[ - # Get all physical devices and sync GPUs - ast.For( - target=ast.Name(id="_device", ctx=ast.Store()), - iter=ast.Call( - func=ast.Attribute( + lineno=1, + ) + ) + + return precompute_statements + + +def _create_device_sync_statements( + used_frameworks: dict[str, str] | None, + for_return_value: bool = False, # noqa: FBT001, FBT002 +) -> list[ast.stmt]: + """Create AST statements for device synchronization using pre-computed conditions. + + Args: + used_frameworks: Dict mapping framework names to their import aliases + (e.g., {'torch': 'th', 'tensorflow': 'tf', 'jax': 'jax'}) + for_return_value: If True, creates sync for after function call (includes JAX block_until_ready) + + Returns: + List of AST statements for device synchronization using pre-computed boolean variables + + """ + if not used_frameworks: + return [] + + sync_statements: list[ast.stmt] = [] + + # PyTorch synchronization using pre-computed conditions + if "torch" in used_frameworks: + torch_alias = used_frameworks["torch"] + # if _codeflash_should_sync_cuda: + # torch.cuda.synchronize() + # elif _codeflash_should_sync_mps: + # torch.mps.synchronize() + cuda_sync = ast.If( + test=ast.Name(id="_codeflash_should_sync_cuda", ctx=ast.Load()), + body=[ + ast.Expr( + value=ast.Call( + func=ast.Attribute( + value=ast.Attribute( + value=ast.Name(id=torch_alias, ctx=ast.Load()), attr="cuda", ctx=ast.Load() + ), + attr="synchronize", + ctx=ast.Load(), + ), + args=[], + keywords=[], + ) + ) + ], + orelse=[ + ast.If( + test=ast.Name(id="_codeflash_should_sync_mps", ctx=ast.Load()), + body=[ + ast.Expr( + value=ast.Call( + func=ast.Attribute( + value=ast.Attribute( + value=ast.Name(id=torch_alias, ctx=ast.Load()), attr="mps", ctx=ast.Load() + ), + attr="synchronize", + ctx=ast.Load(), + ), + args=[], + keywords=[], + ) + ) + ], + orelse=[], + ) + ], + ) + sync_statements.append(cuda_sync) + + # JAX synchronization (only after function call, using block_until_ready on return value) + if "jax" in used_frameworks and for_return_value: + jax_alias = used_frameworks["jax"] + # if _codeflash_should_sync_jax: + # jax.block_until_ready(return_value) + jax_sync = ast.If( + test=ast.Name(id="_codeflash_should_sync_jax", ctx=ast.Load()), + body=[ + ast.Expr( + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id=jax_alias, ctx=ast.Load()), attr="block_until_ready", ctx=ast.Load() + ), + args=[ast.Name(id="return_value", ctx=ast.Load())], + keywords=[], + ) + ) + ], + orelse=[], + ) + sync_statements.append(jax_sync) + + # TensorFlow synchronization using pre-computed condition + if "tensorflow" in used_frameworks: + tf_alias = used_frameworks["tensorflow"] + # if _codeflash_should_sync_tf: + # tf.test.experimental.sync_devices() + tf_sync = ast.If( + test=ast.Name(id="_codeflash_should_sync_tf", ctx=ast.Load()), + body=[ + ast.Expr( + value=ast.Call( + func=ast.Attribute( + value=ast.Attribute( value=ast.Attribute( - value=ast.Name(id="tf", ctx=ast.Load()), attr="config", ctx=ast.Load() + value=ast.Name(id=tf_alias, ctx=ast.Load()), attr="test", ctx=ast.Load() ), - attr="list_physical_devices", + attr="experimental", ctx=ast.Load(), ), - args=[ast.Constant(value="GPU")], - keywords=[], + attr="sync_devices", + ctx=ast.Load(), ), - body=[ - ast.Expr( - value=ast.Call( - func=ast.Attribute( - value=ast.Attribute( - value=ast.Attribute( - value=ast.Name(id="tf", ctx=ast.Load()), attr="test", ctx=ast.Load() - ), - attr="experimental", - ctx=ast.Load(), - ), - attr="sync_devices", - ctx=ast.Load(), - ), - args=[], - keywords=[], - ) - ) - ], - orelse=[], - lineno=lineno, + args=[], + keywords=[], ) - ], - orelse=[], - lineno=lineno, - ), - ], - handlers=[ - ast.ExceptHandler( - type=ast.Tuple( - elts=[ast.Name(id="ImportError", ctx=ast.Load()), ast.Name(id="AttributeError", ctx=ast.Load())], - ctx=ast.Load(), - ), - name=None, - body=[ast.Pass(lineno=lineno)], - lineno=lineno, - ) - ], - orelse=[], - finalbody=[], - lineno=lineno, - ) + ) + ], + orelse=[], + ) + sync_statements.append(tf_sync) - return ast.FunctionDef( - name="_codeflash_jit_sync", - args=ast.arguments( - args=[], vararg=None, kwarg=None, posonlyargs=[], kwonlyargs=[], kw_defaults=[], defaults=[] - ), - # body=[pytorch_sync, jax_sync, mlx_sync, tensorflow_sync], - body=[pytorch_sync, jax_sync, tensorflow_sync], - decorator_list=[], - returns=None, - lineno=lineno, - ) + return sync_statements -def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR, *, jit_warmup: bool = False) -> ast.FunctionDef: +def create_wrapper_function( + mode: TestingMode = TestingMode.BEHAVIOR, used_frameworks: dict[str, str] | None = None +) -> ast.FunctionDef: lineno = 1 wrapper_body: list[ast.stmt] = [ ast.Assign( @@ -1138,6 +1195,8 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR, *, jit_war ast.Assign( targets=[ast.Name(id="exception", ctx=ast.Store())], value=ast.Constant(value=None), lineno=lineno + 10 ), + # Pre-compute device sync conditions before profiling to avoid overhead during timing + *_create_device_sync_precompute_statements(used_frameworks), ast.Expr( value=ast.Call( func=ast.Attribute(value=ast.Name(id="gc", ctx=ast.Load()), attr="disable", ctx=ast.Load()), @@ -1148,19 +1207,8 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR, *, jit_war ), ast.Try( body=[ - # Sync before starting timer (ensure previous operations are complete) - *( - [ - ast.Expr( - value=ast.Call( - func=ast.Name(id="_codeflash_jit_sync", ctx=ast.Load()), args=[], keywords=[] - ), - lineno=lineno + 11, - ) - ] - if jit_warmup - else [] - ), + # Pre-sync: synchronize device before starting timer + *_create_device_sync_statements(used_frameworks, for_return_value=False), ast.Assign( targets=[ast.Name(id="counter", ctx=ast.Store())], value=ast.Call( @@ -1181,19 +1229,8 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR, *, jit_war ), lineno=lineno + 12, ), - # Sync after function call to ensure all GPU/async operations complete before stopping timer - *( - [ - ast.Expr( - value=ast.Call( - func=ast.Name(id="_codeflash_jit_sync", ctx=ast.Load()), args=[], keywords=[] - ), - lineno=lineno + 12, - ) - ] - if jit_warmup - else [] - ), + # Post-sync: synchronize device after function call to ensure all device work is complete + *_create_device_sync_statements(used_frameworks, for_return_value=True), ast.Assign( targets=[ast.Name(id="codeflash_duration", ctx=ast.Store())], value=ast.BinOp( From 6e19826fb8912ed5e01f930647ce5bc14918dff4 Mon Sep 17 00:00:00 2001 From: aseembits93 Date: Mon, 12 Jan 2026 17:37:15 -0800 Subject: [PATCH 10/27] bug fix --- codeflash/optimization/function_optimizer.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 51c0c43bf..d827d1328 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -1251,7 +1251,6 @@ def instrument_existing_tests(self, function_to_all_tests: dict[str, set[Functio call_positions=[test.position for test in tests_in_file_list], function_to_optimize=self.function_to_optimize, tests_project_root=self.test_cfg.tests_project_rootdir, - jit_warmup=True, ) if not success: continue @@ -1261,7 +1260,6 @@ def instrument_existing_tests(self, function_to_all_tests: dict[str, set[Functio call_positions=[test.position for test in tests_in_file_list], function_to_optimize=self.function_to_optimize, tests_project_root=self.test_cfg.tests_project_rootdir, - jit_warmup=True, ) if not success: continue From d7a69abe65348c0ee61b354c43615a74302ad254 Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Wed, 14 Jan 2026 05:56:07 +0000 Subject: [PATCH 11/27] improve comparator --- codeflash/verification/comparator.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/codeflash/verification/comparator.py b/codeflash/verification/comparator.py index 704d19b3c..2241531c5 100644 --- a/codeflash/verification/comparator.py +++ b/codeflash/verification/comparator.py @@ -227,6 +227,19 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001 if isinstance(orig, np.dtype): return orig == new + # Handle numpy random generators + if isinstance(orig, np.random.Generator): + # Compare the underlying BitGenerator state + orig_state = orig.bit_generator.state + new_state = new.bit_generator.state + return comparator(orig_state, new_state, superset_obj) + + if isinstance(orig, np.random.RandomState): + # Compare the internal state + orig_state = orig.get_state(legacy=False) + new_state = new.get_state(legacy=False) + return comparator(orig_state, new_state, superset_obj) + if HAS_SCIPY and isinstance(orig, scipy.sparse.spmatrix): if orig.dtype != new.dtype: return False @@ -283,6 +296,9 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001 if isinstance(orig, torch.dtype): return orig == new + if isinstance(orig, torch.device): + return orig == new + if HAS_PYRSISTENT: import pyrsistent # type: ignore # noqa: PGH003 From 722f2527c9e0e47931731a36fab771d5958c7d0f Mon Sep 17 00:00:00 2001 From: aseembits93 Date: Wed, 14 Jan 2026 12:19:07 -0800 Subject: [PATCH 12/27] reverting tests --- tests/test_instrument_tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_instrument_tests.py b/tests/test_instrument_tests.py index d6439550d..a74f41533 100644 --- a/tests/test_instrument_tests.py +++ b/tests/test_instrument_tests.py @@ -193,7 +193,7 @@ def test_sort(self): Path(f.name), [CodePosition(9, 17), CodePosition(13, 17), CodePosition(17, 17)], func, - Path(f.name).parent, jit_warmup=True + Path(f.name).parent, ) os.chdir(original_cwd) assert success From c597702bbaadb3cc300d240a2319fc1c80ff5dcd Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Wed, 14 Jan 2026 13:19:15 -0800 Subject: [PATCH 13/27] Cleaning up --- codeflash/api/aiservice.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index a3d8adcf7..a43934d2c 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -55,7 +55,7 @@ def get_aiservice_base_url(self) -> str: logger.info("Using local AI Service at http://localhost:8000") console.rule() return "http://localhost:8000" - return "http://localhost:8000" + return "https://app.codeflash.ai" def make_ai_service_request( self, @@ -230,7 +230,7 @@ def get_jit_rewritten_code( # noqa: D417 "current_username": get_last_commit_author_if_pr_exists(None), "repo_owner": git_repo_owner, "repo_name": git_repo_name, - "n_candidates": N_CANDIDATES_EFFECTIVE, + "n_candidates": 1, "is_async": is_async, } @@ -239,7 +239,7 @@ def get_jit_rewritten_code( # noqa: D417 try: response = self.make_ai_service_request("/rewrite_jit", payload=payload, timeout=60) except requests.exceptions.RequestException as e: - logger.exception(f"Error generating optimized candidates: {e}") + logger.exception(f"Error generating jit rewritten candidate: {e}") ph("cli-optimize-error-caught", {"error": str(e)}) return [] @@ -247,13 +247,13 @@ def get_jit_rewritten_code( # noqa: D417 optimizations_json = response.json()["optimizations"] console.rule() end_time = time.perf_counter() - logger.debug(f"!lsp|Generating possible optimizations took {end_time - start_time:.2f} seconds.") - return self._get_valid_candidates(optimizations_json, OptimizedCandidateSource.OPTIMIZE) + logger.debug(f"!lsp|Generating jit rewritten code took {end_time - start_time:.2f} seconds.") + return self._get_valid_candidates(optimizations_json, OptimizedCandidateSource.JIT_REWRITE) try: error = response.json()["error"] except Exception: error = response.text - logger.error(f"Error generating optimized candidates: {response.status_code} - {error}") + logger.error(f"Error generating jit rewritten candidate: {response.status_code} - {error}") ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error}) console.rule() return [] From 2441426f9f5741eaae8152b8a3bbd02f7588f553 Mon Sep 17 00:00:00 2001 From: aseembits93 Date: Wed, 14 Jan 2026 13:20:15 -0800 Subject: [PATCH 14/27] jit rewrite type --- codeflash/models/models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 133299589..d850e3827 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -488,6 +488,7 @@ class OptimizedCandidateSource(str, Enum): REFINE = "REFINE" REPAIR = "REPAIR" ADAPTIVE = "ADAPTIVE" + JIT_REWRITE = "JIT_REWRITE" @dataclass(frozen=True) From d0c66436b551d57b594330f72f30039b51c8125a Mon Sep 17 00:00:00 2001 From: aseembits93 Date: Wed, 14 Jan 2026 18:46:44 -0800 Subject: [PATCH 15/27] cleaning up --- codeflash/api/aiservice.py | 19 +- codeflash/code_utils/code_extractor.py | 12 +- codeflash/optimization/function_optimizer.py | 48 +++-- tests/test_is_numerical_code.py | 197 +++++++++++++++++++ 4 files changed, 233 insertions(+), 43 deletions(-) diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index a43934d2c..1e44c9b6b 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -192,24 +192,14 @@ def optimize_python_code( # noqa: D417 return [] def get_jit_rewritten_code( # noqa: D417 - self, - source_code: str, - dependency_code: str, - trace_id: str, - num_candidates: int = 1, - experiment_metadata: ExperimentMetadata | None = None, - *, - is_async: bool = False, + self, source_code: str, trace_id: str ) -> list[OptimizedCandidate]: """Optimize the given python code for performance by making a request to the Django endpoint. Parameters ---------- - source_code (str): The python code to optimize. - - dependency_code (str): The dependency code used as read-only context for the optimization - trace_id (str): Trace id of optimization run - - num_candidates (int): Number of optimization variants to generate. Default is 10. - - experiment_metadata (Optional[ExperimentalMetadata, None]): Any available experiment metadata for this optimization Returns ------- @@ -221,17 +211,10 @@ def get_jit_rewritten_code( # noqa: D417 payload = { "source_code": source_code, - "dependency_code": dependency_code, - "num_variants": num_candidates, "trace_id": trace_id, - "python_version": platform.python_version(), - "experiment_metadata": experiment_metadata, - "codeflash_version": codeflash_version, "current_username": get_last_commit_author_if_pr_exists(None), "repo_owner": git_repo_owner, "repo_name": git_repo_name, - "n_candidates": 1, - "is_async": is_async, } logger.info("!lsp|Generating optimized candidates…") diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index 935d0a369..66dfd5eb4 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -1298,7 +1298,7 @@ def _find_function_node(tree: ast.Module, name_parts: list[str]) -> ast.Function return None -def is_numerical_code(code_string: str, function_name: str) -> bool: +def is_numerical_code(code_string: str, function_name: str | None = None) -> bool: """Check if a function uses numerical computing libraries. Detects usage of numpy, torch, numba, jax, tensorflow, scipy, and math libraries @@ -1339,6 +1339,13 @@ def is_numerical_code(code_string: str, function_name: str) -> bool: except SyntaxError: return False + # Collect names that reference numerical modules from imports + numerical_names, modules_used = _collect_numerical_imports(tree) + + if not function_name: + # Return True if modules used and (numba available or modules don't all require numba) + return bool(modules_used) and (has_numba or not modules_used.issubset(NUMBA_REQUIRED_MODULES)) + # Split the function name to handle class methods name_parts = function_name.split(".") @@ -1347,9 +1354,6 @@ def is_numerical_code(code_string: str, function_name: str) -> bool: if target_function is None: return False - # Collect names that reference numerical modules from imports - numerical_names, modules_used = _collect_numerical_imports(tree) - # Check if the function body uses any numerical library checker = NumericalUsageChecker(numerical_names) checker.visit(target_function) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 95e216c0e..07bc2b863 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -24,7 +24,7 @@ from codeflash.benchmarking.utils import process_benchmark_data from codeflash.cli_cmds.console import code_print, console, logger, lsp_log, progress_bar from codeflash.code_utils import env_utils -from codeflash.code_utils.code_extractor import get_opt_review_metrics +from codeflash.code_utils.code_extractor import get_opt_review_metrics, is_numerical_code from codeflash.code_utils.code_replacer import ( add_custom_marker_to_all_tests, modify_autouse_fixture, @@ -581,6 +581,10 @@ def generate_and_instrument_tests( ) ) + # def_is_numerical_code_fto_helpers(): + # #get mapping of code string to function names + # # run is_numerical_code on this mapping and return if any of these is true + # note: this isn't called by the lsp, only called by cli def optimize_function(self) -> Result[BestOptimization, str]: initialization_result = self.can_be_optimized() @@ -600,27 +604,29 @@ def optimize_function(self) -> Result[BestOptimization, str]: revert_to_print=bool(get_pr_number()), ): console.rule() - # get new opt candidate - - jit_compiled_opt_candidate = self.aiservice_client.get_jit_rewritten_code( - code_context.read_writable_code.markdown, code_context.read_only_context_code, self.function_trace_id - ) - if len(jit_compiled_opt_candidate) > 0: - # write files - # Try to replace function with optimized code - self.replace_function_and_helpers_with_optimized_code( - code_context=code_context, - optimized_code=jit_compiled_opt_candidate[0].source_code, - original_helper_code=original_helper_code, - ) - # get codecontext - new_code_context = self.get_code_optimization_context().unwrap() - # unwrite files - self.write_code_and_helpers( - self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path + new_code_context = code_context + if is_numerical_code( + code_string=code_context.read_writable_code.flat + ): # if the code is numerical in nature (uses numpy/tensorflow/math/pytorch/jax) + jit_compiled_opt_candidate = self.aiservice_client.get_jit_rewritten_code( + code_context.read_writable_code.markdown, + code_context.read_only_context_code, + self.function_trace_id, ) - else: - new_code_context = code_context + if jit_compiled_opt_candidate: # jit rewrite was successful + # write files + # Try to replace function with optimized code + self.replace_function_and_helpers_with_optimized_code( + code_context=code_context, + optimized_code=jit_compiled_opt_candidate[0].source_code, + original_helper_code=original_helper_code, + ) + # get code context + new_code_context = self.get_code_optimization_context().unwrap() + # unwrite files + self.write_code_and_helpers( + self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path + ) # Generate tests and optimizations in parallel future_tests = self.executor.submit(self.generate_and_instrument_tests, new_code_context) future_optimizations = self.executor.submit( diff --git a/tests/test_is_numerical_code.py b/tests/test_is_numerical_code.py index 97b4b304e..5fedce8d1 100644 --- a/tests/test_is_numerical_code.py +++ b/tests/test_is_numerical_code.py @@ -691,6 +691,203 @@ def method(self): assert is_numerical_code(code, "ClassB.method") is False +@patch("codeflash.code_utils.code_extractor.has_numba", True) +class TestEmptyFunctionName: + """Test behavior when function_name is empty/None. + + When function_name is not provided, the function should just check for the + presence of numerical imports without looking at a specific function body. + """ + + def test_empty_string_with_numpy_import(self): + """Empty function_name with numpy import should return True.""" + code = """ +import numpy as np +def some_func(): + pass +""" + assert is_numerical_code(code, "") is True + + def test_none_with_numpy_import(self): + """None function_name with numpy import should return True.""" + code = """ +import numpy as np +def some_func(): + pass +""" + assert is_numerical_code(code, None) is True + + def test_empty_string_with_torch_import(self): + """Empty function_name with torch import should return True.""" + code = """ +import torch +def some_func(): + pass +""" + assert is_numerical_code(code, "") is True + + def test_empty_string_with_multiple_numerical_imports(self): + """Empty function_name with multiple numerical imports should return True.""" + code = """ +import numpy as np +import torch +from scipy import stats +def some_func(): + pass +""" + assert is_numerical_code(code, "") is True + + def test_empty_string_without_numerical_imports(self): + """Empty function_name without numerical imports should return False.""" + code = """ +import os +import json +from pathlib import Path + +def some_func(): + pass +""" + assert is_numerical_code(code, "") is False + + def test_none_without_numerical_imports(self): + """None function_name without numerical imports should return False.""" + code = """ +import os +def some_func(): + pass +""" + assert is_numerical_code(code, None) is False + + def test_empty_string_with_jax_import(self): + """Empty function_name with jax import should return True.""" + code = """ +import jax +import jax.numpy as jnp +""" + assert is_numerical_code(code, "") is True + + def test_empty_string_with_tensorflow_import(self): + """Empty function_name with tensorflow import should return True.""" + code = """ +import tensorflow as tf +""" + assert is_numerical_code(code, "") is True + + def test_empty_string_with_math_import(self): + """Empty function_name with math import should return True (numba available).""" + code = """ +import math +def calculate(x): + return math.sqrt(x) +""" + assert is_numerical_code(code, "") is True + + def test_empty_string_with_scipy_submodule(self): + """Empty function_name with scipy submodule import should return True.""" + code = """ +from scipy.stats import norm +""" + assert is_numerical_code(code, "") is True + + def test_empty_string_with_numba_import(self): + """Empty function_name with numba import should return True.""" + code = """ +from numba import jit +""" + assert is_numerical_code(code, "") is True + + def test_empty_code_with_empty_function_name(self): + """Empty code with empty function_name should return False.""" + assert is_numerical_code("", "") is False + + def test_syntax_error_with_empty_function_name(self): + """Syntax error code with empty function_name should return False.""" + code = """ +def broken( + import numpy +""" + assert is_numerical_code(code, "") is False + + +@patch("codeflash.code_utils.code_extractor.has_numba", False) +class TestEmptyFunctionNameWithoutNumba: + """Test empty function_name behavior when numba is NOT available. + + When numba is not installed, code using only math/numpy/scipy should return False, + since numba is required to optimize such code. Code using torch/jax/tensorflow/numba + should still return True. + """ + + def test_empty_string_numpy_returns_false_without_numba(self): + """Empty function_name with numpy should return False when numba unavailable.""" + code = """ +import numpy as np +def some_func(): + pass +""" + assert is_numerical_code(code, "") is False + + def test_empty_string_math_returns_false_without_numba(self): + """Empty function_name with math should return False when numba unavailable.""" + code = """ +import math +""" + assert is_numerical_code(code, "") is False + + def test_empty_string_scipy_returns_false_without_numba(self): + """Empty function_name with scipy should return False when numba unavailable.""" + code = """ +from scipy import stats +""" + assert is_numerical_code(code, "") is False + + def test_empty_string_torch_returns_true_without_numba(self): + """Empty function_name with torch should return True even without numba.""" + code = """ +import torch +""" + assert is_numerical_code(code, "") is True + + def test_empty_string_jax_returns_true_without_numba(self): + """Empty function_name with jax should return True even without numba.""" + code = """ +import jax +""" + assert is_numerical_code(code, "") is True + + def test_empty_string_tensorflow_returns_true_without_numba(self): + """Empty function_name with tensorflow should return True even without numba.""" + code = """ +import tensorflow as tf +""" + assert is_numerical_code(code, "") is True + + def test_empty_string_numba_import_returns_true_without_numba(self): + """Empty function_name with numba import should return True.""" + code = """ +from numba import jit +""" + assert is_numerical_code(code, "") is True + + def test_empty_string_numpy_and_torch_returns_true_without_numba(self): + """Empty function_name with numpy+torch should return True (torch doesn't need numba).""" + code = """ +import numpy as np +import torch +""" + # Returns True because torch is in modules_used and doesn't require numba + assert is_numerical_code(code, "") is True + + def test_empty_string_math_and_scipy_returns_false_without_numba(self): + """Empty function_name with only math+scipy should return False without numba.""" + code = """ +import math +from scipy import stats +""" + # Both math and scipy are in NUMBA_REQUIRED_MODULES + assert is_numerical_code(code, "") is False + + @patch("codeflash.code_utils.code_extractor.has_numba", False) class TestNumbaNotAvailable: """Test behavior when numba is NOT available in the environment. From 256332fa2cbc51880e08920bdf35833f60e6dbe9 Mon Sep 17 00:00:00 2001 From: aseembits93 Date: Wed, 14 Jan 2026 19:00:02 -0800 Subject: [PATCH 16/27] expected payload is correct --- codeflash/api/aiservice.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index 1e44c9b6b..a20f20cbd 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -212,6 +212,8 @@ def get_jit_rewritten_code( # noqa: D417 payload = { "source_code": source_code, "trace_id": trace_id, + "dependency_code": "", # dummy value to please the api endpoint + "python_version": "3.12.1", # dummy value to please the api endpoint "current_username": get_last_commit_author_if_pr_exists(None), "repo_owner": git_repo_owner, "repo_name": git_repo_name, From 6c83c349b34c9c85ac70a802d7c6a269af1e26ab Mon Sep 17 00:00:00 2001 From: aseembits93 Date: Thu, 15 Jan 2026 10:33:08 -0800 Subject: [PATCH 17/27] almost ready --- codeflash/api/aiservice.py | 2 +- codeflash/optimization/function_optimizer.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index a20f20cbd..cc6db6e5d 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -194,7 +194,7 @@ def optimize_python_code( # noqa: D417 def get_jit_rewritten_code( # noqa: D417 self, source_code: str, trace_id: str ) -> list[OptimizedCandidate]: - """Optimize the given python code for performance by making a request to the Django endpoint. + """Rewrite the given python code for performance via jit compilation by making a request to the Django endpoint. Parameters ---------- diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 07bc2b863..3863dc983 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -609,9 +609,7 @@ def optimize_function(self) -> Result[BestOptimization, str]: code_string=code_context.read_writable_code.flat ): # if the code is numerical in nature (uses numpy/tensorflow/math/pytorch/jax) jit_compiled_opt_candidate = self.aiservice_client.get_jit_rewritten_code( - code_context.read_writable_code.markdown, - code_context.read_only_context_code, - self.function_trace_id, + code_context.read_writable_code.markdown, self.function_trace_id ) if jit_compiled_opt_candidate: # jit rewrite was successful # write files From 97c9a10bd0ec2785a5aed59696f64f50642d7138 Mon Sep 17 00:00:00 2001 From: aseembits93 Date: Thu, 15 Jan 2026 10:40:38 -0800 Subject: [PATCH 18/27] cleaning up --- codeflash/optimization/function_optimizer.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 3863dc983..0fbefee0d 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -581,10 +581,6 @@ def generate_and_instrument_tests( ) ) - # def_is_numerical_code_fto_helpers(): - # #get mapping of code string to function names - # # run is_numerical_code on this mapping and return if any of these is true - # note: this isn't called by the lsp, only called by cli def optimize_function(self) -> Result[BestOptimization, str]: initialization_result = self.can_be_optimized() From ecce649a7a21bce95d9c11e2b500d4f8dd542d92 Mon Sep 17 00:00:00 2001 From: aseembits93 Date: Thu, 15 Jan 2026 14:43:32 -0800 Subject: [PATCH 19/27] add prompt conditionally --- codeflash/api/aiservice.py | 4 ++++ codeflash/optimization/function_optimizer.py | 11 +++++++---- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index cc6db6e5d..65c4c20a0 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -129,6 +129,7 @@ def optimize_python_code( # noqa: D417 *, is_async: bool = False, n_candidates: int = 5, + is_numerical_code: bool | None = None, ) -> list[OptimizedCandidate]: """Optimize the given python code for performance by making a request to the Django endpoint. @@ -164,6 +165,7 @@ def optimize_python_code( # noqa: D417 "is_async": is_async, "call_sequence": self.get_next_sequence(), "n_candidates": n_candidates, + "is_numerical_code": is_numerical_code, } logger.debug(f"Sending optimize request: trace_id={trace_id}, n_candidates={payload['n_candidates']}") @@ -251,6 +253,7 @@ def optimize_python_code_line_profiler( # noqa: D417 line_profiler_results: str, n_candidates: int, experiment_metadata: ExperimentMetadata | None = None, + is_numerical_code: bool | None = None, # noqa: FBT001 ) -> list[OptimizedCandidate]: """Optimize the given python code for performance using line profiler results. @@ -285,6 +288,7 @@ def optimize_python_code_line_profiler( # noqa: D417 "experiment_metadata": experiment_metadata, "codeflash_version": codeflash_version, "call_sequence": self.get_next_sequence(), + "is_numerical_code": is_numerical_code, } try: diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 0fbefee0d..d534fc297 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -465,6 +465,7 @@ def __init__( self.future_adaptive_optimizations: list[concurrent.futures.Future] = [] self.repair_counter = 0 # track how many repairs we did for each function self.adaptive_optimization_counter = 0 # track how many adaptive optimizations we did for each function + self.is_numerical_code: bool | None = None def can_be_optimized(self) -> Result[tuple[bool, CodeOptimizationContext, dict[Path, str]], str]: should_run_experiment = self.experiment_id is not None @@ -587,7 +588,7 @@ def optimize_function(self) -> Result[BestOptimization, str]: if not is_successful(initialization_result): return Failure(initialization_result.failure()) should_run_experiment, code_context, original_helper_code = initialization_result.unwrap() - + self.is_numerical_code = is_numerical_code(code_string=code_context.read_writable_code.flat) code_print( code_context.read_writable_code.flat, file_name=self.function_to_optimize.file_path, @@ -601,9 +602,7 @@ def optimize_function(self) -> Result[BestOptimization, str]: ): console.rule() new_code_context = code_context - if is_numerical_code( - code_string=code_context.read_writable_code.flat - ): # if the code is numerical in nature (uses numpy/tensorflow/math/pytorch/jax) + if self.is_numerical_code: # if the code is numerical in nature (uses numpy/tensorflow/math/pytorch/jax) jit_compiled_opt_candidate = self.aiservice_client.get_jit_rewritten_code( code_context.read_writable_code.markdown, self.function_trace_id ) @@ -628,6 +627,7 @@ def optimize_function(self) -> Result[BestOptimization, str]: read_writable_code=code_context.read_writable_code, read_only_context_code=code_context.read_only_context_code, run_experiment=should_run_experiment, + is_numerical_code=self.is_numerical_code, ) concurrent.futures.wait([future_tests, future_optimizations]) @@ -1127,6 +1127,7 @@ def determine_best_candidate( ) if self.experiment_id else None, + is_numerical_code=self.is_numerical_code, ) processor = CandidateProcessor( @@ -1594,6 +1595,7 @@ def generate_optimizations( read_writable_code: CodeStringsMarkdown, read_only_context_code: str, run_experiment: bool = False, # noqa: FBT001, FBT002 + is_numerical_code: bool | None = None, # noqa: FBT001 ) -> Result[tuple[OptimizationSet, str], str]: """Generate optimization candidates for the function. Backend handles multi-model diversity.""" n_candidates = get_effort_value(EffortKeys.N_OPTIMIZER_CANDIDATES, self.effort) @@ -1605,6 +1607,7 @@ def generate_optimizations( ExperimentMetadata(id=self.experiment_id, group="control") if run_experiment else None, is_async=self.function_to_optimize.is_async, n_candidates=n_candidates, + is_numerical_code=is_numerical_code, ) future_references = self.executor.submit( From 184b8533e0e181f3b7fac51603085404f3afb440 Mon Sep 17 00:00:00 2001 From: aseembits93 Date: Thu, 15 Jan 2026 17:37:52 -0800 Subject: [PATCH 20/27] some jit heavy functions --- code_to_optimize/discrete_riccati.py | 100 ---- code_to_optimize/sample_jit_code.py | 476 ++++++++++++++++++ .../tests/pytest/test_gridmake2.py | 216 -------- .../tests/pytest/test_gridmake2_torch.py | 267 ---------- .../tests/pytest/test_jax_jit_code.py | 256 ++++++++++ .../tests/pytest/test_numba_jit_code.py | 242 +++++++++ .../tests/pytest/test_tensorflow_jit_code.py | 296 +++++++++++ .../tests/pytest/test_torch_jit_code.py | 285 +++++++++++ 8 files changed, 1555 insertions(+), 583 deletions(-) delete mode 100644 code_to_optimize/discrete_riccati.py create mode 100644 code_to_optimize/sample_jit_code.py delete mode 100644 code_to_optimize/tests/pytest/test_gridmake2.py delete mode 100644 code_to_optimize/tests/pytest/test_gridmake2_torch.py create mode 100644 code_to_optimize/tests/pytest/test_jax_jit_code.py create mode 100644 code_to_optimize/tests/pytest/test_numba_jit_code.py create mode 100644 code_to_optimize/tests/pytest/test_tensorflow_jit_code.py create mode 100644 code_to_optimize/tests/pytest/test_torch_jit_code.py diff --git a/code_to_optimize/discrete_riccati.py b/code_to_optimize/discrete_riccati.py deleted file mode 100644 index 5a032b9c5..000000000 --- a/code_to_optimize/discrete_riccati.py +++ /dev/null @@ -1,100 +0,0 @@ -""" -Utility functions used in CompEcon - -Based routines found in the CompEcon toolbox by Miranda and Fackler. - -References ----------- -Miranda, Mario J, and Paul L Fackler. Applied Computational Economics -and Finance, MIT Press, 2002. - -""" -from functools import reduce -import numpy as np -import torch - -def _gridmake2(x1, x2): - """ - Expands two vectors (or matrices) into a matrix where rows span the - cartesian product of combinations of the input arrays. Each column of the - input arrays will correspond to one column of the output matrix. - - Parameters - ---------- - x1 : np.ndarray - First vector to be expanded. - - x2 : np.ndarray - Second vector to be expanded. - - Returns - ------- - out : np.ndarray - The cartesian product of combinations of the input arrays. - - Notes - ----- - Based of original function ``gridmake2`` in CompEcon toolbox by - Miranda and Fackler. - - References - ---------- - Miranda, Mario J, and Paul L Fackler. Applied Computational Economics - and Finance, MIT Press, 2002. - - """ - if x1.ndim == 1 and x2.ndim == 1: - return np.column_stack([np.tile(x1, x2.shape[0]), - np.repeat(x2, x1.shape[0])]) - elif x1.ndim > 1 and x2.ndim == 1: - first = np.tile(x1, (x2.shape[0], 1)) - second = np.repeat(x2, x1.shape[0]) - return np.column_stack([first, second]) - else: - raise NotImplementedError("Come back here") - - -def _gridmake2_torch(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: - """ - PyTorch version of _gridmake2. - - Expands two tensors into a matrix where rows span the cartesian product - of combinations of the input tensors. Each column of the input tensors - will correspond to one column of the output matrix. - - Parameters - ---------- - x1 : torch.Tensor - First tensor to be expanded. - - x2 : torch.Tensor - Second tensor to be expanded. - - Returns - ------- - out : torch.Tensor - The cartesian product of combinations of the input tensors. - - Notes - ----- - Based on original function ``gridmake2`` in CompEcon toolbox by - Miranda and Fackler. - - References - ---------- - Miranda, Mario J, and Paul L Fackler. Applied Computational Economics - and Finance, MIT Press, 2002. - - """ - if x1.dim() == 1 and x2.dim() == 1: - # tile x1 by x2.shape[0] times, repeat_interleave x2 by x1.shape[0] - first = x1.tile(x2.shape[0]) - second = x2.repeat_interleave(x1.shape[0]) - return torch.column_stack([first, second]) - elif x1.dim() > 1 and x2.dim() == 1: - # tile x1 along first dimension - first = x1.tile(x2.shape[0], 1) - second = x2.repeat_interleave(x1.shape[0]) - return torch.column_stack([first, second]) - else: - raise NotImplementedError("Come back here") diff --git a/code_to_optimize/sample_jit_code.py b/code_to_optimize/sample_jit_code.py new file mode 100644 index 000000000..50232a07e --- /dev/null +++ b/code_to_optimize/sample_jit_code.py @@ -0,0 +1,476 @@ +from functools import partial + +import jax.numpy as jnp +import numpy as np +import tensorflow as tf +import torch +from jax import lax + + +def tridiagonal_solve(a: np.ndarray, b: np.ndarray, c: np.ndarray, d: np.ndarray) -> np.ndarray: + n = len(b) + + c_prime = np.empty(n - 1, dtype=np.float64) + d_prime = np.empty(n, dtype=np.float64) + x = np.empty(n, dtype=np.float64) + + # Alias arrays to local variables to avoid repeated attribute lookups + a_arr = a + b_arr = b + c_arr = c + d_arr = d + cp = c_prime + dp = d_prime + x_arr = x + + # First element + prev_cprime = c_arr[0] / b_arr[0] + cp[0] = prev_cprime + prev_dprime = d_arr[0] / b_arr[0] + dp[0] = prev_dprime + + # Forward sweep (compute c_prime and d_prime) + + for i in range(1, n - 1): + ai_1 = a_arr[i - 1] + denom = b_arr[i] - ai_1 * prev_cprime + curr_cprime = c_arr[i] / denom + curr_dprime = (d_arr[i] - ai_1 * prev_dprime) / denom + cp[i] = curr_cprime + dp[i] = curr_dprime + prev_cprime = curr_cprime + prev_dprime = curr_dprime + + # Last d_prime entry + denom = b_arr[n - 1] - a_arr[n - 2] * prev_cprime + dp[n - 1] = (d_arr[n - 1] - a_arr[n - 2] * prev_dprime) / denom + + # Back substitution using a scalar for the "next x" value + prev_x = dp[n - 1] + x_arr[n - 1] = prev_x + for i in range(n - 2, -1, -1): + xi = dp[i] - cp[i] * prev_x + x_arr[i] = xi + prev_x = xi + + return x + + +def leapfrog_integration( + positions: np.ndarray, + velocities: np.ndarray, + masses: np.ndarray, + dt: float, + n_steps: int, + softening: float = 0.01 +) -> tuple[np.ndarray, np.ndarray]: + n_particles = len(masses) + pos = positions.copy() + vel = velocities.copy() + acc = np.zeros_like(pos) + + G = 1.0 + + for step in range(n_steps): + acc.fill(0.0) + + for i in range(n_particles): + for j in range(i + 1, n_particles): + dx = pos[j, 0] - pos[i, 0] + dy = pos[j, 1] - pos[i, 1] + dz = pos[j, 2] - pos[i, 2] + + dist_sq = dx * dx + dy * dy + dz * dz + softening * softening + dist = np.sqrt(dist_sq) + dist_cubed = dist_sq * dist + + force_over_dist = G / dist_cubed + + acc[i, 0] += masses[j] * force_over_dist * dx + acc[i, 1] += masses[j] * force_over_dist * dy + acc[i, 2] += masses[j] * force_over_dist * dz + + acc[j, 0] -= masses[i] * force_over_dist * dx + acc[j, 1] -= masses[i] * force_over_dist * dy + acc[j, 2] -= masses[i] * force_over_dist * dz + + for i in range(n_particles): + vel[i, 0] += 0.5 * dt * acc[i, 0] + vel[i, 1] += 0.5 * dt * acc[i, 1] + vel[i, 2] += 0.5 * dt * acc[i, 2] + + for i in range(n_particles): + pos[i, 0] += dt * vel[i, 0] + pos[i, 1] += dt * vel[i, 1] + pos[i, 2] += dt * vel[i, 2] + + for i in range(n_particles): + vel[i, 0] += 0.5 * dt * acc[i, 0] + vel[i, 1] += 0.5 * dt * acc[i, 1] + vel[i, 2] += 0.5 * dt * acc[i, 2] + + return pos, vel + + +def longest_increasing_subsequence_length(arr: np.ndarray) -> int: + n = len(arr) + if n == 0: + return 0 + + dp = np.ones(n, dtype=np.int64) + + for i in range(1, n): + for j in range(i): + if arr[j] < arr[i]: + if dp[j] + 1 > dp[i]: + dp[i] = dp[j] + 1 + + max_length = dp[0] + for i in range(1, n): + if dp[i] > max_length: + max_length = dp[i] + + return max_length + + +def _tridiagonal_forward_step_jax(carry, inputs): + c_prev, d_prev = carry + a_i, b_i, c_i, d_i = inputs + denom = b_i - a_i * c_prev + c_new = c_i / denom + d_new = (d_i - a_i * d_prev) / denom + return (c_new, d_new), (c_new, d_new) + + +def _tridiagonal_back_step_jax(x_next, inputs): + d_prime_i, c_prime_i = inputs + x_i = d_prime_i - c_prime_i * x_next + return x_i, x_i + + +def tridiagonal_solve_jax(a, b, c, d): + n = b.shape[0] + + c_prime_0 = c[0] / b[0] + d_prime_0 = d[0] / b[0] + + scan_inputs = (a[:-1], b[1:-1], c[1:], d[1:-1]) + + _, (c_prime_rest, d_prime_mid) = lax.scan( + _tridiagonal_forward_step_jax, + (c_prime_0, d_prime_0), + scan_inputs + ) + + c_prime = jnp.concatenate([jnp.array([c_prime_0]), c_prime_rest]) + + denom_last = b[n - 1] - a[n - 2] * c_prime[n - 2] + d_prime_last = (d[n - 1] - a[n - 2] * d_prime_mid[-1]) / denom_last + d_prime = jnp.concatenate([jnp.array([d_prime_0]), d_prime_mid, jnp.array([d_prime_last])]) + + x_last = d_prime[n - 1] + _, x_rest = lax.scan( + _tridiagonal_back_step_jax, + x_last, + (d_prime[:-1], c_prime), + reverse=True + ) + + x = jnp.concatenate([x_rest, jnp.array([x_last])]) + return x + + +def _leapfrog_compute_accelerations_jax(pos, masses, softening): + G = 1.0 + diff = pos[jnp.newaxis, :, :] - pos[:, jnp.newaxis, :] + + dist_sq = jnp.sum(diff ** 2, axis=-1) + softening ** 2 + dist = jnp.sqrt(dist_sq) + dist_cubed = dist_sq * dist + + dist_cubed = jnp.where(dist_cubed == 0, 1.0, dist_cubed) + + force_factor = G * masses[jnp.newaxis, :] / dist_cubed + + acc = jnp.sum(force_factor[:, :, jnp.newaxis] * diff, axis=1) + return acc + + +def _leapfrog_step_jax(carry, _, masses, softening, dt): + pos, vel = carry + acc = _leapfrog_compute_accelerations_jax(pos, masses, softening) + + vel = vel + 0.5 * dt * acc + pos = pos + dt * vel + vel = vel + 0.5 * dt * acc + + return (pos, vel), None + + +def leapfrog_integration_jax( + positions, + velocities, + masses, + dt: float, + n_steps: int, + softening: float = 0.01 +): + step_fn = partial(_leapfrog_step_jax, masses=masses, softening=softening, dt=dt) + (final_pos, final_vel), _ = lax.scan(step_fn, (positions, velocities), None, length=n_steps) + return final_pos, final_vel + + +def _lis_inner_body_jax(j, dp_inner, arr, i): + condition = (arr[j] < arr[i]) & (dp_inner[j] + 1 > dp_inner[i]) + new_val = jnp.where(condition, dp_inner[j] + 1, dp_inner[i]) + return dp_inner.at[i].set(new_val) + + +def _lis_outer_body_jax(i, dp, arr): + inner_fn = partial(_lis_inner_body_jax, arr=arr, i=i) + dp = lax.fori_loop(0, i, inner_fn, dp) + return dp + + +def longest_increasing_subsequence_length_jax(arr): + n = arr.shape[0] + + if n == 0: + return 0 + + outer_fn = partial(_lis_outer_body_jax, arr=arr) + dp = jnp.ones(n, dtype=jnp.int32) + dp = lax.fori_loop(1, n, outer_fn, dp) + + return int(jnp.max(dp)) + + +def tridiagonal_solve_torch(a, b, c, d): + device = b.device + dtype = b.dtype + n = b.shape[0] + + c_prime = torch.zeros(n - 1, device=device, dtype=dtype) + d_prime = torch.zeros(n, device=device, dtype=dtype) + x = torch.zeros(n, device=device, dtype=dtype) + + c_prime[0] = c[0] / b[0] + d_prime[0] = d[0] / b[0] + + for i in range(1, n - 1): + denom = b[i] - a[i - 1] * c_prime[i - 1] + c_prime[i] = c[i] / denom + d_prime[i] = (d[i] - a[i - 1] * d_prime[i - 1]) / denom + + denom = b[n - 1] - a[n - 2] * c_prime[n - 2] + d_prime[n - 1] = (d[n - 1] - a[n - 2] * d_prime[n - 2]) / denom + + x[n - 1] = d_prime[n - 1] + for i in range(n - 2, -1, -1): + x[i] = d_prime[i] - c_prime[i] * x[i + 1] + + return x + + +def leapfrog_integration_torch( + positions, + velocities, + masses, + dt: float, + n_steps: int, + softening: float = 0.01 +): + G = 1.0 + + pos = positions.clone() + vel = velocities.clone() + + for _ in range(n_steps): + diff = pos.unsqueeze(0) - pos.unsqueeze(1) + + dist_sq = torch.sum(diff ** 2, dim=-1) + softening ** 2 + dist = torch.sqrt(dist_sq) + dist_cubed = dist_sq * dist + + dist_cubed = torch.where(dist_cubed == 0, torch.ones_like(dist_cubed), dist_cubed) + + force_factor = G * masses.unsqueeze(0) / dist_cubed + + acc = torch.sum(force_factor.unsqueeze(-1) * diff, dim=1) + + vel = vel + 0.5 * dt * acc + pos = pos + dt * vel + vel = vel + 0.5 * dt * acc + + return pos, vel + + +def longest_increasing_subsequence_length_torch(arr): + n = arr.shape[0] + + if n == 0: + return 0 + + device = arr.device + dp = torch.ones(n, device=device, dtype=torch.int64) + + for i in range(1, n): + for j in range(i): + if arr[j] < arr[i]: + if dp[j] + 1 > dp[i]: + dp[i] = dp[j] + 1 + + return int(torch.max(dp).item()) + + +def _tridiagonal_forward_cond_tf(i, _c_prime, _d_prime, n, _a, _b, _c, _d): + return i < n - 1 + + +def _tridiagonal_forward_body_tf(i, c_prime, d_prime, n, a, b, c, d): + c_prev = c_prime[i - 1] + d_prev = d_prime[i - 1] + denom = b[i] - a[i - 1] * c_prev + c_val = c[i] / denom + d_val = (d[i] - a[i - 1] * d_prev) / denom + c_prime = tf.tensor_scatter_nd_update(c_prime, tf.reshape(i, [1, 1]), tf.reshape(c_val, [1])) + d_prime = tf.tensor_scatter_nd_update(d_prime, tf.reshape(i, [1, 1]), tf.reshape(d_val, [1])) + return i + 1, c_prime, d_prime, n, a, b, c, d + + +def _tridiagonal_back_cond_tf(i, _x, _c_prime, _d_prime): + return i >= 0 + + +def _tridiagonal_back_body_tf(i, x, c_prime, d_prime): + x_next = x[i + 1] + x_val = d_prime[i] - c_prime[i] * x_next + x = tf.tensor_scatter_nd_update(x, tf.reshape(i, [1, 1]), tf.reshape(x_val, [1])) + return i - 1, x, c_prime, d_prime + + +def tridiagonal_solve_tf(a, b, c, d): + n = tf.shape(b)[0] + dtype = b.dtype + + c_prime = tf.zeros([n - 1], dtype=dtype) + d_prime = tf.zeros([n], dtype=dtype) + + c_prime = tf.tensor_scatter_nd_update(c_prime, [[0]], tf.reshape(c[0] / b[0], [1])) + d_prime = tf.tensor_scatter_nd_update(d_prime, [[0]], tf.reshape(d[0] / b[0], [1])) + + _, c_prime, d_prime, _, _, _, _, _ = tf.while_loop( + _tridiagonal_forward_cond_tf, + _tridiagonal_forward_body_tf, + [1, c_prime, d_prime, n, a, b, c, d] + ) + + c_last = c_prime[n - 2] + d_prev = d_prime[n - 2] + denom = b[n - 1] - a[n - 2] * c_last + d_last = (d[n - 1] - a[n - 2] * d_prev) / denom + d_prime = tf.tensor_scatter_nd_update(d_prime, tf.reshape(n - 1, [1, 1]), tf.reshape(d_last, [1])) + + x = tf.zeros([n], dtype=dtype) + x = tf.tensor_scatter_nd_update(x, tf.reshape(n - 1, [1, 1]), tf.reshape(d_prime[n - 1], [1])) + + _, x, _, _ = tf.while_loop( + _tridiagonal_back_cond_tf, + _tridiagonal_back_body_tf, + [n - 2, x, c_prime, d_prime] + ) + + return x + + +def _leapfrog_compute_accelerations_tf(pos, masses, softening, G): + diff = tf.expand_dims(pos, 0) - tf.expand_dims(pos, 1) + + dist_sq = tf.reduce_sum(diff ** 2, axis=-1) + softening ** 2 + dist = tf.sqrt(dist_sq) + dist_cubed = dist_sq * dist + + dist_cubed = tf.where(dist_cubed == 0, tf.ones_like(dist_cubed), dist_cubed) + + force_factor = G * tf.expand_dims(masses, 0) / dist_cubed + + acc = tf.reduce_sum(tf.expand_dims(force_factor, -1) * diff, axis=1) + return acc + + +def _leapfrog_step_body_tf(i, pos, vel, masses, softening, dt, n_steps): + G = 1.0 + acc = _leapfrog_compute_accelerations_tf(pos, masses, softening, G) + + vel = vel + 0.5 * dt * acc + pos = pos + dt * vel + vel = vel + 0.5 * dt * acc + + return i + 1, pos, vel, masses, softening, dt, n_steps + + +def _leapfrog_step_cond_tf(i, _pos, _vel, _masses, _softening, _dt, n_steps): + return i < n_steps + + +def leapfrog_integration_tf( + positions, + velocities, + masses, + dt: float, + n_steps: int, + softening: float = 0.01 +): + dt = tf.constant(dt, dtype=positions.dtype) + softening = tf.constant(softening, dtype=positions.dtype) + + _, final_pos, final_vel, _, _, _, _ = tf.while_loop( + _leapfrog_step_cond_tf, + _leapfrog_step_body_tf, + [0, positions, velocities, masses, softening, dt, n_steps] + ) + + return final_pos, final_vel + + +def _lis_inner_body_tf(j, dp_inner, arr, i): + condition = tf.logical_and(arr[j] < arr[i], dp_inner[j] + 1 > dp_inner[i]) + new_val = tf.where(condition, dp_inner[j] + 1, dp_inner[i]) + indices = tf.reshape(i, [1, 1]) + updates = tf.reshape(new_val, [1]) + dp_updated = tf.tensor_scatter_nd_update(dp_inner, indices, updates) + return j + 1, dp_updated, arr, i + + +def _lis_inner_cond_tf(j, _dp_inner, _arr, i): + return j < i + + +def _lis_outer_body_tf(i, dp, arr, n): + _, dp, _, _ = tf.while_loop( + _lis_inner_cond_tf, + _lis_inner_body_tf, + [0, dp, arr, i] + ) + return i + 1, dp, arr, n + + +def _lis_outer_cond_tf(i, _dp, _arr, n): + return i < n + + +def longest_increasing_subsequence_length_tf(arr): + n = tf.shape(arr)[0] + + if n == 0: + return 0 + + dp = tf.ones(n, dtype=tf.int32) + + _, dp, _, _ = tf.while_loop( + _lis_outer_cond_tf, + _lis_outer_body_tf, + [1, dp, arr, n] + ) + + return int(tf.reduce_max(dp)) diff --git a/code_to_optimize/tests/pytest/test_gridmake2.py b/code_to_optimize/tests/pytest/test_gridmake2.py deleted file mode 100644 index ae96a54bb..000000000 --- a/code_to_optimize/tests/pytest/test_gridmake2.py +++ /dev/null @@ -1,216 +0,0 @@ -import numpy as np -import pytest -from numpy.testing import assert_array_equal - -from code_to_optimize.discrete_riccati import _gridmake2 - - -class TestGridmake2With1DArrays: - """Tests for _gridmake2 with two 1D arrays.""" - - def test_basic_two_element_arrays(self): - """Test basic cartesian product of two 2-element arrays.""" - x1 = np.array([1, 2]) - x2 = np.array([3, 4]) - result = _gridmake2(x1, x2) - - # Expected: x1 is tiled len(x2) times, x2 is repeated len(x1) times - expected = np.array([ - [1, 3], - [2, 3], - [1, 4], - [2, 4] - ]) - assert_array_equal(result, expected) - - def test_different_length_arrays(self): - """Test cartesian product with arrays of different lengths.""" - x1 = np.array([1, 2, 3]) - x2 = np.array([10, 20]) - result = _gridmake2(x1, x2) - - # Result should have len(x1) * len(x2) = 6 rows - expected = np.array([ - [1, 10], - [2, 10], - [3, 10], - [1, 20], - [2, 20], - [3, 20] - ]) - assert_array_equal(result, expected) - assert result.shape == (6, 2) - - def test_single_element_arrays(self): - """Test with single-element arrays.""" - x1 = np.array([5]) - x2 = np.array([7]) - result = _gridmake2(x1, x2) - - expected = np.array([[5, 7]]) - assert_array_equal(result, expected) - assert result.shape == (1, 2) - - def test_single_element_with_multi_element(self): - """Test single-element array with multi-element array.""" - x1 = np.array([1]) - x2 = np.array([10, 20, 30]) - result = _gridmake2(x1, x2) - - expected = np.array([ - [1, 10], - [1, 20], - [1, 30] - ]) - assert_array_equal(result, expected) - - def test_float_arrays(self): - """Test with float arrays.""" - x1 = np.array([1.5, 2.5]) - x2 = np.array([0.1, 0.2]) - result = _gridmake2(x1, x2) - - expected = np.array([ - [1.5, 0.1], - [2.5, 0.1], - [1.5, 0.2], - [2.5, 0.2] - ]) - assert_array_equal(result, expected) - - def test_negative_values(self): - """Test with negative values.""" - x1 = np.array([-1, 0, 1]) - x2 = np.array([-10, 10]) - result = _gridmake2(x1, x2) - - expected = np.array([ - [-1, -10], - [0, -10], - [1, -10], - [-1, 10], - [0, 10], - [1, 10] - ]) - assert_array_equal(result, expected) - - def test_result_shape(self): - """Test that result shape is (len(x1)*len(x2), 2).""" - x1 = np.array([1, 2, 3, 4]) - x2 = np.array([5, 6, 7]) - result = _gridmake2(x1, x2) - - assert result.shape == (12, 2) - - def test_larger_arrays(self): - """Test with larger arrays.""" - x1 = np.arange(10) - x2 = np.arange(5) - result = _gridmake2(x1, x2) - - assert result.shape == (50, 2) - # Verify first column is x1 tiled 5 times - assert_array_equal(result[:10, 0], x1) - assert_array_equal(result[10:20, 0], x1) - # Verify second column is x2 repeated 10 times each - assert all(result[:10, 1] == 0) - assert all(result[10:20, 1] == 1) - - -class TestGridmake2With2DFirst: - """Tests for _gridmake2 when x1 is 2D and x2 is 1D.""" - - def test_2d_first_1d_second(self): - """Test with 2D first array and 1D second array.""" - x1 = np.array([[1, 2], [3, 4]]) # 2 rows, 2 cols - x2 = np.array([10, 20]) - result = _gridmake2(x1, x2) - - # x1 is tiled len(x2) times vertically - # x2 is repeated len(x1) times (2 rows) - expected = np.array([ - [1, 2, 10], - [3, 4, 10], - [1, 2, 20], - [3, 4, 20] - ]) - assert_array_equal(result, expected) - - def test_2d_single_column(self): - """Test with 2D array having single column.""" - x1 = np.array([[1], [2], [3]]) # 3 rows, 1 col - x2 = np.array([10, 20]) - result = _gridmake2(x1, x2) - - expected = np.array([ - [1, 10], - [2, 10], - [3, 10], - [1, 20], - [2, 20], - [3, 20] - ]) - assert_array_equal(result, expected) - - def test_2d_multiple_columns(self): - """Test with 2D array having multiple columns.""" - x1 = np.array([[1, 2, 3], [4, 5, 6]]) # 2 rows, 3 cols - x2 = np.array([100]) - result = _gridmake2(x1, x2) - - expected = np.array([ - [1, 2, 3, 100], - [4, 5, 6, 100] - ]) - assert_array_equal(result, expected) - - -class TestGridmake2EdgeCases: - """Edge case tests for _gridmake2.""" - - def test_empty_arrays_raise_or_return_empty(self): - """Test behavior with empty arrays.""" - x1 = np.array([]) - x2 = np.array([1, 2]) - result = _gridmake2(x1, x2) - # Empty x1 should result in empty output - assert result.shape[0] == 0 - - def test_both_empty_arrays(self): - """Test with both empty arrays.""" - x1 = np.array([]) - x2 = np.array([]) - result = _gridmake2(x1, x2) - assert result.shape[0] == 0 - - def test_integer_dtype_preserved(self): - """Test that integer dtype is handled correctly.""" - x1 = np.array([1, 2], dtype=np.int64) - x2 = np.array([3, 4], dtype=np.int64) - result = _gridmake2(x1, x2) - assert result.dtype == np.int64 - - def test_float_dtype_preserved(self): - """Test that float dtype is handled correctly.""" - x1 = np.array([1.0, 2.0], dtype=np.float64) - x2 = np.array([3.0, 4.0], dtype=np.float64) - result = _gridmake2(x1, x2) - assert result.dtype == np.float64 - - -class TestGridmake2NotImplemented: - """Tests for NotImplementedError cases.""" - - def test_both_2d_raises(self): - """Test that two 2D arrays raises NotImplementedError.""" - x1 = np.array([[1, 2], [3, 4]]) - x2 = np.array([[5, 6], [7, 8]]) - with pytest.raises(NotImplementedError): - _gridmake2(x1, x2) - - def test_1d_first_2d_second_raises(self): - """Test that 1D first and 2D second raises NotImplementedError.""" - x1 = np.array([1, 2]) - x2 = np.array([[5, 6], [7, 8]]) - with pytest.raises(NotImplementedError): - _gridmake2(x1, x2) diff --git a/code_to_optimize/tests/pytest/test_gridmake2_torch.py b/code_to_optimize/tests/pytest/test_gridmake2_torch.py deleted file mode 100644 index f2ee737a2..000000000 --- a/code_to_optimize/tests/pytest/test_gridmake2_torch.py +++ /dev/null @@ -1,267 +0,0 @@ -import pytest -import torch - -from code_to_optimize.discrete_riccati import _gridmake2_torch - - -class TestGridmake2TorchCPU: - """Tests for _gridmake2_torch with CPU tensors.""" - - def test_both_1d_simple(self): - """Test with two simple 1D tensors.""" - x1 = torch.tensor([1, 2, 3]) - x2 = torch.tensor([10, 20]) - - result = _gridmake2_torch(x1, x2) - - # Expected: x1 tiled x2.shape[0] times, x2 repeat_interleaved x1.shape[0] - # x1 tiled: [1, 2, 3, 1, 2, 3] - # x2 repeated: [10, 10, 10, 20, 20, 20] - expected = torch.tensor([ - [1, 10], - [2, 10], - [3, 10], - [1, 20], - [2, 20], - [3, 20], - ]) - assert torch.equal(result, expected) - - def test_both_1d_single_element(self): - """Test with single element tensors.""" - x1 = torch.tensor([5]) - x2 = torch.tensor([10]) - - result = _gridmake2_torch(x1, x2) - - expected = torch.tensor([[5, 10]]) - assert torch.equal(result, expected) - - def test_both_1d_float_tensors(self): - """Test with float tensors.""" - x1 = torch.tensor([1.5, 2.5]) - x2 = torch.tensor([0.1, 0.2, 0.3]) - - result = _gridmake2_torch(x1, x2) - - assert result.shape == (6, 2) - assert result.dtype == torch.float32 - - def test_2d_and_1d_simple(self): - """Test with 2D x1 and 1D x2.""" - x1 = torch.tensor([[1, 2], [3, 4]]) - x2 = torch.tensor([10, 20]) - - result = _gridmake2_torch(x1, x2) - - # x1 tiled along first dim: [[1, 2], [3, 4], [1, 2], [3, 4]] - # x2 repeated: [10, 10, 20, 20] - # column_stack: [[1, 2, 10], [3, 4, 10], [1, 2, 20], [3, 4, 20]] - expected = torch.tensor([ - [1, 2, 10], - [3, 4, 10], - [1, 2, 20], - [3, 4, 20], - ]) - assert torch.equal(result, expected) - - def test_2d_and_1d_single_column(self): - """Test with 2D x1 having a single column and 1D x2.""" - x1 = torch.tensor([[1], [2], [3]]) - x2 = torch.tensor([10, 20]) - - result = _gridmake2_torch(x1, x2) - - expected = torch.tensor([ - [1, 10], - [2, 10], - [3, 10], - [1, 20], - [2, 20], - [3, 20], - ]) - assert torch.equal(result, expected) - - def test_output_shape_1d_1d(self): - """Test output shape for two 1D tensors.""" - x1 = torch.tensor([1, 2, 3, 4, 5]) - x2 = torch.tensor([10, 20, 30]) - - result = _gridmake2_torch(x1, x2) - - # Shape should be (len(x1) * len(x2), 2) - assert result.shape == (15, 2) - - def test_output_shape_2d_1d(self): - """Test output shape for 2D and 1D tensors.""" - x1 = torch.tensor([[1, 2, 3], [4, 5, 6]]) # Shape (2, 3) - x2 = torch.tensor([10, 20, 30, 40]) # Shape (4,) - - result = _gridmake2_torch(x1, x2) - - # Shape should be (2 * 4, 3 + 1) = (8, 4) - assert result.shape == (8, 4) - - def test_not_implemented_for_2d_2d(self): - """Test that NotImplementedError is raised for two 2D tensors.""" - x1 = torch.tensor([[1, 2], [3, 4]]) - x2 = torch.tensor([[10, 20], [30, 40]]) - - with pytest.raises(NotImplementedError, match="Come back here"): - _gridmake2_torch(x1, x2) - - def test_not_implemented_for_1d_2d(self): - """Test that NotImplementedError is raised for 1D and 2D tensors.""" - x1 = torch.tensor([1, 2, 3]) - x2 = torch.tensor([[10, 20], [30, 40]]) - - with pytest.raises(NotImplementedError, match="Come back here"): - _gridmake2_torch(x1, x2) - - def test_preserves_dtype_int(self): - """Test that integer dtype is preserved.""" - x1 = torch.tensor([1, 2, 3], dtype=torch.int32) - x2 = torch.tensor([10, 20], dtype=torch.int32) - - result = _gridmake2_torch(x1, x2) - - assert result.dtype == torch.int32 - - def test_preserves_dtype_float64(self): - """Test that float64 dtype is preserved.""" - x1 = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float64) - x2 = torch.tensor([10.0, 20.0], dtype=torch.float64) - - result = _gridmake2_torch(x1, x2) - - assert result.dtype == torch.float64 - - def test_large_tensors(self): - """Test with larger tensors.""" - x1 = torch.arange(100) - x2 = torch.arange(50) - - result = _gridmake2_torch(x1, x2) - - assert result.shape == (5000, 2) - # Verify first and last elements - assert result[0, 0] == 0 and result[0, 1] == 0 - assert result[-1, 0] == 99 and result[-1, 1] == 49 - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -class TestGridmake2TorchCUDA: - """Tests for _gridmake2_torch with CUDA tensors.""" - - def test_both_1d_simple_cuda(self): - """Test with two simple 1D CUDA tensors.""" - x1 = torch.tensor([1, 2, 3], device="cuda") - x2 = torch.tensor([10, 20], device="cuda") - - result = _gridmake2_torch(x1, x2) - - expected = torch.tensor([ - [1, 10], - [2, 10], - [3, 10], - [1, 20], - [2, 20], - [3, 20], - ], device="cuda") - assert result.device.type == "cuda" - assert torch.equal(result, expected) - - def test_both_1d_matches_cpu(self): - """Test that CUDA version matches CPU version.""" - x1_cpu = torch.tensor([1.0, 2.0, 3.0, 4.0]) - x2_cpu = torch.tensor([10.0, 20.0, 30.0]) - - x1_cuda = x1_cpu.cuda() - x2_cuda = x2_cpu.cuda() - - result_cpu = _gridmake2_torch(x1_cpu, x2_cpu) - result_cuda = _gridmake2_torch(x1_cuda, x2_cuda) - - assert result_cuda.device.type == "cuda" - torch.testing.assert_close(result_cpu, result_cuda.cpu()) - - def test_2d_and_1d_cuda(self): - """Test with 2D x1 and 1D x2 on CUDA.""" - x1 = torch.tensor([[1, 2], [3, 4]], device="cuda") - x2 = torch.tensor([10, 20], device="cuda") - - result = _gridmake2_torch(x1, x2) - - expected = torch.tensor([ - [1, 2, 10], - [3, 4, 10], - [1, 2, 20], - [3, 4, 20], - ], device="cuda") - assert result.device.type == "cuda" - assert torch.equal(result, expected) - - def test_2d_and_1d_matches_cpu(self): - """Test that CUDA version matches CPU version for 2D, 1D inputs.""" - x1_cpu = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) - x2_cpu = torch.tensor([10.0, 20.0]) - - x1_cuda = x1_cpu.cuda() - x2_cuda = x2_cpu.cuda() - - result_cpu = _gridmake2_torch(x1_cpu, x2_cpu) - result_cuda = _gridmake2_torch(x1_cuda, x2_cuda) - - assert result_cuda.device.type == "cuda" - torch.testing.assert_close(result_cpu, result_cuda.cpu()) - - def test_output_stays_on_cuda(self): - """Test that output tensor stays on CUDA device.""" - x1 = torch.tensor([1, 2, 3], device="cuda") - x2 = torch.tensor([10, 20], device="cuda") - - result = _gridmake2_torch(x1, x2) - - assert result.is_cuda - - def test_preserves_dtype_float32_cuda(self): - """Test that float32 dtype is preserved on CUDA.""" - x1 = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32, device="cuda") - x2 = torch.tensor([10.0, 20.0], dtype=torch.float32, device="cuda") - - result = _gridmake2_torch(x1, x2) - - assert result.dtype == torch.float32 - assert result.device.type == "cuda" - - def test_preserves_dtype_float64_cuda(self): - """Test that float64 dtype is preserved on CUDA.""" - x1 = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float64, device="cuda") - x2 = torch.tensor([10.0, 20.0], dtype=torch.float64, device="cuda") - - result = _gridmake2_torch(x1, x2) - - assert result.dtype == torch.float64 - assert result.device.type == "cuda" - - def test_large_tensors_cuda(self): - """Test with larger tensors on CUDA.""" - x1 = torch.arange(100, device="cuda") - x2 = torch.arange(50, device="cuda") - - result = _gridmake2_torch(x1, x2) - - assert result.shape == (5000, 2) - assert result.device.type == "cuda" - # Verify first and last elements - assert result[0, 0].item() == 0 and result[0, 1].item() == 0 - assert result[-1, 0].item() == 99 and result[-1, 1].item() == 49 - - def test_not_implemented_for_2d_2d_cuda(self): - """Test that NotImplementedError is raised for two 2D CUDA tensors.""" - x1 = torch.tensor([[1, 2], [3, 4]], device="cuda") - x2 = torch.tensor([[10, 20], [30, 40]], device="cuda") - - with pytest.raises(NotImplementedError, match="Come back here"): - _gridmake2_torch(x1, x2) - diff --git a/code_to_optimize/tests/pytest/test_jax_jit_code.py b/code_to_optimize/tests/pytest/test_jax_jit_code.py new file mode 100644 index 000000000..5968bc0e2 --- /dev/null +++ b/code_to_optimize/tests/pytest/test_jax_jit_code.py @@ -0,0 +1,256 @@ +""" +Unit tests for JAX implementations of JIT-suitable functions. + +Tests run on CPU and CUDA devices. +""" + +import numpy as np +import pytest + +jax = pytest.importorskip("jax") +import jax.numpy as jnp + +from code_to_optimize.sample_jit_code import ( + leapfrog_integration_jax, + longest_increasing_subsequence_length_jax, + tridiagonal_solve_jax, +) + + +def get_available_devices(): + """Return list of available JAX devices for testing.""" + devices = [] + + # CPU is always available + devices.append("cpu") + + # Check for CUDA/GPU + try: + gpu_devices = jax.devices("gpu") + if gpu_devices: + devices.append("cuda") + except RuntimeError: + pass + + return devices + + +DEVICES = get_available_devices() + + +def to_device(arr, device): + """Move a JAX array to the specified device.""" + if device == "cpu": + return jax.device_put(arr, jax.devices("cpu")[0]) + elif device == "cuda": + return jax.device_put(arr, jax.devices("gpu")[0]) + return arr + + +class TestTridiagonalSolveJax: + """Tests for the JAX tridiagonal_solve function.""" + + @pytest.mark.parametrize("device", DEVICES) + def test_simple_system(self, device): + """Test a simple 3x3 tridiagonal system with known solution.""" + a = jnp.array([-1.0, -1.0]) + b = jnp.array([2.0, 2.0, 2.0]) + c = jnp.array([-1.0, -1.0]) + d = jnp.array([1.0, 0.0, 1.0]) + + a, b, c, d = to_device(a, device), to_device(b, device), to_device(c, device), to_device(d, device) + + x = tridiagonal_solve_jax(a, b, c, d) + + # Verify solution by multiplying back + result = jnp.zeros(3) + result = result.at[0].set(b[0] * x[0] + c[0] * x[1]) + result = result.at[1].set(a[0] * x[0] + b[1] * x[1] + c[1] * x[2]) + result = result.at[2].set(a[1] * x[1] + b[2] * x[2]) + + np.testing.assert_array_almost_equal(np.array(result), np.array(d), decimal=5) + + @pytest.mark.parametrize("device", DEVICES) + def test_diagonal_system(self, device): + """Test a purely diagonal system.""" + a = jnp.array([0.0, 0.0]) + b = jnp.array([2.0, 3.0, 4.0]) + c = jnp.array([0.0, 0.0]) + d = jnp.array([4.0, 9.0, 16.0]) + + a, b, c, d = to_device(a, device), to_device(b, device), to_device(c, device), to_device(d, device) + + x = tridiagonal_solve_jax(a, b, c, d) + + expected = jnp.array([2.0, 3.0, 4.0]) + np.testing.assert_array_almost_equal(np.array(x), np.array(expected), decimal=5) + + @pytest.mark.parametrize("device", DEVICES) + def test_larger_system(self, device): + """Test a larger tridiagonal system.""" + n = 50 + a = -jnp.ones(n - 1) + b = 2.0 * jnp.ones(n) + c = -jnp.ones(n - 1) + d = jnp.zeros(n).at[0].set(1.0).at[-1].set(1.0) + + a, b, c, d = to_device(a, device), to_device(b, device), to_device(c, device), to_device(d, device) + + x = tridiagonal_solve_jax(a, b, c, d) + + # Verify by reconstructing Ax + result = jnp.zeros(n) + result = result.at[0].set(b[0] * x[0] + c[0] * x[1]) + for i in range(1, n - 1): + result = result.at[i].set(a[i - 1] * x[i - 1] + b[i] * x[i] + c[i] * x[i + 1]) + result = result.at[-1].set(a[-1] * x[-2] + b[-1] * x[-1]) + + np.testing.assert_array_almost_equal(np.array(result), np.array(d), decimal=5) + + +class TestLeapfrogIntegrationJax: + """Tests for the JAX leapfrog_integration function.""" + + @pytest.mark.parametrize("device", DEVICES) + def test_single_stationary_particle(self, device): + """A single particle with no velocity should remain stationary.""" + positions = jnp.array([[0.0, 0.0, 0.0]]) + velocities = jnp.array([[0.0, 0.0, 0.0]]) + masses = jnp.array([1.0]) + + positions = to_device(positions, device) + velocities = to_device(velocities, device) + masses = to_device(masses, device) + + final_pos, final_vel = leapfrog_integration_jax( + positions, velocities, masses, dt=0.01, n_steps=100 + ) + + np.testing.assert_array_almost_equal(np.array(final_pos), np.array(positions), decimal=5) + np.testing.assert_array_almost_equal(np.array(final_vel), np.array(velocities), decimal=5) + + @pytest.mark.parametrize("device", DEVICES) + def test_single_moving_particle(self, device): + """A single moving particle should move in a straight line.""" + positions = jnp.array([[0.0, 0.0, 0.0]]) + velocities = jnp.array([[1.0, 0.0, 0.0]]) + masses = jnp.array([1.0]) + + positions = to_device(positions, device) + velocities = to_device(velocities, device) + masses = to_device(masses, device) + + dt = 0.01 + n_steps = 100 + + final_pos, final_vel = leapfrog_integration_jax( + positions, velocities, masses, dt=dt, n_steps=n_steps + ) + + np.testing.assert_array_almost_equal(np.array(final_vel), np.array(velocities), decimal=5) + expected_pos = jnp.array([[dt * n_steps, 0.0, 0.0]]) + np.testing.assert_array_almost_equal(np.array(final_pos), np.array(expected_pos), decimal=5) + + @pytest.mark.parametrize("device", DEVICES) + def test_two_particles_approach(self, device): + """Two particles should attract each other gravitationally.""" + positions = jnp.array([[-1.0, 0.0, 0.0], [1.0, 0.0, 0.0]]) + velocities = jnp.zeros((2, 3)) + masses = jnp.array([1.0, 1.0]) + + positions = to_device(positions, device) + velocities = to_device(velocities, device) + masses = to_device(masses, device) + + final_pos, _ = leapfrog_integration_jax( + positions, velocities, masses, dt=0.01, n_steps=50, softening=0.1 + ) + + initial_distance = 2.0 + final_distance = float(jnp.linalg.norm(final_pos[1] - final_pos[0])) + assert final_distance < initial_distance + + @pytest.mark.parametrize("device", DEVICES) + def test_momentum_conservation(self, device): + """Total momentum should be approximately conserved.""" + np.random.seed(42) + n_particles = 5 + positions = jnp.array(np.random.randn(n_particles, 3)) + velocities = jnp.array(np.random.randn(n_particles, 3)) + masses = jnp.array(np.abs(np.random.randn(n_particles)) + 0.1) + + positions = to_device(positions, device) + velocities = to_device(velocities, device) + masses = to_device(masses, device) + + initial_momentum = jnp.sum(masses[:, jnp.newaxis] * velocities, axis=0) + + final_pos, final_vel = leapfrog_integration_jax( + positions, velocities, masses, dt=0.001, n_steps=100, softening=0.5 + ) + + final_momentum = jnp.sum(masses[:, jnp.newaxis] * final_vel, axis=0) + + np.testing.assert_array_almost_equal( + np.array(initial_momentum), np.array(final_momentum), decimal=4 + ) + + +class TestLongestIncreasingSubsequenceLengthJax: + """Tests for the JAX longest_increasing_subsequence_length function.""" + + @pytest.mark.parametrize("device", DEVICES) + def test_empty_array(self, device): + """Empty array should return 0.""" + arr = jnp.array([], dtype=jnp.float32) + arr = to_device(arr, device) + assert longest_increasing_subsequence_length_jax(arr) == 0 + + @pytest.mark.parametrize("device", DEVICES) + def test_single_element(self, device): + """Single element array should return 1.""" + arr = jnp.array([5.0]) + arr = to_device(arr, device) + assert longest_increasing_subsequence_length_jax(arr) == 1 + + @pytest.mark.parametrize("device", DEVICES) + def test_strictly_increasing(self, device): + """Strictly increasing array - LIS is the whole array.""" + arr = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0]) + arr = to_device(arr, device) + assert longest_increasing_subsequence_length_jax(arr) == 5 + + @pytest.mark.parametrize("device", DEVICES) + def test_strictly_decreasing(self, device): + """Strictly decreasing array - LIS is length 1.""" + arr = jnp.array([5.0, 4.0, 3.0, 2.0, 1.0]) + arr = to_device(arr, device) + assert longest_increasing_subsequence_length_jax(arr) == 1 + + @pytest.mark.parametrize("device", DEVICES) + def test_classic_example(self, device): + """Classic LIS example.""" + arr = jnp.array([10.0, 9.0, 2.0, 5.0, 3.0, 7.0, 101.0, 18.0]) + arr = to_device(arr, device) + assert longest_increasing_subsequence_length_jax(arr) == 4 + + @pytest.mark.parametrize("device", DEVICES) + def test_all_same_elements(self, device): + """All same elements - LIS is length 1.""" + arr = jnp.array([5.0, 5.0, 5.0, 5.0, 5.0]) + arr = to_device(arr, device) + assert longest_increasing_subsequence_length_jax(arr) == 1 + + @pytest.mark.parametrize("device", DEVICES) + def test_alternating_sequence(self, device): + """Alternating high-low sequence.""" + arr = jnp.array([1.0, 10.0, 2.0, 9.0, 3.0, 8.0, 4.0, 7.0]) + arr = to_device(arr, device) + assert longest_increasing_subsequence_length_jax(arr) == 5 + + @pytest.mark.parametrize("device", DEVICES) + def test_longer_sequence(self, device): + """Test with a longer sequence.""" + arr = jnp.array([0.0, 8.0, 4.0, 12.0, 2.0, 10.0, 6.0, 14.0, 1.0, 9.0, 5.0, 13.0, 3.0, 11.0, 7.0, 15.0]) + arr = to_device(arr, device) + assert longest_increasing_subsequence_length_jax(arr) == 6 \ No newline at end of file diff --git a/code_to_optimize/tests/pytest/test_numba_jit_code.py b/code_to_optimize/tests/pytest/test_numba_jit_code.py new file mode 100644 index 000000000..06738c39b --- /dev/null +++ b/code_to_optimize/tests/pytest/test_numba_jit_code.py @@ -0,0 +1,242 @@ +import numpy as np +import pytest + +from code_to_optimize.sample_jit_code import ( + leapfrog_integration, + longest_increasing_subsequence_length, + tridiagonal_solve, +) + + +class TestTridiagonalSolve: + """Tests for the tridiagonal_solve function (Thomas algorithm).""" + + def test_simple_system(self): + """Test a simple 3x3 tridiagonal system with known solution.""" + # System: [2 -1 0] [x0] [1] + # [-1 2 -1] [x1] = [0] + # [0 -1 2] [x2] [1] + a = np.array([-1.0, -1.0]) # lower diagonal + b = np.array([2.0, 2.0, 2.0]) # main diagonal + c = np.array([-1.0, -1.0]) # upper diagonal + d = np.array([1.0, 0.0, 1.0]) # right-hand side + + x = tridiagonal_solve(a, b, c, d) + + # Verify solution by multiplying back + # Ax should equal d + result = np.zeros(3) + result[0] = b[0] * x[0] + c[0] * x[1] + result[1] = a[0] * x[0] + b[1] * x[1] + c[1] * x[2] + result[2] = a[1] * x[1] + b[2] * x[2] + + np.testing.assert_array_almost_equal(result, d) + + def test_diagonal_system(self): + """Test a purely diagonal system (a and c are zero).""" + a = np.array([0.0, 0.0]) + b = np.array([2.0, 3.0, 4.0]) + c = np.array([0.0, 0.0]) + d = np.array([4.0, 9.0, 16.0]) + + x = tridiagonal_solve(a, b, c, d) + + expected = np.array([2.0, 3.0, 4.0]) + np.testing.assert_array_almost_equal(x, expected) + + def test_larger_system(self): + """Test a larger tridiagonal system.""" + n = 100 + a = -np.ones(n - 1) + b = 2.0 * np.ones(n) + c = -np.ones(n - 1) + d = np.zeros(n) + d[0] = 1.0 + d[-1] = 1.0 + + x = tridiagonal_solve(a, b, c, d) + + # Verify by reconstructing Ax + result = np.zeros(n) + result[0] = b[0] * x[0] + c[0] * x[1] + for i in range(1, n - 1): + result[i] = a[i - 1] * x[i - 1] + b[i] * x[i] + c[i] * x[i + 1] + result[-1] = a[-1] * x[-2] + b[-1] * x[-1] + + np.testing.assert_array_almost_equal(result, d, decimal=10) + + def test_two_element_system(self): + """Test minimal 2x2 tridiagonal system.""" + a = np.array([1.0]) + b = np.array([4.0, 4.0]) + c = np.array([1.0]) + d = np.array([5.0, 5.0]) + + x = tridiagonal_solve(a, b, c, d) + + # Verify: [4 1] [x0] = [5] + # [1 4] [x1] [5] + result = np.array([ + b[0] * x[0] + c[0] * x[1], + a[0] * x[0] + b[1] * x[1] + ]) + np.testing.assert_array_almost_equal(result, d) + + +class TestLeapfrogIntegration: + """Tests for the leapfrog_integration function (N-body simulation).""" + + def test_single_stationary_particle(self): + """A single particle with no velocity should remain stationary.""" + positions = np.array([[0.0, 0.0, 0.0]]) + velocities = np.array([[0.0, 0.0, 0.0]]) + masses = np.array([1.0]) + + final_pos, final_vel = leapfrog_integration( + positions, velocities, masses, dt=0.01, n_steps=100 + ) + + np.testing.assert_array_almost_equal(final_pos, positions) + np.testing.assert_array_almost_equal(final_vel, velocities) + + def test_single_moving_particle(self): + """A single moving particle should move in a straight line.""" + positions = np.array([[0.0, 0.0, 0.0]]) + velocities = np.array([[1.0, 0.0, 0.0]]) + masses = np.array([1.0]) + + dt = 0.01 + n_steps = 100 + + final_pos, final_vel = leapfrog_integration( + positions, velocities, masses, dt=dt, n_steps=n_steps + ) + + # With no other particles, velocity should remain constant + np.testing.assert_array_almost_equal(final_vel, velocities) + + # Position should be initial + velocity * time + expected_pos = np.array([[dt * n_steps, 0.0, 0.0]]) + np.testing.assert_array_almost_equal(final_pos, expected_pos) + + def test_two_particles_approach(self): + """Two particles should attract each other gravitationally.""" + positions = np.array([ + [-1.0, 0.0, 0.0], + [1.0, 0.0, 0.0] + ]) + velocities = np.zeros((2, 3)) + masses = np.array([1.0, 1.0]) + + final_pos, final_vel = leapfrog_integration( + positions, velocities, masses, dt=0.01, n_steps=50, softening=0.1 + ) + + # Particles should move closer together + initial_distance = 2.0 + final_distance = np.linalg.norm(final_pos[1] - final_pos[0]) + assert final_distance < initial_distance + + def test_momentum_conservation(self): + """Total momentum should be approximately conserved.""" + np.random.seed(42) + n_particles = 5 + positions = np.random.randn(n_particles, 3) + velocities = np.random.randn(n_particles, 3) + masses = np.abs(np.random.randn(n_particles)) + 0.1 + + initial_momentum = np.sum(masses[:, np.newaxis] * velocities, axis=0) + + final_pos, final_vel = leapfrog_integration( + positions, velocities, masses, dt=0.001, n_steps=100, softening=0.5 + ) + + final_momentum = np.sum(masses[:, np.newaxis] * final_vel, axis=0) + + # Momentum should be conserved to good precision + np.testing.assert_array_almost_equal( + initial_momentum, final_momentum, decimal=5 + ) + + def test_does_not_modify_input(self): + """Input arrays should not be modified.""" + positions = np.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]]) + velocities = np.array([[0.1, 0.0, 0.0], [-0.1, 0.0, 0.0]]) + masses = np.array([1.0, 1.0]) + + pos_copy = positions.copy() + vel_copy = velocities.copy() + + leapfrog_integration(positions, velocities, masses, dt=0.01, n_steps=10) + + np.testing.assert_array_equal(positions, pos_copy) + np.testing.assert_array_equal(velocities, vel_copy) + + +class TestLongestIncreasingSubsequenceLength: + """Tests for the longest_increasing_subsequence_length function.""" + + def test_empty_array(self): + """Empty array should return 0.""" + arr = np.array([], dtype=np.float64) + assert longest_increasing_subsequence_length(arr) == 0 + + def test_single_element(self): + """Single element array should return 1.""" + arr = np.array([5]) + assert longest_increasing_subsequence_length(arr) == 1 + + def test_strictly_increasing(self): + """Strictly increasing array - LIS is the whole array.""" + arr = np.array([1, 2, 3, 4, 5]) + assert longest_increasing_subsequence_length(arr) == 5 + + def test_strictly_decreasing(self): + """Strictly decreasing array - LIS is length 1.""" + arr = np.array([5, 4, 3, 2, 1]) + assert longest_increasing_subsequence_length(arr) == 1 + + def test_classic_example(self): + """Classic LIS example: [10, 9, 2, 5, 3, 7, 101, 18].""" + arr = np.array([10, 9, 2, 5, 3, 7, 101, 18]) + # LIS: [2, 3, 7, 101] or [2, 5, 7, 101] or [2, 3, 7, 18] etc. + assert longest_increasing_subsequence_length(arr) == 4 + + def test_all_same_elements(self): + """All same elements - LIS is length 1 (strictly increasing).""" + arr = np.array([5, 5, 5, 5, 5]) + assert longest_increasing_subsequence_length(arr) == 1 + + def test_alternating_sequence(self): + """Alternating high-low sequence.""" + arr = np.array([1, 10, 2, 9, 3, 8, 4, 7]) + # LIS: [1, 2, 3, 4] or [1, 2, 3, 4, 7] - length 5 + assert longest_increasing_subsequence_length(arr) == 5 + + def test_two_elements_increasing(self): + """Two elements in increasing order.""" + arr = np.array([1, 2]) + assert longest_increasing_subsequence_length(arr) == 2 + + def test_two_elements_decreasing(self): + """Two elements in decreasing order.""" + arr = np.array([2, 1]) + assert longest_increasing_subsequence_length(arr) == 1 + + def test_longer_sequence(self): + """Test with a longer sequence.""" + arr = np.array([0, 8, 4, 12, 2, 10, 6, 14, 1, 9, 5, 13, 3, 11, 7, 15]) + # Known LIS length for this sequence is 6 + assert longest_increasing_subsequence_length(arr) == 6 + + def test_negative_numbers(self): + """Test with negative numbers.""" + arr = np.array([-5, -2, -8, -1, -6, 0]) + # LIS: [-5, -2, -1, 0] or [-8, -6, 0] etc. - length 4 + assert longest_increasing_subsequence_length(arr) == 4 + + def test_float_values(self): + """Test with floating point values.""" + arr = np.array([1.5, 2.3, 1.8, 3.1, 2.9, 4.0]) + # LIS: [1.5, 2.3, 3.1, 4.0] or [1.5, 1.8, 2.9, 4.0] - length 4 + assert longest_increasing_subsequence_length(arr) == 4 \ No newline at end of file diff --git a/code_to_optimize/tests/pytest/test_tensorflow_jit_code.py b/code_to_optimize/tests/pytest/test_tensorflow_jit_code.py new file mode 100644 index 000000000..3a44d24f5 --- /dev/null +++ b/code_to_optimize/tests/pytest/test_tensorflow_jit_code.py @@ -0,0 +1,296 @@ +""" +Unit tests for TensorFlow implementations of JIT-suitable functions. + +Tests run on CPU and CUDA devices. +""" + +import numpy as np +import pytest + +tf = pytest.importorskip("tensorflow") + +from code_to_optimize.sample_jit_code import ( + leapfrog_integration_tf, + longest_increasing_subsequence_length_tf, + tridiagonal_solve_tf, +) + + +def get_available_devices(): + """Return list of available TensorFlow devices for testing.""" + devices = ["cpu"] + + # Check for CUDA/GPU + gpus = tf.config.list_physical_devices("GPU") + if gpus: + devices.append("cuda") + + return devices + + +DEVICES = get_available_devices() + + +def run_on_device(func, device, *args, **kwargs): + """Run a function on the specified device.""" + if device == "cpu": + device_name = "/CPU:0" + elif device == "cuda": + device_name = "/GPU:0" + else: + device_name = "/CPU:0" + + with tf.device(device_name): + return func(*args, **kwargs) + + +def to_tensor(arr, device, dtype=tf.float64): + """Create a tensor on the specified device.""" + if device == "cpu": + device_name = "/CPU:0" + elif device == "cuda": + device_name = "/GPU:0" + else: + device_name = "/CPU:0" + + with tf.device(device_name): + return tf.constant(arr, dtype=dtype) + + +class TestTridiagonalSolveTf: + """Tests for the TensorFlow tridiagonal_solve function.""" + + @pytest.mark.parametrize("device", DEVICES) + def test_simple_system(self, device): + """Test a simple 3x3 tridiagonal system with known solution.""" + a = to_tensor([-1.0, -1.0], device) + b = to_tensor([2.0, 2.0, 2.0], device) + c = to_tensor([-1.0, -1.0], device) + d = to_tensor([1.0, 0.0, 1.0], device) + + x = run_on_device(tridiagonal_solve_tf, device, a, b, c, d) + + # Verify solution by multiplying back + result = np.zeros(3) + x_np = x.numpy() + b_np = b.numpy() + c_np = c.numpy() + a_np = a.numpy() + result[0] = b_np[0] * x_np[0] + c_np[0] * x_np[1] + result[1] = a_np[0] * x_np[0] + b_np[1] * x_np[1] + c_np[1] * x_np[2] + result[2] = a_np[1] * x_np[1] + b_np[2] * x_np[2] + + np.testing.assert_array_almost_equal(result, d.numpy(), decimal=5) + + @pytest.mark.parametrize("device", DEVICES) + def test_diagonal_system(self, device): + """Test a purely diagonal system.""" + a = to_tensor([0.0, 0.0], device) + b = to_tensor([2.0, 3.0, 4.0], device) + c = to_tensor([0.0, 0.0], device) + d = to_tensor([4.0, 9.0, 16.0], device) + + x = run_on_device(tridiagonal_solve_tf, device, a, b, c, d) + + expected = np.array([2.0, 3.0, 4.0]) + np.testing.assert_array_almost_equal(x.numpy(), expected, decimal=5) + + @pytest.mark.parametrize("device", DEVICES) + def test_larger_system(self, device): + """Test a larger tridiagonal system.""" + n = 50 + a_np = -np.ones(n - 1) + b_np = 2.0 * np.ones(n) + c_np = -np.ones(n - 1) + d_np = np.zeros(n) + d_np[0] = 1.0 + d_np[-1] = 1.0 + + a = to_tensor(a_np, device) + b = to_tensor(b_np, device) + c = to_tensor(c_np, device) + d = to_tensor(d_np, device) + + x = run_on_device(tridiagonal_solve_tf, device, a, b, c, d) + x_np = x.numpy() + + # Verify by reconstructing Ax + result = np.zeros(n) + result[0] = b_np[0] * x_np[0] + c_np[0] * x_np[1] + for i in range(1, n - 1): + result[i] = a_np[i - 1] * x_np[i - 1] + b_np[i] * x_np[i] + c_np[i] * x_np[i + 1] + result[-1] = a_np[-1] * x_np[-2] + b_np[-1] * x_np[-1] + + np.testing.assert_array_almost_equal(result, d_np, decimal=5) + + @pytest.mark.parametrize("device", DEVICES) + def test_two_element_system(self, device): + """Test minimal 2x2 tridiagonal system.""" + a = to_tensor([1.0], device) + b = to_tensor([4.0, 4.0], device) + c = to_tensor([1.0], device) + d = to_tensor([5.0, 5.0], device) + + x = run_on_device(tridiagonal_solve_tf, device, a, b, c, d) + x_np = x.numpy() + b_np = b.numpy() + c_np = c.numpy() + a_np = a.numpy() + + result = np.array([ + b_np[0] * x_np[0] + c_np[0] * x_np[1], + a_np[0] * x_np[0] + b_np[1] * x_np[1] + ]) + np.testing.assert_array_almost_equal(result, d.numpy(), decimal=5) + + +class TestLeapfrogIntegrationTf: + """Tests for the TensorFlow leapfrog_integration function.""" + + @pytest.mark.parametrize("device", DEVICES) + def test_single_stationary_particle(self, device): + """A single particle with no velocity should remain stationary.""" + positions = to_tensor([[0.0, 0.0, 0.0]], device) + velocities = to_tensor([[0.0, 0.0, 0.0]], device) + masses = to_tensor([1.0], device) + + final_pos, final_vel = run_on_device( + leapfrog_integration_tf, device, + positions, velocities, masses, dt=0.01, n_steps=100 + ) + + np.testing.assert_array_almost_equal(final_pos.numpy(), positions.numpy(), decimal=5) + np.testing.assert_array_almost_equal(final_vel.numpy(), velocities.numpy(), decimal=5) + + @pytest.mark.parametrize("device", DEVICES) + def test_single_moving_particle(self, device): + """A single moving particle should move in a straight line.""" + positions = to_tensor([[0.0, 0.0, 0.0]], device) + velocities = to_tensor([[1.0, 0.0, 0.0]], device) + masses = to_tensor([1.0], device) + + dt = 0.01 + n_steps = 100 + + final_pos, final_vel = run_on_device( + leapfrog_integration_tf, device, + positions, velocities, masses, dt=dt, n_steps=n_steps + ) + + np.testing.assert_array_almost_equal(final_vel.numpy(), velocities.numpy(), decimal=5) + expected_pos = np.array([[dt * n_steps, 0.0, 0.0]]) + np.testing.assert_array_almost_equal(final_pos.numpy(), expected_pos, decimal=5) + + @pytest.mark.parametrize("device", DEVICES) + def test_two_particles_approach(self, device): + """Two particles should attract each other gravitationally.""" + positions = to_tensor([[-1.0, 0.0, 0.0], [1.0, 0.0, 0.0]], device) + velocities = to_tensor(np.zeros((2, 3)), device) + masses = to_tensor([1.0, 1.0], device) + + final_pos, _ = run_on_device( + leapfrog_integration_tf, device, + positions, velocities, masses, dt=0.01, n_steps=50, softening=0.1 + ) + + initial_distance = 2.0 + final_distance = np.linalg.norm(final_pos.numpy()[1] - final_pos.numpy()[0]) + assert final_distance < initial_distance + + @pytest.mark.parametrize("device", DEVICES) + def test_momentum_conservation(self, device): + """Total momentum should be approximately conserved.""" + np.random.seed(42) + n_particles = 5 + positions_np = np.random.randn(n_particles, 3) + velocities_np = np.random.randn(n_particles, 3) + masses_np = np.abs(np.random.randn(n_particles)) + 0.1 + + positions = to_tensor(positions_np, device) + velocities = to_tensor(velocities_np, device) + masses = to_tensor(masses_np, device) + + initial_momentum = np.sum(masses_np[:, np.newaxis] * velocities_np, axis=0) + + final_pos, final_vel = run_on_device( + leapfrog_integration_tf, device, + positions, velocities, masses, dt=0.001, n_steps=100, softening=0.5 + ) + + final_momentum = np.sum(masses_np[:, np.newaxis] * final_vel.numpy(), axis=0) + + np.testing.assert_array_almost_equal(initial_momentum, final_momentum, decimal=4) + + +class TestLongestIncreasingSubsequenceLengthTf: + """Tests for the TensorFlow longest_increasing_subsequence_length function.""" + + @pytest.mark.parametrize("device", DEVICES) + def test_single_element(self, device): + """Single element array should return 1.""" + arr = to_tensor([5.0], device, dtype=tf.float32) + result = run_on_device(longest_increasing_subsequence_length_tf, device, arr) + assert result == 1 + + @pytest.mark.parametrize("device", DEVICES) + def test_strictly_increasing(self, device): + """Strictly increasing array - LIS is the whole array.""" + arr = to_tensor([1.0, 2.0, 3.0, 4.0, 5.0], device, dtype=tf.float32) + result = run_on_device(longest_increasing_subsequence_length_tf, device, arr) + assert result == 5 + + @pytest.mark.parametrize("device", DEVICES) + def test_strictly_decreasing(self, device): + """Strictly decreasing array - LIS is length 1.""" + arr = to_tensor([5.0, 4.0, 3.0, 2.0, 1.0], device, dtype=tf.float32) + result = run_on_device(longest_increasing_subsequence_length_tf, device, arr) + assert result == 1 + + @pytest.mark.parametrize("device", DEVICES) + def test_classic_example(self, device): + """Classic LIS example.""" + arr = to_tensor([10.0, 9.0, 2.0, 5.0, 3.0, 7.0, 101.0, 18.0], device, dtype=tf.float32) + result = run_on_device(longest_increasing_subsequence_length_tf, device, arr) + assert result == 4 + + @pytest.mark.parametrize("device", DEVICES) + def test_all_same_elements(self, device): + """All same elements - LIS is length 1.""" + arr = to_tensor([5.0, 5.0, 5.0, 5.0, 5.0], device, dtype=tf.float32) + result = run_on_device(longest_increasing_subsequence_length_tf, device, arr) + assert result == 1 + + @pytest.mark.parametrize("device", DEVICES) + def test_alternating_sequence(self, device): + """Alternating high-low sequence.""" + arr = to_tensor([1.0, 10.0, 2.0, 9.0, 3.0, 8.0, 4.0, 7.0], device, dtype=tf.float32) + result = run_on_device(longest_increasing_subsequence_length_tf, device, arr) + assert result == 5 + + @pytest.mark.parametrize("device", DEVICES) + def test_two_elements_increasing(self, device): + """Two elements in increasing order.""" + arr = to_tensor([1.0, 2.0], device, dtype=tf.float32) + result = run_on_device(longest_increasing_subsequence_length_tf, device, arr) + assert result == 2 + + @pytest.mark.parametrize("device", DEVICES) + def test_two_elements_decreasing(self, device): + """Two elements in decreasing order.""" + arr = to_tensor([2.0, 1.0], device, dtype=tf.float32) + result = run_on_device(longest_increasing_subsequence_length_tf, device, arr) + assert result == 1 + + @pytest.mark.parametrize("device", DEVICES) + def test_longer_sequence(self, device): + """Test with a longer sequence.""" + arr = to_tensor([0.0, 8.0, 4.0, 12.0, 2.0, 10.0, 6.0, 14.0, 1.0, 9.0, 5.0, 13.0, 3.0, 11.0, 7.0, 15.0], device, dtype=tf.float32) + result = run_on_device(longest_increasing_subsequence_length_tf, device, arr) + assert result == 6 + + @pytest.mark.parametrize("device", DEVICES) + def test_negative_numbers(self, device): + """Test with negative numbers.""" + arr = to_tensor([-5.0, -2.0, -8.0, -1.0, -6.0, 0.0], device, dtype=tf.float32) + result = run_on_device(longest_increasing_subsequence_length_tf, device, arr) + assert result == 4 diff --git a/code_to_optimize/tests/pytest/test_torch_jit_code.py b/code_to_optimize/tests/pytest/test_torch_jit_code.py new file mode 100644 index 000000000..8c617b3dd --- /dev/null +++ b/code_to_optimize/tests/pytest/test_torch_jit_code.py @@ -0,0 +1,285 @@ +""" +Unit tests for PyTorch implementations of JIT-suitable functions. + +Tests run on CPU, CUDA, and MPS devices. +""" + +import numpy as np +import pytest + +torch = pytest.importorskip("torch") + +from code_to_optimize.sample_jit_code import ( + leapfrog_integration_torch, + longest_increasing_subsequence_length_torch, + tridiagonal_solve_torch, +) + + +def get_available_devices(): + """Return list of available PyTorch devices for testing.""" + devices = ["cpu"] + + # Check for CUDA + if torch.cuda.is_available(): + devices.append("cuda") + + # Check for MPS (Apple Silicon) + if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + devices.append("mps") + + return devices + + +DEVICES = get_available_devices() + + +def get_dtype(device): + """Get the appropriate dtype for a device. MPS doesn't support float64.""" + if device == "mps": + return torch.float32 + return torch.float64 + + +def to_device(arr, device): + """Move a tensor to the specified device.""" + dtype = get_dtype(device) + if isinstance(arr, np.ndarray): + arr = torch.from_numpy(arr).to(dtype) + return arr.to(device) + + +class TestTridiagonalSolveTorch: + """Tests for the PyTorch tridiagonal_solve function.""" + + @pytest.mark.parametrize("device", DEVICES) + def test_simple_system(self, device): + """Test a simple 3x3 tridiagonal system with known solution.""" + a = torch.tensor([-1.0, -1.0], dtype=get_dtype(device), device=device) + b = torch.tensor([2.0, 2.0, 2.0], dtype=get_dtype(device), device=device) + c = torch.tensor([-1.0, -1.0], dtype=get_dtype(device), device=device) + d = torch.tensor([1.0, 0.0, 1.0], dtype=get_dtype(device), device=device) + + x = tridiagonal_solve_torch(a, b, c, d) + + # Verify solution by multiplying back + result = torch.zeros(3, dtype=get_dtype(device), device=device) + result[0] = b[0] * x[0] + c[0] * x[1] + result[1] = a[0] * x[0] + b[1] * x[1] + c[1] * x[2] + result[2] = a[1] * x[1] + b[2] * x[2] + + np.testing.assert_array_almost_equal(result.cpu().numpy(), d.cpu().numpy(), decimal=5) + + @pytest.mark.parametrize("device", DEVICES) + def test_diagonal_system(self, device): + """Test a purely diagonal system.""" + a = torch.tensor([0.0, 0.0], dtype=get_dtype(device), device=device) + b = torch.tensor([2.0, 3.0, 4.0], dtype=get_dtype(device), device=device) + c = torch.tensor([0.0, 0.0], dtype=get_dtype(device), device=device) + d = torch.tensor([4.0, 9.0, 16.0], dtype=get_dtype(device), device=device) + + x = tridiagonal_solve_torch(a, b, c, d) + + expected = torch.tensor([2.0, 3.0, 4.0], dtype=get_dtype(device)) + np.testing.assert_array_almost_equal(x.cpu().numpy(), expected.numpy(), decimal=5) + + @pytest.mark.parametrize("device", DEVICES) + def test_larger_system(self, device): + """Test a larger tridiagonal system.""" + n = 100 + a = -torch.ones(n - 1, dtype=get_dtype(device), device=device) + b = 2.0 * torch.ones(n, dtype=get_dtype(device), device=device) + c = -torch.ones(n - 1, dtype=get_dtype(device), device=device) + d = torch.zeros(n, dtype=get_dtype(device), device=device) + d[0] = 1.0 + d[-1] = 1.0 + + x = tridiagonal_solve_torch(a, b, c, d) + + # Verify by reconstructing Ax + result = torch.zeros(n, dtype=get_dtype(device), device=device) + result[0] = b[0] * x[0] + c[0] * x[1] + for i in range(1, n - 1): + result[i] = a[i - 1] * x[i - 1] + b[i] * x[i] + c[i] * x[i + 1] + result[-1] = a[-1] * x[-2] + b[-1] * x[-1] + + np.testing.assert_array_almost_equal(result.cpu().numpy(), d.cpu().numpy(), decimal=5) + + @pytest.mark.parametrize("device", DEVICES) + def test_two_element_system(self, device): + """Test minimal 2x2 tridiagonal system.""" + a = torch.tensor([1.0], dtype=get_dtype(device), device=device) + b = torch.tensor([4.0, 4.0], dtype=get_dtype(device), device=device) + c = torch.tensor([1.0], dtype=get_dtype(device), device=device) + d = torch.tensor([5.0, 5.0], dtype=get_dtype(device), device=device) + + x = tridiagonal_solve_torch(a, b, c, d) + + result = torch.tensor([ + b[0] * x[0] + c[0] * x[1], + a[0] * x[0] + b[1] * x[1] + ], device=device) + np.testing.assert_array_almost_equal(result.cpu().numpy(), d.cpu().numpy(), decimal=5) + + +class TestLeapfrogIntegrationTorch: + """Tests for the PyTorch leapfrog_integration function.""" + + @pytest.mark.parametrize("device", DEVICES) + def test_single_stationary_particle(self, device): + """A single particle with no velocity should remain stationary.""" + positions = torch.tensor([[0.0, 0.0, 0.0]], dtype=get_dtype(device), device=device) + velocities = torch.tensor([[0.0, 0.0, 0.0]], dtype=get_dtype(device), device=device) + masses = torch.tensor([1.0], dtype=get_dtype(device), device=device) + + final_pos, final_vel = leapfrog_integration_torch( + positions, velocities, masses, dt=0.01, n_steps=100 + ) + + np.testing.assert_array_almost_equal(final_pos.cpu().numpy(), positions.cpu().numpy(), decimal=5) + np.testing.assert_array_almost_equal(final_vel.cpu().numpy(), velocities.cpu().numpy(), decimal=5) + + @pytest.mark.parametrize("device", DEVICES) + def test_single_moving_particle(self, device): + """A single moving particle should move in a straight line.""" + positions = torch.tensor([[0.0, 0.0, 0.0]], dtype=get_dtype(device), device=device) + velocities = torch.tensor([[1.0, 0.0, 0.0]], dtype=get_dtype(device), device=device) + masses = torch.tensor([1.0], dtype=get_dtype(device), device=device) + + dt = 0.01 + n_steps = 100 + + final_pos, final_vel = leapfrog_integration_torch( + positions, velocities, masses, dt=dt, n_steps=n_steps + ) + + np.testing.assert_array_almost_equal(final_vel.cpu().numpy(), velocities.cpu().numpy(), decimal=5) + expected_pos = torch.tensor([[dt * n_steps, 0.0, 0.0]]) + np.testing.assert_array_almost_equal(final_pos.cpu().numpy(), expected_pos.numpy(), decimal=5) + + @pytest.mark.parametrize("device", DEVICES) + def test_two_particles_approach(self, device): + """Two particles should attract each other gravitationally.""" + positions = torch.tensor([[-1.0, 0.0, 0.0], [1.0, 0.0, 0.0]], dtype=get_dtype(device), device=device) + velocities = torch.zeros((2, 3), dtype=get_dtype(device), device=device) + masses = torch.tensor([1.0, 1.0], dtype=get_dtype(device), device=device) + + final_pos, _ = leapfrog_integration_torch( + positions, velocities, masses, dt=0.01, n_steps=50, softening=0.1 + ) + + initial_distance = 2.0 + final_distance = torch.linalg.norm(final_pos[1] - final_pos[0]).item() + assert final_distance < initial_distance + + @pytest.mark.parametrize("device", DEVICES) + def test_momentum_conservation(self, device): + """Total momentum should be approximately conserved.""" + np.random.seed(42) + n_particles = 5 + positions = torch.tensor(np.random.randn(n_particles, 3), dtype=get_dtype(device), device=device) + velocities = torch.tensor(np.random.randn(n_particles, 3), dtype=get_dtype(device), device=device) + masses = torch.tensor(np.abs(np.random.randn(n_particles)) + 0.1, dtype=get_dtype(device), device=device) + + initial_momentum = torch.sum(masses[:, None] * velocities, dim=0) + + final_pos, final_vel = leapfrog_integration_torch( + positions, velocities, masses, dt=0.001, n_steps=100, softening=0.5 + ) + + final_momentum = torch.sum(masses[:, None] * final_vel, dim=0) + + np.testing.assert_array_almost_equal( + initial_momentum.cpu().numpy(), final_momentum.cpu().numpy(), decimal=4 + ) + + @pytest.mark.parametrize("device", DEVICES) + def test_does_not_modify_input(self, device): + """Input arrays should not be modified.""" + positions = torch.tensor([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], dtype=get_dtype(device), device=device) + velocities = torch.tensor([[0.1, 0.0, 0.0], [-0.1, 0.0, 0.0]], dtype=get_dtype(device), device=device) + masses = torch.tensor([1.0, 1.0], dtype=get_dtype(device), device=device) + + pos_copy = positions.clone() + vel_copy = velocities.clone() + + leapfrog_integration_torch(positions, velocities, masses, dt=0.01, n_steps=10) + + np.testing.assert_array_equal(positions.cpu().numpy(), pos_copy.cpu().numpy()) + np.testing.assert_array_equal(velocities.cpu().numpy(), vel_copy.cpu().numpy()) + + +class TestLongestIncreasingSubsequenceLengthTorch: + """Tests for the PyTorch longest_increasing_subsequence_length function.""" + + @pytest.mark.parametrize("device", DEVICES) + def test_empty_array(self, device): + """Empty array should return 0.""" + arr = torch.tensor([], dtype=get_dtype(device), device=device) + assert longest_increasing_subsequence_length_torch(arr) == 0 + + @pytest.mark.parametrize("device", DEVICES) + def test_single_element(self, device): + """Single element array should return 1.""" + arr = torch.tensor([5.0], dtype=get_dtype(device), device=device) + assert longest_increasing_subsequence_length_torch(arr) == 1 + + @pytest.mark.parametrize("device", DEVICES) + def test_strictly_increasing(self, device): + """Strictly increasing array - LIS is the whole array.""" + arr = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=get_dtype(device), device=device) + assert longest_increasing_subsequence_length_torch(arr) == 5 + + @pytest.mark.parametrize("device", DEVICES) + def test_strictly_decreasing(self, device): + """Strictly decreasing array - LIS is length 1.""" + arr = torch.tensor([5.0, 4.0, 3.0, 2.0, 1.0], dtype=get_dtype(device), device=device) + assert longest_increasing_subsequence_length_torch(arr) == 1 + + @pytest.mark.parametrize("device", DEVICES) + def test_classic_example(self, device): + """Classic LIS example.""" + arr = torch.tensor([10.0, 9.0, 2.0, 5.0, 3.0, 7.0, 101.0, 18.0], dtype=get_dtype(device), device=device) + assert longest_increasing_subsequence_length_torch(arr) == 4 + + @pytest.mark.parametrize("device", DEVICES) + def test_all_same_elements(self, device): + """All same elements - LIS is length 1.""" + arr = torch.tensor([5.0, 5.0, 5.0, 5.0, 5.0], dtype=get_dtype(device), device=device) + assert longest_increasing_subsequence_length_torch(arr) == 1 + + @pytest.mark.parametrize("device", DEVICES) + def test_alternating_sequence(self, device): + """Alternating high-low sequence.""" + arr = torch.tensor([1.0, 10.0, 2.0, 9.0, 3.0, 8.0, 4.0, 7.0], dtype=get_dtype(device), device=device) + assert longest_increasing_subsequence_length_torch(arr) == 5 + + @pytest.mark.parametrize("device", DEVICES) + def test_two_elements_increasing(self, device): + """Two elements in increasing order.""" + arr = torch.tensor([1.0, 2.0], dtype=get_dtype(device), device=device) + assert longest_increasing_subsequence_length_torch(arr) == 2 + + @pytest.mark.parametrize("device", DEVICES) + def test_two_elements_decreasing(self, device): + """Two elements in decreasing order.""" + arr = torch.tensor([2.0, 1.0], dtype=get_dtype(device), device=device) + assert longest_increasing_subsequence_length_torch(arr) == 1 + + @pytest.mark.parametrize("device", DEVICES) + def test_longer_sequence(self, device): + """Test with a longer sequence.""" + arr = torch.tensor([0.0, 8.0, 4.0, 12.0, 2.0, 10.0, 6.0, 14.0, 1.0, 9.0, 5.0, 13.0, 3.0, 11.0, 7.0, 15.0], dtype=get_dtype(device), device=device) + assert longest_increasing_subsequence_length_torch(arr) == 6 + + @pytest.mark.parametrize("device", DEVICES) + def test_negative_numbers(self, device): + """Test with negative numbers.""" + arr = torch.tensor([-5.0, -2.0, -8.0, -1.0, -6.0, 0.0], dtype=get_dtype(device), device=device) + assert longest_increasing_subsequence_length_torch(arr) == 4 + + @pytest.mark.parametrize("device", DEVICES) + def test_float_values(self, device): + """Test with floating point values.""" + arr = torch.tensor([1.5, 2.3, 1.8, 3.1, 2.9, 4.0], dtype=get_dtype(device), device=device) + assert longest_increasing_subsequence_length_torch(arr) == 4 \ No newline at end of file From 3a0e41861c39218b86bf78eed8e85327b6bdaf10 Mon Sep 17 00:00:00 2001 From: aseembits93 Date: Thu, 15 Jan 2026 18:03:16 -0800 Subject: [PATCH 21/27] local aiservice temporary --- codeflash/api/aiservice.py | 4 +++- codeflash/optimization/function_optimizer.py | 1 + codeflash/verification/verifier.py | 2 ++ 3 files changed, 6 insertions(+), 1 deletion(-) diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index 65c4c20a0..817382731 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -55,7 +55,7 @@ def get_aiservice_base_url(self) -> str: logger.info("Using local AI Service at http://localhost:8000") console.rule() return "http://localhost:8000" - return "https://app.codeflash.ai" + return "http://localhost:8000" def make_ai_service_request( self, @@ -634,6 +634,7 @@ def generate_regression_tests( # noqa: D417 test_timeout: int, trace_id: str, test_index: int, + is_numerical_code: bool | None = None, # noqa: FBT001 ) -> tuple[str, str, str] | None: """Generate regression tests for the given function by making a request to the Django endpoint. @@ -670,6 +671,7 @@ def generate_regression_tests( # noqa: D417 "codeflash_version": codeflash_version, "is_async": function_to_optimize.is_async, "call_sequence": self.get_next_sequence(), + "is_numerical_code": is_numerical_code, } try: response = self.make_ai_service_request("/testgen", payload=payload, timeout=self.timeout) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index d534fc297..0beef4f89 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -2485,6 +2485,7 @@ def submit_test_generation_tasks( test_index, test_path, test_perf_path, + self.is_numerical_code, ) for test_index, (test_path, test_perf_path) in enumerate( zip(generated_test_paths, generated_perf_test_paths) diff --git a/codeflash/verification/verifier.py b/codeflash/verification/verifier.py index 8d187f2b1..f60718020 100644 --- a/codeflash/verification/verifier.py +++ b/codeflash/verification/verifier.py @@ -27,6 +27,7 @@ def generate_tests( test_index: int, test_path: Path, test_perf_path: Path, + is_numerical_code: bool | None = None, # noqa: FBT001 ) -> tuple[str, str, Path] | None: # TODO: Sometimes this recreates the original Class definition. This overrides and messes up the original # class import. Remove the recreation of the class definition @@ -42,6 +43,7 @@ def generate_tests( test_timeout=test_timeout, trace_id=function_trace_id, test_index=test_index, + is_numerical_code=is_numerical_code, ) if response and isinstance(response, tuple) and len(response) == 3: generated_test_source, instrumented_behavior_test_source, instrumented_perf_test_source = response From 4d28c1779f9bce61414bb4e108aa6070462f639c Mon Sep 17 00:00:00 2001 From: aseembits93 Date: Thu, 15 Jan 2026 18:58:13 -0800 Subject: [PATCH 22/27] almost ready --- .../{sample_jit_code.py => sample_code.py} | 58 ++++++------------- .../tests/pytest/test_jax_jit_code.py | 2 +- .../tests/pytest/test_numba_jit_code.py | 2 +- .../tests/pytest/test_tensorflow_jit_code.py | 2 +- .../tests/pytest/test_torch_jit_code.py | 2 +- 5 files changed, 23 insertions(+), 43 deletions(-) rename code_to_optimize/{sample_jit_code.py => sample_code.py} (91%) diff --git a/code_to_optimize/sample_jit_code.py b/code_to_optimize/sample_code.py similarity index 91% rename from code_to_optimize/sample_jit_code.py rename to code_to_optimize/sample_code.py index 50232a07e..d356ce807 100644 --- a/code_to_optimize/sample_jit_code.py +++ b/code_to_optimize/sample_code.py @@ -10,48 +10,28 @@ def tridiagonal_solve(a: np.ndarray, b: np.ndarray, c: np.ndarray, d: np.ndarray) -> np.ndarray: n = len(b) - c_prime = np.empty(n - 1, dtype=np.float64) - d_prime = np.empty(n, dtype=np.float64) - x = np.empty(n, dtype=np.float64) - - # Alias arrays to local variables to avoid repeated attribute lookups - a_arr = a - b_arr = b - c_arr = c - d_arr = d - cp = c_prime - dp = d_prime - x_arr = x - - # First element - prev_cprime = c_arr[0] / b_arr[0] - cp[0] = prev_cprime - prev_dprime = d_arr[0] / b_arr[0] - dp[0] = prev_dprime - - # Forward sweep (compute c_prime and d_prime) + # Create working copies to avoid modifying input + c_prime = np.zeros(n - 1, dtype=np.float64) + d_prime = np.zeros(n, dtype=np.float64) + x = np.zeros(n, dtype=np.float64) + + # Forward sweep - sequential dependency: c_prime[i] depends on c_prime[i-1] + c_prime[0] = c[0] / b[0] + d_prime[0] = d[0] / b[0] for i in range(1, n - 1): - ai_1 = a_arr[i - 1] - denom = b_arr[i] - ai_1 * prev_cprime - curr_cprime = c_arr[i] / denom - curr_dprime = (d_arr[i] - ai_1 * prev_dprime) / denom - cp[i] = curr_cprime - dp[i] = curr_dprime - prev_cprime = curr_cprime - prev_dprime = curr_dprime - - # Last d_prime entry - denom = b_arr[n - 1] - a_arr[n - 2] * prev_cprime - dp[n - 1] = (d_arr[n - 1] - a_arr[n - 2] * prev_dprime) / denom - - # Back substitution using a scalar for the "next x" value - prev_x = dp[n - 1] - x_arr[n - 1] = prev_x + denom = b[i] - a[i - 1] * c_prime[i - 1] + c_prime[i] = c[i] / denom + d_prime[i] = (d[i] - a[i - 1] * d_prime[i - 1]) / denom + + # Last row of forward sweep + denom = b[n - 1] - a[n - 2] * c_prime[n - 2] + d_prime[n - 1] = (d[n - 1] - a[n - 2] * d_prime[n - 2]) / denom + + # Back substitution - sequential dependency: x[i] depends on x[i+1] + x[n - 1] = d_prime[n - 1] for i in range(n - 2, -1, -1): - xi = dp[i] - cp[i] * prev_x - x_arr[i] = xi - prev_x = xi + x[i] = d_prime[i] - c_prime[i] * x[i + 1] return x diff --git a/code_to_optimize/tests/pytest/test_jax_jit_code.py b/code_to_optimize/tests/pytest/test_jax_jit_code.py index 5968bc0e2..4a6c44f5c 100644 --- a/code_to_optimize/tests/pytest/test_jax_jit_code.py +++ b/code_to_optimize/tests/pytest/test_jax_jit_code.py @@ -10,7 +10,7 @@ jax = pytest.importorskip("jax") import jax.numpy as jnp -from code_to_optimize.sample_jit_code import ( +from code_to_optimize.sample_code import ( leapfrog_integration_jax, longest_increasing_subsequence_length_jax, tridiagonal_solve_jax, diff --git a/code_to_optimize/tests/pytest/test_numba_jit_code.py b/code_to_optimize/tests/pytest/test_numba_jit_code.py index 06738c39b..a81152901 100644 --- a/code_to_optimize/tests/pytest/test_numba_jit_code.py +++ b/code_to_optimize/tests/pytest/test_numba_jit_code.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from code_to_optimize.sample_jit_code import ( +from code_to_optimize.sample_code import ( leapfrog_integration, longest_increasing_subsequence_length, tridiagonal_solve, diff --git a/code_to_optimize/tests/pytest/test_tensorflow_jit_code.py b/code_to_optimize/tests/pytest/test_tensorflow_jit_code.py index 3a44d24f5..1a545bb00 100644 --- a/code_to_optimize/tests/pytest/test_tensorflow_jit_code.py +++ b/code_to_optimize/tests/pytest/test_tensorflow_jit_code.py @@ -9,7 +9,7 @@ tf = pytest.importorskip("tensorflow") -from code_to_optimize.sample_jit_code import ( +from code_to_optimize.sample_code import ( leapfrog_integration_tf, longest_increasing_subsequence_length_tf, tridiagonal_solve_tf, diff --git a/code_to_optimize/tests/pytest/test_torch_jit_code.py b/code_to_optimize/tests/pytest/test_torch_jit_code.py index 8c617b3dd..4681ed89f 100644 --- a/code_to_optimize/tests/pytest/test_torch_jit_code.py +++ b/code_to_optimize/tests/pytest/test_torch_jit_code.py @@ -9,7 +9,7 @@ torch = pytest.importorskip("torch") -from code_to_optimize.sample_jit_code import ( +from code_to_optimize.sample_code import ( leapfrog_integration_torch, longest_increasing_subsequence_length_torch, tridiagonal_solve_torch, From ee6872c317a16df47981161b4472616ad3e8f8fd Mon Sep 17 00:00:00 2001 From: aseembits93 Date: Thu, 15 Jan 2026 19:04:31 -0800 Subject: [PATCH 23/27] tensorflow, jax, pytorch now working on mac metal --- .../tests/pytest/test_jax_jit_code.py | 14 ++++++++++++-- .../tests/pytest/test_tensorflow_jit_code.py | 16 +++++++++++----- .../tests/pytest/test_torch_jit_code.py | 2 +- 3 files changed, 24 insertions(+), 8 deletions(-) diff --git a/code_to_optimize/tests/pytest/test_jax_jit_code.py b/code_to_optimize/tests/pytest/test_jax_jit_code.py index 4a6c44f5c..3e9afe4e9 100644 --- a/code_to_optimize/tests/pytest/test_jax_jit_code.py +++ b/code_to_optimize/tests/pytest/test_jax_jit_code.py @@ -1,13 +1,13 @@ """ Unit tests for JAX implementations of JIT-suitable functions. -Tests run on CPU and CUDA devices. +Tests run on CPU, CUDA, and Metal (Mac) devices. """ import numpy as np import pytest -jax = pytest.importorskip("jax") +import jax import jax.numpy as jnp from code_to_optimize.sample_code import ( @@ -32,6 +32,14 @@ def get_available_devices(): except RuntimeError: pass + # Check for Metal (Mac) + try: + metal_devices = jax.devices("METAL") + if metal_devices: + devices.append("metal") + except RuntimeError: + pass + return devices @@ -44,6 +52,8 @@ def to_device(arr, device): return jax.device_put(arr, jax.devices("cpu")[0]) elif device == "cuda": return jax.device_put(arr, jax.devices("gpu")[0]) + elif device == "metal": + return jax.device_put(arr, jax.devices("METAL")[0]) return arr diff --git a/code_to_optimize/tests/pytest/test_tensorflow_jit_code.py b/code_to_optimize/tests/pytest/test_tensorflow_jit_code.py index 1a545bb00..cbeb0b308 100644 --- a/code_to_optimize/tests/pytest/test_tensorflow_jit_code.py +++ b/code_to_optimize/tests/pytest/test_tensorflow_jit_code.py @@ -1,9 +1,11 @@ """ Unit tests for TensorFlow implementations of JIT-suitable functions. -Tests run on CPU and CUDA devices. +Tests run on CPU, CUDA, and Metal (Mac) devices. """ +import platform + import numpy as np import pytest @@ -20,10 +22,14 @@ def get_available_devices(): """Return list of available TensorFlow devices for testing.""" devices = ["cpu"] - # Check for CUDA/GPU + # Check for GPU devices gpus = tf.config.list_physical_devices("GPU") if gpus: - devices.append("cuda") + # On macOS, GPUs are Metal devices; on other platforms, they're CUDA + if platform.system() == "Darwin": + devices.append("metal") + else: + devices.append("cuda") return devices @@ -35,7 +41,7 @@ def run_on_device(func, device, *args, **kwargs): """Run a function on the specified device.""" if device == "cpu": device_name = "/CPU:0" - elif device == "cuda": + elif device in ("cuda", "metal"): device_name = "/GPU:0" else: device_name = "/CPU:0" @@ -48,7 +54,7 @@ def to_tensor(arr, device, dtype=tf.float64): """Create a tensor on the specified device.""" if device == "cpu": device_name = "/CPU:0" - elif device == "cuda": + elif device in ("cuda", "metal"): device_name = "/GPU:0" else: device_name = "/CPU:0" diff --git a/code_to_optimize/tests/pytest/test_torch_jit_code.py b/code_to_optimize/tests/pytest/test_torch_jit_code.py index 4681ed89f..63b0e6889 100644 --- a/code_to_optimize/tests/pytest/test_torch_jit_code.py +++ b/code_to_optimize/tests/pytest/test_torch_jit_code.py @@ -7,7 +7,7 @@ import numpy as np import pytest -torch = pytest.importorskip("torch") +import torch from code_to_optimize.sample_code import ( leapfrog_integration_torch, From 9d4fa82b52d7b4aa63c22cd1b6dea49bbfdfe833 Mon Sep 17 00:00:00 2001 From: aseembits93 Date: Thu, 15 Jan 2026 21:25:40 -0800 Subject: [PATCH 24/27] ready to review --- codeflash/api/aiservice.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index 817382731..1ec502ac9 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -55,7 +55,7 @@ def get_aiservice_base_url(self) -> str: logger.info("Using local AI Service at http://localhost:8000") console.rule() return "http://localhost:8000" - return "http://localhost:8000" + return "https://app.codeflash.ai" def make_ai_service_request( self, From 01a8ebf7e794ef880a04a50e3853bdab00aea559 Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Fri, 16 Jan 2026 12:14:40 -0800 Subject: [PATCH 25/27] Apply suggestions from code review --- codeflash/api/aiservice.py | 6 +++--- codeflash/optimization/function_optimizer.py | 5 ++++- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index 1ec502ac9..e4ed074fd 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -221,13 +221,13 @@ def get_jit_rewritten_code( # noqa: D417 "repo_name": git_repo_name, } - logger.info("!lsp|Generating optimized candidates…") + logger.info("!lsp|Rewriting as a JIT function…") console.rule() try: response = self.make_ai_service_request("/rewrite_jit", payload=payload, timeout=60) except requests.exceptions.RequestException as e: logger.exception(f"Error generating jit rewritten candidate: {e}") - ph("cli-optimize-error-caught", {"error": str(e)}) + ph("cli-jit-rewrite-error-caught", {"error": str(e)}) return [] if response.status_code == 200: @@ -241,7 +241,7 @@ def get_jit_rewritten_code( # noqa: D417 except Exception: error = response.text logger.error(f"Error generating jit rewritten candidate: {response.status_code} - {error}") - ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error}) + ph("cli-jit-rewrite-error-response", {"response_status_code": response.status_code, "error": error}) console.rule() return [] diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 0beef4f89..82cd2d208 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -615,7 +615,10 @@ def optimize_function(self) -> Result[BestOptimization, str]: original_helper_code=original_helper_code, ) # get code context - new_code_context = self.get_code_optimization_context().unwrap() + try: + new_code_context = self.get_code_optimization_context().unwrap() + except Exception as e: + logger.debug(f"!lsp|Getting new code context failed, revert to original one") # unwrite files self.write_code_and_helpers( self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path From b41d387368be86b0e83aa20ac12488ad9d552bbd Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Fri, 16 Jan 2026 12:15:38 -0800 Subject: [PATCH 26/27] sentry capture --- codeflash/optimization/function_optimizer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 82cd2d208..8397b130f 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -618,6 +618,7 @@ def optimize_function(self) -> Result[BestOptimization, str]: try: new_code_context = self.get_code_optimization_context().unwrap() except Exception as e: + sentry_sdk.capture_exception(e) logger.debug(f"!lsp|Getting new code context failed, revert to original one") # unwrite files self.write_code_and_helpers( From 074ad24d8b0c8683b771cb7c7994098f1d9090b2 Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Fri, 16 Jan 2026 12:24:44 -0800 Subject: [PATCH 27/27] final formatting --- codeflash/optimization/function_optimizer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 8397b130f..761b8ea0c 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -13,6 +13,7 @@ from typing import TYPE_CHECKING, Callable import libcst as cst +import sentry_sdk from rich.console import Group from rich.panel import Panel from rich.syntax import Syntax @@ -619,7 +620,7 @@ def optimize_function(self) -> Result[BestOptimization, str]: new_code_context = self.get_code_optimization_context().unwrap() except Exception as e: sentry_sdk.capture_exception(e) - logger.debug(f"!lsp|Getting new code context failed, revert to original one") + logger.debug("!lsp|Getting new code context failed, revert to original one") # unwrite files self.write_code_and_helpers( self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path