-
-
Notifications
You must be signed in to change notification settings - Fork 88
Description
Detailed Description
Following the merge of Efficient Sequential Batching (PR #187), the memory bottleneck related to graph edges has been resolved.
However, the model still consumes significant GPU memory storing intermediate activations for backpropagation, particularly in the deep GraphProcessor typically 9 layers.i assure this activation memory scales linearly
This issue proposes implementing gradient checkpointing to trade a small amount of compute re calculating forward passes during backward pass for substantial memory savings
Context
This optimization complements the recent batching fixes and brings the architecture closer to implementations like NVIDIA's GraphCast https://github.com/NVIDIA/physicsnemo/blob/main/physicsnemo/models/graphcast/graph_cast_net.py
till now i can't run 1degree in my system(4gb gpu).after this it is expected to significantly improve support for low vram systems
Possible Implementation
Add a use_checkpointing boolean flag to GraphWeatherForecaster and Processor
in processor.forward wrap the iterative block updates in the checkpoint function