⚡️ Speed up function _lis_outer_body_jax by 6,122%
#1077
+2
−1
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
📄 6,122% (61.22x) speedup for
_lis_outer_body_jaxincode_to_optimize/sample_code.py⏱️ Runtime :
4.09 seconds→65.8 milliseconds(best of14runs)📝 Explanation and details
The optimized code achieves a 62x speedup (6121%) by adding
@jitdecorators to both_lis_inner_body_jaxand_lis_outer_body_jaxfunctions. This simple change enables JAX's Just-In-Time compilation, which fundamentally transforms how the code executes.What changed:
@jitdecorator to both functionsjitto the imports fromjaxWhy this makes the code faster:
Eliminates Python interpreter overhead: Without JIT, each array operation (
arr[j],dp_inner[j], comparisons, etc.) triggers Python function calls and type checks. The line profiler shows the original_lis_inner_body_jaxspent 0.537s on just 42 iterations. With JIT, these operations are compiled once into optimized machine code.Enables operation fusion: JAX's compiler can fuse the sequence of operations in
_lis_inner_body_jax(comparison → logical AND → jnp.where → array update) into a single optimized kernel, eliminating intermediate array allocations and memory transfers.Optimizes the hot loop: The original line profiler shows
lax.fori_looptaking 5.52s (100% of_lis_outer_body_jaxtime). With JIT, JAX optimizes the entire loop body, including the partial function application, into efficient compiled code that runs directly on the accelerator (GPU/TPU) or CPU without Python overhead.Amortizes compilation cost: The first call compiles the function (visible in the ~20-110ms range for first calls in tests), but subsequent calls with same-shaped inputs reuse the compiled version. This is why tests show speedups from 1034% (large arrays) to 48000% (small arrays) - smaller inputs benefit more from eliminating per-call overhead.
Performance characteristics based on test results:
Impact on workloads:
Since this appears to be implementing a longest increasing subsequence (LIS) dynamic programming algorithm, the optimization would be particularly beneficial for:
✅ Correctness verification report:
🌀 Click to see Generated Regression Tests
To edit these changes
git checkout codeflash/optimize-_lis_outer_body_jax-mkgj0d5tand push.