Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion transformer_engine/pytorch/csrc/extensions/gemm.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*************************************************************************
* This file was modified for portability to AMDGPU
* Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
Expand Down Expand Up @@ -177,14 +177,18 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
const int sm_count = transformer_engine::cuda::sm_count(device_id);
int num_math_sms = sm_count - transformer_engine::getenv<int>("NVTE_EXT_MARGIN_SM", sm_count);

#ifndef USE_ROCM
// Keep the swizzled scaling factor tensors alive during the GEMM.
std::vector<std::optional<at::Tensor>> swizzled_scale_inverses_list;
#endif
auto main_stream = at::cuda::getCurrentCUDAStream();
if (A_tensor.numel() != 0 && B_tensor.numel() != 0) {
#ifndef USE_ROCM
// Optionally swizzle the scaling factors
swizzled_scale_inverses_list.emplace_back(std::move(swizzle_scaling_factors(A_tensor, transa)));
swizzled_scale_inverses_list.emplace_back(
std::move(swizzle_scaling_factors(B_tensor, !transb)));
#endif

if (comm_overlap) {
#ifndef USE_ROCM
Expand Down Expand Up @@ -334,8 +338,10 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
te_pre_gelu_out_vector, te_workspace_vector;
std::vector<TensorWrapper> wrappers;
std::vector<at::Tensor> D_vectors;
#ifndef USE_ROCM
// Keep the swizzled scaling factor tensors alive during the GEMMs.
std::vector<std::optional<at::Tensor>> swizzled_scale_inverses_list;
#endif

auto none = py::none();

Expand Down Expand Up @@ -402,9 +408,11 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
continue;
}

#ifndef USE_ROCM
// Optionally swizzle the scaling factors
swizzled_scale_inverses_list.emplace_back(std::move(swizzle_scaling_factors(te_A, transa)));
swizzled_scale_inverses_list.emplace_back(std::move(swizzle_scaling_factors(te_B, !transb)));
#endif

auto te_D = makeTransformerEngineTensor(out_tensor);
auto te_bias = makeTransformerEngineTensor(bias[i]);
Expand Down
6 changes: 6 additions & 0 deletions transformer_engine/pytorch/csrc/util.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
/*************************************************************************
* This file was modified for portability to AMDGPU
* Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

#ifndef USE_ROCM

#include "util.h"

#include "common.h"
Expand Down Expand Up @@ -75,3 +79,5 @@ std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrap

return swizzled_scale_inv;
}

#endif //!USE_ROCM
6 changes: 6 additions & 0 deletions transformer_engine/pytorch/csrc/util.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
/*************************************************************************
* This file was modified for portability to AMDGPU
* Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
Expand All @@ -7,6 +9,8 @@
#ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_
#define TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_

#ifndef USE_ROCM

#include <torch/extension.h>

#include <optional>
Expand All @@ -20,4 +24,6 @@
std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrapper &input,
bool trans);

#endif //!USE_ROCM

#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_