diff --git a/code_to_optimize/sample_code.py b/code_to_optimize/sample_code.py index d356ce807..09b771754 100644 --- a/code_to_optimize/sample_code.py +++ b/code_to_optimize/sample_code.py @@ -116,7 +116,8 @@ def longest_increasing_subsequence_length(arr: np.ndarray) -> int: def _tridiagonal_forward_step_jax(carry, inputs): c_prev, d_prev = carry a_i, b_i, c_i, d_i = inputs - denom = b_i - a_i * c_prev + a_c = a_i * c_prev + denom = b_i - a_c c_new = c_i / denom d_new = (d_i - a_i * d_prev) / denom return (c_new, d_new), (c_new, d_new)