diff --git a/code_to_optimize/sample_code.py b/code_to_optimize/sample_code.py index d356ce807..ca393dd0d 100644 --- a/code_to_optimize/sample_code.py +++ b/code_to_optimize/sample_code.py @@ -4,7 +4,7 @@ import numpy as np import tensorflow as tf import torch -from jax import lax +from jax import jit, lax def tridiagonal_solve(a: np.ndarray, b: np.ndarray, c: np.ndarray, d: np.ndarray) -> np.ndarray: @@ -187,6 +187,7 @@ def _leapfrog_step_jax(carry, _, masses, softening, dt): return (pos, vel), None +@partial(jit, static_argnums=(4,)) def leapfrog_integration_jax( positions, velocities,