Skip to content

Optimization: Gradient checkpointing for deep processor layers #189

@Sidharth1743

Description

@Sidharth1743

Detailed Description

Following the merge of Efficient Sequential Batching (PR #187), the memory bottleneck related to graph edges has been resolved.
However, the model still consumes significant GPU memory storing intermediate activations for backpropagation, particularly in the deep GraphProcessor typically 9 layers.i assure this activation memory scales linearly

This issue proposes implementing gradient checkpointing to trade a small amount of compute re calculating forward passes during backward pass for substantial memory savings

Context

This optimization complements the recent batching fixes and brings the architecture closer to implementations like NVIDIA's GraphCast https://github.com/NVIDIA/physicsnemo/blob/main/physicsnemo/models/graphcast/graph_cast_net.py
till now i can't run 1degree in my system(4gb gpu).after this it is expected to significantly improve support for low vram systems

Possible Implementation

Add a use_checkpointing boolean flag to GraphWeatherForecaster and Processor
in processor.forward wrap the iterative block updates in the checkpoint function

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions