diff --git a/tests/pytorch/distributed/run_fsdp2_model.py b/tests/pytorch/distributed/run_fsdp2_model.py index c3432992429..841c7761e59 100644 --- a/tests/pytorch/distributed/run_fsdp2_model.py +++ b/tests/pytorch/distributed/run_fsdp2_model.py @@ -15,6 +15,7 @@ Float8CurrentScaling, MXFP8BlockScaling, ) +from transformer_engine.pytorch.optimizers.fused_adam import FusedAdam import torch import torch.distributed as dist @@ -84,6 +85,9 @@ def _parse_args(argv=None, namespace=None): nargs="+", help='FSDP/HSDP sharding dimensions ("replicate", "shard")', ) + parser.add_argument( + "--adam", type=str, choices=["fused", "torch"], default="torch", help="Optimizer type." + ) args = parser.parse_args(argv, namespace) if args.sharding_dims: assert len(args.sharding_dims) <= 2 @@ -322,7 +326,10 @@ def _train(args): f"FSDP2 model in cuda, memory allocated: {torch.cuda.memory_allocated(device)/1e6} MB" ) - optimizer = optim.Adam(model.parameters(), lr=1e-3) + if args.adam == "fused": + optimizer = FusedAdam(model.parameters(), lr=1e-3) + else: + optimizer = optim.Adam(model.parameters(), lr=1e-3) for iteration in range(args.iter): # Zero the parameter gradients diff --git a/tests/pytorch/distributed/test_torch_fsdp2.py b/tests/pytorch/distributed/test_torch_fsdp2.py index 91d6fc6ed11..ecfe5da2182 100644 --- a/tests/pytorch/distributed/test_torch_fsdp2.py +++ b/tests/pytorch/distributed/test_torch_fsdp2.py @@ -16,7 +16,7 @@ NUM_PROCS: int = torch.cuda.device_count() -def _run_test(fp_init, sharding_dims, recipe, layer_type): +def _run_test(fp_init, sharding_dims, recipe, layer_type, optim_type="fused"): test_path = Path(__file__).parent.resolve() / "run_fsdp2_model.py" test_cmd = ["torchrun", f"--nproc_per_node={NUM_PROCS}", str(test_path)] @@ -31,6 +31,7 @@ def _run_test(fp_init, sharding_dims, recipe, layer_type): assert False test_cmd += ["--recipe", recipe] test_cmd += ["--layer-type", layer_type] + test_cmd += ["--adam", optim_type] result = subprocess.run(test_cmd, env=os.environ, check=True) @@ -42,8 +43,8 @@ def _run_test(fp_init, sharding_dims, recipe, layer_type): @pytest.mark.parametrize("fp8_init", (False, True)) @pytest.mark.parametrize("recipe", ("delayed_scaling", "current_scaling", "mx_fp8_block_scaling")) @pytest.mark.parametrize("layer_type", ("LayerNormLinear", "TransformerLayer")) -def test_distributed(fp8_init, sharding_dims, recipe, layer_type): - +@pytest.mark.parametrize("optim_type", ("fused", "torch")) +def test_distributed(fp8_init, sharding_dims, recipe, layer_type, optim_type): # Skip invalid configurations if torch.cuda.device_count() < 4: pytest.skip("FSDP2 test requires at least 4 GPUs") @@ -53,7 +54,11 @@ def test_distributed(fp8_init, sharding_dims, recipe, layer_type): elif not fp8_available: pytest.skip(reason_for_no_fp8) - _run_test(fp8_init, sharding_dims, recipe, layer_type) + # Skip incompatible optimizer + recipe combinations + if optim_type == "fused" and recipe in ["mx_fp8_block_scaling"] and fp8_init: + pytest.skip("Fused Adam does not support FP8 with MX FP8 Block Scaling") + + _run_test(fp8_init, sharding_dims, recipe, layer_type, optim_type) def test_dummy() -> None: diff --git a/transformer_engine/pytorch/optimizers/fused_adam.py b/transformer_engine/pytorch/optimizers/fused_adam.py index b5c87b4815c..935af8ee0e3 100644 --- a/transformer_engine/pytorch/optimizers/fused_adam.py +++ b/transformer_engine/pytorch/optimizers/fused_adam.py @@ -373,9 +373,9 @@ def _initialize_state( """ dtype = self.name_to_dtype_map[state_name] if store_param_remainders: - data = torch.zeros(param.shape, dtype=torch.int16, device=param.device) + data = torch.zeros_like(param, dtype=torch.int16, device=param.device) else: - data = torch.empty(param.shape, dtype=dtype, device=param.device) + data = torch.empty_like(param, dtype=dtype, device=param.device) if zero_buffer: data.zero_()