Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 38 additions & 3 deletions code_to_optimize/sample_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,21 @@
import tensorflow as tf
import torch
from jax import lax
from numba import njit

_numba_tridiagonal_solve = None


def tridiagonal_solve(a: np.ndarray, b: np.ndarray, c: np.ndarray, d: np.ndarray) -> np.ndarray:
n = len(b)

# Create working copies to avoid modifying input
c_prime = np.zeros(n - 1, dtype=np.float64)
d_prime = np.zeros(n, dtype=np.float64)
x = np.zeros(n, dtype=np.float64)
c_prime = np.empty(n - 1, dtype=np.float64)
d_prime = np.empty(n, dtype=np.float64)
x = np.empty(n, dtype=np.float64)

if _numba_tridiagonal_solve is not None and n > 64:
return _numba_tridiagonal_solve(a, b, c, d)

# Forward sweep - sequential dependency: c_prime[i] depends on c_prime[i-1]
c_prime[0] = c[0] / b[0]
Expand Down Expand Up @@ -454,3 +460,32 @@ def longest_increasing_subsequence_length_tf(arr):
)

return int(tf.reduce_max(dp))

@njit(cache=True)
def _numba_tridiagonal_solve(a: np.ndarray, b: np.ndarray, c: np.ndarray, d: np.ndarray) -> np.ndarray:
n = b.shape[0]

# Create working copies to avoid modifying input
c_prime = np.empty(n - 1, dtype=np.float64)
d_prime = np.empty(n, dtype=np.float64)
x = np.empty(n, dtype=np.float64)

# Forward sweep - sequential dependency: c_prime[i] depends on c_prime[i-1]
c_prime[0] = c[0] / b[0]
d_prime[0] = d[0] / b[0]

for i in range(1, n - 1):
denom = b[i] - a[i - 1] * c_prime[i - 1]
c_prime[i] = c[i] / denom
d_prime[i] = (d[i] - a[i - 1] * d_prime[i - 1]) / denom

# Last row of forward sweep
denom = b[n - 1] - a[n - 2] * c_prime[n - 2]
d_prime[n - 1] = (d[n - 1] - a[n - 2] * d_prime[n - 2]) / denom

# Back substitution - sequential dependency: x[i] depends on x[i+1]
x[n - 1] = d_prime[n - 1]
for i in range(n - 2, -1, -1):
x[i] = d_prime[i] - c_prime[i] * x[i + 1]

return x
Loading