FFT-guided period selection + 2D Inception CNN + embedding-aware adaptive probabilistic head for multivariate time-series forecasting.
Keeps the canonical [B, T, N] → [B, H, N] interface while adding robust contextualization and calibrated uncertainty.
- Beyond vanilla TimesNet. We preserve TimesNet’s FFT-to-2D idea but add: channel-robust shared period search, PeriodGrouper (merge penalized/duplicate periods), rich embeddings (IDs + statics + low-rank temporal context), and an adaptive probabilistic head (Negative Binomial with dispersion floors) for stability.
- CNN synergy. Embeddings lift semantic signal; the 2D Inception CNN captures intra-/inter-period structure; the probabilistic head calibrates rate/dispersion per series and horizon — three gears meshing for accuracy and robustness.
- Direct & recursive. Train once, forecast direct (
Hat once) or recursive (rolling one-step) with the same checkpoints. - Modular data I/O. Input schemas and test loaders are fully pluggable. Swap CSV layouts, feature sets, and eval folds via config — no code surgery required.
Why this matters Injecting static per‑series information (IDs, categories) as a dynamic signal that evolves over time is notoriously hard.
LowRankTemporalContextsolves this with low‑rank approximation: a compact, principled way to compose per‑ID temporal background signals from a tiny set of shared basis waves.
We add to each series a learned, time‑varying context:
where
Naive plan: learn every
| Approach | Parameter scale | Generalization | Notes |
|---|---|---|---|
| Per‑timestep per‑ID table | ❌ poor | Memorizes, no sharing | |
| One shared temporal vector | ❌ ignores ID | Misses heterogeneity | |
| Low‑rank (ours) | ✅ strong | Shares basis, personalizes via coeffs |
With rank
Assume each per‑ID context is a linear mix of a small set of shared basis signals
Matrix view (dimensions in brackets):
Here
For
Why cosine/DCT? energy compaction, near‑orthogonality, FFT‑friendly numerics.
We center each basis columnwise:
Effect: the context modulates patterns without drifting the global scale/level of
Tip: Use zero‑mean when you want to shape temporal texture (seasonality, pulse, curvature) but not bias levels. Turn it off only if you explicitly want baseline shifts.
Each series has a static embedding
Minimal, fast, and expressive enough to map “who you are” → “how your context sounds.”
Let
# LowRankTemporalContext.forward (conceptual)
context = torch.einsum("lr,bnr->bln", basis, coeff) # S_{t,n} = sum_r w_{n,r} * b_r(t)
x_out = x_in + contextDims:
-
Params:
$L\times R$ (basis, fixed or learnable) + affine$d\times R$ (coeff head). -
FLOPs:
$\mathcal O(B,L,N,R)$ (einsum), linear in rank$R$ . -
Memory:
$\mathbf B$ is shared;$\mathbf W$ computed on the fly from embeddings.
| Knob | Role | Rule of thumb |
|---|---|---|
| Rank |
richness of temporal palette | start |
| Zero‑mean | level‑stability | on for pattern‑only; off to adjust baselines |
| Basis type | prior over shapes | Cosine/DCT for smoothness; Spline/learned for flexibility |
| Coeff head | ID→mix mapping | Linear is robust; MLP if highly nonlinear IDs |
LowRankTemporalContext pre‑conditions the input with ID‑aware, time‑varying structure:
flowchart LR
E["ID Embedding (e_n)"] --> H["Coeff Head"]
H -- R weights per ID --> W["W (R x N)"]
B["Cosine Basis B (L x R)"] --> S{{"einsum lr,bnr->bln"}}
W --> S
X["Input X (B x L x N)"] --> A["Add"]
S --> A
A --> C["TimesNet / CNN"]
C --> Y["Forecast"]
- Before temporal blocks: inject smooth, ID‑specific rhythms that align with periodic convolutions/2D kernels.
- Effect: CNN/TimesNet layers spend less capacity rediscovering obvious ID‑periodicity; they focus on higher‑order interactions and residuals.
class LowRankTemporalContext(nn.Module):
def __init__(self, length: int, rank: int, embed_dim: int,
zero_mean: bool = True, learn_basis: bool = False):
super().__init__()
self.length, self.rank = length, rank
# (A) Shared temporal basis B ∈ ℝ^{L×R}
B = self._cosine_basis(length, rank) # DCT-II grid
if zero_mean:
B = B - B.mean(dim=0, keepdim=True)
self.basis = nn.Parameter(B, requires_grad=learn_basis)
# (B) ID-embedding → R weights
self.context_coeff = nn.Linear(embed_dim, rank)
@staticmethod
def _cosine_basis(L, R):
t = torch.arange(L).float().unsqueeze(1) # [L,1]
r = torch.arange(1, R+1).float().unsqueeze(0) # [1,R]
return torch.cos(math.pi / float(L) * (t + 0.5) * r) # [L,R]
def forward(self, x: torch.Tensor, id_embed: torch.Tensor):
"""
x: [B,L,N] (time-major)
id_embed: [B,N,D] (static embeddings per series)
"""
coeff = self.context_coeff(id_embed) # [B,N,R]
context = torch.einsum("lr,bnr->bln", self.basis, coeff)
return x + context, context-
Initialization: keep
learn_basis=Falsefirst; let the linear head learn mixtures on a fixed palette. -
Regularization: mild weight decay on the coeff head; optional
$\ell_2$ on coefficients to avoid over‑energetic contexts. -
Sanity checks: plot a few
${b_r(t)}$ and sampled$S_{t,n}$ ; verify $\mathrm{mean}{t}(S{t,n})\approx0$ when zero‑mean is on; ablate$R$ for diminishing returns. -
When to increase
$R$ : multi‑scale seasonality (daily/weekly), heterogeneous venues/menus, long horizons. - When to learn the basis: if domain rhythms differ from cosines (e.g., holiday pulses, regime switches). Consider piecewise or spline bases.
- Baseline: model w/o context.
-
+ Low‑rank context:
$R=8$ , zero‑mean on, fixed DCT basis. -
Rank sweep:
$R\in{4,8,12,16}$ . - Basis ablation: fixed DCT vs learnable basis.
- Downstream impact: check sMAPE/NLL deltas, especially on sparse/volatile IDs.
| Setting | sMAPE ↓ | NLL ↓ | Notes |
|---|---|---|---|
| Baseline | — | — | reference |
| + LowRank (R=8) | ↓ | ↓ | best cost/benefit |
| + Learnable basis (R=8) | ↓ | ↓ | may improve, watch overfit |
| + High rank (R=32) | ~ | ~ | risk: diminishing returns |
-
Parsimony: replaces
$\mathcal O(LN)$ free knobs with$\mathcal O(LR+NR)$ structured ones. - Inductive bias: smooth, near‑orthogonal atoms match the physics of seasonal/slow dynamics.
- Composability: cleanly adds to any sequence backbone (TimesNet, CNNs, Transformers).
- Controllability: zero‑mean switch separates pattern shaping from level shifting.
Vanilla TimesNet (paper-style): reshape a 1D series into a 2D period-phase grid chosen by FFT, then apply 2D CNN blocks to capture intra-period (phase) and inter-period (cycle) patterns.
This repo extends that design in three principled ways:
-
Embedding-aware context
- Value + positional + (optional) time features with configurable normalization (LayerNorm/RMSNorm/decoupled).
- Series ID embeddings and static covariates projected/fused into a context vector.
- Low-Rank Temporal Context (LRTC) injects a compact learned basis over time so static info can modulate temporal dynamics.
-
Robust, shared FFT period search
- Channel-median magnitude with batch averaging, DC removal, and log-penalty for long periods to avoid spurious peaks.
- PeriodGrouper merges near-duplicate periods (log buckets, min-cycle guards), producing stable, soft-weighted candidates for CNN processing.
-
Adaptive probabilistic head
- Negative Binomial (
rate,dispersion) withmin_sigma/per-series floors to keep dispersion positive on sparse demand. - Heads are AMP-safe; outputs are masked at invalid points.
- Negative Binomial (
Net effect: Embeddings raise semantic SNR, 2D Inception CNN exploits phase-by-cycle structure, and the probabilistic head adapts level/dispersion per series — a synergistic trio that outperforms naively stacked modules.
flowchart TD
%% --- 1. Input Layer ---
subgraph "Input Data"
direction LR
X["<b>Time-Series Input</b><br>[Batch, Lookback, Channels]"]
Meta["<b>Series Metadata</b><br><i>IDs & Static Features</i>"]
end
%% --- 2. Context Generation (from Metadata) ---
subgraph "A. Context Generation"
direction TB
ContextEmbed["<b>Static/ID Embedding</b>"]
LRTC["<b>Low-Rank Temporal Context (LRTC)</b><br>Generates time-varying 'Shape' signal"]
LateBias["<b>Late Bias Head</b><br>Generates time-invariant 'Scale' signal"]
Meta --> ContextEmbed
ContextEmbed --> LRTC
ContextEmbed --> LateBias
end
%% --- 3. Input Conditioning & Embedding ---
subgraph "B. Input Conditioning & Embedding"
ConditionedInput["(+) <b>Input Conditioning</b>"]
DataEmbed["<b>DataEmbedding Block</b><br>Value, Positional, Time Features"]
X --> ConditionedInput
LRTC -- Injects 'Shape' Signal --> ConditionedInput
ConditionedInput --> DataEmbed
end
%% --- 4. TimesNet Core Backbone ---
subgraph "C. TimesNet Backbone"
direction LR
FFT["<b>FFT-guided Period Selection</b>"]
Loop["TimesBlock Loop (N layers)"]
DataEmbed --> FFT
FFT -- Guides --> Loop
DataEmbed -- Residual Link --> Loop
end
%% --- 5. Probabilistic Forecasting Head ---
subgraph "D. Forecasting Head"
direction TB
Projection["<b>Temporal Projection</b><br>Maps to Prediction Horizon"]
ContextNote["<i>Shared temporal context from LRTC<br>via DataEmbedding feeds both heads</i>"]
subgraph "Rate Path (μ)"
direction TB
RateBias["(+) <b>Bias Injection</b>"]
RateHead["<b>Rate Head</b><br>Negative Binomial μ"]
end
subgraph "Dispersion Path (α)"
direction TB
DispersionHead["<b>Dispersion Head</b><br>Negative Binomial α"]
end
Loop --> Projection
Projection --> ContextNote
ContextNote -.-> RateBias
ContextNote -.-> DispersionHead
Projection --> RateBias
Projection --> DispersionHead
LateBias -- Injects 'Scale' Signal --> RateBias
RateBias --> RateHead
end
%% --- 6. Output ---
Forecast["<b>Final Forecast Distribution</b><br>[Batch, Horizon, Channels]"]
RateHead --> Forecast
DispersionHead --> Forecast
%% --- Styling for emphasis ---
style LRTC fill:#e8f5e9,stroke:#388e3c
style LateBias fill:#e8f5e9,stroke:#388e3c
style FFT fill:#fffde7,stroke:#fbc02d
style RateBias fill:#fce4ec,stroke:#c2185b
style ContextNote fill:#e3f2fd,stroke:#1976d2,stroke-dasharray: 5 5
- DataEmbedding: value + positional + optional time features; integrates ID & static embeddings and LRTC.
- FFTPeriodSelector: channel-robust spectrum summary → top-k frequencies (DC removed, long-period damped) → period lengths (≥2 cycles).
- PeriodGrouper: merges close periods, yields logits for softmax weighting.
- TimesBlock (2D Inception CNN): reshape
[B,T,N]to period grids, apply multi-kernel Inception with bottlenecks, compute residuals, then weighted sum across periods. - Forecast head: time projection to horizon
H, plus Negative Binomial rate/dispersion heads with stability floors. - Training: NB-NLL (default) with AMP-safe masking; supports direct and recursive decoding; logs sMAPE/NLL and coverage.
# Python ≥ 3.10; PyTorch ≥ 2.1 recommended
pip install -r requirements.txt
# Optional: CUDA/cuDNN for GPUs; AMP is supportedInitiate(Colab)
!git clone https://github.com/ShinDongWoon/Recursive-TimesNet.git
%cd Recursive-TimesNet
!pip install -r requirements.txt
!pip install -e .
# Train
python -m timesnet_forecast.cli train \
--config configs/default.yaml \
--override train.lr=1e-3 window.input_len=336 window.pred_len=24
# Predict (direct or recursive; controlled by config)
python -m timesnet_forecast.cli predict --config configs/default.yaml
# Hyperparameter search (Optuna)
python -m timesnet_forecast.cli tune \
--config configs/default.yaml \
--space configs/search_space.yaml- Structure: A single CSV file consisting of three columns:
영업일자(business date),영업장명_메뉴명(store-menu identifier), and매출수량(sales quantity). - Each row represents the sales quantity for a specific (date, store-menu combination), providing the date in
YYYY-MM-DDformat and an integer sales quantity. - The file is saved in UTF-8 with BOM, so BOM handling should be considered when reading it in environments like Python.
- Consists of 10 files in total, from
TEST_00.csvtoTEST_09.csv, all using the same schema. - The column structure is identical to the training data:
영업일자,영업장명_메뉴명, and매출수량. - Each file contains 5,404 rows (193 store-menu combinations × 28 days of records), providing the most recent 4 weeks of sales history for the subsequent 7-day forecast.
- For analysis, you can read a single test file and sort it by the required store-menu combination and date to preprocess it in the same manner as the training data.
- Composed of 194 columns in total. The first column is
영업일자(e.g.,TEST_00+1일), and the following 193 columns correspond to each store-menu combination. - It consists of 70 rows, where you must fill in 7 days of predictions (
+1dayto+7day) for each test set fromTEST_00toTEST_09. - The sample submission file is also saved in UTF-8 with BOM. You should overwrite the
매출수량prediction values while maintaining the same encoding for submission. - The model's output must be non-negative sales quantity predictions, and the column order and headers must exactly match the sample submission file.
- Schema: map your
date,target,series_idcolumns via config (auto-infer candidates if unspecified). - Loaders: swap train CSV, test directory (e.g.,
TEST_00.csv … TEST_09.csv), and sample submission without code changes. - Features: enable/disable calendar time covariates (day-of-week/day-of-month/month/day-of-year) with configurable cyclical/one-hot/numeric encodings.
- Evaluation: choose holdout or rolling CV, horizon H, and sMAPE/NLL aggregation rules.
- Augmentation: add Gaussian noise and/or time shifts to input windows via
data.augment.
- Your validation holdout must span at least
input_len + pred_lendays.
This ensures each eval window has enough history and produces a full horizon.
- Training writes
artifacts/metadata.json(meta_version=1) with:- window sizes (
input_len,pred_len), - inferred schema (date/target/id),
- enabled time-feature set,
- names of static features aligned to
series_ids.
- window sizes (
- The prediction CLI compares the runtime config with this metadata and fails fast on drift (mismatched schema, window sizes, or feature toggles).
- The first column is now the business date column (
submission.date_col, default영업일자), not an abstract row key. Downstream graders and dashboards can join on calendar directly.
- Default:
preprocess.normalize: "none". - If no scaler is used, the saved scaler artifact is
Noneand the pipeline stays on original units.
- Windows are exact length
input_len, no zero-padding. - The forward pass crops to the first
input_lensteps, so extra history at inference does not change output shape.
from timesnet_forecast.config import PipelineConfig
from timesnet_forecast.train import train_once
from timesnet_forecast.predict import predict_once
cfg = PipelineConfig.from_files(
"configs/default.yaml",
overrides={"window.input_len": 336, "window.pred_len": 24},
)
val_nll, artifacts = train_once(cfg) # (best_nll, paths for checkpoints/scalers/schema/etc.)
submission_path = predict_once(cfg) # CSV written to submission.output_path/out_pathdata:
train_csv: "data/train.csv"
test_dir: "data/test"
sample_submission: "data/sample_submission.csv"
date_col: "영업일자"
target_col: "매출수량"
id_col: "영업장명_메뉴명"
fill_missing_dates: true
augment:
add_noise_std: 0.005
time_shift: 2
preprocess:
normalize: "none"
normalize_per_series: true
clip_negative: true
train:
device: "cuda"
epochs: 70
early_stopping_patience: 5
batch_size: 128
lr: 1.0e-4
amp: true
cuda_graphs: false
compile: false
val:
strategy: "rolling"
holdout_days: 35
rolling_folds: 3
rolling_step_days: 14
model:
mode: "direct"
input_len: 28
pred_len: 7
d_model: 128
d_ff: 256
n_layers: 2
k_periods: 2
kernel_set:
- [3, 3]
- [5, 5]
- [7, 7]
bottleneck_ratio: 4.0
id_embed_dim: 32
static_proj_dim: 32
static_layernorm: true
window:
input_len: 28
pred_len: 7train.cuda_graphs: true→ model is captured in eval mode; dropout disabled.train.use_checkpoint: true→ reduces memory, slower; auto-disabled whencuda_graphs: true.train.compile: truecannot be combined withcuda_graphs: true.
train.deterministic: trueseeds RNGs, disables cuDNN benchmarking, and enables deterministic algorithms — ideal for integration tests and CI.
- During training, the pipeline computes simple statics per series via
compute_series_featuresand stores them in the scaler artifact (or alongside it if normalization is disabled). - These are aligned with
series_ids, so inference can reuse them without recomputation.
- Supply your own statics as a precomputed tensor
[num_series, feature_dim]and pass them through data-loader hooks intrain_once(seeseries_staticarguments to_build_dataloader).
series_ids: per-series integer identifiers (for ID embedding).series_static: optional static covariates[num_series, feature_dim].
model.id_embed_dim(default 32) — width of the learned ID embedding. Set0to disable when IDs lack signal.model.static_proj_dim(default 32) — projection width applied to static covariates before concatenation;nullkeeps raw dimensionality.model.static_layernorm— toggles a LayerNorm after the static projection (recommended when mixing disparate scales).
- ID embedding params ≈
num_series × id_embed_dim- Example:
1,000series ×64→ 64,000 parameters.
- Example:
- Static projection params ≈
static_input_dim × static_proj_dim + static_proj_dim.
Plan GPU memory when increasing these knobs or sweeping them with Optuna.
# Override via CLI
timesnet-forecast train --override model.id_embed_dim=16 model.static_proj_dim=null# Optuna search-space example (search_space.yaml)
model.id_embed_dim:
type: categorical
choices: [0, 16, 32]-
Embeddings → raise semantic SNR
- ID embeddings disambiguate per-series regimes.
- Static/meta features tilt the representation toward series-specific baselines.
- LRTC lets static info modulate temporal filters with a tiny rank budget.
-
2D Inception CNN → structured expressivity
- Period-phase grids expose intra-/inter-period patterns.
- Multi-kernel branches act like “band-pass microscopes” over phase and cycle.
-
Adaptive Probabilistic Head → calibrated outputs
- Negative Binomial (
rate,dispersion) forecasts by default for count-style demand. - Softplus + dispersion floors (driven by
train.min_sigmaor per-series buffers) stabilize training at low signal levels.
- Negative Binomial (
-
Together
- Context sets the playing field; CNN plays the structure; the head keeps score honestly.
Negative-Binomial NLL (default)
Training and validation call negative_binomial_nll(rate, dispersion, y, mask) with masks from negative_binomial_mask so only finite targets contribute. Dispersion floors (from train.min_sigma or per-series vectors) keep the likelihood well-behaved.
sMAPE (reported)
sMAPE = mean( 2|y - ŷ| / (|y| + |ŷ| + ε) ) over valid targets.
- Direct: single forward pass yields
H-step forecasts (lower error accumulation). - Recursive: repeated 1-step predictions rolled over horizon (more flexible with covariates).
- Switch via config: set
model.mode: directormodel.mode: recursive(can override via--override model.mode=recursive).
- Periods look unstable?
- Increase
k_periodsslightly and enable log-bucket merging. - Ensure
min_cycles ≥ 2and DC removal is on.
- Increase
- Over/under-dispersion or poor calibration?
- Adjust
train.min_sigma/min_sigma_scale; enrich statics/ID embeddings; review normalization choices.
- Adjust
- Intermittent zeros dominate?
- Stay with the NB head but add calendar covariates or richer statics to stabilize baselines.
- AMP overflow/NaNs?
- Clamp logits/residuals; keep Softplus β modest; verify mixed-precision safe ops.
- Memory pressure?
- Enable
train.use_checkpoint; reduced_model/d_ff; shrinkkernel_set; raisebottleneck_ratio.
- Enable
- Reproducibility?
- Set
train.deterministic: true; fix seed; pin versions; avoidcuda_graphsunder variant kernels.
- Set
- Dataset: Walmart retail demand (Kaggle)
- Split: rolling CV (7-day horizon), seeded
- Metric: sMAPE (reported), NB NLL (selection)
- Score: ≈ 0.14
- Notes: Negative Binomial head (rate/dispersion) with dispersion floors; embeddings + LRTC on; PeriodGrouper enabled; AMP on GPU.
Reproduce by fixing seed, window.pred_len=7, keeping the NB head active with train.min_sigma, and mirroring the rolling CV profile described in your configuration (this repo does not ship a configs/benchmarks/walmart.yaml).
- FFTPeriodSelector (shared across channels)
Usetorch.fft.rfftwith channel-median and batch averaging. Remove DC, apply log-penalty on long indices, top-k_periods, map to periods, and require≥ 2cycles. - TimesBlock
Gridify per period, run Inception branches with bottlenecks, compute residuals, softmax-weight across grouped periods, and fuse with skip connections. - Performance features
Activation checkpointing (toggle): ↓ memory at ↑ time; auto-off with CUDA graphs.
CUDA graphs (toggle): throughput wins; disables dropout; avoid withtrain.compile.
channels-lastand AMP are supported end-to-end.
- Channel-specific period sub-selection (per-series attention over FFT bins)
- Lightweight exogenous encoder for promotions/events
- Quantile and CRPS training objectives
- Dynamic routing across Inception branches
This project is built upon the foundational concepts and architecture introduced in the original TimesNet paper. The core implementation of the TimesNet model is inspired by the official source code provided by the authors.
- Original Paper: Haixu Wu, et al. “TimesNet: Temporal 2D-Variation Modeling for General Time Series Analysis.” ICLR 2023.
- Official Repository: https://github.com/thuml/TimesNet
- The original TimesNet source code is licensed under the Apache License 2.0. A copy of the license can be found in the NOTICE file within this repository.