diff --git a/code_to_optimize/sample_code.py b/code_to_optimize/sample_code.py index d356ce807..b2f1d4040 100644 --- a/code_to_optimize/sample_code.py +++ b/code_to_optimize/sample_code.py @@ -225,6 +225,7 @@ def longest_increasing_subsequence_length_jax(arr): return int(jnp.max(dp)) +@torch.compile(mode="reduce-overhead") def tridiagonal_solve_torch(a, b, c, d): device = b.device dtype = b.dtype