Skip to content

Feature Request: Simulation Checkpoint and Restart Capability #1679

@Aaryan-549

Description

@Aaryan-549

Problem Statement

Currently, TORAX simulations write final results to NetCDF files (state_history_%Y%m%d_%H%M%S.nc) but lack the ability to save intermediate simulation states and resume from checkpoints. This limitation creates significant challenges for:

  • Long-running simulations: If a simulation crashes or is interrupted after hours of computation, all progress is lost
  • Iterative workflow optimization: Researchers cannot easily branch from intermediate states to test different parameters
  • Resource efficiency: Computing clusters with time limits require ability to save and resume work across multiple job submissions
  • Debugging and analysis: Cannot inspect and resume from specific simulation timestamps
  • Parameter sweeps: Must restart entire simulations when exploring slight variations from a common baseline

Comparison with Similar Tokamak Simulation Tools

RAPTOR (RApid Plasma Transport simulatOR)

  • Restart capabilities: Supports state snapshots for real-time control applications
  • Use case: Enables rapid model-based control requiring frequent save/load cycles
  • Reference: RAPTOR is designed for real-time execution where checkpoint/restart is essential for control workflows

TokSys Suite

  • Checkpoint system: Full support for saving and restoring plasma control simulations
  • Integration: Works with GSevolve for closed-loop simulations with restart capability
  • Application: ITER PCS development, EAST, KSTAR, DIII-D, and NSTX-U all use TokSys checkpointing

IPS (Integrated Plasma Simulator)

  • Robust restart framework: Dedicated SIMULATION_MODE = RESTART configuration
  • Features:
    • RESTART_TIME parameter to specify checkpoint timestamp
    • RESTART_DIRECTORY for seamless simulation continuation
    • Automatic archiving of restart files per component
    • Support for appending to existing simulation data
  • Reference: Used extensively in fusion simulation workflows

OpenMM (Physics Simulation - Analogous Domain)

  • Dual save methods:
    • saveCheckpoint(): Binary format with complete internal state (RNG states, etc.)
    • saveState(): XML format with public data only (positions, velocities)
  • Resume capability: loadCheckpoint() restores exact simulation state
  • Periodic checkpointing: Automatic saves at configurable intervals during long runs

Key Gap in TORAX

TORAX currently lacks any equivalent checkpoint/restart mechanism, despite being positioned as a tool for "pulse-design, trajectory optimization, and controller design workflows" where interrupted or iterative execution is common.

Proposed Solution

Implement a comprehensive checkpoint/restart system integrated with TORAX's existing JAX-based architecture and NetCDF output format.

Core Components

1. Checkpoint Data Structure

Create a checkpoint that captures complete simulation state:

@dataclass
class ToraxCheckpoint:
    """Complete checkpoint containing all data needed to resume simulation."""
    
    # Simulation metadata
    checkpoint_version: str  # For backward compatibility
    torax_version: str
    timestamp: float
    step_number: int
    
    # Core simulation state
    core_profiles: CoreProfiles  # All plasma profiles at checkpoint time
    core_sources: CoreSources    # Source terms and derived quantities
    geometry: Geometry           # Geometric configuration
    
    # Solver state
    solver_state: Any  # Internal solver state (for exact reproducibility)
    stepper_state: Any # Time stepping state
    
    # Transport model state
    transport_model_state: Dict[str, Any]  # State of QLKNN or other transport models
    
    # JAX-specific state
    jax_rng_state: Any  # Random number generator state for stochastic elements
    
    # Simulation parameters at checkpoint
    runtime_params: RuntimeParams
    config_snapshot: Dict[str, Any]  # Full config at checkpoint time
    
    # Compilation cache hints (optional, for performance)
    cache_metadata: Optional[Dict[str, Any]] = None

2. Save Checkpoint Function

