diff --git a/code_to_optimize/sample_code.py b/code_to_optimize/sample_code.py index d356ce807..5d31dc139 100644 --- a/code_to_optimize/sample_code.py +++ b/code_to_optimize/sample_code.py @@ -333,32 +333,62 @@ def tridiagonal_solve_tf(a, b, c, d): n = tf.shape(b)[0] dtype = b.dtype - c_prime = tf.zeros([n - 1], dtype=dtype) - d_prime = tf.zeros([n], dtype=dtype) + c0 = c[0] / b[0] + d0 = d[0] / b[0] + + def _forward_scan(): + a_prev = a[: n - 2] + b_mid = b[1 : n - 1] + c_mid = c[1 : n - 1] + d_mid = d[1 : n - 1] + + def fn(acc, elems): + c_prev, d_prev = acc + a_e, b_e, c_e, d_e = elems + denom = b_e - a_e * c_prev + c_new = c_e / denom + d_new = (d_e - a_e * d_prev) / denom + return (c_new, d_new) + + c_vals, d_vals = tf.scan(fn, (a_prev, b_mid, c_mid, d_mid), initializer=(c0, d0)) + c_prime = tf.concat([tf.reshape(c0, [1]), c_vals], axis=0) + d_prime_partial = tf.concat([tf.reshape(d0, [1]), d_vals], axis=0) + return c_prime, d_prime_partial + + def _forward_simple(): + return tf.reshape(c0, [1]), tf.reshape(d0, [1]) + + c_prime, d_prime_partial = tf.cond(tf.greater(n, 2), _forward_scan, _forward_simple) + + c_last = c_prime[-1] + d_prev = d_prime_partial[-1] + denom = b[n - 1] - a[n - 2] * c_last + d_last = (d[n - 1] - a[n - 2] * d_prev) / denom + d_prime = tf.concat([d_prime_partial, tf.reshape(d_last, [1])], axis=0) - c_prime = tf.tensor_scatter_nd_update(c_prime, [[0]], tf.reshape(c[0] / b[0], [1])) - d_prime = tf.tensor_scatter_nd_update(d_prime, [[0]], tf.reshape(d[0] / b[0], [1])) + x_last = d_last - _, c_prime, d_prime, _, _, _, _, _ = tf.while_loop( - _tridiagonal_forward_cond_tf, - _tridiagonal_forward_body_tf, - [1, c_prime, d_prime, n, a, b, c, d] - ) + def _back_scan(): + c_rev = tf.reverse(c_prime, axis=[0]) + d_rev = tf.reverse(d_prime[:-1], axis=[0]) - c_last = c_prime[n - 2] - d_prev = d_prime[n - 2] - denom = b[n - 1] - a[n - 2] * c_last - d_last = (d[n - 1] - a[n - 2] * d_prev) / denom - d_prime = tf.tensor_scatter_nd_update(d_prime, tf.reshape(n - 1, [1, 1]), tf.reshape(d_last, [1])) + def fn(x_next, elems): + c_e, d_e = elems + x_i = d_e - c_e * x_next + return x_i - x = tf.zeros([n], dtype=dtype) - x = tf.tensor_scatter_nd_update(x, tf.reshape(n - 1, [1, 1]), tf.reshape(d_prime[n - 1], [1])) + x_seq = tf.scan(fn, (c_rev, d_rev), initializer=x_last) + x_rev = tf.reverse(x_seq, axis=[0]) + x = tf.concat([x_rev, tf.reshape(x_last, [1])], axis=0) + return x - _, x, _, _ = tf.while_loop( - _tridiagonal_back_cond_tf, - _tridiagonal_back_body_tf, - [n - 2, x, c_prime, d_prime] - ) + def _back_simple(): + def for_n_eq_2(): + x0 = d_prime_partial[0] - c_prime[0] * x_last + return tf.stack([x0, x_last]) + return tf.cond(tf.equal(n, 1), lambda: tf.reshape(d / b, [1]), for_n_eq_2) + + x = tf.cond(tf.greater(n, 1), _back_scan, _back_simple) return x