Version: 0.0.2
linox is a Python package that provides a collection of linear operators for JAX, enabling efficient and flexible linear algebra operations with lazy evaluation. This package is designed as a JAX alternative to probnum.linops, but it is currently still under development having less and more instable features. It has no dependencies other than JAX and plum for multiple dispatch.
Note (v0.0.2): The API has been updated to remove the "l" prefix from function names. Functions like lsolve, linverse, ldet, etc. are now available as solve, inverse, det, etc. The old "l"-prefixed functions are deprecated and will be removed in version 0.0.3.
Matrix‑free Gaussian Process predictions and posterior uncertainty on a 2D heat‑equation task using Kronecker‑structured kernels.
- Lazy Evaluation: All operators support lazy evaluation, allowing for efficient computation of complex linear transformations
- JAX Integration: Built on top of JAX, providing automatic differentiation, parallelization, JIT compilation, and GPU/TPU support
- Composable Operators: Operators can be combined to form complex linear transformations
Matrix: General matrix operatorIdentity: Identity matrix operatorDiagonal: Diagonal matrix operatorScalar: Scalar multiple of identityZero: Zero matrix operatorOnes: Matrix of ones operator
BlockMatrix: General block matrix operatorBlockMatrix2x2: 2x2 block matrix operatorBlockDiagonal: Block diagonal matrix operator
LowRank: General low rank operatorSymmetricLowRank: Symmetric low rank operatorIsotropicScalingPlusSymmetricLowRank: Isotropic scaling plus symmetric low rankPositiveDiagonalPlusSymmetricLowRank: Positive diagonal plus symmetric low rank
Kronecker: Kronecker product operatorPermutation: Permutation matrix operatorEigenD: Eigenvalue decomposition operatorToeplitz: Toeplitz matrix operatorIsotropicAdditiveLinearOperator: Efficient operator fors*I + Awith spectral transforms
ScaledLinearOperator: Scalar multiple of an operatorAddLinearOperator: Sum of multiple operatorsProductLinearOperator: Product of multiple operatorsTransposedLinearOperator: Transpose of an operatorInverseLinearOperator: Inverse of an operatorPseudoInverseLinearOperator: Pseudo-inverse of an operatorCongruenceTransform: Congruence transformationA B A^T
solve(A, b): Solve the linear systemAx = bpsolve(A, b): Solve using pseudo-inverse for singular/rectangular systemslu_factor(A): LU factorizationlu_solve(A, b): Solve using LU factorization
eigh(A): Eigendecomposition for Hermitian matricessvd(A): Singular Value Decompositionqr(A): QR decompositioncholesky(A): Cholesky decomposition
inverse(A): Compute inverseA^{-1}pinverse(A): Compute pseudo-inverseA^†sqrt(A): Compute matrix square roottranspose(A): Transpose operatordet(A): Compute determinantslogdet(A): Compute sign and log-determinant
diagonal(A): Extract diagonal elementssymmetrize(A): Symmetrize operator(A + A^T)/2congruence_transform(A, B): ComputeA B A^Tkron(A, B): Kronecker productiso(s, A): Create isotropic additive operators*I + A
add(A, B): Add two operatorssub(A, B): Subtract operatorsmul(scalar, A): Scalar multiplicationmatmul(A, B): Matrix multiplicationneg(A): Negate operatordiv(A, B): Division (for diagonal operators)
is_square(A): Check if operator is squareis_symmetric(A): Check symmetry without densification (randomized)is_hermitian(A): Check Hermitian property without densification (randomized)
todense(A): Convert to dense arrayallclose(A, B): Compare operatorsset_debug(enabled): Enable/disable densification warningsis_debug(): Check debug mode status
- Automatic Differentiation: Compute gradients automatically through operator compositions
- JIT Compilation: Speed up computations with just-in-time compilation
- Vectorization: Efficient batch processing of linear operations via e.g.
jax.vmap - GPU/TPU Support: Run computations on accelerators without code changes
- Functional Programming: Pure functions enable better optimization and parallelization
import jax
import jax.numpy as jnp
from linox import Matrix, Diagonal, BlockMatrix, inverse, solve, det
# Create operators
A = Matrix(jnp.array([[1, 2], [3, 4]], dtype=jnp.float32))
D = Diagonal(jnp.array([1, 2], dtype=jnp.float32))
# Compose operators
B = BlockMatrix([[A, D], [D, A]])
# Apply to vector
x = jnp.ones((4,), dtype=jnp.float32)
y = B @ x # Lazy evaluation
# Solve linear system
b = jnp.ones((4,), dtype=jnp.float32)
x_solved = solve(B, b)
# Compute inverse and determinant
B_inv = inverse(B)
det_B = det(B)
# Parallelize over batch of vectors
x_batched = jnp.ones((10, 4), dtype=jnp.float32)
y_batched = jax.vmap(B)(x_batched)Linox makes it easy to build Gaussian Process (GP) operators that factorize across function and spatial dimensions. This leverages Kronecker structure and preserves matrix‑free behavior, so you can compose large kernels without materializing massive dense arrays.
Example: a modular GP prior with a function kernel ⊗ spatial kernel
import jax
import jax.numpy as jnp
from helper.new_gp import (
CombinationConfig,
DimensionSpec,
ModularGPPrior,
StructureConfig,
params_from_structure,
)
from helper.gp import KernelType, CombinationStrategy
# Enable double precision for numerical stability (optional)
jax.config.update("jax_enable_x64", True)
# 2D setup (one function dim u, two spatial dims x,y)
structure = StructureConfig(
spatial_dims=[
DimensionSpec(name="x", kernel_type=KernelType.RBF),
DimensionSpec(name="y", kernel_type=KernelType.RBF),
],
function_dims=[DimensionSpec(name="u", kernel_type=KernelType.L2)],
)
combo = CombinationConfig(strategy=CombinationStrategy.ADDITIVE, output_scale=1.0)
prior = ModularGPPrior(structure, combo)
params = params_from_structure(structure)
# Training data (N_train functions, evaluated on an (nx, ny) grid)
N_train, N_test = 25, 3
nx, ny = 15, 15
nx_plot, ny_plot = 25, 25
# See helper.plotting.generate_preprocess_data_2d for data creation
from helper.plotting import generate_preprocess_data_2d
(
operator_inputs, # (N_train, nx, ny)
spatial_inputs, # (nx, ny, 2)
outputs, # (N_train * nx * ny,)
operator_inputs_test, # (N_test, nx, ny)
spatial_inputs_test, # (nx, ny, 2)
outputs_test, # (N_test * nx * ny,)
spatial_inputs_plot, # (nx_plot, ny_plot, 2)
) = generate_preprocess_data_2d(
x_range=(0.0, jnp.pi), y_range=(0.0, jnp.pi),
nx=nx, ny=ny, T=0.1, alpha=0.5,
N_train=N_train, N_test=N_test,
nx_plot=nx_plot, ny_plot=ny_plot,
)
# Build the Kronecker‑structured kernel and run predictions
pred_mean_flat, pred_cov = prior.predict(
operator_inputs,
outputs,
spatial_inputs,
operator_inputs_test,
spatial_inputs_plot,
params,
)
# pred_mean_flat has shape (N_test * nx_plot * ny_plot,)
# pred_cov is a LinearOperator (matrix‑free) you can densify only for plottingWhy this is fast and memory‑efficient
- Kronecker structure: The prior kernel is built as
K_function ⊗ K_spatial, usinglinox.Kronecker, so large grids are handled as compositions rather than dense matrices. - Matrix‑free algebra: Solves and products are done via LinearOperators (e.g.,
IsotropicAdditiveLinearOperator,linverse,lsolve) without forming dense blocks. - Lazy properties: Many operations (like
diagonal) propagate into factors and avoid densification unless explicitly required (see “Densification Warnings”).
Illustrative outputs (2D heat‑equation demo)
See the example notebook for a walkthrough: examples/gp_operator_walkthrough.ipynb.
Some operations fall back to dense computations when a lazy, structure‑preserving
path is not available (e.g., diagonal of a general product of non‑diagonal factors,
explicit inverse materialization). To help diagnose performance, linox can emit
warnings whenever an operation densifies.
By default, these warnings are suppressed. Enable them via the API or an environment variable:
from linox import set_debug
# Turn on debug warnings
set_debug(True)
# Turn them off again
set_debug(False)Or set an environment variable before running Python:
export LINOX_DEBUG=1 # enables densification warnings
python your_script.pyExamples of operations that may warn when debug is enabled:
diagonal(op)when it must convert an operator to dense to compute the diagonal.- Decompositions like
leigh,svd,lqrfalling back to dense. InverseLinearOperator.todense()and pseudo‑inverse matmul paths that need dense.Matrix.todense()when explicitly materializing the dense array.
Note: Many structure‑aware paths remain lazy (e.g., diagonals of Kronecker products and of diagonal‑like products). The warnings help ensure large operators aren't accidentally densified.
linox draws inspiration from and complements matfree by Nicholas Krämer, which provides matrix-free linear algebra methods in JAX including randomized and deterministic methods for trace estimation, functions of matrices, and matrix factorizations.
If you use matrix-free methods or differentiable linear algebra iterations in your work, consider citing the matfree library:
For differentiable Lanczos or Arnoldi iterations:
@article{kraemer2024gradients,
title={Gradients of functions of large matrices},
author={Krämer, Nicholas and Moreno-Muñoz, Pablo and Roy, Hrittik and Hauberg, Søren},
journal={Advances in Neural Information Processing Systems},
volume={37},
pages={49484--49518},
year={2024}
}For differentiable LSMR implementation:
@article{roy2025matrix,
title={Matrix-Free Least Squares Solvers: Values, Gradients, and What to Do With Them},
author={Roy, Hrittik and Hauberg, Søren and Krämer, Nicholas},
journal={arXiv preprint arXiv:2510.19634},
year={2025}
}probnum.linops: The original inspiration for linox, providing linear operators in Python/NumPymatfree: Specialized matrix-free methods for large-scale problems
pip install linoxOr install from source:
git clone https://github.com/2bys/linox.git
cd linox
pip install -e .Contributions are welcome! Please feel free to submit pull requests or open issues on the GitHub repository.
This project is licensed under the MIT License.
