diff --git a/code_to_optimize/sample_code.py b/code_to_optimize/sample_code.py index d356ce807..54ade853a 100644 --- a/code_to_optimize/sample_code.py +++ b/code_to_optimize/sample_code.py @@ -4,7 +4,7 @@ import numpy as np import tensorflow as tf import torch -from jax import lax +from jax import jit, lax def tridiagonal_solve(a: np.ndarray, b: np.ndarray, c: np.ndarray, d: np.ndarray) -> np.ndarray: @@ -200,6 +200,7 @@ def leapfrog_integration_jax( return final_pos, final_vel +@jit def _lis_inner_body_jax(j, dp_inner, arr, i): condition = (arr[j] < arr[i]) & (dp_inner[j] + 1 > dp_inner[i]) new_val = jnp.where(condition, dp_inner[j] + 1, dp_inner[i])