Skip to content

MAGICS-LAB/state_space_duality

Repository files navigation

Structured State-Space Duality Experiments

This repository hosts the experiments that illustrate the duality between diagonal state-space models (SSMs) and semi-separable attention, plus a compact suite of Mamba-vs-baseline comparisons.

Repository Layout

state_space_duality/
├── README.md                 # This document
├── LICENSE
├── experiments/              # Structured state-space duality demos
│   ├── __init__.py
│   ├── common.py             # Shared utilities and ExperimentResult dataclass
│   ├── exp1_scalar_equivalence.py
│   ├── exp2_diagonal_equivalence.py
│   ├── exp2b_timevarying_diagonal.py
│   ├── exp3_rank_vs_state_dim.py
│   ├── exp3b_rank_vs_state_dim.py
│   ├── exp4_time_scaling.py
│   ├── exp5_softmax_rank_growth.py
│   ├── exp5b_softmax_rank_growth.py
│   ├── verify_ssd.py         # Aggregate runner (uv run python -m experiments.verify_ssd)
│   ├── run_and_log.py        # CLI with logging + grids
│   └── check_logs.py         # Sanity checks for logged results
├── mamba_experiments/        # Small-scale Mamba vs. baseline comparisons
├── time_series_experiments/  # Synthetic time-series regression experiments + plotting
├── pyproject.toml            # Project metadata + optional extras
└── uv.lock                   # Reproducible dependency lockfile (uv)

Setup

The SSD experiments only require NumPy (plus Matplotlib for optional Exp4 plots), while mamba_experiments/ uses PyTorch, datasets, einops, etc. Pick your preferred workflow:

  • Install via pyproject.toml extras

    uv sync

Quick Start

Run all experiments once:

uv run python -m experiments.verify_ssd

Run a single experiment module directly:

uv run python -m experiments.exp1_scalar_equivalence
uv run python -m experiments.exp2_diagonal_equivalence
uv run python -m experiments.exp2b_timevarying_diagonal
uv run python -m experiments.exp3_rank_vs_state_dim
uv run python -m experiments.exp3b_rank_vs_state_dim
uv run python -m experiments.exp4_time_scaling
uv run python -m experiments.exp4b_time_scaling_errorbars
uv run python -m experiments.exp5_softmax_rank_growth
uv run python -m experiments.exp5b_softmax_rank_growth

Running Individual Experiments

Invoke the modules directly (prepend uv run if uv manages your environment):

  • uv run python -m experiments.exp1_scalar_equivalence --T 32 --a 0.9 --seed 1
  • uv run python -m experiments.exp2_diagonal_equivalence --T 40 --seed 2 --decays 0.5 0.8
  • uv run python -m experiments.exp2b_timevarying_diagonal --T 16 --N 4 --seed 3
  • uv run python -m experiments.exp3_rank_vs_state_dim --T-values 15 30 --seeds 0 1
  • uv run python -m experiments.exp4_time_scaling --T-values 300 600 1200 --end-to-end --exp4-plot
  • uv run python -m experiments.exp4b_time_scaling_errorbars --n-trials 8 --n-repeats 5
  • uv run python -m experiments.run_and_log --experiment exp4 --repeats 10 --exp4-plot --exp4-plot-path outputs/exp4_time_scaling.png
  • uv run python -m experiments.exp5_softmax_rank_growth --T-values 80 160 320 --d-k 16

Every script has --help support.

Experiment Overview

  1. exp1_scalar_equivalence.py – Scalar SSM ≡ 1-SS attention
    Propagates a length-T sequence through a single-state recurrence and matches it with a causal 1-semiseparable attention kernel built from powers of a scalar decay a. The experiment reports the maximum absolute difference between the recurrent and attention outputs, along with the seed and decay.

  2. exp2_diagonal_equivalence.py – Diagonal SSM vs. sum of 1-SS heads
    Generalizes Exp1 to an N-state diagonal SSM. Each diagonal decay contributes one causal head whose kernel is (C[m]*B[m]) * a_m^{t-s}. The script also records the Vandermonde generator rank to show that distinct eigenvalues guarantee full-rank semiseparable generators. Control the state dimension via --exp2-n-list when using run_and_log.py.

  3. exp2b_timevarying_diagonal.py – Time-varying diagonal decays
    Allows both the recurrence weights and the input/output couplings to change per timestep. _build_time_varying_mask constructs the causal kernel ∏_{k=s+1}^t A_k. The experiment confirms that the resulting masked attention still reproduces the exact recurrent outputs even under non-stationary dynamics.

  4. exp3_rank_vs_state_dim.py – Generator rank studies
    Sweeps over sequence lengths, hand-crafted decay sets, and optional random configurations to compare the theoretical semiseparable generator rank with the empirical matrix rank of the induced attention kernel. Use --exp3-t-list, --exp3-seeds, and --exp3-random-n in run_and_log.py to extend the grid.

  5. exp3b_rank_vs_state_dim.py – Rank vs. state dimension plot
    Wraps Exp3 to sweep N and seeds with spaced decays, then plots the generator rank mean ± std against N with a y = N reference line. Saves to outputs/exp3b_rank_vs_state_dim.png by default.

  6. exp4_time_scaling.py – O(T) recurrence vs. O(T²) attention cost
    Benchmarks the wall-clock time of directly running the diagonal recurrence against explicitly forming the T × T kernel and multiplying inputs. Toggle --exp4-end-to-end to include kernel construction time, and --exp4-plot --exp4-plot-path ... to save the matplotlib visualization generated by plot_results.

  7. exp4b_time_scaling_errorbars.py – Runtime scaling with 95% CIs
    Runs multiple random trials per sequence length and plots the mean ± 95% confidence interval for both recurrence and attention costs (styled via the Flow-KL settings) into outputs/exp4b_time_scaling.png. Configure the number of trials via --n-trials, repeats per trial via --n-repeats, and switch to end-to-end timing with --end-to-end.

  8. exp5_softmax_rank_growth.py – Softmax attention negative test
    Generates random query/key pairs, applies stable causal softmax attention, and measures matrix rank as T grows. The rank growth illustrates why unrestricted softmax attention does not stay low-rank in contrast to the semiseparable kernels above.

  9. exp5b_softmax_rank_growth.py – Softmax rank growth with variance bands
    Runs Exp5 across multiple seeds, plots the mean ± std rank vs. T, and optionally saves a rank-gap (T - rank) curve. Defaults to outputs/exp5b_rank_growth.png and outputs/exp5b_rank_growth_gap.png. Each experiment exposes a run(...) function that returns an ExperimentResult (name, human-readable details, metadata dict) for easy downstream logging.

