diff --git a/code_to_optimize/sample_code.py b/code_to_optimize/sample_code.py index d356ce807..f13444358 100644 --- a/code_to_optimize/sample_code.py +++ b/code_to_optimize/sample_code.py @@ -4,7 +4,7 @@ import numpy as np import tensorflow as tf import torch -from jax import lax +from jax import jit, lax def tridiagonal_solve(a: np.ndarray, b: np.ndarray, c: np.ndarray, d: np.ndarray) -> np.ndarray: @@ -128,6 +128,7 @@ def _tridiagonal_back_step_jax(x_next, inputs): return x_i, x_i +@jit def tridiagonal_solve_jax(a, b, c, d): n = b.shape[0]