def save_checkpoint(
    sim: Sim,
    checkpoint_path: str | Path,
    include_cache_hints: bool = False,
    compression: str = 'gzip',
) -> None:
    """
    Save simulation checkpoint to disk.
    
    Args:
        sim: Current Sim object with all state
        checkpoint_path: Path where checkpoint will be saved
        include_cache_hints: Whether to save JAX compilation cache metadata
        compression: NetCDF compression algorithm ('gzip', 'zlib', or None)
    
    The checkpoint is saved as a NetCDF file for consistency with TORAX outputs,
    but with a special structure flag identifying it as a checkpoint.
    """
    checkpoint = ToraxCheckpoint(
        checkpoint_version="1.0.0",
        torax_version=torax.__version__,
        timestamp=sim.time,
        step_number=sim.step_count,
        core_profiles=sim.core_profiles,
        core_sources=sim.core_sources,
        geometry=sim.geometry,
        solver_state=sim.solver.get_state(),  # New method needed
        stepper_state=sim.stepper.get_state(),  # New method needed
        transport_model_state=sim.transport_model.get_state(),
        jax_rng_state=jax.random.get_state(),  # If using random elements
        runtime_params=sim.runtime_params,
        config_snapshot=serialize_config(sim.config),
    )
    
    # Serialize to NetCDF with xarray
    ds = checkpoint_to_xarray(checkpoint)
    ds.attrs['checkpoint_file'] = True
    ds.attrs['checkpoint_version'] = checkpoint.checkpoint_version
    
    encoding = {}
    if compression:
        encoding = {var: {'compression': compression, 'complevel': 4} 
                   for var in ds.data_vars}
    
    ds.to_netcdf(checkpoint_path, encoding=encoding)
    logger.info(f"Checkpoint saved to {checkpoint_path} at t={sim.time:.3f}s, step={sim.step_count}")

3. Load Checkpoint Function

def load_checkpoint(
    checkpoint_path: str | Path,
    config: Optional[Config] = None,
    override_runtime_params: Optional[Dict] = None,
) -> Sim:
    """
    Load simulation from checkpoint and return initialized Sim object.
    
    Args:
        checkpoint_path: Path to checkpoint file
        config: Optional new config (must be compatible). If None, uses checkpointed config.
        override_runtime_params: Optional dict to override specific runtime parameters
        
    Returns:
        Sim object initialized from checkpoint, ready to continue simulation
        
    Raises:
        CheckpointIncompatibleError: If checkpoint version or structure incompatible
        ConfigMismatchError: If provided config incompatible with checkpoint
    """
    # Load checkpoint data
    ds = xr.open_dataset(checkpoint_path)
    
    if not ds.attrs.get('checkpoint_file', False):
        raise ValueError(f"{checkpoint_path} is not a valid TORAX checkpoint file")
    
    checkpoint = xarray_to_checkpoint(ds)
    
    # Version compatibility check
    if not is_compatible_version(checkpoint.checkpoint_version):
        logger.warning(f"Checkpoint version {checkpoint.checkpoint_version} may be incompatible")
    
    # Use checkpointed config or validate provided config
    if config is None:
        config = deserialize_config(checkpoint.config_snapshot)
    else:
        validate_config_compatibility(config, checkpoint.config_snapshot)
    
    # Apply runtime parameter overrides
    runtime_params = checkpoint.runtime_params
    if override_runtime_params:
        runtime_params = update_runtime_params(runtime_params, override_runtime_params)
    
    # Reconstruct Sim object
    sim = Sim(config=config)
    
    # Restore complete state
    sim.time = checkpoint.timestamp
    sim.step_count = checkpoint.step_number
    sim.core_profiles = checkpoint.core_profiles
    sim.core_sources = checkpoint.core_sources
    sim.geometry = checkpoint.geometry
    sim.runtime_params = runtime_params
    
    # Restore solver internal state
    sim.solver.set_state(checkpoint.solver_state)
    sim.stepper.set_state(checkpoint.stepper_state)
    sim.transport_model.set_state(checkpoint.transport_model_state)
    
    # Restore JAX RNG state if present
    if checkpoint.jax_rng_state is not None:
        jax.random.set_state(checkpoint.jax_rng_state)
    
    logger.info(f"Checkpoint loaded from {checkpoint_path}, resuming at t={sim.time:.3f}s, step={sim.step_count}")
    
    return sim

4. Automatic Checkpoint Saving

Integrate with the existing simulation loop to enable periodic automatic checkpointing:

@dataclass
class CheckpointConfig:
    """Configuration for automatic checkpointing during simulation."""
    
    enabled: bool = False
    interval_steps: Optional[int] = None  # Save every N steps
    interval_time: Optional[float] = None  # Save every T seconds of simulation time
    interval_walltime: Optional[float] = None  # Save every W minutes of wall clock time
    
    checkpoint_dir: str = "/tmp/torax_checkpoints"
    keep_last_n: int = 3  # Keep only last N checkpoints to save disk space
    checkpoint_prefix: str = "checkpoint"
    
    compression: str = 'gzip'  # Compression for checkpoint files
    
    # Advanced options
    save_on_nan: bool = True  # Auto-save if NaN detected (for debugging)
    save_on_error: bool = True  # Auto-save before raising exceptions


def run_simulation_with_checkpointing(
    config: Config,
    checkpoint_config: CheckpointConfig,
    resume_from: Optional[str] = None,
) -> StateHistory:
    """
    Run TORAX simulation with automatic checkpoint saving.
    
    Args:
        config: TORAX configuration
        checkpoint_config: Checkpoint configuration
        resume_from: Optional path to checkpoint to resume from
    
    Returns:
        StateHistory object with complete simulation results
    """
    # Initialize or resume simulation
    if resume_from:
        sim = load_checkpoint(resume_from, config=config)
        logger.info(f"Resuming simulation from checkpoint at t={sim.time:.3f}s")
    else:
        sim = Sim(config=config)
        logger.info("Starting new simulation")
    
    # Setup checkpoint manager
    checkpoint_manager = CheckpointManager(checkpoint_config)
    
    try:
        while sim.time < sim.config.end_time:
            # Take simulation step
            sim.step()
            
            # Check if checkpoint should be saved
            if checkpoint_manager.should_save(sim):
                checkpoint_path = checkpoint_manager.get_checkpoint_path(sim)
                save_checkpoint(sim, checkpoint_path)
                checkpoint_manager.cleanup_old_checkpoints()
            
            # Check for NaN or other errors
            if checkpoint_config.save_on_nan and has_nan(sim.core_profiles):
                logger.error("NaN detected in simulation, saving emergency checkpoint")
                emergency_path = checkpoint_manager.get_emergency_checkpoint_path(sim)
                save_checkpoint(sim, emergency_path)
                raise SimulationError("NaN detected in core profiles")
                
    except Exception as e:
        if checkpoint_config.save_on_error:
            logger.error(f"Error during simulation: {e}, saving checkpoint")
            error_path = checkpoint_manager.get_error_checkpoint_path(sim)
            save_checkpoint(sim, error_path)
        raise
    
    return sim.get_state_history()

5. Command-Line Interface Updates

Extend the run_torax CLI to support checkpoint operations:

# Run with automatic checkpointing every 100 steps
run_torax --config=examples/iterhybrid_rampup.py \
          --checkpoint-interval-steps=100 \
          --checkpoint-dir=/path/to/checkpoints \
          --checkpoint-keep-last=5

# Resume from checkpoint
run_torax --config=examples/iterhybrid_rampup.py \
          --resume-from=/path/to/checkpoints/checkpoint_t10.5_step1000.nc

# Resume and override some runtime parameters
run_torax --config=examples/iterhybrid_rampup.py \
          --resume-from=/path/to/checkpoints/checkpoint_t10.5_step1000.nc \
          --override-params='{"transport.model": "constant", "heating.power": 60e6}'

6. Configuration File Integration

Add checkpoint section to TORAX config files:

# In config file (e.g., examples/iterhybrid_rampup.py)

from torax import checkpoint

config = {
    # ... existing config ...
    
    'checkpoint': {
        'enabled': True,
        'interval_time': 1.0,  # Save every 1 second of simulation time
        'checkpoint_dir': './checkpoints',
        'keep_last_n': 5,
        'compression': 'gzip',
        'save_on_nan': True,
    }
}

