diff --git a/code_to_optimize/sample_code.py b/code_to_optimize/sample_code.py index d356ce807..5d15d4e7d 100644 --- a/code_to_optimize/sample_code.py +++ b/code_to_optimize/sample_code.py @@ -1,6 +1,7 @@ from functools import partial import jax.numpy as jnp +import numba import numpy as np import tensorflow as tf import torch @@ -36,6 +37,7 @@ def tridiagonal_solve(a: np.ndarray, b: np.ndarray, c: np.ndarray, d: np.ndarray return x +@numba.njit(cache=True) def leapfrog_integration( positions: np.ndarray, velocities: np.ndarray,