-
Notifications
You must be signed in to change notification settings - Fork 588
Fix FusedAdam DTensor compatibility issue #2425
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
0b1d2db to
629c786
Compare
Greptile SummaryThis PR fixes a DTensor compatibility regression in FusedAdam introduced by commit d52ed47. The core issue was that optimizer states were being initialized using Key Changes:
Note: Line 388 still uses Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User as User/Training Script
participant FSDP2 as FSDP2 (PyTorch)
participant Model as TE Model with DTensor params
participant FusedAdam as FusedAdam Optimizer
participant StateInit as _initialize_state()
User->>FSDP2: Apply fully_shard() to model
FSDP2->>Model: Convert params to DTensor
Note over Model: Parameters are now DTensor<br/>with distributed properties
User->>FusedAdam: Create optimizer(model.parameters())
User->>FusedAdam: optimizer.step()
FusedAdam->>FusedAdam: Check if state exists for param
alt State not initialized
FusedAdam->>StateInit: _initialize_state(param, "exp_avg")
alt Before PR (d52ed47)
StateInit->>StateInit: torch.zeros(param.shape)
Note over StateInit: Creates regular tensor<br/>using global shape!<br/>❌ Loses DTensor properties
else After PR (this fix)
StateInit->>StateInit: torch.zeros_like(param)
Note over StateInit: Creates DTensor state<br/>preserving distribution!<br/>✅ Maintains DTensor properties
end
StateInit->>FusedAdam: Return initialized state
alt FP8 quantization path (dtype==uint8)
Note over StateInit: quantizer.make_empty(param.shape)<br/>⚠️ Still uses param.shape<br/>Incompatible with DTensor
end
end
FusedAdam->>Model: Update parameters with optimizer step
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, no comments
| """ | ||
| dtype = self.name_to_dtype_map[state_name] | ||
| if store_param_remainders: | ||
| data = torch.zeros(param.shape, dtype=torch.int16, device=param.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we also change run_fsdp2_model.py to use te FusedAdam optimizer instead of torch Adam so we dont break this again in the future?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, sorry for the late reply, added FusedAdam in run_fsdp2_model.py.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you! Looks good. LGTM!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like some other fused adam tests are failing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The combination of fused_adam + mxfp8 + fp8_init is problematic.
I’ve temporarily skipped tests for this test case, but I believe it is a bug and needs to be fixed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you trigger the test again?
572c176 to
d655670
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
6 files reviewed, 1 comment
5717328 to
d372019
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
transformer_engine/pytorch/optimizers/fused_adam.py, line 388 (link)style: This line still uses
param.shapewhich may cause DTensor incompatibility whendtype == torch.uint8. Whenparamis a DTensor,param.shapereturns the global shape, not the local shape. Consider whether FP8 quantized states need similar treatment as the fix on lines 376-378. Does the FP8 quantization path handle DTensor parameters, or are FP8 states only used with regular tensors?
3 files reviewed, 1 comment
|
/te-ci L1 pytorch |
…ike(param)/empty_like(param) to support DTensor Signed-off-by: jianbinc <shjwudp@gmail.com>
Signed-off-by: jianbinc <shjwudp@gmail.com>
…his combination test is problematic. 2. set run_fsdp2_model.py default use torch adam Signed-off-by: jianbinc <shjwudp@gmail.com>
925c97c to
1b08d23
Compare
for more information, see https://pre-commit.ci
Greptile found no issues!From now on, if a review finishes and we haven't found any issues, we will not post anything, but you can confirm that we reviewed your changes in the status check section. This feature can be toggled off in your Code Review Settings by deselecting "Create a status check for each PR". |
Description
Recent modifications to FusedAdam have made it incompatible with DTensor. Specifically, in the optimizer state initialization section, the optimizer state is now created according to the global shape of the DTensor instead of creating a DTensor optimizer state with the same shape as the parameters.
To maintain compatibility with DTensor, the state tensors should be initialized using zeros_like(param) or empty_like(param) instead of zeros(param.shape) or empty(param.shape).
Fixes #2424
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: