Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions code_to_optimize/sample_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,13 +308,20 @@ def _tridiagonal_forward_cond_tf(i, _c_prime, _d_prime, n, _a, _b, _c, _d):


def _tridiagonal_forward_body_tf(i, c_prime, d_prime, n, a, b, c, d):
c_prev = c_prime[i - 1]
d_prev = d_prime[i - 1]
denom = b[i] - a[i - 1] * c_prev
c_val = c[i] / denom
d_val = (d[i] - a[i - 1] * d_prev) / denom
c_prime = tf.tensor_scatter_nd_update(c_prime, tf.reshape(i, [1, 1]), tf.reshape(c_val, [1]))
d_prime = tf.tensor_scatter_nd_update(d_prime, tf.reshape(i, [1, 1]), tf.reshape(d_val, [1]))
im1 = i - 1
# Read previous values explicitly
c_prev = tf.gather(c_prime, im1)
d_prev = tf.gather(d_prime, im1)
# Compute denominator and new scalar values
denom = tf.gather(b, i) - tf.gather(a, im1) * c_prev
c_val = tf.gather(c, i) / denom
d_val = (tf.gather(d, i) - tf.gather(a, im1) * d_prev) / denom
# Update the i-th entry using a one-hot mask instead of tensor_scatter_nd_update.
# This avoids creating and reshaping index tensors and reduces graph/node overhead.
mask = tf.one_hot(i, n, dtype=c_prime.dtype)
inv_mask = 1 - mask
c_prime = c_prime * inv_mask + mask * c_val
d_prime = d_prime * inv_mask + mask * d_val
return i + 1, c_prime, d_prime, n, a, b, c, d


Expand Down
Loading