Mamba Experiments

mamba_experiments/ compares two small Mamba variants (mamba_simple, mamba_SSD_diag_exp) on synthetic recall data and WikiText-2 windows.

  • Recall task example:

    uv run python -m mamba_experiments.train \
      --model mamba_simple \
      --dataset recall \
      --seq-len 64 --gap 8 --epochs 5 \
      --save-plots --plot-dir logs/recall
  • WikiText-2 example:

    uv run python -m mamba_experiments.train \
      --model mamba_SSD_diag_exp \
      --dataset wikitext2 \
      --seq-len 128 \
      --max-vocab 20000 \
      --train-max-samples 10000 \
      --val-max-samples 2000
  • Compare Original vs SSD Mamba with variance bands (WikiText-2, seeds 0–9):

    RUN_DIR=outputs/mamba_experiments/compare_wikitext2_s128_d64_n16_l2_e10_b64_lr1e-3_m5k1k
    for model in mamba_simple mamba_SSD_diag_exp; do
      for seed in $(seq 0 9); do
        uv run python -m mamba_experiments.train \
          --model $model \
          --dataset wikitext2 \
          --seq-len 128 \
          --max-vocab 20000 \
          --train-max-samples 5000 \
          --val-max-samples 1000 \
          --d-model 64 \
          --n-state 16 \
          --n-layers 2 \
          --epochs 10 \
          --batch-size 64 \
          --lr 1e-3 \
          --seed $seed \
          --save-plots \
          --plot-dir $RUN_DIR
      done
    done
    
    MPLCONFIGDIR=outputs/.mplconfig uv run python -m mamba_experiments.cmds.plot_variance_bands \
      --log-path $RUN_DIR/train_runs.jsonl \
      --out-dir $RUN_DIR/plots_compare \
      --dataset wikitext2 \
      --models mamba_simple mamba_SSD_diag_exp \
      --metric val_loss \
      --require-all-models

Helpful flags (--help for full list):

  • --model {mamba_simple,mamba_SSD_diag_exp}
  • --dataset {recall,wikitext2} with dataset-specific args (--gap, --max-vocab, etc.)
  • Architecture/training knobs: --d-model, --n-state, --n-layers, --epochs, --batch-size, --lr
  • Logging knobs: --save-plots, --plot-dir (plots and JSONL appended to <plot_dir>/train_runs.jsonl)

Additional modules (mamba_simple.py, mamba_SSD.py, etc.) live in the same folder if you want to extend the blocks or plug them into other projects.

Time-Series Experiments

time_series_experiments/ contains a small synthetic regression benchmark (mixture-of-decays) plus JSONL logging and plotting helpers.

Run the synthetic N-sweep (diagonal SSD-style block, N=d_state):

uv run python -m time_series_experiments.exp_synthetic_ndecay --models mamba_diag_exp --N_values 1 2 --epochs 60 --seed 42 --T 100 --lambdas 0.9 0.5 --coeffs 1.0 0.7 --noise_std 1e-4

Plot the latest results (reads outputs/time_series_experiments/ts_runs.jsonl):

uv run python -m time_series_experiments.plot_runs --task synthetic_decays --latest_only --log_val_curves

Plots are written to outputs/time_series_experiments/ by default.

Notes

  • The SSD scripts in experiments/ run purely on CPU with double-precision NumPy.
  • mamba_experiments/ relies on PyTorch (and optionally GPUs) plus its extra dependencies.
  • Logged JSONL/CSV entries include the experiment name, seed, human-readable summary, and a meta dict (decays, ranks, timings, tolerances, etc.).
  • Matplotlib is only required for Exp4 plots or optional Mamba training curves.

License

This repository is licensed under the Apache License. See LICENSE for details.