Implementation Phases

Phase 1: Core Checkpoint Infrastructure (Priority: High)

  • Design and implement ToraxCheckpoint dataclass
  • Implement save_checkpoint() function with NetCDF serialization
  • Implement load_checkpoint() function with validation
  • Add state getters/setters to Solver, Stepper, and transport model classes
  • Write comprehensive unit tests for save/load cycle
  • Ensure exact reproducibility: verify that resumed simulation produces identical results

Phase 2: Automatic Checkpointing (Priority: High)

  • Implement CheckpointConfig and CheckpointManager classes
  • Integrate checkpoint saving into main simulation loop
  • Add checkpoint cleanup/rotation logic (keep last N checkpoints)
  • Implement wall-clock time based checkpointing for long runs
  • Add emergency checkpointing on NaN/error detection

Phase 3: CLI and Configuration Integration (Priority: Medium)

  • Add checkpoint-related flags to run_torax CLI
  • Add --resume-from flag for loading checkpoints
  • Add --override-params flag for modifying resumed simulations
  • Update config file format to support checkpoint settings
  • Create user-friendly checkpoint directory management

Phase 4: Validation and Testing (Priority: High)

  • Create test suite for checkpoint/restart functionality
  • Validate exact reproducibility across different JAX configurations
  • Test checkpoint version compatibility
  • Benchmark checkpoint save/load performance overhead
  • Test checkpoint file size and compression effectiveness
  • Verify compatibility with different transport models (QLKNN_7_11, QLKNN10D, constant)

Phase 5: Documentation and Examples (Priority: Medium)

  • Add checkpoint/restart documentation to readthedocs
  • Create tutorial notebook demonstrating checkpoint usage
  • Add examples showing parameter sweep from checkpoint
  • Document checkpoint file format specification
  • Add troubleshooting guide for checkpoint issues
  • Document best practices for long-running HPC simulations

Technical Considerations

JAX Compatibility

  • JAX's functional design and JIT compilation require careful handling of state
  • Use jax.tree_util.tree_map for serializing PyTree structures
  • Checkpoint should be agnostic to JAX backend (CPU/GPU/TPU)
  • Test that resumed simulations maintain JIT compilation cache when possible

Backward Compatibility

  • Implement checkpoint versioning system (semver-style)
  • Provide migration tools for older checkpoint formats
  • Gracefully handle checkpoint version mismatches with clear error messages
  • Document which TORAX versions are compatible with which checkpoint versions

Performance Optimization

  • Checkpoint saves should have minimal performance impact (<5% overhead)
  • Use asynchronous I/O where possible to avoid blocking simulation
  • Implement incremental checkpointing for very large simulations (save only changed data)
  • Consider checkpoint compression trade-offs (time vs. storage)

File Format Considerations

  • Use NetCDF for consistency with existing TORAX outputs
  • NetCDF supports:
    • Efficient compression
    • Partial loading (can inspect checkpoint metadata without loading full state)
    • Wide tool compatibility (xarray, Paraview, etc.)
    • Append-mode for continuous time series
  • Include human-readable metadata (TORAX version, timestamp, step number, etc.)

Error Handling

  • Validate checkpoint integrity before loading (checksums)
  • Provide clear error messages for corrupted checkpoints
  • Implement checkpoint verification tool
  • Handle missing/incompatible solver states gracefully

Testing Strategy

Unit Tests

  • Test save/load cycle preserves all state variables
  • Test checkpoint compression works correctly
  • Test version compatibility checks
  • Test error handling for corrupted checkpoints

Integration Tests

  • Test complete simulation workflow with checkpointing
  • Test resume from checkpoint produces identical trajectories
  • Test checkpoint rotation and cleanup
  • Test emergency checkpointing on errors

Performance Tests

  • Measure checkpoint save time vs. simulation step time
  • Measure checkpoint file size growth
  • Test scalability with large simulations (>1000 time steps)
  • Benchmark compression algorithms (gzip, zlib, lz4)

