diff --git a/code_to_optimize/sample_code.py b/code_to_optimize/sample_code.py index d356ce807..7d01db327 100644 --- a/code_to_optimize/sample_code.py +++ b/code_to_optimize/sample_code.py @@ -427,12 +427,26 @@ def _lis_inner_cond_tf(j, _dp_inner, _arr, i): def _lis_outer_body_tf(i, dp, arr, n): - _, dp, _, _ = tf.while_loop( - _lis_inner_cond_tf, - _lis_inner_body_tf, - [0, dp, arr, i] - ) - return i + 1, dp, arr, n + def true_fn(): + prefix = tf.slice(dp, [0], [i]) + arr_prefix = tf.slice(arr, [0], [i]) + arr_i = tf.gather(arr, i) + mask = tf.less(arr_prefix, arr_i) + fill_val = tf.reduce_min(dp) - tf.constant(1, dtype=dp.dtype) + candidates = tf.where(mask, prefix + 1, tf.fill(tf.shape(prefix), fill_val)) + max_cand = tf.reduce_max(candidates) + dp_i = tf.gather(dp, i) + new_val = tf.maximum(dp_i, max_cand) + indices = tf.reshape(i, [1, 1]) + updates = tf.reshape(new_val, [1]) + dp_updated = tf.tensor_scatter_nd_update(dp, indices, updates) + return dp_updated + + def false_fn(): + return dp + + dp_updated = tf.cond(tf.greater(i, 0), true_fn, false_fn) + return i + 1, dp_updated, arr, n def _lis_outer_cond_tf(i, _dp, _arr, n):