This repository hosts the experiments that illustrate the duality between diagonal state-space models (SSMs) and semi-separable attention, plus a compact suite of Mamba-vs-baseline comparisons.
state_space_duality/
├── README.md # This document
├── LICENSE
├── experiments/ # Structured state-space duality demos
│ ├── __init__.py
│ ├── common.py # Shared utilities and ExperimentResult dataclass
│ ├── exp1_scalar_equivalence.py
│ ├── exp2_diagonal_equivalence.py
│ ├── exp2b_timevarying_diagonal.py
│ ├── exp3_rank_vs_state_dim.py
│ ├── exp3b_rank_vs_state_dim.py
│ ├── exp4_time_scaling.py
│ ├── exp5_softmax_rank_growth.py
│ ├── exp5b_softmax_rank_growth.py
│ ├── verify_ssd.py # Aggregate runner (uv run python -m experiments.verify_ssd)
│ ├── run_and_log.py # CLI with logging + grids
│ └── check_logs.py # Sanity checks for logged results
├── mamba_experiments/ # Small-scale Mamba vs. baseline comparisons
├── time_series_experiments/ # Synthetic time-series regression experiments + plotting
├── pyproject.toml # Project metadata + optional extras
└── uv.lock # Reproducible dependency lockfile (uv)
The SSD experiments only require NumPy (plus Matplotlib for optional Exp4 plots), while mamba_experiments/ uses PyTorch, datasets, einops, etc. Pick your preferred workflow:
-
Install via
pyproject.tomlextrasuv sync
Run all experiments once:
uv run python -m experiments.verify_ssdRun a single experiment module directly:
uv run python -m experiments.exp1_scalar_equivalence
uv run python -m experiments.exp2_diagonal_equivalence
uv run python -m experiments.exp2b_timevarying_diagonal
uv run python -m experiments.exp3_rank_vs_state_dim
uv run python -m experiments.exp3b_rank_vs_state_dim
uv run python -m experiments.exp4_time_scaling
uv run python -m experiments.exp4b_time_scaling_errorbars
uv run python -m experiments.exp5_softmax_rank_growth
uv run python -m experiments.exp5b_softmax_rank_growthInvoke the modules directly (prepend uv run if uv manages your environment):
uv run python -m experiments.exp1_scalar_equivalence --T 32 --a 0.9 --seed 1uv run python -m experiments.exp2_diagonal_equivalence --T 40 --seed 2 --decays 0.5 0.8uv run python -m experiments.exp2b_timevarying_diagonal --T 16 --N 4 --seed 3uv run python -m experiments.exp3_rank_vs_state_dim --T-values 15 30 --seeds 0 1uv run python -m experiments.exp4_time_scaling --T-values 300 600 1200 --end-to-end --exp4-plotuv run python -m experiments.exp4b_time_scaling_errorbars --n-trials 8 --n-repeats 5uv run python -m experiments.run_and_log --experiment exp4 --repeats 10 --exp4-plot --exp4-plot-path outputs/exp4_time_scaling.pnguv run python -m experiments.exp5_softmax_rank_growth --T-values 80 160 320 --d-k 16
Every script has --help support.
-
exp1_scalar_equivalence.py– Scalar SSM ≡ 1-SS attention
Propagates a length-Tsequence through a single-state recurrence and matches it with a causal 1-semiseparable attention kernel built from powers of a scalar decaya. The experiment reports the maximum absolute difference between the recurrent and attention outputs, along with the seed and decay. -
exp2_diagonal_equivalence.py– Diagonal SSM vs. sum of 1-SS heads
Generalizes Exp1 to anN-state diagonal SSM. Each diagonal decay contributes one causal head whose kernel is(C[m]*B[m]) * a_m^{t-s}. The script also records the Vandermonde generator rank to show that distinct eigenvalues guarantee full-rank semiseparable generators. Control the state dimension via--exp2-n-listwhen usingrun_and_log.py. -
exp2b_timevarying_diagonal.py– Time-varying diagonal decays
Allows both the recurrence weights and the input/output couplings to change per timestep._build_time_varying_maskconstructs the causal kernel∏_{k=s+1}^t A_k. The experiment confirms that the resulting masked attention still reproduces the exact recurrent outputs even under non-stationary dynamics. -
exp3_rank_vs_state_dim.py– Generator rank studies
Sweeps over sequence lengths, hand-crafted decay sets, and optional random configurations to compare the theoretical semiseparable generator rank with the empirical matrix rank of the induced attention kernel. Use--exp3-t-list,--exp3-seeds, and--exp3-random-ninrun_and_log.pyto extend the grid. -
exp3b_rank_vs_state_dim.py– Rank vs. state dimension plot
Wraps Exp3 to sweepNand seeds with spaced decays, then plots the generator rank mean ± std againstNwith a y = N reference line. Saves tooutputs/exp3b_rank_vs_state_dim.pngby default. -
exp4_time_scaling.py– O(T) recurrence vs. O(T²) attention cost
Benchmarks the wall-clock time of directly running the diagonal recurrence against explicitly forming theT × Tkernel and multiplying inputs. Toggle--exp4-end-to-endto include kernel construction time, and--exp4-plot --exp4-plot-path ...to save the matplotlib visualization generated byplot_results. -
exp4b_time_scaling_errorbars.py– Runtime scaling with 95% CIs
Runs multiple random trials per sequence length and plots the mean ± 95% confidence interval for both recurrence and attention costs (styled via the Flow-KL settings) intooutputs/exp4b_time_scaling.png. Configure the number of trials via--n-trials, repeats per trial via--n-repeats, and switch to end-to-end timing with--end-to-end. -
exp5_softmax_rank_growth.py– Softmax attention negative test
Generates random query/key pairs, applies stable causal softmax attention, and measures matrix rank asTgrows. The rank growth illustrates why unrestricted softmax attention does not stay low-rank in contrast to the semiseparable kernels above. -
exp5b_softmax_rank_growth.py– Softmax rank growth with variance bands
Runs Exp5 across multiple seeds, plots the mean ± std rank vs.T, and optionally saves a rank-gap (T - rank) curve. Defaults tooutputs/exp5b_rank_growth.pngandoutputs/exp5b_rank_growth_gap.png. Each experiment exposes arun(...)function that returns anExperimentResult(name, human-readable details, metadata dict) for easy downstream logging.
mamba_experiments/ compares two small Mamba variants (mamba_simple, mamba_SSD_diag_exp) on synthetic recall data and WikiText-2 windows.
-
Recall task example:
uv run python -m mamba_experiments.train \ --model mamba_simple \ --dataset recall \ --seq-len 64 --gap 8 --epochs 5 \ --save-plots --plot-dir logs/recall
-
WikiText-2 example:
uv run python -m mamba_experiments.train \ --model mamba_SSD_diag_exp \ --dataset wikitext2 \ --seq-len 128 \ --max-vocab 20000 \ --train-max-samples 10000 \ --val-max-samples 2000
-
Compare Original vs SSD Mamba with variance bands (WikiText-2, seeds 0–9):
RUN_DIR=outputs/mamba_experiments/compare_wikitext2_s128_d64_n16_l2_e10_b64_lr1e-3_m5k1k for model in mamba_simple mamba_SSD_diag_exp; do for seed in $(seq 0 9); do uv run python -m mamba_experiments.train \ --model $model \ --dataset wikitext2 \ --seq-len 128 \ --max-vocab 20000 \ --train-max-samples 5000 \ --val-max-samples 1000 \ --d-model 64 \ --n-state 16 \ --n-layers 2 \ --epochs 10 \ --batch-size 64 \ --lr 1e-3 \ --seed $seed \ --save-plots \ --plot-dir $RUN_DIR done done MPLCONFIGDIR=outputs/.mplconfig uv run python -m mamba_experiments.cmds.plot_variance_bands \ --log-path $RUN_DIR/train_runs.jsonl \ --out-dir $RUN_DIR/plots_compare \ --dataset wikitext2 \ --models mamba_simple mamba_SSD_diag_exp \ --metric val_loss \ --require-all-models
Helpful flags (--help for full list):
--model {mamba_simple,mamba_SSD_diag_exp}--dataset {recall,wikitext2}with dataset-specific args (--gap,--max-vocab, etc.)- Architecture/training knobs:
--d-model,--n-state,--n-layers,--epochs,--batch-size,--lr - Logging knobs:
--save-plots,--plot-dir(plots and JSONL appended to<plot_dir>/train_runs.jsonl)
Additional modules (mamba_simple.py, mamba_SSD.py, etc.) live in the same folder if you want to extend the blocks or plug them into other projects.
time_series_experiments/ contains a small synthetic regression benchmark (mixture-of-decays) plus JSONL logging and plotting helpers.
Run the synthetic N-sweep (diagonal SSD-style block, N=d_state):
uv run python -m time_series_experiments.exp_synthetic_ndecay --models mamba_diag_exp --N_values 1 2 --epochs 60 --seed 42 --T 100 --lambdas 0.9 0.5 --coeffs 1.0 0.7 --noise_std 1e-4Plot the latest results (reads outputs/time_series_experiments/ts_runs.jsonl):
uv run python -m time_series_experiments.plot_runs --task synthetic_decays --latest_only --log_val_curvesPlots are written to outputs/time_series_experiments/ by default.
- The SSD scripts in
experiments/run purely on CPU with double-precision NumPy. mamba_experiments/relies on PyTorch (and optionally GPUs) plus its extra dependencies.- Logged JSONL/CSV entries include the experiment name, seed, human-readable summary, and a
metadict (decays, ranks, timings, tolerances, etc.). - Matplotlib is only required for Exp4 plots or optional Mamba training curves.
This repository is licensed under the Apache License. See LICENSE for details.