Validation Tests

  • Verify bit-exact reproducibility when resuming
  • Test across different JAX backends (CPU, GPU)
  • Test with different transport models
  • Test with different solver configurations

User Benefits

  1. Reliability: Long simulations can be safely interrupted and resumed
  2. Efficiency: No wasted computation if job exceeds cluster time limits
  3. Flexibility: Explore different parameter branches from common checkpoints
  4. Debugging: Save state before failures for post-mortem analysis
  5. Workflows: Enable complex multi-stage simulation workflows
  6. Reproducibility: Checkpoint includes full configuration for exact reproduction

Example Use Cases

Use Case 1: Long HPC Simulation

# Submit job that runs for 12 hours max, auto-checkpointing
run_torax --config=long_simulation.py \
          --checkpoint-interval-walltime=60  # Save every hour
          --checkpoint-keep-last=3

# If job times out, resume from last checkpoint in next job
run_torax --config=long_simulation.py \
          --resume-from=./checkpoints/checkpoint_latest.nc

Use Case 2: Parameter Sweep from Baseline

# Run baseline simulation with checkpointing
sim = run_simulation_with_checkpointing(
    config=baseline_config,
    checkpoint_config=CheckpointConfig(
        enabled=True,
        interval_time=5.0,  # Save every 5 seconds
    )
)

# Resume from t=10s checkpoint and try different heating power
checkpoint_t10 = "./checkpoints/checkpoint_t10.0.nc"

for power in [30e6, 40e6, 50e6, 60e6]:
    override_params = {'heating.power': power}
    sim = load_checkpoint(
        checkpoint_t10,
        override_runtime_params=override_params
    )
    results = continue_simulation(sim, end_time=20.0)
    analyze_results(results, power)

Use Case 3: Debug NaN Failure

# Simulation that fails with NaN at t=15.3s
# Emergency checkpoint automatically saved due to save_on_nan=True

# Load emergency checkpoint and inspect state
checkpoint = xr.open_dataset("./checkpoints/emergency_t15.3.nc")
print(checkpoint.core_profiles)  # Inspect profiles at failure point

# Resume with modified parameters to avoid instability
sim = load_checkpoint(
    "./checkpoints/checkpoint_t15.0.nc",  # Load from just before failure
    override_runtime_params={
        'solver.predictor_corrector': True,  # Try more stable solver
        'time_step_calculator.dt': 0.005,    # Reduce time step
    }
)
continue_simulation(sim)  # Try to get past the problematic region

Use Case 4: Trajectory Optimization

# Gradient-based optimization of heating trajectory
# Uses checkpoints to enable efficient gradient computation

def objective_function(heating_schedule):
    # Load checkpoint at optimization start point
    sim = load_checkpoint("./checkpoints/optimization_start.nc")
    
    # Apply heating schedule
    sim.runtime_params.heating.schedule = heating_schedule
    
    # Run forward simulation
    results = continue_simulation(sim, end_time=20.0)
    
    # Compute objective (e.g., maximize Q while keeping beta < threshold)
    return compute_performance_metric(results)

# Use JAX autodiff for gradient-based optimization
optimal_schedule = jax.scipy.optimize.minimize(
    objective_function,
    initial_heating_schedule,
    method='BFGS'
)

Open Questions for Discussion

  1. Checkpoint frequency defaults: What are sensible defaults for interval_steps and interval_time?
  2. Disk space management: Should we implement automatic cleanup policies based on disk usage?
  3. Distributed computing: How should checkpointing work in multi-node/multi-GPU scenarios?
  4. Cloud storage integration: Should we support cloud storage backends (S3, GCS) for checkpoints?
  5. Differential privacy: Should sensitive simulation data in checkpoints be encrypted?

Related Issues & References

Community Input Requested

  • Are there specific use cases for checkpointing not covered above?
  • What checkpoint metadata would be most valuable for your workflows?
  • Should checkpointing be enabled by default for all simulations?
  • What file naming conventions would be most useful?

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions