diff --git a/code_to_optimize/sample_code.py b/code_to_optimize/sample_code.py index d356ce807..19f4beab4 100644 --- a/code_to_optimize/sample_code.py +++ b/code_to_optimize/sample_code.py @@ -4,7 +4,7 @@ import numpy as np import tensorflow as tf import torch -from jax import lax +from jax import jit, lax def tridiagonal_solve(a: np.ndarray, b: np.ndarray, c: np.ndarray, d: np.ndarray) -> np.ndarray: @@ -206,6 +206,7 @@ def _lis_inner_body_jax(j, dp_inner, arr, i): return dp_inner.at[i].set(new_val) +@jit def _lis_outer_body_jax(i, dp, arr): inner_fn = partial(_lis_inner_body_jax, arr=arr, i=i) dp = lax.fori_loop(0, i, inner_fn, dp)