Skip to content

Conversation

@shjwudp
Copy link
Contributor

@shjwudp shjwudp commented Nov 26, 2025

Description

Recent modifications to FusedAdam have made it incompatible with DTensor. Specifically, in the optimizer state initialization section, the optimizer state is now created according to the global shape of the DTensor instead of creating a DTensor optimizer state with the same shape as the parameters.

To maintain compatibility with DTensor, the state tensors should be initialized using zeros_like(param) or empty_like(param) instead of zeros(param.shape) or empty(param.shape).

Fixes #2424

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@shjwudp shjwudp force-pushed the fused_adam_dtensor_issue branch from 0b1d2db to 629c786 Compare November 26, 2025 03:11
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Nov 26, 2025

Greptile Summary

This PR fixes a DTensor compatibility regression in FusedAdam introduced by commit d52ed47. The core issue was that optimizer states were being initialized using torch.zeros(param.shape) and torch.empty(param.shape), which creates regular tensors based on the global shape of DTensors instead of creating DTensor optimizer states with the same distributed properties as the parameters.

Key Changes:

  • Modified _initialize_state() method in fused_adam.py to use zeros_like(param) and empty_like(param) instead of zeros(param.shape) and empty(param.shape) (lines 376, 378)
  • Added comprehensive test coverage for FusedAdam with FSDP2 by parameterizing the optimizer type in tests
  • Includes a skip condition for the known incompatible combination of fused_adam + mx_fp8_block_scaling + fp8_init due to remaining DTensor issue with FP8 quantization at line 388

Note: Line 388 still uses param.shape in the FP8 quantization path (quantizer.make_empty(param.shape)), which remains incompatible with DTensor. This is why the test explicitly skips the mx_fp8_block_scaling + fp8_init combination when using FusedAdam.

Confidence Score: 4/5

  • This PR is safe to merge with one remaining edge case limitation documented in tests
  • The fix correctly addresses the DTensor compatibility issue for the common case by using *_like() functions. However, the FP8 quantization path (line 388) still has a DTensor incompatibility that requires a test skip. This is a known limitation rather than a bug in this PR, as fixing it would require changes to the Float8Quantizer.make_empty() API.
  • Pay attention to transformer_engine/pytorch/optimizers/fused_adam.py line 388 - the FP8 quantization path still uses param.shape and remains incompatible with DTensor

Important Files Changed

Filename Overview
transformer_engine/pytorch/optimizers/fused_adam.py Fixed DTensor compatibility by replacing zeros(param.shape) and empty(param.shape) with zeros_like(param) and empty_like(param) in state initialization (lines 376, 378). This ensures optimizer states are created as DTensors when params are DTensors.
tests/pytorch/distributed/test_torch_fsdp2.py Added optim_type parameter to test both FusedAdam and torch.Adam optimizers with FSDP2. Includes skip condition for known incompatible combination: fused_adam + mx_fp8_block_scaling + fp8_init (line 58-59).
tests/pytorch/distributed/run_fsdp2_model.py Added --adam command-line argument to allow choosing between FusedAdam and torch.Adam optimizers (lines 88-90, 329-332). Default changed from torch.Adam to make tests configurable.

Sequence Diagram

sequenceDiagram
    participant User as User/Training Script
    participant FSDP2 as FSDP2 (PyTorch)
    participant Model as TE Model with DTensor params
    participant FusedAdam as FusedAdam Optimizer
    participant StateInit as _initialize_state()
    
    User->>FSDP2: Apply fully_shard() to model
    FSDP2->>Model: Convert params to DTensor
    Note over Model: Parameters are now DTensor<br/>with distributed properties
    
    User->>FusedAdam: Create optimizer(model.parameters())
    User->>FusedAdam: optimizer.step()
    
    FusedAdam->>FusedAdam: Check if state exists for param
    alt State not initialized
        FusedAdam->>StateInit: _initialize_state(param, "exp_avg")
        
        alt Before PR (d52ed47)
            StateInit->>StateInit: torch.zeros(param.shape)
            Note over StateInit: Creates regular tensor<br/>using global shape!<br/>❌ Loses DTensor properties
        else After PR (this fix)
            StateInit->>StateInit: torch.zeros_like(param)
            Note over StateInit: Creates DTensor state<br/>preserving distribution!<br/>✅ Maintains DTensor properties
        end
        
        StateInit->>FusedAdam: Return initialized state
        
        alt FP8 quantization path (dtype==uint8)
            Note over StateInit: quantizer.make_empty(param.shape)<br/>⚠️ Still uses param.shape<br/>Incompatible with DTensor
        end
    end
    
    FusedAdam->>Model: Update parameters with optimizer step
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

"""
dtype = self.name_to_dtype_map[state_name]
if store_param_remainders:
data = torch.zeros(param.shape, dtype=torch.int16, device=param.device)
Copy link
Collaborator

@vthumbe1503 vthumbe1503 Nov 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we also change run_fsdp2_model.py to use te FusedAdam optimizer instead of torch Adam so we dont break this again in the future?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, sorry for the late reply, added FusedAdam in run_fsdp2_model.py.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! Looks good. LGTM!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like some other fused adam tests are failing

Copy link
Contributor Author

@shjwudp shjwudp Dec 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The combination of fused_adam + mxfp8 + fp8_init is problematic.
I’ve temporarily skipped tests for this test case, but I believe it is a bug and needs to be fixed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you trigger the test again?

@shjwudp shjwudp force-pushed the fused_adam_dtensor_issue branch from 572c176 to d655670 Compare December 25, 2025 03:22
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

6 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@shjwudp shjwudp force-pushed the fused_adam_dtensor_issue branch from 5717328 to d372019 Compare December 25, 2025 03:29
@shjwudp shjwudp requested a review from vthumbe1503 December 25, 2025 03:32
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. transformer_engine/pytorch/optimizers/fused_adam.py, line 388 (link)

    style: This line still uses param.shape which may cause DTensor incompatibility when dtype == torch.uint8. When param is a DTensor, param.shape returns the global shape, not the local shape. Consider whether FP8 quantized states need similar treatment as the fix on lines 376-378. Does the FP8 quantization path handle DTensor parameters, or are FP8 states only used with regular tensors?

3 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@vthumbe1503
Copy link
Collaborator

/te-ci L1 pytorch

…ike(param)/empty_like(param) to support DTensor

Signed-off-by: jianbinc <shjwudp@gmail.com>
Signed-off-by: jianbinc <shjwudp@gmail.com>
…his combination test is problematic.

2. set run_fsdp2_model.py default use torch adam

Signed-off-by: jianbinc <shjwudp@gmail.com>
@shjwudp shjwudp force-pushed the fused_adam_dtensor_issue branch from 925c97c to 1b08d23 Compare December 25, 2025 07:15
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Dec 25, 2025

Greptile found no issues!

From now on, if a review finishes and we haven't found any issues, we will not post anything, but you can confirm that we reviewed your changes in the status check section.

This feature can be toggled off in your Code Review Settings by deselecting "Create a status check for each PR".

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Compatibility issues between FusedAdam and DTensor

2 participants