⚡️ Speed up function leapfrog_integration_jax by 19,337%
#1075
+2
−1
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
📄 19,337% (193.37x) speedup for
leapfrog_integration_jaxincode_to_optimize/sample_code.py⏱️ Runtime :
4.30 seconds→22.1 milliseconds(best of5runs)📝 Explanation and details
The optimized code achieves a massive 19,336% speedup (from 4.30 seconds to 22.1 milliseconds) by adding JIT (Just-In-Time) compilation to the
leapfrog_integration_jaxfunction.Key Optimization: JIT Compilation
The core change is adding
@partial(jit, static_argnums=(4,))decorator toleapfrog_integration_jax. This triggers JAX's XLA compiler to:lax.scanloop, eliminating intermediate array allocationsThe
static_argnums=(4,)parameter tells JAX thatn_stepsis a compile-time constant, allowing the compiler to unroll or optimize the scan loop structure based on the known length.Why This Creates Such a Large Speedup
Looking at the line profiler results, the original code spends 100% of execution time (6.94 seconds) in the
lax.scancall. Without JIT:With JIT compilation:
Test Results Analysis
The annotated tests show consistent speedups across all scenarios:
The optimization is particularly effective for:
Impact Considerations
Since function references are not available, this optimization would be most beneficial if
leapfrog_integration_jaxis:The optimization maintains identical numerical behavior and all test correctness, making it a safe drop-in replacement.
✅ Correctness verification report:
⚙️ Click to see Existing Unit Tests
test_jax_jit_code.py::TestLeapfrogIntegrationJax.test_momentum_conservationtest_jax_jit_code.py::TestLeapfrogIntegrationJax.test_single_moving_particletest_jax_jit_code.py::TestLeapfrogIntegrationJax.test_single_stationary_particletest_jax_jit_code.py::TestLeapfrogIntegrationJax.test_two_particles_approach🌀 Click to see Generated Regression Tests
To edit these changes
git checkout codeflash/optimize-leapfrog_integration_jax-mkghvhl2and push.