diff --git a/code_to_optimize/discrete_riccati.py b/code_to_optimize/discrete_riccati.py deleted file mode 100644 index 5a032b9c5..000000000 --- a/code_to_optimize/discrete_riccati.py +++ /dev/null @@ -1,100 +0,0 @@ -""" -Utility functions used in CompEcon - -Based routines found in the CompEcon toolbox by Miranda and Fackler. - -References ----------- -Miranda, Mario J, and Paul L Fackler. Applied Computational Economics -and Finance, MIT Press, 2002. - -""" -from functools import reduce -import numpy as np -import torch - -def _gridmake2(x1, x2): - """ - Expands two vectors (or matrices) into a matrix where rows span the - cartesian product of combinations of the input arrays. Each column of the - input arrays will correspond to one column of the output matrix. - - Parameters - ---------- - x1 : np.ndarray - First vector to be expanded. - - x2 : np.ndarray - Second vector to be expanded. - - Returns - ------- - out : np.ndarray - The cartesian product of combinations of the input arrays. - - Notes - ----- - Based of original function ``gridmake2`` in CompEcon toolbox by - Miranda and Fackler. - - References - ---------- - Miranda, Mario J, and Paul L Fackler. Applied Computational Economics - and Finance, MIT Press, 2002. - - """ - if x1.ndim == 1 and x2.ndim == 1: - return np.column_stack([np.tile(x1, x2.shape[0]), - np.repeat(x2, x1.shape[0])]) - elif x1.ndim > 1 and x2.ndim == 1: - first = np.tile(x1, (x2.shape[0], 1)) - second = np.repeat(x2, x1.shape[0]) - return np.column_stack([first, second]) - else: - raise NotImplementedError("Come back here") - - -def _gridmake2_torch(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: - """ - PyTorch version of _gridmake2. - - Expands two tensors into a matrix where rows span the cartesian product - of combinations of the input tensors. Each column of the input tensors - will correspond to one column of the output matrix. - - Parameters - ---------- - x1 : torch.Tensor - First tensor to be expanded. - - x2 : torch.Tensor - Second tensor to be expanded. - - Returns - ------- - out : torch.Tensor - The cartesian product of combinations of the input tensors. - - Notes - ----- - Based on original function ``gridmake2`` in CompEcon toolbox by - Miranda and Fackler. - - References - ---------- - Miranda, Mario J, and Paul L Fackler. Applied Computational Economics - and Finance, MIT Press, 2002. - - """ - if x1.dim() == 1 and x2.dim() == 1: - # tile x1 by x2.shape[0] times, repeat_interleave x2 by x1.shape[0] - first = x1.tile(x2.shape[0]) - second = x2.repeat_interleave(x1.shape[0]) - return torch.column_stack([first, second]) - elif x1.dim() > 1 and x2.dim() == 1: - # tile x1 along first dimension - first = x1.tile(x2.shape[0], 1) - second = x2.repeat_interleave(x1.shape[0]) - return torch.column_stack([first, second]) - else: - raise NotImplementedError("Come back here") diff --git a/code_to_optimize/sample_code.py b/code_to_optimize/sample_code.py new file mode 100644 index 000000000..d356ce807 --- /dev/null +++ b/code_to_optimize/sample_code.py @@ -0,0 +1,456 @@ +from functools import partial + +import jax.numpy as jnp +import numpy as np +import tensorflow as tf +import torch +from jax import lax + + +def tridiagonal_solve(a: np.ndarray, b: np.ndarray, c: np.ndarray, d: np.ndarray) -> np.ndarray: + n = len(b) + + # Create working copies to avoid modifying input + c_prime = np.zeros(n - 1, dtype=np.float64) + d_prime = np.zeros(n, dtype=np.float64) + x = np.zeros(n, dtype=np.float64) + + # Forward sweep - sequential dependency: c_prime[i] depends on c_prime[i-1] + c_prime[0] = c[0] / b[0] + d_prime[0] = d[0] / b[0] + + for i in range(1, n - 1): + denom = b[i] - a[i - 1] * c_prime[i - 1] + c_prime[i] = c[i] / denom + d_prime[i] = (d[i] - a[i - 1] * d_prime[i - 1]) / denom + + # Last row of forward sweep + denom = b[n - 1] - a[n - 2] * c_prime[n - 2] + d_prime[n - 1] = (d[n - 1] - a[n - 2] * d_prime[n - 2]) / denom + + # Back substitution - sequential dependency: x[i] depends on x[i+1] + x[n - 1] = d_prime[n - 1] + for i in range(n - 2, -1, -1): + x[i] = d_prime[i] - c_prime[i] * x[i + 1] + + return x + + +def leapfrog_integration( + positions: np.ndarray, + velocities: np.ndarray, + masses: np.ndarray, + dt: float, + n_steps: int, + softening: float = 0.01 +) -> tuple[np.ndarray, np.ndarray]: + n_particles = len(masses) + pos = positions.copy() + vel = velocities.copy() + acc = np.zeros_like(pos) + + G = 1.0 + + for step in range(n_steps): + acc.fill(0.0) + + for i in range(n_particles): + for j in range(i + 1, n_particles): + dx = pos[j, 0] - pos[i, 0] + dy = pos[j, 1] - pos[i, 1] + dz = pos[j, 2] - pos[i, 2] + + dist_sq = dx * dx + dy * dy + dz * dz + softening * softening + dist = np.sqrt(dist_sq) + dist_cubed = dist_sq * dist + + force_over_dist = G / dist_cubed + + acc[i, 0] += masses[j] * force_over_dist * dx + acc[i, 1] += masses[j] * force_over_dist * dy + acc[i, 2] += masses[j] * force_over_dist * dz + + acc[j, 0] -= masses[i] * force_over_dist * dx + acc[j, 1] -= masses[i] * force_over_dist * dy + acc[j, 2] -= masses[i] * force_over_dist * dz + + for i in range(n_particles): + vel[i, 0] += 0.5 * dt * acc[i, 0] + vel[i, 1] += 0.5 * dt * acc[i, 1] + vel[i, 2] += 0.5 * dt * acc[i, 2] + + for i in range(n_particles): + pos[i, 0] += dt * vel[i, 0] + pos[i, 1] += dt * vel[i, 1] + pos[i, 2] += dt * vel[i, 2] + + for i in range(n_particles): + vel[i, 0] += 0.5 * dt * acc[i, 0] + vel[i, 1] += 0.5 * dt * acc[i, 1] + vel[i, 2] += 0.5 * dt * acc[i, 2] + + return pos, vel + + +def longest_increasing_subsequence_length(arr: np.ndarray) -> int: + n = len(arr) + if n == 0: + return 0 + + dp = np.ones(n, dtype=np.int64) + + for i in range(1, n): + for j in range(i): + if arr[j] < arr[i]: + if dp[j] + 1 > dp[i]: + dp[i] = dp[j] + 1 + + max_length = dp[0] + for i in range(1, n): + if dp[i] > max_length: + max_length = dp[i] + + return max_length + + +def _tridiagonal_forward_step_jax(carry, inputs): + c_prev, d_prev = carry + a_i, b_i, c_i, d_i = inputs + denom = b_i - a_i * c_prev + c_new = c_i / denom + d_new = (d_i - a_i * d_prev) / denom + return (c_new, d_new), (c_new, d_new) + + +def _tridiagonal_back_step_jax(x_next, inputs): + d_prime_i, c_prime_i = inputs + x_i = d_prime_i - c_prime_i * x_next + return x_i, x_i + + +def tridiagonal_solve_jax(a, b, c, d): + n = b.shape[0] + + c_prime_0 = c[0] / b[0] + d_prime_0 = d[0] / b[0] + + scan_inputs = (a[:-1], b[1:-1], c[1:], d[1:-1]) + + _, (c_prime_rest, d_prime_mid) = lax.scan( + _tridiagonal_forward_step_jax, + (c_prime_0, d_prime_0), + scan_inputs + ) + + c_prime = jnp.concatenate([jnp.array([c_prime_0]), c_prime_rest]) + + denom_last = b[n - 1] - a[n - 2] * c_prime[n - 2] + d_prime_last = (d[n - 1] - a[n - 2] * d_prime_mid[-1]) / denom_last + d_prime = jnp.concatenate([jnp.array([d_prime_0]), d_prime_mid, jnp.array([d_prime_last])]) + + x_last = d_prime[n - 1] + _, x_rest = lax.scan( + _tridiagonal_back_step_jax, + x_last, + (d_prime[:-1], c_prime), + reverse=True + ) + + x = jnp.concatenate([x_rest, jnp.array([x_last])]) + return x + + +def _leapfrog_compute_accelerations_jax(pos, masses, softening): + G = 1.0 + diff = pos[jnp.newaxis, :, :] - pos[:, jnp.newaxis, :] + + dist_sq = jnp.sum(diff ** 2, axis=-1) + softening ** 2 + dist = jnp.sqrt(dist_sq) + dist_cubed = dist_sq * dist + + dist_cubed = jnp.where(dist_cubed == 0, 1.0, dist_cubed) + + force_factor = G * masses[jnp.newaxis, :] / dist_cubed + + acc = jnp.sum(force_factor[:, :, jnp.newaxis] * diff, axis=1) + return acc + + +def _leapfrog_step_jax(carry, _, masses, softening, dt): + pos, vel = carry + acc = _leapfrog_compute_accelerations_jax(pos, masses, softening) + + vel = vel + 0.5 * dt * acc + pos = pos + dt * vel + vel = vel + 0.5 * dt * acc + + return (pos, vel), None + + +def leapfrog_integration_jax( + positions, + velocities, + masses, + dt: float, + n_steps: int, + softening: float = 0.01 +): + step_fn = partial(_leapfrog_step_jax, masses=masses, softening=softening, dt=dt) + (final_pos, final_vel), _ = lax.scan(step_fn, (positions, velocities), None, length=n_steps) + return final_pos, final_vel + + +def _lis_inner_body_jax(j, dp_inner, arr, i): + condition = (arr[j] < arr[i]) & (dp_inner[j] + 1 > dp_inner[i]) + new_val = jnp.where(condition, dp_inner[j] + 1, dp_inner[i]) + return dp_inner.at[i].set(new_val) + + +def _lis_outer_body_jax(i, dp, arr): + inner_fn = partial(_lis_inner_body_jax, arr=arr, i=i) + dp = lax.fori_loop(0, i, inner_fn, dp) + return dp + + +def longest_increasing_subsequence_length_jax(arr): + n = arr.shape[0] + + if n == 0: + return 0 + + outer_fn = partial(_lis_outer_body_jax, arr=arr) + dp = jnp.ones(n, dtype=jnp.int32) + dp = lax.fori_loop(1, n, outer_fn, dp) + + return int(jnp.max(dp)) + + +def tridiagonal_solve_torch(a, b, c, d): + device = b.device + dtype = b.dtype + n = b.shape[0] + + c_prime = torch.zeros(n - 1, device=device, dtype=dtype) + d_prime = torch.zeros(n, device=device, dtype=dtype) + x = torch.zeros(n, device=device, dtype=dtype) + + c_prime[0] = c[0] / b[0] + d_prime[0] = d[0] / b[0] + + for i in range(1, n - 1): + denom = b[i] - a[i - 1] * c_prime[i - 1] + c_prime[i] = c[i] / denom + d_prime[i] = (d[i] - a[i - 1] * d_prime[i - 1]) / denom + + denom = b[n - 1] - a[n - 2] * c_prime[n - 2] + d_prime[n - 1] = (d[n - 1] - a[n - 2] * d_prime[n - 2]) / denom + + x[n - 1] = d_prime[n - 1] + for i in range(n - 2, -1, -1): + x[i] = d_prime[i] - c_prime[i] * x[i + 1] + + return x + + +def leapfrog_integration_torch( + positions, + velocities, + masses, + dt: float, + n_steps: int, + softening: float = 0.01 +): + G = 1.0 + + pos = positions.clone() + vel = velocities.clone() + + for _ in range(n_steps): + diff = pos.unsqueeze(0) - pos.unsqueeze(1) + + dist_sq = torch.sum(diff ** 2, dim=-1) + softening ** 2 + dist = torch.sqrt(dist_sq) + dist_cubed = dist_sq * dist + + dist_cubed = torch.where(dist_cubed == 0, torch.ones_like(dist_cubed), dist_cubed) + + force_factor = G * masses.unsqueeze(0) / dist_cubed + + acc = torch.sum(force_factor.unsqueeze(-1) * diff, dim=1) + + vel = vel + 0.5 * dt * acc + pos = pos + dt * vel + vel = vel + 0.5 * dt * acc + + return pos, vel + + +def longest_increasing_subsequence_length_torch(arr): + n = arr.shape[0] + + if n == 0: + return 0 + + device = arr.device + dp = torch.ones(n, device=device, dtype=torch.int64) + + for i in range(1, n): + for j in range(i): + if arr[j] < arr[i]: + if dp[j] + 1 > dp[i]: + dp[i] = dp[j] + 1 + + return int(torch.max(dp).item()) + + +def _tridiagonal_forward_cond_tf(i, _c_prime, _d_prime, n, _a, _b, _c, _d): + return i < n - 1 + + +def _tridiagonal_forward_body_tf(i, c_prime, d_prime, n, a, b, c, d): + c_prev = c_prime[i - 1] + d_prev = d_prime[i - 1] + denom = b[i] - a[i - 1] * c_prev + c_val = c[i] / denom + d_val = (d[i] - a[i - 1] * d_prev) / denom + c_prime = tf.tensor_scatter_nd_update(c_prime, tf.reshape(i, [1, 1]), tf.reshape(c_val, [1])) + d_prime = tf.tensor_scatter_nd_update(d_prime, tf.reshape(i, [1, 1]), tf.reshape(d_val, [1])) + return i + 1, c_prime, d_prime, n, a, b, c, d + + +def _tridiagonal_back_cond_tf(i, _x, _c_prime, _d_prime): + return i >= 0 + + +def _tridiagonal_back_body_tf(i, x, c_prime, d_prime): + x_next = x[i + 1] + x_val = d_prime[i] - c_prime[i] * x_next + x = tf.tensor_scatter_nd_update(x, tf.reshape(i, [1, 1]), tf.reshape(x_val, [1])) + return i - 1, x, c_prime, d_prime + + +def tridiagonal_solve_tf(a, b, c, d): + n = tf.shape(b)[0] + dtype = b.dtype + + c_prime = tf.zeros([n - 1], dtype=dtype) + d_prime = tf.zeros([n], dtype=dtype) + + c_prime = tf.tensor_scatter_nd_update(c_prime, [[0]], tf.reshape(c[0] / b[0], [1])) + d_prime = tf.tensor_scatter_nd_update(d_prime, [[0]], tf.reshape(d[0] / b[0], [1])) + + _, c_prime, d_prime, _, _, _, _, _ = tf.while_loop( + _tridiagonal_forward_cond_tf, + _tridiagonal_forward_body_tf, + [1, c_prime, d_prime, n, a, b, c, d] + ) + + c_last = c_prime[n - 2] + d_prev = d_prime[n - 2] + denom = b[n - 1] - a[n - 2] * c_last + d_last = (d[n - 1] - a[n - 2] * d_prev) / denom + d_prime = tf.tensor_scatter_nd_update(d_prime, tf.reshape(n - 1, [1, 1]), tf.reshape(d_last, [1])) + + x = tf.zeros([n], dtype=dtype) + x = tf.tensor_scatter_nd_update(x, tf.reshape(n - 1, [1, 1]), tf.reshape(d_prime[n - 1], [1])) + + _, x, _, _ = tf.while_loop( + _tridiagonal_back_cond_tf, + _tridiagonal_back_body_tf, + [n - 2, x, c_prime, d_prime] + ) + + return x + + +def _leapfrog_compute_accelerations_tf(pos, masses, softening, G): + diff = tf.expand_dims(pos, 0) - tf.expand_dims(pos, 1) + + dist_sq = tf.reduce_sum(diff ** 2, axis=-1) + softening ** 2 + dist = tf.sqrt(dist_sq) + dist_cubed = dist_sq * dist + + dist_cubed = tf.where(dist_cubed == 0, tf.ones_like(dist_cubed), dist_cubed) + + force_factor = G * tf.expand_dims(masses, 0) / dist_cubed + + acc = tf.reduce_sum(tf.expand_dims(force_factor, -1) * diff, axis=1) + return acc + + +def _leapfrog_step_body_tf(i, pos, vel, masses, softening, dt, n_steps): + G = 1.0 + acc = _leapfrog_compute_accelerations_tf(pos, masses, softening, G) + + vel = vel + 0.5 * dt * acc + pos = pos + dt * vel + vel = vel + 0.5 * dt * acc + + return i + 1, pos, vel, masses, softening, dt, n_steps + + +def _leapfrog_step_cond_tf(i, _pos, _vel, _masses, _softening, _dt, n_steps): + return i < n_steps + + +def leapfrog_integration_tf( + positions, + velocities, + masses, + dt: float, + n_steps: int, + softening: float = 0.01 +): + dt = tf.constant(dt, dtype=positions.dtype) + softening = tf.constant(softening, dtype=positions.dtype) + + _, final_pos, final_vel, _, _, _, _ = tf.while_loop( + _leapfrog_step_cond_tf, + _leapfrog_step_body_tf, + [0, positions, velocities, masses, softening, dt, n_steps] + ) + + return final_pos, final_vel + + +def _lis_inner_body_tf(j, dp_inner, arr, i): + condition = tf.logical_and(arr[j] < arr[i], dp_inner[j] + 1 > dp_inner[i]) + new_val = tf.where(condition, dp_inner[j] + 1, dp_inner[i]) + indices = tf.reshape(i, [1, 1]) + updates = tf.reshape(new_val, [1]) + dp_updated = tf.tensor_scatter_nd_update(dp_inner, indices, updates) + return j + 1, dp_updated, arr, i + + +def _lis_inner_cond_tf(j, _dp_inner, _arr, i): + return j < i + + +def _lis_outer_body_tf(i, dp, arr, n): + _, dp, _, _ = tf.while_loop( + _lis_inner_cond_tf, + _lis_inner_body_tf, + [0, dp, arr, i] + ) + return i + 1, dp, arr, n + + +def _lis_outer_cond_tf(i, _dp, _arr, n): + return i < n + + +def longest_increasing_subsequence_length_tf(arr): + n = tf.shape(arr)[0] + + if n == 0: + return 0 + + dp = tf.ones(n, dtype=tf.int32) + + _, dp, _, _ = tf.while_loop( + _lis_outer_cond_tf, + _lis_outer_body_tf, + [1, dp, arr, n] + ) + + return int(tf.reduce_max(dp)) diff --git a/code_to_optimize/tests/pytest/test_gridmake2.py b/code_to_optimize/tests/pytest/test_gridmake2.py deleted file mode 100644 index ae96a54bb..000000000 --- a/code_to_optimize/tests/pytest/test_gridmake2.py +++ /dev/null @@ -1,216 +0,0 @@ -import numpy as np -import pytest -from numpy.testing import assert_array_equal - -from code_to_optimize.discrete_riccati import _gridmake2 - - -class TestGridmake2With1DArrays: - """Tests for _gridmake2 with two 1D arrays.""" - - def test_basic_two_element_arrays(self): - """Test basic cartesian product of two 2-element arrays.""" - x1 = np.array([1, 2]) - x2 = np.array([3, 4]) - result = _gridmake2(x1, x2) - - # Expected: x1 is tiled len(x2) times, x2 is repeated len(x1) times - expected = np.array([ - [1, 3], - [2, 3], - [1, 4], - [2, 4] - ]) - assert_array_equal(result, expected) - - def test_different_length_arrays(self): - """Test cartesian product with arrays of different lengths.""" - x1 = np.array([1, 2, 3]) - x2 = np.array([10, 20]) - result = _gridmake2(x1, x2) - - # Result should have len(x1) * len(x2) = 6 rows - expected = np.array([ - [1, 10], - [2, 10], - [3, 10], - [1, 20], - [2, 20], - [3, 20] - ]) - assert_array_equal(result, expected) - assert result.shape == (6, 2) - - def test_single_element_arrays(self): - """Test with single-element arrays.""" - x1 = np.array([5]) - x2 = np.array([7]) - result = _gridmake2(x1, x2) - - expected = np.array([[5, 7]]) - assert_array_equal(result, expected) - assert result.shape == (1, 2) - - def test_single_element_with_multi_element(self): - """Test single-element array with multi-element array.""" - x1 = np.array([1]) - x2 = np.array([10, 20, 30]) - result = _gridmake2(x1, x2) - - expected = np.array([ - [1, 10], - [1, 20], - [1, 30] - ]) - assert_array_equal(result, expected) - - def test_float_arrays(self): - """Test with float arrays.""" - x1 = np.array([1.5, 2.5]) - x2 = np.array([0.1, 0.2]) - result = _gridmake2(x1, x2) - - expected = np.array([ - [1.5, 0.1], - [2.5, 0.1], - [1.5, 0.2], - [2.5, 0.2] - ]) - assert_array_equal(result, expected) - - def test_negative_values(self): - """Test with negative values.""" - x1 = np.array([-1, 0, 1]) - x2 = np.array([-10, 10]) - result = _gridmake2(x1, x2) - - expected = np.array([ - [-1, -10], - [0, -10], - [1, -10], - [-1, 10], - [0, 10], - [1, 10] - ]) - assert_array_equal(result, expected) - - def test_result_shape(self): - """Test that result shape is (len(x1)*len(x2), 2).""" - x1 = np.array([1, 2, 3, 4]) - x2 = np.array([5, 6, 7]) - result = _gridmake2(x1, x2) - - assert result.shape == (12, 2) - - def test_larger_arrays(self): - """Test with larger arrays.""" - x1 = np.arange(10) - x2 = np.arange(5) - result = _gridmake2(x1, x2) - - assert result.shape == (50, 2) - # Verify first column is x1 tiled 5 times - assert_array_equal(result[:10, 0], x1) - assert_array_equal(result[10:20, 0], x1) - # Verify second column is x2 repeated 10 times each - assert all(result[:10, 1] == 0) - assert all(result[10:20, 1] == 1) - - -class TestGridmake2With2DFirst: - """Tests for _gridmake2 when x1 is 2D and x2 is 1D.""" - - def test_2d_first_1d_second(self): - """Test with 2D first array and 1D second array.""" - x1 = np.array([[1, 2], [3, 4]]) # 2 rows, 2 cols - x2 = np.array([10, 20]) - result = _gridmake2(x1, x2) - - # x1 is tiled len(x2) times vertically - # x2 is repeated len(x1) times (2 rows) - expected = np.array([ - [1, 2, 10], - [3, 4, 10], - [1, 2, 20], - [3, 4, 20] - ]) - assert_array_equal(result, expected) - - def test_2d_single_column(self): - """Test with 2D array having single column.""" - x1 = np.array([[1], [2], [3]]) # 3 rows, 1 col - x2 = np.array([10, 20]) - result = _gridmake2(x1, x2) - - expected = np.array([ - [1, 10], - [2, 10], - [3, 10], - [1, 20], - [2, 20], - [3, 20] - ]) - assert_array_equal(result, expected) - - def test_2d_multiple_columns(self): - """Test with 2D array having multiple columns.""" - x1 = np.array([[1, 2, 3], [4, 5, 6]]) # 2 rows, 3 cols - x2 = np.array([100]) - result = _gridmake2(x1, x2) - - expected = np.array([ - [1, 2, 3, 100], - [4, 5, 6, 100] - ]) - assert_array_equal(result, expected) - - -class TestGridmake2EdgeCases: - """Edge case tests for _gridmake2.""" - - def test_empty_arrays_raise_or_return_empty(self): - """Test behavior with empty arrays.""" - x1 = np.array([]) - x2 = np.array([1, 2]) - result = _gridmake2(x1, x2) - # Empty x1 should result in empty output - assert result.shape[0] == 0 - - def test_both_empty_arrays(self): - """Test with both empty arrays.""" - x1 = np.array([]) - x2 = np.array([]) - result = _gridmake2(x1, x2) - assert result.shape[0] == 0 - - def test_integer_dtype_preserved(self): - """Test that integer dtype is handled correctly.""" - x1 = np.array([1, 2], dtype=np.int64) - x2 = np.array([3, 4], dtype=np.int64) - result = _gridmake2(x1, x2) - assert result.dtype == np.int64 - - def test_float_dtype_preserved(self): - """Test that float dtype is handled correctly.""" - x1 = np.array([1.0, 2.0], dtype=np.float64) - x2 = np.array([3.0, 4.0], dtype=np.float64) - result = _gridmake2(x1, x2) - assert result.dtype == np.float64 - - -class TestGridmake2NotImplemented: - """Tests for NotImplementedError cases.""" - - def test_both_2d_raises(self): - """Test that two 2D arrays raises NotImplementedError.""" - x1 = np.array([[1, 2], [3, 4]]) - x2 = np.array([[5, 6], [7, 8]]) - with pytest.raises(NotImplementedError): - _gridmake2(x1, x2) - - def test_1d_first_2d_second_raises(self): - """Test that 1D first and 2D second raises NotImplementedError.""" - x1 = np.array([1, 2]) - x2 = np.array([[5, 6], [7, 8]]) - with pytest.raises(NotImplementedError): - _gridmake2(x1, x2) diff --git a/code_to_optimize/tests/pytest/test_gridmake2_torch.py b/code_to_optimize/tests/pytest/test_gridmake2_torch.py deleted file mode 100644 index f2ee737a2..000000000 --- a/code_to_optimize/tests/pytest/test_gridmake2_torch.py +++ /dev/null @@ -1,267 +0,0 @@ -import pytest -import torch - -from code_to_optimize.discrete_riccati import _gridmake2_torch - - -class TestGridmake2TorchCPU: - """Tests for _gridmake2_torch with CPU tensors.""" - - def test_both_1d_simple(self): - """Test with two simple 1D tensors.""" - x1 = torch.tensor([1, 2, 3]) - x2 = torch.tensor([10, 20]) - - result = _gridmake2_torch(x1, x2) - - # Expected: x1 tiled x2.shape[0] times, x2 repeat_interleaved x1.shape[0] - # x1 tiled: [1, 2, 3, 1, 2, 3] - # x2 repeated: [10, 10, 10, 20, 20, 20] - expected = torch.tensor([ - [1, 10], - [2, 10], - [3, 10], - [1, 20], - [2, 20], - [3, 20], - ]) - assert torch.equal(result, expected) - - def test_both_1d_single_element(self): - """Test with single element tensors.""" - x1 = torch.tensor([5]) - x2 = torch.tensor([10]) - - result = _gridmake2_torch(x1, x2) - - expected = torch.tensor([[5, 10]]) - assert torch.equal(result, expected) - - def test_both_1d_float_tensors(self): - """Test with float tensors.""" - x1 = torch.tensor([1.5, 2.5]) - x2 = torch.tensor([0.1, 0.2, 0.3]) - - result = _gridmake2_torch(x1, x2) - - assert result.shape == (6, 2) - assert result.dtype == torch.float32 - - def test_2d_and_1d_simple(self): - """Test with 2D x1 and 1D x2.""" - x1 = torch.tensor([[1, 2], [3, 4]]) - x2 = torch.tensor([10, 20]) - - result = _gridmake2_torch(x1, x2) - - # x1 tiled along first dim: [[1, 2], [3, 4], [1, 2], [3, 4]] - # x2 repeated: [10, 10, 20, 20] - # column_stack: [[1, 2, 10], [3, 4, 10], [1, 2, 20], [3, 4, 20]] - expected = torch.tensor([ - [1, 2, 10], - [3, 4, 10], - [1, 2, 20], - [3, 4, 20], - ]) - assert torch.equal(result, expected) - - def test_2d_and_1d_single_column(self): - """Test with 2D x1 having a single column and 1D x2.""" - x1 = torch.tensor([[1], [2], [3]]) - x2 = torch.tensor([10, 20]) - - result = _gridmake2_torch(x1, x2) - - expected = torch.tensor([ - [1, 10], - [2, 10], - [3, 10], - [1, 20], - [2, 20], - [3, 20], - ]) - assert torch.equal(result, expected) - - def test_output_shape_1d_1d(self): - """Test output shape for two 1D tensors.""" - x1 = torch.tensor([1, 2, 3, 4, 5]) - x2 = torch.tensor([10, 20, 30]) - - result = _gridmake2_torch(x1, x2) - - # Shape should be (len(x1) * len(x2), 2) - assert result.shape == (15, 2) - - def test_output_shape_2d_1d(self): - """Test output shape for 2D and 1D tensors.""" - x1 = torch.tensor([[1, 2, 3], [4, 5, 6]]) # Shape (2, 3) - x2 = torch.tensor([10, 20, 30, 40]) # Shape (4,) - - result = _gridmake2_torch(x1, x2) - - # Shape should be (2 * 4, 3 + 1) = (8, 4) - assert result.shape == (8, 4) - - def test_not_implemented_for_2d_2d(self): - """Test that NotImplementedError is raised for two 2D tensors.""" - x1 = torch.tensor([[1, 2], [3, 4]]) - x2 = torch.tensor([[10, 20], [30, 40]]) - - with pytest.raises(NotImplementedError, match="Come back here"): - _gridmake2_torch(x1, x2) - - def test_not_implemented_for_1d_2d(self): - """Test that NotImplementedError is raised for 1D and 2D tensors.""" - x1 = torch.tensor([1, 2, 3]) - x2 = torch.tensor([[10, 20], [30, 40]]) - - with pytest.raises(NotImplementedError, match="Come back here"): - _gridmake2_torch(x1, x2) - - def test_preserves_dtype_int(self): - """Test that integer dtype is preserved.""" - x1 = torch.tensor([1, 2, 3], dtype=torch.int32) - x2 = torch.tensor([10, 20], dtype=torch.int32) - - result = _gridmake2_torch(x1, x2) - - assert result.dtype == torch.int32 - - def test_preserves_dtype_float64(self): - """Test that float64 dtype is preserved.""" - x1 = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float64) - x2 = torch.tensor([10.0, 20.0], dtype=torch.float64) - - result = _gridmake2_torch(x1, x2) - - assert result.dtype == torch.float64 - - def test_large_tensors(self): - """Test with larger tensors.""" - x1 = torch.arange(100) - x2 = torch.arange(50) - - result = _gridmake2_torch(x1, x2) - - assert result.shape == (5000, 2) - # Verify first and last elements - assert result[0, 0] == 0 and result[0, 1] == 0 - assert result[-1, 0] == 99 and result[-1, 1] == 49 - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -class TestGridmake2TorchCUDA: - """Tests for _gridmake2_torch with CUDA tensors.""" - - def test_both_1d_simple_cuda(self): - """Test with two simple 1D CUDA tensors.""" - x1 = torch.tensor([1, 2, 3], device="cuda") - x2 = torch.tensor([10, 20], device="cuda") - - result = _gridmake2_torch(x1, x2) - - expected = torch.tensor([ - [1, 10], - [2, 10], - [3, 10], - [1, 20], - [2, 20], - [3, 20], - ], device="cuda") - assert result.device.type == "cuda" - assert torch.equal(result, expected) - - def test_both_1d_matches_cpu(self): - """Test that CUDA version matches CPU version.""" - x1_cpu = torch.tensor([1.0, 2.0, 3.0, 4.0]) - x2_cpu = torch.tensor([10.0, 20.0, 30.0]) - - x1_cuda = x1_cpu.cuda() - x2_cuda = x2_cpu.cuda() - - result_cpu = _gridmake2_torch(x1_cpu, x2_cpu) - result_cuda = _gridmake2_torch(x1_cuda, x2_cuda) - - assert result_cuda.device.type == "cuda" - torch.testing.assert_close(result_cpu, result_cuda.cpu()) - - def test_2d_and_1d_cuda(self): - """Test with 2D x1 and 1D x2 on CUDA.""" - x1 = torch.tensor([[1, 2], [3, 4]], device="cuda") - x2 = torch.tensor([10, 20], device="cuda") - - result = _gridmake2_torch(x1, x2) - - expected = torch.tensor([ - [1, 2, 10], - [3, 4, 10], - [1, 2, 20], - [3, 4, 20], - ], device="cuda") - assert result.device.type == "cuda" - assert torch.equal(result, expected) - - def test_2d_and_1d_matches_cpu(self): - """Test that CUDA version matches CPU version for 2D, 1D inputs.""" - x1_cpu = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) - x2_cpu = torch.tensor([10.0, 20.0]) - - x1_cuda = x1_cpu.cuda() - x2_cuda = x2_cpu.cuda() - - result_cpu = _gridmake2_torch(x1_cpu, x2_cpu) - result_cuda = _gridmake2_torch(x1_cuda, x2_cuda) - - assert result_cuda.device.type == "cuda" - torch.testing.assert_close(result_cpu, result_cuda.cpu()) - - def test_output_stays_on_cuda(self): - """Test that output tensor stays on CUDA device.""" - x1 = torch.tensor([1, 2, 3], device="cuda") - x2 = torch.tensor([10, 20], device="cuda") - - result = _gridmake2_torch(x1, x2) - - assert result.is_cuda - - def test_preserves_dtype_float32_cuda(self): - """Test that float32 dtype is preserved on CUDA.""" - x1 = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32, device="cuda") - x2 = torch.tensor([10.0, 20.0], dtype=torch.float32, device="cuda") - - result = _gridmake2_torch(x1, x2) - - assert result.dtype == torch.float32 - assert result.device.type == "cuda" - - def test_preserves_dtype_float64_cuda(self): - """Test that float64 dtype is preserved on CUDA.""" - x1 = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float64, device="cuda") - x2 = torch.tensor([10.0, 20.0], dtype=torch.float64, device="cuda") - - result = _gridmake2_torch(x1, x2) - - assert result.dtype == torch.float64 - assert result.device.type == "cuda" - - def test_large_tensors_cuda(self): - """Test with larger tensors on CUDA.""" - x1 = torch.arange(100, device="cuda") - x2 = torch.arange(50, device="cuda") - - result = _gridmake2_torch(x1, x2) - - assert result.shape == (5000, 2) - assert result.device.type == "cuda" - # Verify first and last elements - assert result[0, 0].item() == 0 and result[0, 1].item() == 0 - assert result[-1, 0].item() == 99 and result[-1, 1].item() == 49 - - def test_not_implemented_for_2d_2d_cuda(self): - """Test that NotImplementedError is raised for two 2D CUDA tensors.""" - x1 = torch.tensor([[1, 2], [3, 4]], device="cuda") - x2 = torch.tensor([[10, 20], [30, 40]], device="cuda") - - with pytest.raises(NotImplementedError, match="Come back here"): - _gridmake2_torch(x1, x2) - diff --git a/code_to_optimize/tests/pytest/test_jax_jit_code.py b/code_to_optimize/tests/pytest/test_jax_jit_code.py new file mode 100644 index 000000000..3e9afe4e9 --- /dev/null +++ b/code_to_optimize/tests/pytest/test_jax_jit_code.py @@ -0,0 +1,266 @@ +""" +Unit tests for JAX implementations of JIT-suitable functions. + +Tests run on CPU, CUDA, and Metal (Mac) devices. +""" + +import numpy as np +import pytest + +import jax +import jax.numpy as jnp + +from code_to_optimize.sample_code import ( + leapfrog_integration_jax, + longest_increasing_subsequence_length_jax, + tridiagonal_solve_jax, +) + + +def get_available_devices(): + """Return list of available JAX devices for testing.""" + devices = [] + + # CPU is always available + devices.append("cpu") + + # Check for CUDA/GPU + try: + gpu_devices = jax.devices("gpu") + if gpu_devices: + devices.append("cuda") + except RuntimeError: + pass + + # Check for Metal (Mac) + try: + metal_devices = jax.devices("METAL") + if metal_devices: + devices.append("metal") + except RuntimeError: + pass + + return devices + + +DEVICES = get_available_devices() + + +def to_device(arr, device): + """Move a JAX array to the specified device.""" + if device == "cpu": + return jax.device_put(arr, jax.devices("cpu")[0]) + elif device == "cuda": + return jax.device_put(arr, jax.devices("gpu")[0]) + elif device == "metal": + return jax.device_put(arr, jax.devices("METAL")[0]) + return arr + + +class TestTridiagonalSolveJax: + """Tests for the JAX tridiagonal_solve function.""" + + @pytest.mark.parametrize("device", DEVICES) + def test_simple_system(self, device): + """Test a simple 3x3 tridiagonal system with known solution.""" + a = jnp.array([-1.0, -1.0]) + b = jnp.array([2.0, 2.0, 2.0]) + c = jnp.array([-1.0, -1.0]) + d = jnp.array([1.0, 0.0, 1.0]) + + a, b, c, d = to_device(a, device), to_device(b, device), to_device(c, device), to_device(d, device) + + x = tridiagonal_solve_jax(a, b, c, d) + + # Verify solution by multiplying back + result = jnp.zeros(3) + result = result.at[0].set(b[0] * x[0] + c[0] * x[1]) + result = result.at[1].set(a[0] * x[0] + b[1] * x[1] + c[1] * x[2]) + result = result.at[2].set(a[1] * x[1] + b[2] * x[2]) + + np.testing.assert_array_almost_equal(np.array(result), np.array(d), decimal=5) + + @pytest.mark.parametrize("device", DEVICES) + def test_diagonal_system(self, device): + """Test a purely diagonal system.""" + a = jnp.array([0.0, 0.0]) + b = jnp.array([2.0, 3.0, 4.0]) + c = jnp.array([0.0, 0.0]) + d = jnp.array([4.0, 9.0, 16.0]) + + a, b, c, d = to_device(a, device), to_device(b, device), to_device(c, device), to_device(d, device) + + x = tridiagonal_solve_jax(a, b, c, d) + + expected = jnp.array([2.0, 3.0, 4.0]) + np.testing.assert_array_almost_equal(np.array(x), np.array(expected), decimal=5) + + @pytest.mark.parametrize("device", DEVICES) + def test_larger_system(self, device): + """Test a larger tridiagonal system.""" + n = 50 + a = -jnp.ones(n - 1) + b = 2.0 * jnp.ones(n) + c = -jnp.ones(n - 1) + d = jnp.zeros(n).at[0].set(1.0).at[-1].set(1.0) + + a, b, c, d = to_device(a, device), to_device(b, device), to_device(c, device), to_device(d, device) + + x = tridiagonal_solve_jax(a, b, c, d) + + # Verify by reconstructing Ax + result = jnp.zeros(n) + result = result.at[0].set(b[0] * x[0] + c[0] * x[1]) + for i in range(1, n - 1): + result = result.at[i].set(a[i - 1] * x[i - 1] + b[i] * x[i] + c[i] * x[i + 1]) + result = result.at[-1].set(a[-1] * x[-2] + b[-1] * x[-1]) + + np.testing.assert_array_almost_equal(np.array(result), np.array(d), decimal=5) + + +class TestLeapfrogIntegrationJax: + """Tests for the JAX leapfrog_integration function.""" + + @pytest.mark.parametrize("device", DEVICES) + def test_single_stationary_particle(self, device): + """A single particle with no velocity should remain stationary.""" + positions = jnp.array([[0.0, 0.0, 0.0]]) + velocities = jnp.array([[0.0, 0.0, 0.0]]) + masses = jnp.array([1.0]) + + positions = to_device(positions, device) + velocities = to_device(velocities, device) + masses = to_device(masses, device) + + final_pos, final_vel = leapfrog_integration_jax( + positions, velocities, masses, dt=0.01, n_steps=100 + ) + + np.testing.assert_array_almost_equal(np.array(final_pos), np.array(positions), decimal=5) + np.testing.assert_array_almost_equal(np.array(final_vel), np.array(velocities), decimal=5) + + @pytest.mark.parametrize("device", DEVICES) + def test_single_moving_particle(self, device): + """A single moving particle should move in a straight line.""" + positions = jnp.array([[0.0, 0.0, 0.0]]) + velocities = jnp.array([[1.0, 0.0, 0.0]]) + masses = jnp.array([1.0]) + + positions = to_device(positions, device) + velocities = to_device(velocities, device) + masses = to_device(masses, device) + + dt = 0.01 + n_steps = 100 + + final_pos, final_vel = leapfrog_integration_jax( + positions, velocities, masses, dt=dt, n_steps=n_steps + ) + + np.testing.assert_array_almost_equal(np.array(final_vel), np.array(velocities), decimal=5) + expected_pos = jnp.array([[dt * n_steps, 0.0, 0.0]]) + np.testing.assert_array_almost_equal(np.array(final_pos), np.array(expected_pos), decimal=5) + + @pytest.mark.parametrize("device", DEVICES) + def test_two_particles_approach(self, device): + """Two particles should attract each other gravitationally.""" + positions = jnp.array([[-1.0, 0.0, 0.0], [1.0, 0.0, 0.0]]) + velocities = jnp.zeros((2, 3)) + masses = jnp.array([1.0, 1.0]) + + positions = to_device(positions, device) + velocities = to_device(velocities, device) + masses = to_device(masses, device) + + final_pos, _ = leapfrog_integration_jax( + positions, velocities, masses, dt=0.01, n_steps=50, softening=0.1 + ) + + initial_distance = 2.0 + final_distance = float(jnp.linalg.norm(final_pos[1] - final_pos[0])) + assert final_distance < initial_distance + + @pytest.mark.parametrize("device", DEVICES) + def test_momentum_conservation(self, device): + """Total momentum should be approximately conserved.""" + np.random.seed(42) + n_particles = 5 + positions = jnp.array(np.random.randn(n_particles, 3)) + velocities = jnp.array(np.random.randn(n_particles, 3)) + masses = jnp.array(np.abs(np.random.randn(n_particles)) + 0.1) + + positions = to_device(positions, device) + velocities = to_device(velocities, device) + masses = to_device(masses, device) + + initial_momentum = jnp.sum(masses[:, jnp.newaxis] * velocities, axis=0) + + final_pos, final_vel = leapfrog_integration_jax( + positions, velocities, masses, dt=0.001, n_steps=100, softening=0.5 + ) + + final_momentum = jnp.sum(masses[:, jnp.newaxis] * final_vel, axis=0) + + np.testing.assert_array_almost_equal( + np.array(initial_momentum), np.array(final_momentum), decimal=4 + ) + + +class TestLongestIncreasingSubsequenceLengthJax: + """Tests for the JAX longest_increasing_subsequence_length function.""" + + @pytest.mark.parametrize("device", DEVICES) + def test_empty_array(self, device): + """Empty array should return 0.""" + arr = jnp.array([], dtype=jnp.float32) + arr = to_device(arr, device) + assert longest_increasing_subsequence_length_jax(arr) == 0 + + @pytest.mark.parametrize("device", DEVICES) + def test_single_element(self, device): + """Single element array should return 1.""" + arr = jnp.array([5.0]) + arr = to_device(arr, device) + assert longest_increasing_subsequence_length_jax(arr) == 1 + + @pytest.mark.parametrize("device", DEVICES) + def test_strictly_increasing(self, device): + """Strictly increasing array - LIS is the whole array.""" + arr = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0]) + arr = to_device(arr, device) + assert longest_increasing_subsequence_length_jax(arr) == 5 + + @pytest.mark.parametrize("device", DEVICES) + def test_strictly_decreasing(self, device): + """Strictly decreasing array - LIS is length 1.""" + arr = jnp.array([5.0, 4.0, 3.0, 2.0, 1.0]) + arr = to_device(arr, device) + assert longest_increasing_subsequence_length_jax(arr) == 1 + + @pytest.mark.parametrize("device", DEVICES) + def test_classic_example(self, device): + """Classic LIS example.""" + arr = jnp.array([10.0, 9.0, 2.0, 5.0, 3.0, 7.0, 101.0, 18.0]) + arr = to_device(arr, device) + assert longest_increasing_subsequence_length_jax(arr) == 4 + + @pytest.mark.parametrize("device", DEVICES) + def test_all_same_elements(self, device): + """All same elements - LIS is length 1.""" + arr = jnp.array([5.0, 5.0, 5.0, 5.0, 5.0]) + arr = to_device(arr, device) + assert longest_increasing_subsequence_length_jax(arr) == 1 + + @pytest.mark.parametrize("device", DEVICES) + def test_alternating_sequence(self, device): + """Alternating high-low sequence.""" + arr = jnp.array([1.0, 10.0, 2.0, 9.0, 3.0, 8.0, 4.0, 7.0]) + arr = to_device(arr, device) + assert longest_increasing_subsequence_length_jax(arr) == 5 + + @pytest.mark.parametrize("device", DEVICES) + def test_longer_sequence(self, device): + """Test with a longer sequence.""" + arr = jnp.array([0.0, 8.0, 4.0, 12.0, 2.0, 10.0, 6.0, 14.0, 1.0, 9.0, 5.0, 13.0, 3.0, 11.0, 7.0, 15.0]) + arr = to_device(arr, device) + assert longest_increasing_subsequence_length_jax(arr) == 6 \ No newline at end of file diff --git a/code_to_optimize/tests/pytest/test_numba_jit_code.py b/code_to_optimize/tests/pytest/test_numba_jit_code.py new file mode 100644 index 000000000..a81152901 --- /dev/null +++ b/code_to_optimize/tests/pytest/test_numba_jit_code.py @@ -0,0 +1,242 @@ +import numpy as np +import pytest + +from code_to_optimize.sample_code import ( + leapfrog_integration, + longest_increasing_subsequence_length, + tridiagonal_solve, +) + + +class TestTridiagonalSolve: + """Tests for the tridiagonal_solve function (Thomas algorithm).""" + + def test_simple_system(self): + """Test a simple 3x3 tridiagonal system with known solution.""" + # System: [2 -1 0] [x0] [1] + # [-1 2 -1] [x1] = [0] + # [0 -1 2] [x2] [1] + a = np.array([-1.0, -1.0]) # lower diagonal + b = np.array([2.0, 2.0, 2.0]) # main diagonal + c = np.array([-1.0, -1.0]) # upper diagonal + d = np.array([1.0, 0.0, 1.0]) # right-hand side + + x = tridiagonal_solve(a, b, c, d) + + # Verify solution by multiplying back + # Ax should equal d + result = np.zeros(3) + result[0] = b[0] * x[0] + c[0] * x[1] + result[1] = a[0] * x[0] + b[1] * x[1] + c[1] * x[2] + result[2] = a[1] * x[1] + b[2] * x[2] + + np.testing.assert_array_almost_equal(result, d) + + def test_diagonal_system(self): + """Test a purely diagonal system (a and c are zero).""" + a = np.array([0.0, 0.0]) + b = np.array([2.0, 3.0, 4.0]) + c = np.array([0.0, 0.0]) + d = np.array([4.0, 9.0, 16.0]) + + x = tridiagonal_solve(a, b, c, d) + + expected = np.array([2.0, 3.0, 4.0]) + np.testing.assert_array_almost_equal(x, expected) + + def test_larger_system(self): + """Test a larger tridiagonal system.""" + n = 100 + a = -np.ones(n - 1) + b = 2.0 * np.ones(n) + c = -np.ones(n - 1) + d = np.zeros(n) + d[0] = 1.0 + d[-1] = 1.0 + + x = tridiagonal_solve(a, b, c, d) + + # Verify by reconstructing Ax + result = np.zeros(n) + result[0] = b[0] * x[0] + c[0] * x[1] + for i in range(1, n - 1): + result[i] = a[i - 1] * x[i - 1] + b[i] * x[i] + c[i] * x[i + 1] + result[-1] = a[-1] * x[-2] + b[-1] * x[-1] + + np.testing.assert_array_almost_equal(result, d, decimal=10) + + def test_two_element_system(self): + """Test minimal 2x2 tridiagonal system.""" + a = np.array([1.0]) + b = np.array([4.0, 4.0]) + c = np.array([1.0]) + d = np.array([5.0, 5.0]) + + x = tridiagonal_solve(a, b, c, d) + + # Verify: [4 1] [x0] = [5] + # [1 4] [x1] [5] + result = np.array([ + b[0] * x[0] + c[0] * x[1], + a[0] * x[0] + b[1] * x[1] + ]) + np.testing.assert_array_almost_equal(result, d) + + +class TestLeapfrogIntegration: + """Tests for the leapfrog_integration function (N-body simulation).""" + + def test_single_stationary_particle(self): + """A single particle with no velocity should remain stationary.""" + positions = np.array([[0.0, 0.0, 0.0]]) + velocities = np.array([[0.0, 0.0, 0.0]]) + masses = np.array([1.0]) + + final_pos, final_vel = leapfrog_integration( + positions, velocities, masses, dt=0.01, n_steps=100 + ) + + np.testing.assert_array_almost_equal(final_pos, positions) + np.testing.assert_array_almost_equal(final_vel, velocities) + + def test_single_moving_particle(self): + """A single moving particle should move in a straight line.""" + positions = np.array([[0.0, 0.0, 0.0]]) + velocities = np.array([[1.0, 0.0, 0.0]]) + masses = np.array([1.0]) + + dt = 0.01 + n_steps = 100 + + final_pos, final_vel = leapfrog_integration( + positions, velocities, masses, dt=dt, n_steps=n_steps + ) + + # With no other particles, velocity should remain constant + np.testing.assert_array_almost_equal(final_vel, velocities) + + # Position should be initial + velocity * time + expected_pos = np.array([[dt * n_steps, 0.0, 0.0]]) + np.testing.assert_array_almost_equal(final_pos, expected_pos) + + def test_two_particles_approach(self): + """Two particles should attract each other gravitationally.""" + positions = np.array([ + [-1.0, 0.0, 0.0], + [1.0, 0.0, 0.0] + ]) + velocities = np.zeros((2, 3)) + masses = np.array([1.0, 1.0]) + + final_pos, final_vel = leapfrog_integration( + positions, velocities, masses, dt=0.01, n_steps=50, softening=0.1 + ) + + # Particles should move closer together + initial_distance = 2.0 + final_distance = np.linalg.norm(final_pos[1] - final_pos[0]) + assert final_distance < initial_distance + + def test_momentum_conservation(self): + """Total momentum should be approximately conserved.""" + np.random.seed(42) + n_particles = 5 + positions = np.random.randn(n_particles, 3) + velocities = np.random.randn(n_particles, 3) + masses = np.abs(np.random.randn(n_particles)) + 0.1 + + initial_momentum = np.sum(masses[:, np.newaxis] * velocities, axis=0) + + final_pos, final_vel = leapfrog_integration( + positions, velocities, masses, dt=0.001, n_steps=100, softening=0.5 + ) + + final_momentum = np.sum(masses[:, np.newaxis] * final_vel, axis=0) + + # Momentum should be conserved to good precision + np.testing.assert_array_almost_equal( + initial_momentum, final_momentum, decimal=5 + ) + + def test_does_not_modify_input(self): + """Input arrays should not be modified.""" + positions = np.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]]) + velocities = np.array([[0.1, 0.0, 0.0], [-0.1, 0.0, 0.0]]) + masses = np.array([1.0, 1.0]) + + pos_copy = positions.copy() + vel_copy = velocities.copy() + + leapfrog_integration(positions, velocities, masses, dt=0.01, n_steps=10) + + np.testing.assert_array_equal(positions, pos_copy) + np.testing.assert_array_equal(velocities, vel_copy) + + +class TestLongestIncreasingSubsequenceLength: + """Tests for the longest_increasing_subsequence_length function.""" + + def test_empty_array(self): + """Empty array should return 0.""" + arr = np.array([], dtype=np.float64) + assert longest_increasing_subsequence_length(arr) == 0 + + def test_single_element(self): + """Single element array should return 1.""" + arr = np.array([5]) + assert longest_increasing_subsequence_length(arr) == 1 + + def test_strictly_increasing(self): + """Strictly increasing array - LIS is the whole array.""" + arr = np.array([1, 2, 3, 4, 5]) + assert longest_increasing_subsequence_length(arr) == 5 + + def test_strictly_decreasing(self): + """Strictly decreasing array - LIS is length 1.""" + arr = np.array([5, 4, 3, 2, 1]) + assert longest_increasing_subsequence_length(arr) == 1 + + def test_classic_example(self): + """Classic LIS example: [10, 9, 2, 5, 3, 7, 101, 18].""" + arr = np.array([10, 9, 2, 5, 3, 7, 101, 18]) + # LIS: [2, 3, 7, 101] or [2, 5, 7, 101] or [2, 3, 7, 18] etc. + assert longest_increasing_subsequence_length(arr) == 4 + + def test_all_same_elements(self): + """All same elements - LIS is length 1 (strictly increasing).""" + arr = np.array([5, 5, 5, 5, 5]) + assert longest_increasing_subsequence_length(arr) == 1 + + def test_alternating_sequence(self): + """Alternating high-low sequence.""" + arr = np.array([1, 10, 2, 9, 3, 8, 4, 7]) + # LIS: [1, 2, 3, 4] or [1, 2, 3, 4, 7] - length 5 + assert longest_increasing_subsequence_length(arr) == 5 + + def test_two_elements_increasing(self): + """Two elements in increasing order.""" + arr = np.array([1, 2]) + assert longest_increasing_subsequence_length(arr) == 2 + + def test_two_elements_decreasing(self): + """Two elements in decreasing order.""" + arr = np.array([2, 1]) + assert longest_increasing_subsequence_length(arr) == 1 + + def test_longer_sequence(self): + """Test with a longer sequence.""" + arr = np.array([0, 8, 4, 12, 2, 10, 6, 14, 1, 9, 5, 13, 3, 11, 7, 15]) + # Known LIS length for this sequence is 6 + assert longest_increasing_subsequence_length(arr) == 6 + + def test_negative_numbers(self): + """Test with negative numbers.""" + arr = np.array([-5, -2, -8, -1, -6, 0]) + # LIS: [-5, -2, -1, 0] or [-8, -6, 0] etc. - length 4 + assert longest_increasing_subsequence_length(arr) == 4 + + def test_float_values(self): + """Test with floating point values.""" + arr = np.array([1.5, 2.3, 1.8, 3.1, 2.9, 4.0]) + # LIS: [1.5, 2.3, 3.1, 4.0] or [1.5, 1.8, 2.9, 4.0] - length 4 + assert longest_increasing_subsequence_length(arr) == 4 \ No newline at end of file diff --git a/code_to_optimize/tests/pytest/test_tensorflow_jit_code.py b/code_to_optimize/tests/pytest/test_tensorflow_jit_code.py new file mode 100644 index 000000000..cbeb0b308 --- /dev/null +++ b/code_to_optimize/tests/pytest/test_tensorflow_jit_code.py @@ -0,0 +1,302 @@ +""" +Unit tests for TensorFlow implementations of JIT-suitable functions. + +Tests run on CPU, CUDA, and Metal (Mac) devices. +""" + +import platform + +import numpy as np +import pytest + +tf = pytest.importorskip("tensorflow") + +from code_to_optimize.sample_code import ( + leapfrog_integration_tf, + longest_increasing_subsequence_length_tf, + tridiagonal_solve_tf, +) + + +def get_available_devices(): + """Return list of available TensorFlow devices for testing.""" + devices = ["cpu"] + + # Check for GPU devices + gpus = tf.config.list_physical_devices("GPU") + if gpus: + # On macOS, GPUs are Metal devices; on other platforms, they're CUDA + if platform.system() == "Darwin": + devices.append("metal") + else: + devices.append("cuda") + + return devices + + +DEVICES = get_available_devices() + + +def run_on_device(func, device, *args, **kwargs): + """Run a function on the specified device.""" + if device == "cpu": + device_name = "/CPU:0" + elif device in ("cuda", "metal"): + device_name = "/GPU:0" + else: + device_name = "/CPU:0" + + with tf.device(device_name): + return func(*args, **kwargs) + + +def to_tensor(arr, device, dtype=tf.float64): + """Create a tensor on the specified device.""" + if device == "cpu": + device_name = "/CPU:0" + elif device in ("cuda", "metal"): + device_name = "/GPU:0" + else: + device_name = "/CPU:0" + + with tf.device(device_name): + return tf.constant(arr, dtype=dtype) + + +class TestTridiagonalSolveTf: + """Tests for the TensorFlow tridiagonal_solve function.""" + + @pytest.mark.parametrize("device", DEVICES) + def test_simple_system(self, device): + """Test a simple 3x3 tridiagonal system with known solution.""" + a = to_tensor([-1.0, -1.0], device) + b = to_tensor([2.0, 2.0, 2.0], device) + c = to_tensor([-1.0, -1.0], device) + d = to_tensor([1.0, 0.0, 1.0], device) + + x = run_on_device(tridiagonal_solve_tf, device, a, b, c, d) + + # Verify solution by multiplying back + result = np.zeros(3) + x_np = x.numpy() + b_np = b.numpy() + c_np = c.numpy() + a_np = a.numpy() + result[0] = b_np[0] * x_np[0] + c_np[0] * x_np[1] + result[1] = a_np[0] * x_np[0] + b_np[1] * x_np[1] + c_np[1] * x_np[2] + result[2] = a_np[1] * x_np[1] + b_np[2] * x_np[2] + + np.testing.assert_array_almost_equal(result, d.numpy(), decimal=5) + + @pytest.mark.parametrize("device", DEVICES) + def test_diagonal_system(self, device): + """Test a purely diagonal system.""" + a = to_tensor([0.0, 0.0], device) + b = to_tensor([2.0, 3.0, 4.0], device) + c = to_tensor([0.0, 0.0], device) + d = to_tensor([4.0, 9.0, 16.0], device) + + x = run_on_device(tridiagonal_solve_tf, device, a, b, c, d) + + expected = np.array([2.0, 3.0, 4.0]) + np.testing.assert_array_almost_equal(x.numpy(), expected, decimal=5) + + @pytest.mark.parametrize("device", DEVICES) + def test_larger_system(self, device): + """Test a larger tridiagonal system.""" + n = 50 + a_np = -np.ones(n - 1) + b_np = 2.0 * np.ones(n) + c_np = -np.ones(n - 1) + d_np = np.zeros(n) + d_np[0] = 1.0 + d_np[-1] = 1.0 + + a = to_tensor(a_np, device) + b = to_tensor(b_np, device) + c = to_tensor(c_np, device) + d = to_tensor(d_np, device) + + x = run_on_device(tridiagonal_solve_tf, device, a, b, c, d) + x_np = x.numpy() + + # Verify by reconstructing Ax + result = np.zeros(n) + result[0] = b_np[0] * x_np[0] + c_np[0] * x_np[1] + for i in range(1, n - 1): + result[i] = a_np[i - 1] * x_np[i - 1] + b_np[i] * x_np[i] + c_np[i] * x_np[i + 1] + result[-1] = a_np[-1] * x_np[-2] + b_np[-1] * x_np[-1] + + np.testing.assert_array_almost_equal(result, d_np, decimal=5) + + @pytest.mark.parametrize("device", DEVICES) + def test_two_element_system(self, device): + """Test minimal 2x2 tridiagonal system.""" + a = to_tensor([1.0], device) + b = to_tensor([4.0, 4.0], device) + c = to_tensor([1.0], device) + d = to_tensor([5.0, 5.0], device) + + x = run_on_device(tridiagonal_solve_tf, device, a, b, c, d) + x_np = x.numpy() + b_np = b.numpy() + c_np = c.numpy() + a_np = a.numpy() + + result = np.array([ + b_np[0] * x_np[0] + c_np[0] * x_np[1], + a_np[0] * x_np[0] + b_np[1] * x_np[1] + ]) + np.testing.assert_array_almost_equal(result, d.numpy(), decimal=5) + + +class TestLeapfrogIntegrationTf: + """Tests for the TensorFlow leapfrog_integration function.""" + + @pytest.mark.parametrize("device", DEVICES) + def test_single_stationary_particle(self, device): + """A single particle with no velocity should remain stationary.""" + positions = to_tensor([[0.0, 0.0, 0.0]], device) + velocities = to_tensor([[0.0, 0.0, 0.0]], device) + masses = to_tensor([1.0], device) + + final_pos, final_vel = run_on_device( + leapfrog_integration_tf, device, + positions, velocities, masses, dt=0.01, n_steps=100 + ) + + np.testing.assert_array_almost_equal(final_pos.numpy(), positions.numpy(), decimal=5) + np.testing.assert_array_almost_equal(final_vel.numpy(), velocities.numpy(), decimal=5) + + @pytest.mark.parametrize("device", DEVICES) + def test_single_moving_particle(self, device): + """A single moving particle should move in a straight line.""" + positions = to_tensor([[0.0, 0.0, 0.0]], device) + velocities = to_tensor([[1.0, 0.0, 0.0]], device) + masses = to_tensor([1.0], device) + + dt = 0.01 + n_steps = 100 + + final_pos, final_vel = run_on_device( + leapfrog_integration_tf, device, + positions, velocities, masses, dt=dt, n_steps=n_steps + ) + + np.testing.assert_array_almost_equal(final_vel.numpy(), velocities.numpy(), decimal=5) + expected_pos = np.array([[dt * n_steps, 0.0, 0.0]]) + np.testing.assert_array_almost_equal(final_pos.numpy(), expected_pos, decimal=5) + + @pytest.mark.parametrize("device", DEVICES) + def test_two_particles_approach(self, device): + """Two particles should attract each other gravitationally.""" + positions = to_tensor([[-1.0, 0.0, 0.0], [1.0, 0.0, 0.0]], device) + velocities = to_tensor(np.zeros((2, 3)), device) + masses = to_tensor([1.0, 1.0], device) + + final_pos, _ = run_on_device( + leapfrog_integration_tf, device, + positions, velocities, masses, dt=0.01, n_steps=50, softening=0.1 + ) + + initial_distance = 2.0 + final_distance = np.linalg.norm(final_pos.numpy()[1] - final_pos.numpy()[0]) + assert final_distance < initial_distance + + @pytest.mark.parametrize("device", DEVICES) + def test_momentum_conservation(self, device): + """Total momentum should be approximately conserved.""" + np.random.seed(42) + n_particles = 5 + positions_np = np.random.randn(n_particles, 3) + velocities_np = np.random.randn(n_particles, 3) + masses_np = np.abs(np.random.randn(n_particles)) + 0.1 + + positions = to_tensor(positions_np, device) + velocities = to_tensor(velocities_np, device) + masses = to_tensor(masses_np, device) + + initial_momentum = np.sum(masses_np[:, np.newaxis] * velocities_np, axis=0) + + final_pos, final_vel = run_on_device( + leapfrog_integration_tf, device, + positions, velocities, masses, dt=0.001, n_steps=100, softening=0.5 + ) + + final_momentum = np.sum(masses_np[:, np.newaxis] * final_vel.numpy(), axis=0) + + np.testing.assert_array_almost_equal(initial_momentum, final_momentum, decimal=4) + + +class TestLongestIncreasingSubsequenceLengthTf: + """Tests for the TensorFlow longest_increasing_subsequence_length function.""" + + @pytest.mark.parametrize("device", DEVICES) + def test_single_element(self, device): + """Single element array should return 1.""" + arr = to_tensor([5.0], device, dtype=tf.float32) + result = run_on_device(longest_increasing_subsequence_length_tf, device, arr) + assert result == 1 + + @pytest.mark.parametrize("device", DEVICES) + def test_strictly_increasing(self, device): + """Strictly increasing array - LIS is the whole array.""" + arr = to_tensor([1.0, 2.0, 3.0, 4.0, 5.0], device, dtype=tf.float32) + result = run_on_device(longest_increasing_subsequence_length_tf, device, arr) + assert result == 5 + + @pytest.mark.parametrize("device", DEVICES) + def test_strictly_decreasing(self, device): + """Strictly decreasing array - LIS is length 1.""" + arr = to_tensor([5.0, 4.0, 3.0, 2.0, 1.0], device, dtype=tf.float32) + result = run_on_device(longest_increasing_subsequence_length_tf, device, arr) + assert result == 1 + + @pytest.mark.parametrize("device", DEVICES) + def test_classic_example(self, device): + """Classic LIS example.""" + arr = to_tensor([10.0, 9.0, 2.0, 5.0, 3.0, 7.0, 101.0, 18.0], device, dtype=tf.float32) + result = run_on_device(longest_increasing_subsequence_length_tf, device, arr) + assert result == 4 + + @pytest.mark.parametrize("device", DEVICES) + def test_all_same_elements(self, device): + """All same elements - LIS is length 1.""" + arr = to_tensor([5.0, 5.0, 5.0, 5.0, 5.0], device, dtype=tf.float32) + result = run_on_device(longest_increasing_subsequence_length_tf, device, arr) + assert result == 1 + + @pytest.mark.parametrize("device", DEVICES) + def test_alternating_sequence(self, device): + """Alternating high-low sequence.""" + arr = to_tensor([1.0, 10.0, 2.0, 9.0, 3.0, 8.0, 4.0, 7.0], device, dtype=tf.float32) + result = run_on_device(longest_increasing_subsequence_length_tf, device, arr) + assert result == 5 + + @pytest.mark.parametrize("device", DEVICES) + def test_two_elements_increasing(self, device): + """Two elements in increasing order.""" + arr = to_tensor([1.0, 2.0], device, dtype=tf.float32) + result = run_on_device(longest_increasing_subsequence_length_tf, device, arr) + assert result == 2 + + @pytest.mark.parametrize("device", DEVICES) + def test_two_elements_decreasing(self, device): + """Two elements in decreasing order.""" + arr = to_tensor([2.0, 1.0], device, dtype=tf.float32) + result = run_on_device(longest_increasing_subsequence_length_tf, device, arr) + assert result == 1 + + @pytest.mark.parametrize("device", DEVICES) + def test_longer_sequence(self, device): + """Test with a longer sequence.""" + arr = to_tensor([0.0, 8.0, 4.0, 12.0, 2.0, 10.0, 6.0, 14.0, 1.0, 9.0, 5.0, 13.0, 3.0, 11.0, 7.0, 15.0], device, dtype=tf.float32) + result = run_on_device(longest_increasing_subsequence_length_tf, device, arr) + assert result == 6 + + @pytest.mark.parametrize("device", DEVICES) + def test_negative_numbers(self, device): + """Test with negative numbers.""" + arr = to_tensor([-5.0, -2.0, -8.0, -1.0, -6.0, 0.0], device, dtype=tf.float32) + result = run_on_device(longest_increasing_subsequence_length_tf, device, arr) + assert result == 4 diff --git a/code_to_optimize/tests/pytest/test_torch_jit_code.py b/code_to_optimize/tests/pytest/test_torch_jit_code.py new file mode 100644 index 000000000..63b0e6889 --- /dev/null +++ b/code_to_optimize/tests/pytest/test_torch_jit_code.py @@ -0,0 +1,285 @@ +""" +Unit tests for PyTorch implementations of JIT-suitable functions. + +Tests run on CPU, CUDA, and MPS devices. +""" + +import numpy as np +import pytest + +import torch + +from code_to_optimize.sample_code import ( + leapfrog_integration_torch, + longest_increasing_subsequence_length_torch, + tridiagonal_solve_torch, +) + + +def get_available_devices(): + """Return list of available PyTorch devices for testing.""" + devices = ["cpu"] + + # Check for CUDA + if torch.cuda.is_available(): + devices.append("cuda") + + # Check for MPS (Apple Silicon) + if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + devices.append("mps") + + return devices + + +DEVICES = get_available_devices() + + +def get_dtype(device): + """Get the appropriate dtype for a device. MPS doesn't support float64.""" + if device == "mps": + return torch.float32 + return torch.float64 + + +def to_device(arr, device): + """Move a tensor to the specified device.""" + dtype = get_dtype(device) + if isinstance(arr, np.ndarray): + arr = torch.from_numpy(arr).to(dtype) + return arr.to(device) + + +class TestTridiagonalSolveTorch: + """Tests for the PyTorch tridiagonal_solve function.""" + + @pytest.mark.parametrize("device", DEVICES) + def test_simple_system(self, device): + """Test a simple 3x3 tridiagonal system with known solution.""" + a = torch.tensor([-1.0, -1.0], dtype=get_dtype(device), device=device) + b = torch.tensor([2.0, 2.0, 2.0], dtype=get_dtype(device), device=device) + c = torch.tensor([-1.0, -1.0], dtype=get_dtype(device), device=device) + d = torch.tensor([1.0, 0.0, 1.0], dtype=get_dtype(device), device=device) + + x = tridiagonal_solve_torch(a, b, c, d) + + # Verify solution by multiplying back + result = torch.zeros(3, dtype=get_dtype(device), device=device) + result[0] = b[0] * x[0] + c[0] * x[1] + result[1] = a[0] * x[0] + b[1] * x[1] + c[1] * x[2] + result[2] = a[1] * x[1] + b[2] * x[2] + + np.testing.assert_array_almost_equal(result.cpu().numpy(), d.cpu().numpy(), decimal=5) + + @pytest.mark.parametrize("device", DEVICES) + def test_diagonal_system(self, device): + """Test a purely diagonal system.""" + a = torch.tensor([0.0, 0.0], dtype=get_dtype(device), device=device) + b = torch.tensor([2.0, 3.0, 4.0], dtype=get_dtype(device), device=device) + c = torch.tensor([0.0, 0.0], dtype=get_dtype(device), device=device) + d = torch.tensor([4.0, 9.0, 16.0], dtype=get_dtype(device), device=device) + + x = tridiagonal_solve_torch(a, b, c, d) + + expected = torch.tensor([2.0, 3.0, 4.0], dtype=get_dtype(device)) + np.testing.assert_array_almost_equal(x.cpu().numpy(), expected.numpy(), decimal=5) + + @pytest.mark.parametrize("device", DEVICES) + def test_larger_system(self, device): + """Test a larger tridiagonal system.""" + n = 100 + a = -torch.ones(n - 1, dtype=get_dtype(device), device=device) + b = 2.0 * torch.ones(n, dtype=get_dtype(device), device=device) + c = -torch.ones(n - 1, dtype=get_dtype(device), device=device) + d = torch.zeros(n, dtype=get_dtype(device), device=device) + d[0] = 1.0 + d[-1] = 1.0 + + x = tridiagonal_solve_torch(a, b, c, d) + + # Verify by reconstructing Ax + result = torch.zeros(n, dtype=get_dtype(device), device=device) + result[0] = b[0] * x[0] + c[0] * x[1] + for i in range(1, n - 1): + result[i] = a[i - 1] * x[i - 1] + b[i] * x[i] + c[i] * x[i + 1] + result[-1] = a[-1] * x[-2] + b[-1] * x[-1] + + np.testing.assert_array_almost_equal(result.cpu().numpy(), d.cpu().numpy(), decimal=5) + + @pytest.mark.parametrize("device", DEVICES) + def test_two_element_system(self, device): + """Test minimal 2x2 tridiagonal system.""" + a = torch.tensor([1.0], dtype=get_dtype(device), device=device) + b = torch.tensor([4.0, 4.0], dtype=get_dtype(device), device=device) + c = torch.tensor([1.0], dtype=get_dtype(device), device=device) + d = torch.tensor([5.0, 5.0], dtype=get_dtype(device), device=device) + + x = tridiagonal_solve_torch(a, b, c, d) + + result = torch.tensor([ + b[0] * x[0] + c[0] * x[1], + a[0] * x[0] + b[1] * x[1] + ], device=device) + np.testing.assert_array_almost_equal(result.cpu().numpy(), d.cpu().numpy(), decimal=5) + + +class TestLeapfrogIntegrationTorch: + """Tests for the PyTorch leapfrog_integration function.""" + + @pytest.mark.parametrize("device", DEVICES) + def test_single_stationary_particle(self, device): + """A single particle with no velocity should remain stationary.""" + positions = torch.tensor([[0.0, 0.0, 0.0]], dtype=get_dtype(device), device=device) + velocities = torch.tensor([[0.0, 0.0, 0.0]], dtype=get_dtype(device), device=device) + masses = torch.tensor([1.0], dtype=get_dtype(device), device=device) + + final_pos, final_vel = leapfrog_integration_torch( + positions, velocities, masses, dt=0.01, n_steps=100 + ) + + np.testing.assert_array_almost_equal(final_pos.cpu().numpy(), positions.cpu().numpy(), decimal=5) + np.testing.assert_array_almost_equal(final_vel.cpu().numpy(), velocities.cpu().numpy(), decimal=5) + + @pytest.mark.parametrize("device", DEVICES) + def test_single_moving_particle(self, device): + """A single moving particle should move in a straight line.""" + positions = torch.tensor([[0.0, 0.0, 0.0]], dtype=get_dtype(device), device=device) + velocities = torch.tensor([[1.0, 0.0, 0.0]], dtype=get_dtype(device), device=device) + masses = torch.tensor([1.0], dtype=get_dtype(device), device=device) + + dt = 0.01 + n_steps = 100 + + final_pos, final_vel = leapfrog_integration_torch( + positions, velocities, masses, dt=dt, n_steps=n_steps + ) + + np.testing.assert_array_almost_equal(final_vel.cpu().numpy(), velocities.cpu().numpy(), decimal=5) + expected_pos = torch.tensor([[dt * n_steps, 0.0, 0.0]]) + np.testing.assert_array_almost_equal(final_pos.cpu().numpy(), expected_pos.numpy(), decimal=5) + + @pytest.mark.parametrize("device", DEVICES) + def test_two_particles_approach(self, device): + """Two particles should attract each other gravitationally.""" + positions = torch.tensor([[-1.0, 0.0, 0.0], [1.0, 0.0, 0.0]], dtype=get_dtype(device), device=device) + velocities = torch.zeros((2, 3), dtype=get_dtype(device), device=device) + masses = torch.tensor([1.0, 1.0], dtype=get_dtype(device), device=device) + + final_pos, _ = leapfrog_integration_torch( + positions, velocities, masses, dt=0.01, n_steps=50, softening=0.1 + ) + + initial_distance = 2.0 + final_distance = torch.linalg.norm(final_pos[1] - final_pos[0]).item() + assert final_distance < initial_distance + + @pytest.mark.parametrize("device", DEVICES) + def test_momentum_conservation(self, device): + """Total momentum should be approximately conserved.""" + np.random.seed(42) + n_particles = 5 + positions = torch.tensor(np.random.randn(n_particles, 3), dtype=get_dtype(device), device=device) + velocities = torch.tensor(np.random.randn(n_particles, 3), dtype=get_dtype(device), device=device) + masses = torch.tensor(np.abs(np.random.randn(n_particles)) + 0.1, dtype=get_dtype(device), device=device) + + initial_momentum = torch.sum(masses[:, None] * velocities, dim=0) + + final_pos, final_vel = leapfrog_integration_torch( + positions, velocities, masses, dt=0.001, n_steps=100, softening=0.5 + ) + + final_momentum = torch.sum(masses[:, None] * final_vel, dim=0) + + np.testing.assert_array_almost_equal( + initial_momentum.cpu().numpy(), final_momentum.cpu().numpy(), decimal=4 + ) + + @pytest.mark.parametrize("device", DEVICES) + def test_does_not_modify_input(self, device): + """Input arrays should not be modified.""" + positions = torch.tensor([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], dtype=get_dtype(device), device=device) + velocities = torch.tensor([[0.1, 0.0, 0.0], [-0.1, 0.0, 0.0]], dtype=get_dtype(device), device=device) + masses = torch.tensor([1.0, 1.0], dtype=get_dtype(device), device=device) + + pos_copy = positions.clone() + vel_copy = velocities.clone() + + leapfrog_integration_torch(positions, velocities, masses, dt=0.01, n_steps=10) + + np.testing.assert_array_equal(positions.cpu().numpy(), pos_copy.cpu().numpy()) + np.testing.assert_array_equal(velocities.cpu().numpy(), vel_copy.cpu().numpy()) + + +class TestLongestIncreasingSubsequenceLengthTorch: + """Tests for the PyTorch longest_increasing_subsequence_length function.""" + + @pytest.mark.parametrize("device", DEVICES) + def test_empty_array(self, device): + """Empty array should return 0.""" + arr = torch.tensor([], dtype=get_dtype(device), device=device) + assert longest_increasing_subsequence_length_torch(arr) == 0 + + @pytest.mark.parametrize("device", DEVICES) + def test_single_element(self, device): + """Single element array should return 1.""" + arr = torch.tensor([5.0], dtype=get_dtype(device), device=device) + assert longest_increasing_subsequence_length_torch(arr) == 1 + + @pytest.mark.parametrize("device", DEVICES) + def test_strictly_increasing(self, device): + """Strictly increasing array - LIS is the whole array.""" + arr = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=get_dtype(device), device=device) + assert longest_increasing_subsequence_length_torch(arr) == 5 + + @pytest.mark.parametrize("device", DEVICES) + def test_strictly_decreasing(self, device): + """Strictly decreasing array - LIS is length 1.""" + arr = torch.tensor([5.0, 4.0, 3.0, 2.0, 1.0], dtype=get_dtype(device), device=device) + assert longest_increasing_subsequence_length_torch(arr) == 1 + + @pytest.mark.parametrize("device", DEVICES) + def test_classic_example(self, device): + """Classic LIS example.""" + arr = torch.tensor([10.0, 9.0, 2.0, 5.0, 3.0, 7.0, 101.0, 18.0], dtype=get_dtype(device), device=device) + assert longest_increasing_subsequence_length_torch(arr) == 4 + + @pytest.mark.parametrize("device", DEVICES) + def test_all_same_elements(self, device): + """All same elements - LIS is length 1.""" + arr = torch.tensor([5.0, 5.0, 5.0, 5.0, 5.0], dtype=get_dtype(device), device=device) + assert longest_increasing_subsequence_length_torch(arr) == 1 + + @pytest.mark.parametrize("device", DEVICES) + def test_alternating_sequence(self, device): + """Alternating high-low sequence.""" + arr = torch.tensor([1.0, 10.0, 2.0, 9.0, 3.0, 8.0, 4.0, 7.0], dtype=get_dtype(device), device=device) + assert longest_increasing_subsequence_length_torch(arr) == 5 + + @pytest.mark.parametrize("device", DEVICES) + def test_two_elements_increasing(self, device): + """Two elements in increasing order.""" + arr = torch.tensor([1.0, 2.0], dtype=get_dtype(device), device=device) + assert longest_increasing_subsequence_length_torch(arr) == 2 + + @pytest.mark.parametrize("device", DEVICES) + def test_two_elements_decreasing(self, device): + """Two elements in decreasing order.""" + arr = torch.tensor([2.0, 1.0], dtype=get_dtype(device), device=device) + assert longest_increasing_subsequence_length_torch(arr) == 1 + + @pytest.mark.parametrize("device", DEVICES) + def test_longer_sequence(self, device): + """Test with a longer sequence.""" + arr = torch.tensor([0.0, 8.0, 4.0, 12.0, 2.0, 10.0, 6.0, 14.0, 1.0, 9.0, 5.0, 13.0, 3.0, 11.0, 7.0, 15.0], dtype=get_dtype(device), device=device) + assert longest_increasing_subsequence_length_torch(arr) == 6 + + @pytest.mark.parametrize("device", DEVICES) + def test_negative_numbers(self, device): + """Test with negative numbers.""" + arr = torch.tensor([-5.0, -2.0, -8.0, -1.0, -6.0, 0.0], dtype=get_dtype(device), device=device) + assert longest_increasing_subsequence_length_torch(arr) == 4 + + @pytest.mark.parametrize("device", DEVICES) + def test_float_values(self, device): + """Test with floating point values.""" + arr = torch.tensor([1.5, 2.3, 1.8, 3.1, 2.9, 4.0], dtype=get_dtype(device), device=device) + assert longest_increasing_subsequence_length_torch(arr) == 4 \ No newline at end of file diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index 9eb1906b4..e4ed074fd 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -129,6 +129,7 @@ def optimize_python_code( # noqa: D417 *, is_async: bool = False, n_candidates: int = 5, + is_numerical_code: bool | None = None, ) -> list[OptimizedCandidate]: """Optimize the given python code for performance by making a request to the Django endpoint. @@ -164,6 +165,7 @@ def optimize_python_code( # noqa: D417 "is_async": is_async, "call_sequence": self.get_next_sequence(), "n_candidates": n_candidates, + "is_numerical_code": is_numerical_code, } logger.debug(f"Sending optimize request: trace_id={trace_id}, n_candidates={payload['n_candidates']}") @@ -191,6 +193,58 @@ def optimize_python_code( # noqa: D417 console.rule() return [] + def get_jit_rewritten_code( # noqa: D417 + self, source_code: str, trace_id: str + ) -> list[OptimizedCandidate]: + """Rewrite the given python code for performance via jit compilation by making a request to the Django endpoint. + + Parameters + ---------- + - source_code (str): The python code to optimize. + - trace_id (str): Trace id of optimization run + + Returns + ------- + - List[OptimizationCandidate]: A list of Optimization Candidates. + + """ + start_time = time.perf_counter() + git_repo_owner, git_repo_name = safe_get_repo_owner_and_name() + + payload = { + "source_code": source_code, + "trace_id": trace_id, + "dependency_code": "", # dummy value to please the api endpoint + "python_version": "3.12.1", # dummy value to please the api endpoint + "current_username": get_last_commit_author_if_pr_exists(None), + "repo_owner": git_repo_owner, + "repo_name": git_repo_name, + } + + logger.info("!lsp|Rewriting as a JIT function…") + console.rule() + try: + response = self.make_ai_service_request("/rewrite_jit", payload=payload, timeout=60) + except requests.exceptions.RequestException as e: + logger.exception(f"Error generating jit rewritten candidate: {e}") + ph("cli-jit-rewrite-error-caught", {"error": str(e)}) + return [] + + if response.status_code == 200: + optimizations_json = response.json()["optimizations"] + console.rule() + end_time = time.perf_counter() + logger.debug(f"!lsp|Generating jit rewritten code took {end_time - start_time:.2f} seconds.") + return self._get_valid_candidates(optimizations_json, OptimizedCandidateSource.JIT_REWRITE) + try: + error = response.json()["error"] + except Exception: + error = response.text + logger.error(f"Error generating jit rewritten candidate: {response.status_code} - {error}") + ph("cli-jit-rewrite-error-response", {"response_status_code": response.status_code, "error": error}) + console.rule() + return [] + def optimize_python_code_line_profiler( # noqa: D417 self, source_code: str, @@ -199,6 +253,7 @@ def optimize_python_code_line_profiler( # noqa: D417 line_profiler_results: str, n_candidates: int, experiment_metadata: ExperimentMetadata | None = None, + is_numerical_code: bool | None = None, # noqa: FBT001 ) -> list[OptimizedCandidate]: """Optimize the given python code for performance using line profiler results. @@ -233,6 +288,7 @@ def optimize_python_code_line_profiler( # noqa: D417 "experiment_metadata": experiment_metadata, "codeflash_version": codeflash_version, "call_sequence": self.get_next_sequence(), + "is_numerical_code": is_numerical_code, } try: @@ -578,6 +634,7 @@ def generate_regression_tests( # noqa: D417 test_timeout: int, trace_id: str, test_index: int, + is_numerical_code: bool | None = None, # noqa: FBT001 ) -> tuple[str, str, str] | None: """Generate regression tests for the given function by making a request to the Django endpoint. @@ -614,6 +671,7 @@ def generate_regression_tests( # noqa: D417 "codeflash_version": codeflash_version, "is_async": function_to_optimize.is_async, "call_sequence": self.get_next_sequence(), + "is_numerical_code": is_numerical_code, } try: response = self.make_ai_service_request("/testgen", payload=payload, timeout=self.timeout) diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index 935d0a369..66dfd5eb4 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -1298,7 +1298,7 @@ def _find_function_node(tree: ast.Module, name_parts: list[str]) -> ast.Function return None -def is_numerical_code(code_string: str, function_name: str) -> bool: +def is_numerical_code(code_string: str, function_name: str | None = None) -> bool: """Check if a function uses numerical computing libraries. Detects usage of numpy, torch, numba, jax, tensorflow, scipy, and math libraries @@ -1339,6 +1339,13 @@ def is_numerical_code(code_string: str, function_name: str) -> bool: except SyntaxError: return False + # Collect names that reference numerical modules from imports + numerical_names, modules_used = _collect_numerical_imports(tree) + + if not function_name: + # Return True if modules used and (numba available or modules don't all require numba) + return bool(modules_used) and (has_numba or not modules_used.issubset(NUMBA_REQUIRED_MODULES)) + # Split the function name to handle class methods name_parts = function_name.split(".") @@ -1347,9 +1354,6 @@ def is_numerical_code(code_string: str, function_name: str) -> bool: if target_function is None: return False - # Collect names that reference numerical modules from imports - numerical_names, modules_used = _collect_numerical_imports(tree) - # Check if the function body uses any numerical library checker = NumericalUsageChecker(numerical_names) checker.visit(target_function) diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 133299589..d850e3827 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -488,6 +488,7 @@ class OptimizedCandidateSource(str, Enum): REFINE = "REFINE" REPAIR = "REPAIR" ADAPTIVE = "ADAPTIVE" + JIT_REWRITE = "JIT_REWRITE" @dataclass(frozen=True) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 15d4d8b42..761b8ea0c 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -13,6 +13,7 @@ from typing import TYPE_CHECKING, Callable import libcst as cst +import sentry_sdk from rich.console import Group from rich.panel import Panel from rich.syntax import Syntax @@ -24,7 +25,7 @@ from codeflash.benchmarking.utils import process_benchmark_data from codeflash.cli_cmds.console import code_print, console, logger, lsp_log, progress_bar from codeflash.code_utils import env_utils -from codeflash.code_utils.code_extractor import get_opt_review_metrics +from codeflash.code_utils.code_extractor import get_opt_review_metrics, is_numerical_code from codeflash.code_utils.code_replacer import ( add_custom_marker_to_all_tests, modify_autouse_fixture, @@ -465,6 +466,7 @@ def __init__( self.future_adaptive_optimizations: list[concurrent.futures.Future] = [] self.repair_counter = 0 # track how many repairs we did for each function self.adaptive_optimization_counter = 0 # track how many adaptive optimizations we did for each function + self.is_numerical_code: bool | None = None def can_be_optimized(self) -> Result[tuple[bool, CodeOptimizationContext, dict[Path, str]], str]: should_run_experiment = self.experiment_id is not None @@ -587,7 +589,7 @@ def optimize_function(self) -> Result[BestOptimization, str]: if not is_successful(initialization_result): return Failure(initialization_result.failure()) should_run_experiment, code_context, original_helper_code = initialization_result.unwrap() - + self.is_numerical_code = is_numerical_code(code_string=code_context.read_writable_code.flat) code_print( code_context.read_writable_code.flat, file_name=self.function_to_optimize.file_path, @@ -600,13 +602,37 @@ def optimize_function(self) -> Result[BestOptimization, str]: revert_to_print=bool(get_pr_number()), ): console.rule() + new_code_context = code_context + if self.is_numerical_code: # if the code is numerical in nature (uses numpy/tensorflow/math/pytorch/jax) + jit_compiled_opt_candidate = self.aiservice_client.get_jit_rewritten_code( + code_context.read_writable_code.markdown, self.function_trace_id + ) + if jit_compiled_opt_candidate: # jit rewrite was successful + # write files + # Try to replace function with optimized code + self.replace_function_and_helpers_with_optimized_code( + code_context=code_context, + optimized_code=jit_compiled_opt_candidate[0].source_code, + original_helper_code=original_helper_code, + ) + # get code context + try: + new_code_context = self.get_code_optimization_context().unwrap() + except Exception as e: + sentry_sdk.capture_exception(e) + logger.debug("!lsp|Getting new code context failed, revert to original one") + # unwrite files + self.write_code_and_helpers( + self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path + ) # Generate tests and optimizations in parallel - future_tests = self.executor.submit(self.generate_and_instrument_tests, code_context) + future_tests = self.executor.submit(self.generate_and_instrument_tests, new_code_context) future_optimizations = self.executor.submit( self.generate_optimizations, read_writable_code=code_context.read_writable_code, read_only_context_code=code_context.read_only_context_code, run_experiment=should_run_experiment, + is_numerical_code=self.is_numerical_code, ) concurrent.futures.wait([future_tests, future_optimizations]) @@ -1106,6 +1132,7 @@ def determine_best_candidate( ) if self.experiment_id else None, + is_numerical_code=self.is_numerical_code, ) processor = CandidateProcessor( @@ -1573,6 +1600,7 @@ def generate_optimizations( read_writable_code: CodeStringsMarkdown, read_only_context_code: str, run_experiment: bool = False, # noqa: FBT001, FBT002 + is_numerical_code: bool | None = None, # noqa: FBT001 ) -> Result[tuple[OptimizationSet, str], str]: """Generate optimization candidates for the function. Backend handles multi-model diversity.""" n_candidates = get_effort_value(EffortKeys.N_OPTIMIZER_CANDIDATES, self.effort) @@ -1584,6 +1612,7 @@ def generate_optimizations( ExperimentMetadata(id=self.experiment_id, group="control") if run_experiment else None, is_async=self.function_to_optimize.is_async, n_candidates=n_candidates, + is_numerical_code=is_numerical_code, ) future_references = self.executor.submit( @@ -2461,6 +2490,7 @@ def submit_test_generation_tasks( test_index, test_path, test_perf_path, + self.is_numerical_code, ) for test_index, (test_path, test_perf_path) in enumerate( zip(generated_test_paths, generated_perf_test_paths) diff --git a/codeflash/verification/verifier.py b/codeflash/verification/verifier.py index 8d187f2b1..f60718020 100644 --- a/codeflash/verification/verifier.py +++ b/codeflash/verification/verifier.py @@ -27,6 +27,7 @@ def generate_tests( test_index: int, test_path: Path, test_perf_path: Path, + is_numerical_code: bool | None = None, # noqa: FBT001 ) -> tuple[str, str, Path] | None: # TODO: Sometimes this recreates the original Class definition. This overrides and messes up the original # class import. Remove the recreation of the class definition @@ -42,6 +43,7 @@ def generate_tests( test_timeout=test_timeout, trace_id=function_trace_id, test_index=test_index, + is_numerical_code=is_numerical_code, ) if response and isinstance(response, tuple) and len(response) == 3: generated_test_source, instrumented_behavior_test_source, instrumented_perf_test_source = response diff --git a/tests/test_is_numerical_code.py b/tests/test_is_numerical_code.py index 97b4b304e..5fedce8d1 100644 --- a/tests/test_is_numerical_code.py +++ b/tests/test_is_numerical_code.py @@ -691,6 +691,203 @@ def method(self): assert is_numerical_code(code, "ClassB.method") is False +@patch("codeflash.code_utils.code_extractor.has_numba", True) +class TestEmptyFunctionName: + """Test behavior when function_name is empty/None. + + When function_name is not provided, the function should just check for the + presence of numerical imports without looking at a specific function body. + """ + + def test_empty_string_with_numpy_import(self): + """Empty function_name with numpy import should return True.""" + code = """ +import numpy as np +def some_func(): + pass +""" + assert is_numerical_code(code, "") is True + + def test_none_with_numpy_import(self): + """None function_name with numpy import should return True.""" + code = """ +import numpy as np +def some_func(): + pass +""" + assert is_numerical_code(code, None) is True + + def test_empty_string_with_torch_import(self): + """Empty function_name with torch import should return True.""" + code = """ +import torch +def some_func(): + pass +""" + assert is_numerical_code(code, "") is True + + def test_empty_string_with_multiple_numerical_imports(self): + """Empty function_name with multiple numerical imports should return True.""" + code = """ +import numpy as np +import torch +from scipy import stats +def some_func(): + pass +""" + assert is_numerical_code(code, "") is True + + def test_empty_string_without_numerical_imports(self): + """Empty function_name without numerical imports should return False.""" + code = """ +import os +import json +from pathlib import Path + +def some_func(): + pass +""" + assert is_numerical_code(code, "") is False + + def test_none_without_numerical_imports(self): + """None function_name without numerical imports should return False.""" + code = """ +import os +def some_func(): + pass +""" + assert is_numerical_code(code, None) is False + + def test_empty_string_with_jax_import(self): + """Empty function_name with jax import should return True.""" + code = """ +import jax +import jax.numpy as jnp +""" + assert is_numerical_code(code, "") is True + + def test_empty_string_with_tensorflow_import(self): + """Empty function_name with tensorflow import should return True.""" + code = """ +import tensorflow as tf +""" + assert is_numerical_code(code, "") is True + + def test_empty_string_with_math_import(self): + """Empty function_name with math import should return True (numba available).""" + code = """ +import math +def calculate(x): + return math.sqrt(x) +""" + assert is_numerical_code(code, "") is True + + def test_empty_string_with_scipy_submodule(self): + """Empty function_name with scipy submodule import should return True.""" + code = """ +from scipy.stats import norm +""" + assert is_numerical_code(code, "") is True + + def test_empty_string_with_numba_import(self): + """Empty function_name with numba import should return True.""" + code = """ +from numba import jit +""" + assert is_numerical_code(code, "") is True + + def test_empty_code_with_empty_function_name(self): + """Empty code with empty function_name should return False.""" + assert is_numerical_code("", "") is False + + def test_syntax_error_with_empty_function_name(self): + """Syntax error code with empty function_name should return False.""" + code = """ +def broken( + import numpy +""" + assert is_numerical_code(code, "") is False + + +@patch("codeflash.code_utils.code_extractor.has_numba", False) +class TestEmptyFunctionNameWithoutNumba: + """Test empty function_name behavior when numba is NOT available. + + When numba is not installed, code using only math/numpy/scipy should return False, + since numba is required to optimize such code. Code using torch/jax/tensorflow/numba + should still return True. + """ + + def test_empty_string_numpy_returns_false_without_numba(self): + """Empty function_name with numpy should return False when numba unavailable.""" + code = """ +import numpy as np +def some_func(): + pass +""" + assert is_numerical_code(code, "") is False + + def test_empty_string_math_returns_false_without_numba(self): + """Empty function_name with math should return False when numba unavailable.""" + code = """ +import math +""" + assert is_numerical_code(code, "") is False + + def test_empty_string_scipy_returns_false_without_numba(self): + """Empty function_name with scipy should return False when numba unavailable.""" + code = """ +from scipy import stats +""" + assert is_numerical_code(code, "") is False + + def test_empty_string_torch_returns_true_without_numba(self): + """Empty function_name with torch should return True even without numba.""" + code = """ +import torch +""" + assert is_numerical_code(code, "") is True + + def test_empty_string_jax_returns_true_without_numba(self): + """Empty function_name with jax should return True even without numba.""" + code = """ +import jax +""" + assert is_numerical_code(code, "") is True + + def test_empty_string_tensorflow_returns_true_without_numba(self): + """Empty function_name with tensorflow should return True even without numba.""" + code = """ +import tensorflow as tf +""" + assert is_numerical_code(code, "") is True + + def test_empty_string_numba_import_returns_true_without_numba(self): + """Empty function_name with numba import should return True.""" + code = """ +from numba import jit +""" + assert is_numerical_code(code, "") is True + + def test_empty_string_numpy_and_torch_returns_true_without_numba(self): + """Empty function_name with numpy+torch should return True (torch doesn't need numba).""" + code = """ +import numpy as np +import torch +""" + # Returns True because torch is in modules_used and doesn't require numba + assert is_numerical_code(code, "") is True + + def test_empty_string_math_and_scipy_returns_false_without_numba(self): + """Empty function_name with only math+scipy should return False without numba.""" + code = """ +import math +from scipy import stats +""" + # Both math and scipy are in NUMBA_REQUIRED_MODULES + assert is_numerical_code(code, "") is False + + @patch("codeflash.code_utils.code_extractor.has_numba", False) class TestNumbaNotAvailable: """Test behavior when numba is NOT available in the environment.