This repository contains a suite of tools and scripts for training and inference of machine learning models. The suite includes functionalities for data preprocessing, model training, evaluation, and deployment.
If you're used to training models in Jupyter notebooks, you might wonder why you need all this infrastructure. Here's the short version:
Jupyter is great for:
- Prototyping and experimentation
- Interactive debugging and visualization
- Quick iterations on small datasets
This setup is designed for:
- Training that takes hours/days (you can't keep a notebook open that long)
- Using multiple GPUs (distributed training)
- Running on HPC clusters with job schedulers (SLURM)
- Reproducible experiments with version-controlled configs
- Automatic recovery from crashes or time limits
Key differences from Jupyter:
| Jupyter Notebook | This Training Suite |
|---|---|
| Run cells manually | Run a single script that handles everything |
| Lose progress if kernel dies | Automatic checkpointing and resumption |
| Print statements for logging | Structured logging + WandB dashboards |
| Hardcoded hyperparameters | YAML config files (easy to track and modify) |
| Single GPU | Multi-GPU and multi-node support |
| Keep notebook open | Submit job and check results later |
The learning curve is worth it: once you understand this setup, you can train models for days on powerful hardware without babysitting them.
- Modular design for easy integration and extension
- torch.compile (with memory constraints)
- Automatic Mixed Precision (AMP) support
- Distributed training capabilities
- WandB integration for experiment tracking
- Configurable via YAML files
- Checkpointing and resuming training support
- Time keeping for graceful shutdowns and resuming on HPCs
This section explains the advanced features you'll encounter. If you've only used basic PyTorch in Jupyter, these might be new to you.
What it is: By default, PyTorch uses 32-bit floating point numbers (float32) for all computations. AMP automatically uses lower precision (float16 or bfloat16) where it's safe, while keeping float32 where precision matters.
Why use it:
- 2x faster training (or more) on modern GPUs
- Lower memory usage = larger batch sizes
- Almost no accuracy loss when done correctly
How it works in this codebase:
# Without AMP (what you might write in Jupyter):
output = model(x)
loss = criterion(output, target)
loss.backward()
# With AMP (what this codebase does):
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
output = model(x)
loss = criterion(output, target)
scaler.scale(loss).backward() # GradScaler prevents underflowConfig options:
amp: true # Enable/disable AMP
precision: bfloat16 # Use bfloat16 (recommended for A100/H100) or float16When to disable: If you see NaN losses or unstable training, try amp: false to rule out precision issues.
What it is: A PyTorch 2.0+ feature that JIT-compiles your model for faster execution. Think of it as an optimizer that fuses operations and generates optimized GPU kernels.
Why use it:
- 10-30% speedup on many models (sometimes more)
- No code changes required to your model
The catch:
- First batch is slow (compilation time)
- Not all models compile successfully
- Some dynamic operations aren't supported
Config options:
compile: true # Enable torch.compile
mem_budget: 1 # Memory budget (1 = full memory, <1 = gradient checkpointing)Gradient checkpointing via mem_budget: When mem_budget < 1, the compiled model trades compute for memory by recomputing activations during the backward pass instead of storing them all. Set to 0.5 to use roughly half the memory at the cost of ~20-30% more compute.
What it is: Training on multiple GPUs simultaneously, where each GPU processes different batches and gradients are synchronized.
Why use it:
- Near-linear speedup with more GPUs (4 GPUs ≈ 4x faster)
- Larger effective batch sizes without memory limits
How it works:
- Your model is copied to each GPU
- Each GPU gets different data (via DistributedSampler)
- Each GPU computes gradients independently
- Gradients are averaged across all GPUs (all-reduce)
- All GPUs update with the same averaged gradients
Running with multiple GPUs: Instead of python train.py, use:
# 4 GPUs on one machine
torchrun --standalone --nproc_per_node=4 run_training.py --config_path config.yamlImportant: The batch_size in config is per GPU. With 4 GPUs and batch_size=64, your effective batch size is 256.
What it is: A cloud service for tracking ML experiments. Think of it as a supercharged TensorBoard with collaboration features.
Why use it:
- Track metrics across runs without scrolling through terminal output
- Compare experiments side-by-side with interactive charts
- Log hyperparameters automatically for reproducibility
- Share results with collaborators via web dashboard
Setup:
- Create a free account at wandb.ai
- Run
wandb loginor add your API key to.env - Set project/entity in config
What gets logged: Loss curves, learning rate, validation metrics, system stats (GPU memory, utilization).
What it is: Saving the complete training state (model weights, optimizer state, learning rate scheduler, epoch number) so you can continue training later.
Why it matters:
- HPC time limits: Clusters often limit jobs to 24-48 hours. Checkpointing lets you chain jobs.
- Crash recovery: If training crashes at hour 23, you don't lose everything.
- Experimentation: Train for a while, then try different fine-tuning strategies.
What gets saved:
- Model weights
- Optimizer state (momentum, Adam statistics)
- Learning rate scheduler state
- Number of samples/batches trained
- GradScaler state (for AMP)
Config options:
checkpoint:
checkpoint_name: latest # Load "latest", "best", or epoch number (e.g., "5")
restart: true # true = continue training, false = only load weights (for fine-tuning)What it is: The training loop monitors elapsed time and gracefully stops before hitting the job time limit, saving a checkpoint first.
Why it matters: On SLURM clusters, if your job hits the time limit, it gets killed immediately (SIGKILL). Any unsaved progress is lost. This suite estimates how long epochs take and stops early to save.
Config:
time_limit: "24:00:00" # Should be slightly less than SLURM --timeWhat it is: A human-readable format for configuration files. Better than hardcoding hyperparameters in Python.
Why use YAML instead of command-line args:
- Readable: Easy to see all settings at once
- Version controllable: Track config changes in git
- Reproducible: Copy the config to reproduce an experiment exactly
- Complex structures: Nested configs for model, optimizer, scheduler
Example structure:
dataset:
name: my_dataset
train_split: 0.8
model:
type: transformer
num_layers: 12
optimizer:
name: AdamW
learning_rate: 1e-4What it is: A file containing machine-specific paths and secrets that shouldn't be in version control.
Why use it:
- Security: API keys stay out of git
- Portability: Same code works on laptop and HPC with different paths
- Flexibility: Easy to switch between machines
Example .env:
WANDB_API_KEY=your_secret_key_here
BASE_DIR=/home/user/ml-training-suite # Path to this repo
DATA_DIR=/scratch/datasets # Where your data lives
RESULTS_DIR=/scratch/results # Where checkpoints goML-Training-Suite/
├── .env # Your environment variables (create this)
├── pyproject.toml # Python dependencies (managed by uv)
├── ml_suite/
│ ├── data/
│ │ ├── dataset.py # Dataset class (customize this)
│ │ └── dataloader.py # DataLoader factory with DDP support
│ ├── models/
│ │ ├── model_utils.py # get_model() function (customize this)
│ │ ├── loss_fns.py # Loss functions (MSE, MAE, RMSE, etc.)
│ │ └── unet.py # Example model architecture
│ └── train/
│ ├── train_base.py # Core Trainer class (don't modify unless needed)
│ ├── eval.py # Evaluator class
│ ├── run_training.py # CLI entry point (config loading, setup)
│ ├── train.yml # Example config (copy to results dir)
│ ├── scripts/
│ │ └── train_riv.sh # SLURM job script template
│ └── utils/
│ ├── checkpoint_utils.py # Save/load checkpoints
│ ├── lr_scheduler.py # Chained LR scheduler
│ ├── optimizer.py # Optimizer factory
│ ├── time_keeper.py # HPC time limit handling
│ ├── wandb_logger.py # WandB integration
│ └── logger.py # Console logging
└── tests/ # Unit tests
Understanding the separation:
The code is split into three layers:
-
Pure Python logic (
train_base.py,eval.py): Clean, testable classes that take Python objects as input. No file I/O, no config parsing. This is what you'd write in a well-structured Jupyter notebook. -
CLI glue code (
run_training.py): Reads configs, loads checkpoints, sets up distributed training, creates objects, and passes them to the Trainer. This is the "ugly but necessary" code that connects configs to the pure logic. -
Shell scripts (
scripts/train_riv.sh): SLURM job submission and environment setup. Different for each cluster.
When adapting for your project, you'll mainly modify:
ml_suite/data/dataset.py- Your data loading logicml_suite/models/model_utils.py- Your model architectures- Config YAML files - Your hyperparameters
Fork or copy the repository:
git clone https://github.com/your-username/ML-Training-Suite.git
cd ML-Training-SuiteInstall uv (Python package manager):
uv is a fast Python package manager. It's like pip + venv but much faster.
# On Linux/Mac
curl -LsSf https://astral.sh/uv/install.sh | sh
# Or with pip
pip install uvInstall dependencies:
uv sync --extra devThis creates a virtual environment in .venv/ and installs everything, including PyTorch with CUDA support. The exact versions are locked in uv.lock for reproducibility.
Create a .env file in the repository root:
# .env
WANDB_API_KEY=your_api_key_here # Get from wandb.ai/settings
BASE_DIR=/path/to/this/repo # Absolute path to ML-Training-Suite
DATA_DIR=/path/to/your/datasets # Where your training data lives
RESULTS_DIR=/path/to/results # Where checkpoints and logs goWhy separate directories?
BASE_DIR: The code. Should be the same across machines (or cloned separately).DATA_DIR: Your datasets. Might be on a fast NVMe drive or shared filesystem.RESULTS_DIR: Checkpoints and logs. Often on scratch space that doesn't count against quota.
- Create a free account at wandb.ai
- Go to Settings → API Keys and copy your key
- Add it to your
.envfile - Create a new project on WandB (or the suite will create one automatically)
Edit ml_suite/models/model_utils.py:
def get_model(config: dict) -> torch.nn.Module:
"""Factory function that returns a model based on config."""
model_type = config.get("type", "transformer")
if model_type == "my_custom_model":
return MyCustomModel(
input_dim=config["input_dim"],
hidden_dim=config["hidden_dim"],
# ... other params from config
)
elif model_type == "transformer":
# existing code...Tip: Keep your model as a standard nn.Module. The training loop handles AMP, DDP wrapping, and compilation automatically.
Edit ml_suite/data/dataset.py:
def get_dataset(config: dict, split: str = "train") -> Dataset:
"""
Factory function that returns a dataset.
Args:
config: Dataset configuration from YAML
split: "train" or "valid"
Returns:
A PyTorch Dataset that returns (input, target) tuples
"""
# Your dataset loading logic here
# Must return (input_tensor, target_tensor) per itemImportant: Your dataset must return (input, target) tuples. The training loop expects this format.
- Create a results directory for your experiment:
mkdir -p $RESULTS_DIR/my_experiment- Copy and edit the config:
cp ml_suite/train/train.yml $RESULTS_DIR/my_experiment/train.yml- Edit the config file. Key sections to modify:
# Dataset - your data parameters
dataset:
name: my_dataset
train_split: 0.8
# Model - must match get_model()
model:
type: my_custom_model
input_dim: 128
hidden_dim: 256
criterion: MSE
# Training parameters
batch_size: 64 # Per GPU!
total_updates: 100000 # Total gradient updates
updates_per_epoch: 1000 # Updates between evaluations
# Optimizer
optimizer:
name: AdamW
learning_rate: 1e-4
weight_decay: 0.01
# WandB
wandb:
enabled: true
project: my-project
entity: my-usernameLocal machine with single GPU:
source .venv/bin/activate
python ml_suite/train/run_training.py --config_path $RESULTS_DIR/my_experiment/train.ymlLocal machine with multiple GPUs:
source .venv/bin/activate
torchrun --standalone --nproc_per_node=4 ml_suite/train/run_training.py --config_path $RESULTS_DIR/my_experiment/train.ymlOn SLURM cluster:
-
Edit
ml_suite/train/scripts/train_riv.sh:- Set
#SBATCH --account=your_account - Set
#SBATCH --gres=gpu:a100:4(or your GPU type) - Update
sim_nameto match your experiment folder name
- Set
-
Submit the job:
sbatch ml_suite/train/scripts/train_riv.shCheck logs:
tail -f $RESULTS_DIR/my_experiment/training.logView WandB dashboard: Go to wandb.ai/your-username/your-project
Resume from checkpoint: Set in your config:
checkpoint:
checkpoint_name: latest # or "best" or epoch number like "5"
restart: trueThis suite uses gradient updates (batches) instead of epochs to measure training progress:
total_updates: 100000 # Stop after 100k gradient updates
updates_per_epoch: 1000 # Run validation every 1000 updatesWhy? Epochs are confusing when dataset size or batch size changes:
- 10 epochs with batch_size=32 on 10k samples = 3,125 updates
- 10 epochs with batch_size=64 on 50k samples = 7,812 updates
Updates are consistent regardless of dataset/batch size. If you prefer epochs, calculate: updates = epochs * (dataset_size / batch_size).
The suite supports chained schedulers - multiple phases that run sequentially:
lr_scheduler:
first_stage: # Warmup: gradually increase LR
name: LinearLR
start_factor: 0.001
end_factor: 1.0
num_updates: 5000
second_stage: # Main training: cosine decay
name: CosineAnnealingLR
num_updates: -1 # -1 = use remaining updates
end_factor: 0.01 # End at 1% of peak LR
third_stage: # Cooldown: decay to zero
name: LinearLR
end_factor: 0
num_updates: 10This creates the classic "warmup → cosine decay → cooldown" schedule that works well for most models.
The restart parameter controls what happens when loading a checkpoint:
checkpoint:
checkpoint_name: latest
restart: true # Continue exactly where you left offrestart: true: Load model, optimizer, scheduler, and training state. Training continues seamlessly.restart: false: Load only model weights. Use fresh optimizer/scheduler. Good for fine-tuning with different hyperparameters.
-
Start simple: Disable
compileandampinitially. Get your model training first, then enable optimizations. -
Use WandB from day one. Even for quick experiments. You'll thank yourself when you need to compare runs.
-
Start with small
total_updates(e.g., 1000) to verify everything works before long runs. -
Default LR schedule: Linear warmup (5% of training) + cosine annealing works for most models.
-
If training is unstable: Try gradient clipping with
max_grad_norm: 1.0. -
Memory issues? Reduce
batch_sizeor setmem_budget: 0.5for gradient checkpointing. -
Debug on CPU first: Set
compile: falseand use a small subset of data. GPU debugging is painful.
"CUDA out of memory"
- Reduce
batch_sizein config - Set
mem_budget: 0.5(or lower) for gradient checkpointing - Disable
compile(compiled models use more memory initially)
NaN losses
- Disable AMP:
amp: false - Lower learning rate
- Add gradient clipping:
max_grad_norm: 1.0 - Check your data for NaN/Inf values
torch.compile errors
- Set
compile: false- not all models compile successfully - Check for dynamic shapes in your model (variable sequence lengths, etc.)
- Compile works best with static shapes
Training doesn't resume from checkpoint
- Verify
checkpoint_name: latestis set in config - Check that
latest.ptexists in your results directory - Make sure
restart: trueif you want to continue training
WandB not logging
- Check
WANDB_API_KEYin.env - Verify
wandb.enabled: truein config - Run
wandb loginmanually to test
"Address already in use" with torchrun
- A previous run didn't clean up properly
- Kill orphan processes:
pkill -f torchrun - Or wait a minute for the port to be released
SLURM job immediately fails
- Check output file in
results/00_slrm_logs/ - Verify paths in
.envare absolute paths - Make sure the venv exists:
ls $BASE_DIR/.venv/bin/python
Test your model independently:
# In a Python shell
from ml_suite.models.model_utils import get_model
config = {"type": "my_model", "input_dim": 128}
model = get_model(config)
x = torch.randn(2, 128) # Small batch
y = model(x)
print(y.shape) # Verify output shapeTest your dataset:
from ml_suite.data.dataset import get_dataset
config = {"name": "my_dataset"}
ds = get_dataset(config, split="train")
x, y = ds[0]
print(x.shape, y.shape) # Verify shapes match model expectationsRun a quick sanity check:
# Single GPU, small run
python run_training.py --config_path config.yml
# Set total_updates: 100 in config for a quick test| Term | Meaning |
|---|---|
| AMP | Automatic Mixed Precision - using float16/bfloat16 for faster training |
| Batch size | Number of samples processed before a gradient update |
| Checkpoint | Saved state of model, optimizer, scheduler for resumption |
| DDP | DistributedDataParallel - PyTorch's multi-GPU training wrapper |
| Epoch | One complete pass through the training dataset |
| Gradient clipping | Limiting gradient magnitude to prevent exploding gradients |
| GradScaler | Scales loss in AMP to prevent underflow in float16 |
| HPC | High-Performance Computing cluster |
| LR | Learning Rate |
| NCCL | NVIDIA Collective Communications Library - for GPU-to-GPU communication |
| SLURM | Job scheduler used on HPC clusters |
| torchrun | PyTorch's launcher for distributed training |
| Update | One gradient update step (forward + backward + optimizer step) |
| Warmup | Gradually increasing LR at start of training for stability |
| World size | Total number of GPUs/processes in distributed training |