⚡️ Speed up function tridiagonal_solve_jax by 1,995%
#1074
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
📄 1,995% (19.95x) speedup for
tridiagonal_solve_jaxincode_to_optimize/sample_code.py⏱️ Runtime :
42.9 milliseconds→2.05 milliseconds(best of5runs)📝 Explanation and details
The optimized code achieves a 20x speedup (1995%) by applying JAX's Just-In-Time (JIT) compilation via the
@jitdecorator totridiagonal_solve_jax.Key optimization:
@jitdecorator compiles the entire function into optimized XLA (Accelerated Linear Algebra) machine code on first execution, then reuses this compiled version for subsequent calls.Why this is dramatically faster:
Eliminates Python interpreter overhead: Without JIT, every array operation and function call goes through Python's interpreter. The line profiler shows that even simple operations like
c[0] / b[0]took ~1.17 seconds across 31 calls (~38ms each). JIT removes this overhead entirely by compiling to native code.Operator fusion: JAX can fuse multiple array operations into single kernels. For example, operations like
(d_i - a_i * d_prev) / denombecome a single fused kernel instead of separate subtract, multiply, and divide operations, reducing memory bandwidth requirements.Optimizes the scan operations: The two
lax.scancalls (forward and backward passes) benefit significantly from JIT. The profiler shows these took ~515ms and ~348ms respectively in the original code. JIT enables better loop unrolling and vectorization within the scan kernels.Reduces array creation overhead: Multiple
jnp.concatenateandjnp.arraycalls (totaling ~644ms in the profiler) are optimized away or compiled into efficient buffer operations.Impact: This optimization is particularly effective for the tridiagonal solver because it's called repeatedly (31 times in the profile), and JIT compilation amortizes the one-time compilation cost across all invocations. The optimization benefits any workload that calls this function multiple times with similarly-shaped inputs, making it ideal for numerical PDE solvers, time-stepping algorithms, or batch processing scenarios.
✅ Correctness verification report:
⚙️ Click to see Existing Unit Tests
test_jax_jit_code.py::TestTridiagonalSolveJax.test_diagonal_systemtest_jax_jit_code.py::TestTridiagonalSolveJax.test_larger_systemtest_jax_jit_code.py::TestTridiagonalSolveJax.test_simple_systemTo edit these changes
git checkout codeflash/optimize-tridiagonal_solve_jax-mkggf6xmand push.