diff --git a/code_to_optimize/sample_code.py b/code_to_optimize/sample_code.py index d356ce807..d651297af 100644 --- a/code_to_optimize/sample_code.py +++ b/code_to_optimize/sample_code.py @@ -308,13 +308,20 @@ def _tridiagonal_forward_cond_tf(i, _c_prime, _d_prime, n, _a, _b, _c, _d): def _tridiagonal_forward_body_tf(i, c_prime, d_prime, n, a, b, c, d): - c_prev = c_prime[i - 1] - d_prev = d_prime[i - 1] - denom = b[i] - a[i - 1] * c_prev - c_val = c[i] / denom - d_val = (d[i] - a[i - 1] * d_prev) / denom - c_prime = tf.tensor_scatter_nd_update(c_prime, tf.reshape(i, [1, 1]), tf.reshape(c_val, [1])) - d_prime = tf.tensor_scatter_nd_update(d_prime, tf.reshape(i, [1, 1]), tf.reshape(d_val, [1])) + im1 = i - 1 + # Read previous values explicitly + c_prev = tf.gather(c_prime, im1) + d_prev = tf.gather(d_prime, im1) + # Compute denominator and new scalar values + denom = tf.gather(b, i) - tf.gather(a, im1) * c_prev + c_val = tf.gather(c, i) / denom + d_val = (tf.gather(d, i) - tf.gather(a, im1) * d_prev) / denom + # Update the i-th entry using a one-hot mask instead of tensor_scatter_nd_update. + # This avoids creating and reshaping index tensors and reduces graph/node overhead. + mask = tf.one_hot(i, n, dtype=c_prime.dtype) + inv_mask = 1 - mask + c_prime = c_prime * inv_mask + mask * c_val + d_prime = d_prime * inv_mask + mask * d_val return i + 1, c_prime, d_prime, n, a, b, c, d