-
Notifications
You must be signed in to change notification settings - Fork 104
Description
Problem Statement
Currently, TORAX simulations write final results to NetCDF files (state_history_%Y%m%d_%H%M%S.nc) but lack the ability to save intermediate simulation states and resume from checkpoints. This limitation creates significant challenges for:
- Long-running simulations: If a simulation crashes or is interrupted after hours of computation, all progress is lost
- Iterative workflow optimization: Researchers cannot easily branch from intermediate states to test different parameters
- Resource efficiency: Computing clusters with time limits require ability to save and resume work across multiple job submissions
- Debugging and analysis: Cannot inspect and resume from specific simulation timestamps
- Parameter sweeps: Must restart entire simulations when exploring slight variations from a common baseline
Comparison with Similar Tokamak Simulation Tools
RAPTOR (RApid Plasma Transport simulatOR)
- Restart capabilities: Supports state snapshots for real-time control applications
- Use case: Enables rapid model-based control requiring frequent save/load cycles
- Reference: RAPTOR is designed for real-time execution where checkpoint/restart is essential for control workflows
TokSys Suite
- Checkpoint system: Full support for saving and restoring plasma control simulations
- Integration: Works with GSevolve for closed-loop simulations with restart capability
- Application: ITER PCS development, EAST, KSTAR, DIII-D, and NSTX-U all use TokSys checkpointing
IPS (Integrated Plasma Simulator)
- Robust restart framework: Dedicated
SIMULATION_MODE = RESTARTconfiguration - Features:
RESTART_TIMEparameter to specify checkpoint timestampRESTART_DIRECTORYfor seamless simulation continuation- Automatic archiving of restart files per component
- Support for appending to existing simulation data
- Reference: Used extensively in fusion simulation workflows
OpenMM (Physics Simulation - Analogous Domain)
- Dual save methods:
saveCheckpoint(): Binary format with complete internal state (RNG states, etc.)saveState(): XML format with public data only (positions, velocities)
- Resume capability:
loadCheckpoint()restores exact simulation state - Periodic checkpointing: Automatic saves at configurable intervals during long runs
Key Gap in TORAX
TORAX currently lacks any equivalent checkpoint/restart mechanism, despite being positioned as a tool for "pulse-design, trajectory optimization, and controller design workflows" where interrupted or iterative execution is common.
Proposed Solution
Implement a comprehensive checkpoint/restart system integrated with TORAX's existing JAX-based architecture and NetCDF output format.
Core Components
1. Checkpoint Data Structure
Create a checkpoint that captures complete simulation state:
@dataclass
class ToraxCheckpoint:
"""Complete checkpoint containing all data needed to resume simulation."""
# Simulation metadata
checkpoint_version: str # For backward compatibility
torax_version: str
timestamp: float
step_number: int
# Core simulation state
core_profiles: CoreProfiles # All plasma profiles at checkpoint time
core_sources: CoreSources # Source terms and derived quantities
geometry: Geometry # Geometric configuration
# Solver state
solver_state: Any # Internal solver state (for exact reproducibility)
stepper_state: Any # Time stepping state
# Transport model state
transport_model_state: Dict[str, Any] # State of QLKNN or other transport models
# JAX-specific state
jax_rng_state: Any # Random number generator state for stochastic elements
# Simulation parameters at checkpoint
runtime_params: RuntimeParams
config_snapshot: Dict[str, Any] # Full config at checkpoint time
# Compilation cache hints (optional, for performance)
cache_metadata: Optional[Dict[str, Any]] = None2. Save Checkpoint Function
def save_checkpoint(
sim: Sim,
checkpoint_path: str | Path,
include_cache_hints: bool = False,
compression: str = 'gzip',
) -> None:
"""
Save simulation checkpoint to disk.
Args:
sim: Current Sim object with all state
checkpoint_path: Path where checkpoint will be saved
include_cache_hints: Whether to save JAX compilation cache metadata
compression: NetCDF compression algorithm ('gzip', 'zlib', or None)
The checkpoint is saved as a NetCDF file for consistency with TORAX outputs,
but with a special structure flag identifying it as a checkpoint.
"""
checkpoint = ToraxCheckpoint(
checkpoint_version="1.0.0",
torax_version=torax.__version__,
timestamp=sim.time,
step_number=sim.step_count,
core_profiles=sim.core_profiles,
core_sources=sim.core_sources,
geometry=sim.geometry,
solver_state=sim.solver.get_state(), # New method needed
stepper_state=sim.stepper.get_state(), # New method needed
transport_model_state=sim.transport_model.get_state(),
jax_rng_state=jax.random.get_state(), # If using random elements
runtime_params=sim.runtime_params,
config_snapshot=serialize_config(sim.config),
)
# Serialize to NetCDF with xarray
ds = checkpoint_to_xarray(checkpoint)
ds.attrs['checkpoint_file'] = True
ds.attrs['checkpoint_version'] = checkpoint.checkpoint_version
encoding = {}
if compression:
encoding = {var: {'compression': compression, 'complevel': 4}
for var in ds.data_vars}
ds.to_netcdf(checkpoint_path, encoding=encoding)
logger.info(f"Checkpoint saved to {checkpoint_path} at t={sim.time:.3f}s, step={sim.step_count}")3. Load Checkpoint Function
def load_checkpoint(
checkpoint_path: str | Path,
config: Optional[Config] = None,
override_runtime_params: Optional[Dict] = None,
) -> Sim:
"""
Load simulation from checkpoint and return initialized Sim object.
Args:
checkpoint_path: Path to checkpoint file
config: Optional new config (must be compatible). If None, uses checkpointed config.
override_runtime_params: Optional dict to override specific runtime parameters
Returns:
Sim object initialized from checkpoint, ready to continue simulation
Raises:
CheckpointIncompatibleError: If checkpoint version or structure incompatible
ConfigMismatchError: If provided config incompatible with checkpoint
"""
# Load checkpoint data
ds = xr.open_dataset(checkpoint_path)
if not ds.attrs.get('checkpoint_file', False):
raise ValueError(f"{checkpoint_path} is not a valid TORAX checkpoint file")
checkpoint = xarray_to_checkpoint(ds)
# Version compatibility check
if not is_compatible_version(checkpoint.checkpoint_version):
logger.warning(f"Checkpoint version {checkpoint.checkpoint_version} may be incompatible")
# Use checkpointed config or validate provided config
if config is None:
config = deserialize_config(checkpoint.config_snapshot)
else:
validate_config_compatibility(config, checkpoint.config_snapshot)
# Apply runtime parameter overrides
runtime_params = checkpoint.runtime_params
if override_runtime_params:
runtime_params = update_runtime_params(runtime_params, override_runtime_params)
# Reconstruct Sim object
sim = Sim(config=config)
# Restore complete state
sim.time = checkpoint.timestamp
sim.step_count = checkpoint.step_number
sim.core_profiles = checkpoint.core_profiles
sim.core_sources = checkpoint.core_sources
sim.geometry = checkpoint.geometry
sim.runtime_params = runtime_params
# Restore solver internal state
sim.solver.set_state(checkpoint.solver_state)
sim.stepper.set_state(checkpoint.stepper_state)
sim.transport_model.set_state(checkpoint.transport_model_state)
# Restore JAX RNG state if present
if checkpoint.jax_rng_state is not None:
jax.random.set_state(checkpoint.jax_rng_state)
logger.info(f"Checkpoint loaded from {checkpoint_path}, resuming at t={sim.time:.3f}s, step={sim.step_count}")
return sim4. Automatic Checkpoint Saving
Integrate with the existing simulation loop to enable periodic automatic checkpointing:
@dataclass
class CheckpointConfig:
"""Configuration for automatic checkpointing during simulation."""
enabled: bool = False
interval_steps: Optional[int] = None # Save every N steps
interval_time: Optional[float] = None # Save every T seconds of simulation time
interval_walltime: Optional[float] = None # Save every W minutes of wall clock time
checkpoint_dir: str = "/tmp/torax_checkpoints"
keep_last_n: int = 3 # Keep only last N checkpoints to save disk space
checkpoint_prefix: str = "checkpoint"
compression: str = 'gzip' # Compression for checkpoint files
# Advanced options
save_on_nan: bool = True # Auto-save if NaN detected (for debugging)
save_on_error: bool = True # Auto-save before raising exceptions
def run_simulation_with_checkpointing(
config: Config,
checkpoint_config: CheckpointConfig,
resume_from: Optional[str] = None,
) -> StateHistory:
"""
Run TORAX simulation with automatic checkpoint saving.
Args:
config: TORAX configuration
checkpoint_config: Checkpoint configuration
resume_from: Optional path to checkpoint to resume from
Returns:
StateHistory object with complete simulation results
"""
# Initialize or resume simulation
if resume_from:
sim = load_checkpoint(resume_from, config=config)
logger.info(f"Resuming simulation from checkpoint at t={sim.time:.3f}s")
else:
sim = Sim(config=config)
logger.info("Starting new simulation")
# Setup checkpoint manager
checkpoint_manager = CheckpointManager(checkpoint_config)
try:
while sim.time < sim.config.end_time:
# Take simulation step
sim.step()
# Check if checkpoint should be saved
if checkpoint_manager.should_save(sim):
checkpoint_path = checkpoint_manager.get_checkpoint_path(sim)
save_checkpoint(sim, checkpoint_path)
checkpoint_manager.cleanup_old_checkpoints()
# Check for NaN or other errors
if checkpoint_config.save_on_nan and has_nan(sim.core_profiles):
logger.error("NaN detected in simulation, saving emergency checkpoint")
emergency_path = checkpoint_manager.get_emergency_checkpoint_path(sim)
save_checkpoint(sim, emergency_path)
raise SimulationError("NaN detected in core profiles")
except Exception as e:
if checkpoint_config.save_on_error:
logger.error(f"Error during simulation: {e}, saving checkpoint")
error_path = checkpoint_manager.get_error_checkpoint_path(sim)
save_checkpoint(sim, error_path)
raise
return sim.get_state_history()5. Command-Line Interface Updates
Extend the run_torax CLI to support checkpoint operations:
# Run with automatic checkpointing every 100 steps
run_torax --config=examples/iterhybrid_rampup.py \
--checkpoint-interval-steps=100 \
--checkpoint-dir=/path/to/checkpoints \
--checkpoint-keep-last=5
# Resume from checkpoint
run_torax --config=examples/iterhybrid_rampup.py \
--resume-from=/path/to/checkpoints/checkpoint_t10.5_step1000.nc
# Resume and override some runtime parameters
run_torax --config=examples/iterhybrid_rampup.py \
--resume-from=/path/to/checkpoints/checkpoint_t10.5_step1000.nc \
--override-params='{"transport.model": "constant", "heating.power": 60e6}'6. Configuration File Integration
Add checkpoint section to TORAX config files:
# In config file (e.g., examples/iterhybrid_rampup.py)
from torax import checkpoint
config = {
# ... existing config ...
'checkpoint': {
'enabled': True,
'interval_time': 1.0, # Save every 1 second of simulation time
'checkpoint_dir': './checkpoints',
'keep_last_n': 5,
'compression': 'gzip',
'save_on_nan': True,
}
}Implementation Phases
Phase 1: Core Checkpoint Infrastructure (Priority: High)
- Design and implement
ToraxCheckpointdataclass - Implement
save_checkpoint()function with NetCDF serialization - Implement
load_checkpoint()function with validation - Add state getters/setters to
Solver,Stepper, and transport model classes - Write comprehensive unit tests for save/load cycle
- Ensure exact reproducibility: verify that resumed simulation produces identical results
Phase 2: Automatic Checkpointing (Priority: High)
- Implement
CheckpointConfigandCheckpointManagerclasses - Integrate checkpoint saving into main simulation loop
- Add checkpoint cleanup/rotation logic (keep last N checkpoints)
- Implement wall-clock time based checkpointing for long runs
- Add emergency checkpointing on NaN/error detection
Phase 3: CLI and Configuration Integration (Priority: Medium)
- Add checkpoint-related flags to
run_toraxCLI - Add
--resume-fromflag for loading checkpoints - Add
--override-paramsflag for modifying resumed simulations - Update config file format to support checkpoint settings
- Create user-friendly checkpoint directory management
Phase 4: Validation and Testing (Priority: High)
- Create test suite for checkpoint/restart functionality
- Validate exact reproducibility across different JAX configurations
- Test checkpoint version compatibility
- Benchmark checkpoint save/load performance overhead
- Test checkpoint file size and compression effectiveness
- Verify compatibility with different transport models (QLKNN_7_11, QLKNN10D, constant)
Phase 5: Documentation and Examples (Priority: Medium)
- Add checkpoint/restart documentation to readthedocs
- Create tutorial notebook demonstrating checkpoint usage
- Add examples showing parameter sweep from checkpoint
- Document checkpoint file format specification
- Add troubleshooting guide for checkpoint issues
- Document best practices for long-running HPC simulations
Technical Considerations
JAX Compatibility
- JAX's functional design and JIT compilation require careful handling of state
- Use
jax.tree_util.tree_mapfor serializing PyTree structures - Checkpoint should be agnostic to JAX backend (CPU/GPU/TPU)
- Test that resumed simulations maintain JIT compilation cache when possible
Backward Compatibility
- Implement checkpoint versioning system (semver-style)
- Provide migration tools for older checkpoint formats
- Gracefully handle checkpoint version mismatches with clear error messages
- Document which TORAX versions are compatible with which checkpoint versions
Performance Optimization
- Checkpoint saves should have minimal performance impact (<5% overhead)
- Use asynchronous I/O where possible to avoid blocking simulation
- Implement incremental checkpointing for very large simulations (save only changed data)
- Consider checkpoint compression trade-offs (time vs. storage)
File Format Considerations
- Use NetCDF for consistency with existing TORAX outputs
- NetCDF supports:
- Efficient compression
- Partial loading (can inspect checkpoint metadata without loading full state)
- Wide tool compatibility (xarray, Paraview, etc.)
- Append-mode for continuous time series
- Include human-readable metadata (TORAX version, timestamp, step number, etc.)
Error Handling
- Validate checkpoint integrity before loading (checksums)
- Provide clear error messages for corrupted checkpoints
- Implement checkpoint verification tool
- Handle missing/incompatible solver states gracefully
Testing Strategy
Unit Tests
- Test save/load cycle preserves all state variables
- Test checkpoint compression works correctly
- Test version compatibility checks
- Test error handling for corrupted checkpoints
Integration Tests
- Test complete simulation workflow with checkpointing
- Test resume from checkpoint produces identical trajectories
- Test checkpoint rotation and cleanup
- Test emergency checkpointing on errors
Performance Tests
- Measure checkpoint save time vs. simulation step time
- Measure checkpoint file size growth
- Test scalability with large simulations (>1000 time steps)
- Benchmark compression algorithms (gzip, zlib, lz4)
Validation Tests
- Verify bit-exact reproducibility when resuming
- Test across different JAX backends (CPU, GPU)
- Test with different transport models
- Test with different solver configurations
User Benefits
- Reliability: Long simulations can be safely interrupted and resumed
- Efficiency: No wasted computation if job exceeds cluster time limits
- Flexibility: Explore different parameter branches from common checkpoints
- Debugging: Save state before failures for post-mortem analysis
- Workflows: Enable complex multi-stage simulation workflows
- Reproducibility: Checkpoint includes full configuration for exact reproduction
Example Use Cases
Use Case 1: Long HPC Simulation
# Submit job that runs for 12 hours max, auto-checkpointing
run_torax --config=long_simulation.py \
--checkpoint-interval-walltime=60 # Save every hour
--checkpoint-keep-last=3
# If job times out, resume from last checkpoint in next job
run_torax --config=long_simulation.py \
--resume-from=./checkpoints/checkpoint_latest.ncUse Case 2: Parameter Sweep from Baseline
# Run baseline simulation with checkpointing
sim = run_simulation_with_checkpointing(
config=baseline_config,
checkpoint_config=CheckpointConfig(
enabled=True,
interval_time=5.0, # Save every 5 seconds
)
)
# Resume from t=10s checkpoint and try different heating power
checkpoint_t10 = "./checkpoints/checkpoint_t10.0.nc"
for power in [30e6, 40e6, 50e6, 60e6]:
override_params = {'heating.power': power}
sim = load_checkpoint(
checkpoint_t10,
override_runtime_params=override_params
)
results = continue_simulation(sim, end_time=20.0)
analyze_results(results, power)Use Case 3: Debug NaN Failure
# Simulation that fails with NaN at t=15.3s
# Emergency checkpoint automatically saved due to save_on_nan=True
# Load emergency checkpoint and inspect state
checkpoint = xr.open_dataset("./checkpoints/emergency_t15.3.nc")
print(checkpoint.core_profiles) # Inspect profiles at failure point
# Resume with modified parameters to avoid instability
sim = load_checkpoint(
"./checkpoints/checkpoint_t15.0.nc", # Load from just before failure
override_runtime_params={
'solver.predictor_corrector': True, # Try more stable solver
'time_step_calculator.dt': 0.005, # Reduce time step
}
)
continue_simulation(sim) # Try to get past the problematic regionUse Case 4: Trajectory Optimization
# Gradient-based optimization of heating trajectory
# Uses checkpoints to enable efficient gradient computation
def objective_function(heating_schedule):
# Load checkpoint at optimization start point
sim = load_checkpoint("./checkpoints/optimization_start.nc")
# Apply heating schedule
sim.runtime_params.heating.schedule = heating_schedule
# Run forward simulation
results = continue_simulation(sim, end_time=20.0)
# Compute objective (e.g., maximize Q while keeping beta < threshold)
return compute_performance_metric(results)
# Use JAX autodiff for gradient-based optimization
optimal_schedule = jax.scipy.optimize.minimize(
objective_function,
initial_heating_schedule,
method='BFGS'
)Open Questions for Discussion
- Checkpoint frequency defaults: What are sensible defaults for
interval_stepsandinterval_time? - Disk space management: Should we implement automatic cleanup policies based on disk usage?
- Distributed computing: How should checkpointing work in multi-node/multi-GPU scenarios?
- Cloud storage integration: Should we support cloud storage backends (S3, GCS) for checkpoints?
- Differential privacy: Should sensitive simulation data in checkpoints be encrypted?
Related Issues & References
- TORAX v1.0.0 Release Notes: Mentions output structure standardization
- RAPTOR validation: TORAX was verified against RAPTOR (which has restart capability)
- Similar feature in OpenMM: https://docs.openmm.org/latest/api-python/generated/openmm.app.simulation.Simulation.html
- IPS Framework Config: https://ips-framework.readthedocs.io/en/latest/user_guides/config_file.html
Community Input Requested
- Are there specific use cases for checkpointing not covered above?
- What checkpoint metadata would be most valuable for your workflows?
- Should checkpointing be enabled by default for all simulations?
- What file naming conventions would be most useful?