From 84ec02a866891663269878bcd7b654a8b9187205 Mon Sep 17 00:00:00 2001 From: Ilya Panfilov Date: Fri, 16 Jan 2026 21:32:50 -0500 Subject: [PATCH] Disable Pytorch MXFP8 scale swizzling --- transformer_engine/pytorch/csrc/extensions/gemm.cpp | 10 +++++++++- transformer_engine/pytorch/csrc/util.cpp | 6 ++++++ transformer_engine/pytorch/csrc/util.h | 6 ++++++ 3 files changed, 21 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 7b51e8f8f..d8696c14d 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -177,14 +177,18 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans const int sm_count = transformer_engine::cuda::sm_count(device_id); int num_math_sms = sm_count - transformer_engine::getenv("NVTE_EXT_MARGIN_SM", sm_count); +#ifndef USE_ROCM // Keep the swizzled scaling factor tensors alive during the GEMM. std::vector> swizzled_scale_inverses_list; +#endif auto main_stream = at::cuda::getCurrentCUDAStream(); if (A_tensor.numel() != 0 && B_tensor.numel() != 0) { +#ifndef USE_ROCM // Optionally swizzle the scaling factors swizzled_scale_inverses_list.emplace_back(std::move(swizzle_scaling_factors(A_tensor, transa))); swizzled_scale_inverses_list.emplace_back( std::move(swizzle_scaling_factors(B_tensor, !transb))); +#endif if (comm_overlap) { #ifndef USE_ROCM @@ -334,8 +338,10 @@ std::optional> te_general_grouped_gemm( te_pre_gelu_out_vector, te_workspace_vector; std::vector wrappers; std::vector D_vectors; +#ifndef USE_ROCM // Keep the swizzled scaling factor tensors alive during the GEMMs. std::vector> swizzled_scale_inverses_list; +#endif auto none = py::none(); @@ -402,9 +408,11 @@ std::optional> te_general_grouped_gemm( continue; } +#ifndef USE_ROCM // Optionally swizzle the scaling factors swizzled_scale_inverses_list.emplace_back(std::move(swizzle_scaling_factors(te_A, transa))); swizzled_scale_inverses_list.emplace_back(std::move(swizzle_scaling_factors(te_B, !transb))); +#endif auto te_D = makeTransformerEngineTensor(out_tensor); auto te_bias = makeTransformerEngineTensor(bias[i]); diff --git a/transformer_engine/pytorch/csrc/util.cpp b/transformer_engine/pytorch/csrc/util.cpp index a878345ff..b58573419 100644 --- a/transformer_engine/pytorch/csrc/util.cpp +++ b/transformer_engine/pytorch/csrc/util.cpp @@ -1,9 +1,13 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ +#ifndef USE_ROCM + #include "util.h" #include "common.h" @@ -75,3 +79,5 @@ std::optional swizzle_scaling_factors(transformer_engine::TensorWrap return swizzled_scale_inv; } + +#endif //!USE_ROCM diff --git a/transformer_engine/pytorch/csrc/util.h b/transformer_engine/pytorch/csrc/util.h index 0cfeb81f5..97f22ae18 100644 --- a/transformer_engine/pytorch/csrc/util.h +++ b/transformer_engine/pytorch/csrc/util.h @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -7,6 +9,8 @@ #ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_ #define TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_ +#ifndef USE_ROCM + #include #include @@ -20,4 +24,6 @@ std::optional swizzle_scaling_factors(transformer_engine::TensorWrapper &input, bool trans); +#endif //!USE_ROCM + #endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_