diff --git a/code_to_optimize/sample_code.py b/code_to_optimize/sample_code.py index d356ce807..f567f1c82 100644 --- a/code_to_optimize/sample_code.py +++ b/code_to_optimize/sample_code.py @@ -363,6 +363,7 @@ def tridiagonal_solve_tf(a, b, c, d): return x +@tf.function(jit_compile=True) def _leapfrog_compute_accelerations_tf(pos, masses, softening, G): diff = tf.expand_dims(pos, 0) - tf.expand_dims(pos, 1) @@ -374,10 +375,11 @@ def _leapfrog_compute_accelerations_tf(pos, masses, softening, G): force_factor = G * tf.expand_dims(masses, 0) / dist_cubed - acc = tf.reduce_sum(tf.expand_dims(force_factor, -1) * diff, axis=1) + acc = tf.einsum('ij,ijk->ik', force_factor, diff) return acc +@tf.function(jit_compile=True) def _leapfrog_step_body_tf(i, pos, vel, masses, softening, dt, n_steps): G = 1.0 acc = _leapfrog_compute_accelerations_tf(pos, masses, softening, G)