diff --git a/code_to_optimize/sample_code.py b/code_to_optimize/sample_code.py index d356ce807..1fcc767a8 100644 --- a/code_to_optimize/sample_code.py +++ b/code_to_optimize/sample_code.py @@ -1,3 +1,4 @@ +import bisect from functools import partial import jax.numpy as jnp @@ -97,19 +98,18 @@ def longest_increasing_subsequence_length(arr: np.ndarray) -> int: if n == 0: return 0 - dp = np.ones(n, dtype=np.int64) + # Use patience sorting / tails method for O(n log n) time. + tails = [] # tails[k] = smallest tail value of an increasing subsequence of length k+1 - for i in range(1, n): - for j in range(i): - if arr[j] < arr[i]: - if dp[j] + 1 > dp[i]: - dp[i] = dp[j] + 1 - - max_length = dp[0] - for i in range(1, n): - if dp[i] > max_length: - max_length = dp[i] + for x in arr: + # Find the insertion point for x in tails to maintain sorted order. + i = bisect.bisect_left(tails, x) + if i == len(tails): + tails.append(x) + else: + tails[i] = x + max_length = np.int64(len(tails)) return max_length