From e73210538d7aeb88205864dbcafc3746153f9333 Mon Sep 17 00:00:00 2001 From: jianbinc Date: Wed, 26 Nov 2025 11:07:08 +0800 Subject: [PATCH 1/4] FusedAdam: replace zeros(param.shape)/empty(param.shape) with zeros_like(param)/empty_like(param) to support DTensor Signed-off-by: jianbinc --- transformer_engine/pytorch/optimizers/fused_adam.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/optimizers/fused_adam.py b/transformer_engine/pytorch/optimizers/fused_adam.py index b5c87b4815c..935af8ee0e3 100644 --- a/transformer_engine/pytorch/optimizers/fused_adam.py +++ b/transformer_engine/pytorch/optimizers/fused_adam.py @@ -373,9 +373,9 @@ def _initialize_state( """ dtype = self.name_to_dtype_map[state_name] if store_param_remainders: - data = torch.zeros(param.shape, dtype=torch.int16, device=param.device) + data = torch.zeros_like(param, dtype=torch.int16, device=param.device) else: - data = torch.empty(param.shape, dtype=dtype, device=param.device) + data = torch.empty_like(param, dtype=dtype, device=param.device) if zero_buffer: data.zero_() From 67202574db9f7bf4e86e992d8d802cea87fe509f Mon Sep 17 00:00:00 2001 From: jianbinc Date: Thu, 25 Dec 2025 11:18:14 +0800 Subject: [PATCH 2/4] add fused_adam functional test in test_torch_fsdp2 Signed-off-by: jianbinc --- tests/pytorch/distributed/run_fsdp2_model.py | 10 +++++++++- tests/pytorch/distributed/test_torch_fsdp2.py | 8 +++++--- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/tests/pytorch/distributed/run_fsdp2_model.py b/tests/pytorch/distributed/run_fsdp2_model.py index c3432992429..9b8d2e7cef9 100644 --- a/tests/pytorch/distributed/run_fsdp2_model.py +++ b/tests/pytorch/distributed/run_fsdp2_model.py @@ -15,6 +15,7 @@ Float8CurrentScaling, MXFP8BlockScaling, ) +from transformer_engine.pytorch.optimizers.fused_adam import FusedAdam import torch import torch.distributed as dist @@ -84,6 +85,10 @@ def _parse_args(argv=None, namespace=None): nargs="+", help='FSDP/HSDP sharding dimensions ("replicate", "shard")', ) + parser.add_argument( + "--adam", type=str, choices=["fused", "torch"], default="fused", + help="Optimizer type." + ) args = parser.parse_args(argv, namespace) if args.sharding_dims: assert len(args.sharding_dims) <= 2 @@ -322,7 +327,10 @@ def _train(args): f"FSDP2 model in cuda, memory allocated: {torch.cuda.memory_allocated(device)/1e6} MB" ) - optimizer = optim.Adam(model.parameters(), lr=1e-3) + if args.adam == "fused": + optimizer = FusedAdam(model.parameters(), lr=1e-3) + else: + optimizer = optim.Adam(model.parameters(), lr=1e-3) for iteration in range(args.iter): # Zero the parameter gradients diff --git a/tests/pytorch/distributed/test_torch_fsdp2.py b/tests/pytorch/distributed/test_torch_fsdp2.py index 91d6fc6ed11..075d6309dae 100644 --- a/tests/pytorch/distributed/test_torch_fsdp2.py +++ b/tests/pytorch/distributed/test_torch_fsdp2.py @@ -16,7 +16,7 @@ NUM_PROCS: int = torch.cuda.device_count() -def _run_test(fp_init, sharding_dims, recipe, layer_type): +def _run_test(fp_init, sharding_dims, recipe, layer_type, optim_type="fused"): test_path = Path(__file__).parent.resolve() / "run_fsdp2_model.py" test_cmd = ["torchrun", f"--nproc_per_node={NUM_PROCS}", str(test_path)] @@ -31,6 +31,7 @@ def _run_test(fp_init, sharding_dims, recipe, layer_type): assert False test_cmd += ["--recipe", recipe] test_cmd += ["--layer-type", layer_type] + test_cmd += ["--adam", optim_type] result = subprocess.run(test_cmd, env=os.environ, check=True) @@ -42,7 +43,8 @@ def _run_test(fp_init, sharding_dims, recipe, layer_type): @pytest.mark.parametrize("fp8_init", (False, True)) @pytest.mark.parametrize("recipe", ("delayed_scaling", "current_scaling", "mx_fp8_block_scaling")) @pytest.mark.parametrize("layer_type", ("LayerNormLinear", "TransformerLayer")) -def test_distributed(fp8_init, sharding_dims, recipe, layer_type): +@pytest.mark.parametrize("optim_type", ("fused", "torch")) +def test_distributed(fp8_init, sharding_dims, recipe, layer_type, optim_type): # Skip invalid configurations if torch.cuda.device_count() < 4: @@ -53,7 +55,7 @@ def test_distributed(fp8_init, sharding_dims, recipe, layer_type): elif not fp8_available: pytest.skip(reason_for_no_fp8) - _run_test(fp8_init, sharding_dims, recipe, layer_type) + _run_test(fp8_init, sharding_dims, recipe, layer_type, optim_type) def test_dummy() -> None: From 1b08d236fdc828454bcd96131932c04325e5d0b6 Mon Sep 17 00:00:00 2001 From: jianbinc Date: Thu, 25 Dec 2025 15:11:34 +0800 Subject: [PATCH 3/4] 1. skip fsdp2 UT fused_adam + mxfp8 + fp8_init combination. Because this combination test is problematic. 2. set run_fsdp2_model.py default use torch adam Signed-off-by: jianbinc --- tests/pytorch/distributed/run_fsdp2_model.py | 2 +- tests/pytorch/distributed/test_torch_fsdp2.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/pytorch/distributed/run_fsdp2_model.py b/tests/pytorch/distributed/run_fsdp2_model.py index 9b8d2e7cef9..a729da152c9 100644 --- a/tests/pytorch/distributed/run_fsdp2_model.py +++ b/tests/pytorch/distributed/run_fsdp2_model.py @@ -86,7 +86,7 @@ def _parse_args(argv=None, namespace=None): help='FSDP/HSDP sharding dimensions ("replicate", "shard")', ) parser.add_argument( - "--adam", type=str, choices=["fused", "torch"], default="fused", + "--adam", type=str, choices=["fused", "torch"], default="torch", help="Optimizer type." ) args = parser.parse_args(argv, namespace) diff --git a/tests/pytorch/distributed/test_torch_fsdp2.py b/tests/pytorch/distributed/test_torch_fsdp2.py index 075d6309dae..ecfe5da2182 100644 --- a/tests/pytorch/distributed/test_torch_fsdp2.py +++ b/tests/pytorch/distributed/test_torch_fsdp2.py @@ -45,7 +45,6 @@ def _run_test(fp_init, sharding_dims, recipe, layer_type, optim_type="fused"): @pytest.mark.parametrize("layer_type", ("LayerNormLinear", "TransformerLayer")) @pytest.mark.parametrize("optim_type", ("fused", "torch")) def test_distributed(fp8_init, sharding_dims, recipe, layer_type, optim_type): - # Skip invalid configurations if torch.cuda.device_count() < 4: pytest.skip("FSDP2 test requires at least 4 GPUs") @@ -55,6 +54,10 @@ def test_distributed(fp8_init, sharding_dims, recipe, layer_type, optim_type): elif not fp8_available: pytest.skip(reason_for_no_fp8) + # Skip incompatible optimizer + recipe combinations + if optim_type == "fused" and recipe in ["mx_fp8_block_scaling"] and fp8_init: + pytest.skip("Fused Adam does not support FP8 with MX FP8 Block Scaling") + _run_test(fp8_init, sharding_dims, recipe, layer_type, optim_type) From c4ec3fc8e87f4fd995ddc6e595baebd7a1352ac0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 25 Dec 2025 07:15:51 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/distributed/run_fsdp2_model.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/pytorch/distributed/run_fsdp2_model.py b/tests/pytorch/distributed/run_fsdp2_model.py index a729da152c9..841c7761e59 100644 --- a/tests/pytorch/distributed/run_fsdp2_model.py +++ b/tests/pytorch/distributed/run_fsdp2_model.py @@ -86,8 +86,7 @@ def _parse_args(argv=None, namespace=None): help='FSDP/HSDP sharding dimensions ("replicate", "shard")', ) parser.add_argument( - "--adam", type=str, choices=["fused", "torch"], default="torch", - help="Optimizer type." + "--adam", type=str, choices=["fused", "torch"], default="torch", help="Optimizer type." ) args = parser.parse_args(argv, namespace) if args.sharding_dims: