From bf3ebc2ccf98a016ff61f859df7fa2686f36114d Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 10 Dec 2025 15:29:37 +0100 Subject: [PATCH 01/17] code drop Signed-off-by: Pawel Gadzinski --- tests/cpp/operator/CMakeLists.txt | 1 + tests/cpp/operator/test_grouped_gemm.cu | 511 ++++++++++++++++++ .../common/gemm/cublaslt_gemm.cu | 484 +++++++++++++++++ .../common/include/transformer_engine/gemm.h | 36 ++ 4 files changed, 1032 insertions(+) create mode 100644 tests/cpp/operator/test_grouped_gemm.cu diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index b2f14b1892d..1392ffdadc3 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -30,6 +30,7 @@ add_executable(test_operator test_causal_softmax.cu test_swizzle.cu test_swap_first_dims.cu + test_grouped_gemm.cu ../test_common.cu) # Find required packages diff --git a/tests/cpp/operator/test_grouped_gemm.cu b/tests/cpp/operator/test_grouped_gemm.cu new file mode 100644 index 00000000000..0e9c6c6a4d6 --- /dev/null +++ b/tests/cpp/operator/test_grouped_gemm.cu @@ -0,0 +1,511 @@ +/*********************************************************************** + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. + * + * See LICENSE for license information. + **********************************************************************/ + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "../test_common.h" + +using namespace transformer_engine; +using namespace test; + +namespace { + +enum class InputCase { + kFP8Delayed, + kFP8Current, + kBF16, +}; + +enum class ShapeCase { + kAllSame, + kSameFirst, + kSameLast, + kAllDifferent, +}; + +// Helper owning GPU buffers that back NVTEGroupedTensor. +// NVTEGroupedTensor does not own memory; data/offsets/scales +// must be allocated and freed by the test. +struct GroupedBuffers { + NVTEGroupedTensor handle{nullptr}; + void* data{nullptr}; + void* scale_inv{nullptr}; + int64_t* first_dims_dev{nullptr}; + int64_t* last_dims_dev{nullptr}; + int64_t* offsets_dev{nullptr}; + void* columnwise_data{nullptr}; + NVTEShape logical_shape{}; + std::vector offsets_host; + std::vector tensor_bytes; + size_t num_tensors{0}; + size_t elem_size{0}; + DType dtype{DType::kFloat32}; + NVTEScalingMode scaling_mode{NVTE_DELAYED_TENSOR_SCALING}; + + GroupedBuffers() = default; + GroupedBuffers(const GroupedBuffers&) = delete; + GroupedBuffers& operator=(const GroupedBuffers&) = delete; + GroupedBuffers(GroupedBuffers&& other) noexcept { + *this = std::move(other); + } + GroupedBuffers& operator=(GroupedBuffers&& other) noexcept { + if (this == &other) return *this; + handle = other.handle; + data = other.data; + scale_inv = other.scale_inv; + first_dims_dev = other.first_dims_dev; + last_dims_dev = other.last_dims_dev; + offsets_dev = other.offsets_dev; + logical_shape = other.logical_shape; + offsets_host = std::move(other.offsets_host); + tensor_bytes = std::move(other.tensor_bytes); + num_tensors = other.num_tensors; + elem_size = other.elem_size; + dtype = other.dtype; + scaling_mode = other.scaling_mode; + + other.handle = nullptr; + other.data = nullptr; + other.scale_inv = nullptr; + other.first_dims_dev = nullptr; + other.last_dims_dev = nullptr; + other.offsets_dev = nullptr; + other.num_tensors = 0; + return *this; + } + + ~GroupedBuffers() { + if (data) { + cudaFree(data); + data = nullptr; + } + if (scale_inv) { + cudaFree(scale_inv); + scale_inv = nullptr; + } + if (columnwise_data) { + cudaFree(columnwise_data); + columnwise_data = nullptr; + } + if (first_dims_dev) { + cudaFree(first_dims_dev); + first_dims_dev = nullptr; + } + if (last_dims_dev) { + cudaFree(last_dims_dev); + last_dims_dev = nullptr; + } + if (offsets_dev) { + cudaFree(offsets_dev); + offsets_dev = nullptr; + } + if (handle) { + nvte_destroy_grouped_tensor(handle); + handle = nullptr; + } + } +}; + +size_t grouped_setup_workspace_size(const size_t num_tensors) { + const size_t ptr_bytes = num_tensors * sizeof(void*); + const size_t int_bytes = num_tensors * sizeof(int); + size_t size = 4 * ptr_bytes + 3 * int_bytes + 2 * ptr_bytes; + const size_t alignment = 256; + size = ((size + alignment - 1) / alignment) * alignment; + return size; +} + +GroupedBuffers build_grouped_tensor(const std::vector& tensors, + const NVTEScalingMode scaling_mode) { + NVTE_CHECK(!tensors.empty(), "No tensors provided for grouped tensor build."); + const NVTEShape shape = tensors[0]->rowwise_shape(); + const DType dtype = tensors[0]->dtype(); + const size_t num_tensors = tensors.size(); + const size_t elem_size = typeToSize(dtype); + GroupedBuffers grouped; + grouped.elem_size = elem_size; + grouped.num_tensors = num_tensors; + grouped.dtype = dtype; + grouped.scaling_mode = scaling_mode; + grouped.tensor_bytes.resize(num_tensors); + grouped.offsets_host.resize(num_tensors, 0); + + std::vector first_dims(num_tensors); + std::vector last_dims(num_tensors); + for (size_t i = 0; i < num_tensors; ++i) { + const auto s = tensors[i]->rowwise_shape(); + NVTE_CHECK(s.ndim == 2, "Grouped GEMM test expects 2D tensors."); + first_dims[i] = static_cast(s.data[0]); + last_dims[i] = static_cast(s.data[1]); + grouped.tensor_bytes[i] = bytes(s, dtype); + } + + const bool same_first = std::all_of(first_dims.begin(), first_dims.end(), + [&](int64_t v) { return v == first_dims[0]; }); + const bool same_last = std::all_of(last_dims.begin(), last_dims.end(), + [&](int64_t v) { return v == last_dims[0]; }); + + std::vector offsets(num_tensors, 0); + auto random_padding = [&]() -> int64_t { + static std::mt19937 gen(12345); + std::uniform_int_distribution dist(0, 3); + return dist(gen); + }; + + auto numel = [&](size_t idx) -> int64_t { + return first_dims[idx] * last_dims[idx]; + }; + + const bool need_offsets = !same_first || !same_last; + if (need_offsets) { + offsets[0] = 0; + for (size_t i = 1; i < num_tensors; ++i) { + offsets[i] = offsets[i - 1] + numel(i - 1) + random_padding(); + } + } else { + for (size_t i = 0; i < num_tensors; ++i) { + offsets[i] = static_cast(i) * numel(0); + } + } + grouped.offsets_host = offsets; + + int64_t logical_first = 0; + int64_t logical_last = 0; + if (same_first && same_last) { + logical_first = first_dims[0] * static_cast(num_tensors); + logical_last = last_dims[0]; + } else if (same_first && !same_last) { + logical_first = first_dims[0]; + logical_last = std::accumulate(last_dims.begin(), last_dims.end(), int64_t{0}); + } else if (!same_first && same_last) { + logical_first = std::accumulate(first_dims.begin(), first_dims.end(), int64_t{0}); + logical_last = last_dims[0]; + } else { + logical_first = 1; + logical_last = 0; + for (size_t i = 0; i < num_tensors; ++i) { + logical_last += first_dims[i] * last_dims[i]; + } + } + size_t logical_data[2] = {static_cast(logical_first), + static_cast(logical_last)}; + grouped.logical_shape = nvte_make_shape(logical_data, 2); + grouped.handle = nvte_create_grouped_tensor(scaling_mode, num_tensors, grouped.logical_shape); + + const int64_t last_idx = static_cast(num_tensors - 1); + const int64_t total_elems = need_offsets + ? (offsets[last_idx] + numel(last_idx)) + : (logical_first * logical_last); + const size_t total_bytes = static_cast(total_elems) * elem_size; + + NVTE_CHECK_CUDA(cudaMalloc(&grouped.data, total_bytes)); + for (size_t i = 0; i < num_tensors; ++i) { + const size_t offset_bytes = static_cast(offsets[i]) * elem_size; + NVTE_CHECK_CUDA(cudaMemcpy(static_cast(grouped.data) + offset_bytes, + tensors[i]->rowwise_dptr(), + grouped.tensor_bytes[i], + cudaMemcpyDeviceToDevice)); + } + + NVTEBasicTensor data_tensor{grouped.data, static_cast(dtype), grouped.logical_shape}; + nvte_set_grouped_tensor_param(&grouped.handle, kNVTEGroupedRowwiseData, &data_tensor); + + const bool include_columnwise = isFp8Type(dtype) || isFp4Type(dtype); + if (include_columnwise) { + NVTE_CHECK_CUDA(cudaMalloc(&grouped.columnwise_data, total_bytes)); + for (size_t i = 0; i < num_tensors; ++i) { + const size_t offset_bytes = static_cast(offsets[i]) * elem_size; + NVTE_CHECK_CUDA(cudaMemcpy(static_cast(grouped.columnwise_data) + offset_bytes, + tensors[i]->columnwise_dptr(), + grouped.tensor_bytes[i], + cudaMemcpyDeviceToDevice)); + } + NVTEBasicTensor col_tensor{grouped.columnwise_data, + static_cast(dtype), + grouped.logical_shape}; + nvte_set_grouped_tensor_param(&grouped.handle, kNVTEGroupedColumnwiseData, &col_tensor); + } + + if (!same_first) { + NVTE_CHECK_CUDA(cudaMalloc(&grouped.first_dims_dev, num_tensors * sizeof(int64_t))); + NVTE_CHECK_CUDA(cudaMemcpy(grouped.first_dims_dev, first_dims.data(), + num_tensors * sizeof(int64_t), cudaMemcpyHostToDevice)); + NVTEShape fd_shape = nvte_make_shape(&num_tensors, 1); + NVTEBasicTensor fd_tensor{grouped.first_dims_dev, kNVTEInt64, fd_shape}; + nvte_set_grouped_tensor_param(&grouped.handle, kNVTEGroupedFirstDims, &fd_tensor); + } + + if (!same_last) { + NVTE_CHECK_CUDA(cudaMalloc(&grouped.last_dims_dev, num_tensors * sizeof(int64_t))); + NVTE_CHECK_CUDA(cudaMemcpy(grouped.last_dims_dev, last_dims.data(), + num_tensors * sizeof(int64_t), cudaMemcpyHostToDevice)); + NVTEShape ld_shape = nvte_make_shape(&num_tensors, 1); + NVTEBasicTensor ld_tensor{grouped.last_dims_dev, kNVTEInt64, ld_shape}; + nvte_set_grouped_tensor_param(&grouped.handle, kNVTEGroupedLastDims, &ld_tensor); + } + + if (!same_first || !same_last) { + NVTE_CHECK_CUDA(cudaMalloc(&grouped.offsets_dev, num_tensors * sizeof(int64_t))); + NVTE_CHECK_CUDA(cudaMemcpy(grouped.offsets_dev, offsets.data(), + num_tensors * sizeof(int64_t), cudaMemcpyHostToDevice)); + NVTEShape off_shape = nvte_make_shape(&num_tensors, 1); + NVTEBasicTensor off_tensor{grouped.offsets_dev, kNVTEInt64, off_shape}; + nvte_set_grouped_tensor_param(&grouped.handle, kNVTEGroupedTensorOffsets, &off_tensor); + } + + if (isFp8Type(dtype)) { + std::vector scale_inv_cpu(num_tensors, 1.f); + for (size_t i = 0; i < num_tensors; ++i) { + tensors[i]->to_cpu(); + scale_inv_cpu[i] = tensors[i]->rowwise_cpu_scale_inv_ptr()[0]; + } + NVTE_CHECK_CUDA(cudaMalloc(&grouped.scale_inv, sizeof(float) * num_tensors)); + NVTE_CHECK_CUDA(cudaMemcpy(grouped.scale_inv, scale_inv_cpu.data(), + sizeof(float) * num_tensors, cudaMemcpyHostToDevice)); + NVTEShape scale_shape = nvte_make_shape(&num_tensors, 1); + NVTEBasicTensor scale_tensor{grouped.scale_inv, kNVTEFloat32, scale_shape}; + nvte_set_grouped_tensor_param(&grouped.handle, kNVTEGroupedRowwiseScaleInv, &scale_tensor); + nvte_set_grouped_tensor_param(&grouped.handle, kNVTEGroupedColumnwiseScaleInv, &scale_tensor); + } + + return grouped; +} + +Tensor make_fp8_operand(const std::string& name, const std::vector& shape) { + Tensor input_fp32(name + "_fp32", shape, DType::kFloat32); + fillUniform(&input_fp32); + + Tensor fp8(name, shape, TypeInfo::dtype, true, true, NVTE_DELAYED_TENSOR_SCALING); + + nvte_compute_amax(input_fp32.data(), fp8.data(), 0); + QuantizationConfigWrapper config; + nvte_compute_scale_from_amax(fp8.data(), config, 0); + nvte_quantize(input_fp32.data(), fp8.data(), 0); + return fp8; +} + +Tensor make_bf16_operand(const std::string& name, const std::vector& shape) { + Tensor t(name, shape, DType::kBFloat16); + fillUniform(&t); + return t; +} + +struct TestParams { + InputCase input_case; + bool transa; + bool transb; + ShapeCase shape_case; +}; + +std::vector> make_shapes(ShapeCase scase) { + switch (scase) { + case ShapeCase::kAllSame: + return {{64, 64, 32}, {64, 64, 32}, {64, 64, 32}}; + case ShapeCase::kSameFirst: // M wspólne, N/K zróżnicowane + return {{64, 64, 32}, {64, 96, 32}, {64, 80, 48}}; + case ShapeCase::kSameLast: // N wspólne, M/K zróżnicowane + return {{48, 80, 32}, {96, 80, 48}, {72, 80, 40}}; + case ShapeCase::kAllDifferent: + default: + return {{48, 80, 32}, {96, 64, 48}, {40, 72, 24}}; + } +} + +void run_grouped_gemm_case(const TestParams& params) { + if (params.input_case != InputCase::kBF16 && + getDeviceComputeCapability() < hopperComputeCapability) { + GTEST_SKIP() << "FP8 grouped GEMM requires Hopper or newer."; + } + + const std::vector> shapes = make_shapes(params.shape_case); + + const size_t num_gemms = shapes.size(); + std::vector A_tensors; + std::vector B_tensors; + std::vector D_multi; + + A_tensors.reserve(num_gemms); + B_tensors.reserve(num_gemms); + D_multi.reserve(num_gemms); + + for (size_t i = 0; i < num_gemms; ++i) { + const auto [M, N, K] = shapes[i]; + const std::vector a_shape = params.transa ? std::vector{K, M} + : std::vector{M, K}; + const std::vector b_shape = params.transb ? std::vector{N, K} + : std::vector{K, N}; + switch (params.input_case) { + case InputCase::kFP8Current: { + A_tensors.emplace_back(make_fp8_operand("A" + std::to_string(i), a_shape)); + B_tensors.emplace_back(make_fp8_operand("B" + std::to_string(i), b_shape)); + break; + } + case InputCase::kBF16: { + A_tensors.emplace_back(make_bf16_operand("A" + std::to_string(i), a_shape)); + B_tensors.emplace_back(make_bf16_operand("B" + std::to_string(i), b_shape)); + break; + } + } + D_multi.emplace_back(Tensor("D_multi" + std::to_string(i), + std::vector{M, N}, + DType::kBFloat16)); + } + + std::vector A_ptrs(num_gemms); + std::vector B_ptrs(num_gemms); + std::vector D_ptrs(num_gemms); + std::vector bias_ptrs(num_gemms, nullptr); + std::vector gelu_ptrs(num_gemms, nullptr); + std::vector workspaces(num_gemms); + std::vector workspace_ptrs(num_gemms, nullptr); + + const size_t cublas_ws_bytes = 32ull * 1024 * 1024; + + for (size_t i = 0; i < num_gemms; ++i) { + A_ptrs[i] = A_tensors[i].data(); + B_ptrs[i] = B_tensors[i].data(); + D_ptrs[i] = D_multi[i].data(); + workspaces[i] = Tensor("workspace" + std::to_string(i), std::vector{cublas_ws_bytes}, DType::kByte); + workspace_ptrs[i] = workspaces[i].data(); + } + + nvte_multi_tensor_gemm(A_ptrs.data(), + B_ptrs.data(), + D_ptrs.data(), + bias_ptrs.data(), + gelu_ptrs.data(), + static_cast(num_gemms), + params.transa, + params.transb, + false, + workspace_ptrs.data(), + false, + false, + 0, + 0); + + GroupedBuffers grouped_A = build_grouped_tensor(A_tensors, A_tensors[0].scaling_mode()); + GroupedBuffers grouped_B = build_grouped_tensor(B_tensors, B_tensors[0].scaling_mode()); + + std::vector C_tensors; + std::vector D_group_tensors; + C_tensors.reserve(num_gemms); + D_group_tensors.reserve(num_gemms); + for (size_t i = 0; i < num_gemms; ++i) { + const auto [M, N, K] = shapes[i]; + (void)K; + C_tensors.emplace_back(Tensor("C" + std::to_string(i), + std::vector{static_cast(M), static_cast(N)}, + DType::kBFloat16)); + D_group_tensors.emplace_back(Tensor("D_group" + std::to_string(i), + std::vector{static_cast(M), static_cast(N)}, + DType::kBFloat16)); + NVTE_CHECK_CUDA(cudaMemset(D_group_tensors.back().rowwise_dptr(), 0, bytes(D_group_tensors.back().rowwise_shape(), D_group_tensors.back().dtype()))); + } + + std::vector C_views, D_views; + for (size_t i = 0; i < num_gemms; ++i) { + C_views.push_back(&C_tensors[i]); + D_views.push_back(&D_group_tensors[i]); + } + + GroupedBuffers grouped_C = build_grouped_tensor(C_views, NVTE_DELAYED_TENSOR_SCALING); + GroupedBuffers grouped_D = build_grouped_tensor(D_views, NVTE_DELAYED_TENSOR_SCALING); + + Tensor alpha_tensor("alpha", std::vector{1}, DType::kFloat32); + Tensor beta_tensor("beta", std::vector{1}, DType::kFloat32); + const float alpha_val = 1.f; + const float beta_val = 0.f; + NVTE_CHECK_CUDA(cudaMemcpy(alpha_tensor.rowwise_dptr(), &alpha_val, sizeof(float), cudaMemcpyHostToDevice)); + NVTE_CHECK_CUDA(cudaMemcpy(beta_tensor.rowwise_dptr(), &beta_val, sizeof(float), cudaMemcpyHostToDevice)); + + const size_t setup_ws_bytes = grouped_setup_workspace_size(num_gemms); + Tensor setup_ws("setup_ws", std::vector{setup_ws_bytes}, DType::kByte); + Tensor cublas_ws("cublas_ws", std::vector{cublas_ws_bytes}, DType::kByte); + + nvte_grouped_gemm(params.transa, + params.transb, + alpha_tensor.data(), + grouped_A.handle, + grouped_B.handle, + beta_tensor.data(), + grouped_C.handle, + grouped_D.handle, + setup_ws.data(), + cublas_ws.data(), + nullptr, + 0, + nullptr, + nullptr, + nullptr); + + for (size_t i = 0; i < num_gemms; ++i) { + Tensor grouped_split("grouped_D" + std::to_string(i), + std::vector{static_cast(std::get<0>(shapes[i])), + static_cast(std::get<1>(shapes[i]))}, + D_multi[i].dtype()); + const size_t offset_bytes = static_cast(grouped_D.offsets_host[i]) * grouped_D.elem_size; + NVTE_CHECK_CUDA(cudaMemcpy(grouped_split.rowwise_dptr(), + static_cast(grouped_D.data) + offset_bytes, + grouped_D.tensor_bytes[i], + cudaMemcpyDeviceToDevice)); + grouped_split.to_cpu(); + D_multi[i].to_cpu(); + auto [atol, rtol] = getTolerances(D_multi[i].dtype()); + compareResults("grouped_vs_multi", + grouped_split, + D_multi[i].rowwise_cpu_dptr(), + true, + atol, + rtol); + } +} + +class GroupedGemmTest : public ::testing::TestWithParam {}; + +TEST_P(GroupedGemmTest, CompareWithMultiTensorGemm) { + run_grouped_gemm_case(GetParam()); +} + +std::string MakeGroupedGemmTestName(const testing::TestParamInfo& info) { + constexpr const char* kInputNames[] = {"FP8Delayed", "FP8Current", "BF16"}; + constexpr const char* kShapeNames[] = {"AllSame", "SameM", "SameN", "AllDiff"}; + const std::string layout = std::string("ta") + (info.param.transa ? "T" : "N") + + "tb" + (info.param.transb ? "T" : "N"); + return std::string(kInputNames[static_cast(info.param.input_case)]) + "_" + + kShapeNames[static_cast(info.param.shape_case)] + "_" + layout; +} + +const std::vector kTestParams = { + {InputCase::kFP8Current, true, false, ShapeCase::kAllDifferent}, + {InputCase::kFP8Current, false, true, ShapeCase::kAllDifferent}, + {InputCase::kFP8Current, false, false, ShapeCase::kAllSame}, + {InputCase::kBF16, true, false, ShapeCase::kSameFirst}, + {InputCase::kBF16, false, true, ShapeCase::kSameLast}, + {InputCase::kBF16, false, false, ShapeCase::kAllSame}, + {InputCase::kBF16, true, true, ShapeCase::kAllDifferent}, +}; + +INSTANTIATE_TEST_SUITE_P(OperatorTest, + GroupedGemmTest, + ::testing::ValuesIn(kTestParams), + MakeGroupedGemmTestName); + +} // namespace + + diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 97e8ec9a3ea..53be59cc00c 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1104,3 +1104,487 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor cublas_path(); } } + + +// Helper struct to pass per-tensor shape/offset info (pointer or uniform value) +struct TensorShapeInfo { + const int64_t *first_dims; // nullptr if uniform + const int64_t *last_dims; // nullptr if uniform + const int64_t *offsets; // nullptr if need to compute + int64_t uniform_first; // used if first_dims == nullptr + int64_t uniform_last; // used if last_dims == nullptr + + // Create from GroupedTensor + static TensorShapeInfo from_tensor(const transformer_engine::GroupedTensor *t) { + return { + t->first_dims.has_data() ? static_cast(t->first_dims.dptr) : nullptr, + t->last_dims.has_data() ? static_cast(t->last_dims.dptr) : nullptr, + t->tensor_offsets.has_data() ? static_cast(t->tensor_offsets.dptr) : nullptr, + t->get_common_first_dim(), + t->get_common_last_dim()}; + } + + // Create for C tensor (uses D's dimensions, only has offsets) + static TensorShapeInfo for_C(const transformer_engine::GroupedTensor *C, + const transformer_engine::GroupedTensor *D) { + return { + nullptr, + nullptr, + C->tensor_offsets.has_data() ? static_cast(C->tensor_offsets.dptr) : nullptr, + D->get_common_first_dim(), + D->get_common_last_dim()}; + } +}; + +// Helper functions to compute average dimensions from logical_shape for heuristics +// These are hints for cuBLASLt algorithm selection, don't need to be exact +inline int64_t compute_avg_first_dim(const transformer_engine::GroupedTensor* t) { + // logical_shape[0] is either num_tensors*M (uniform) or sum_of_M (varying first) + // In both cases, dividing by num_tensors gives the average + return static_cast(t->logical_shape.data[0]) / static_cast(t->num_tensors); +} + +inline int64_t compute_avg_last_dim(const transformer_engine::GroupedTensor* t) { + if (t->all_same_last_dim()) { + // logical_shape[1] is the common N + return static_cast(t->logical_shape.data[1]); + } else { + // logical_shape[1] is sum_of_N, divide by num_tensors + return static_cast(t->logical_shape.data[1]) / static_cast(t->num_tensors); + } +} + +// Workspace layout for grouped GEMM +struct GroupedGemmSetupWorkspace { + void **A_ptrs; + void **B_ptrs; + void **C_ptrs; + void **D_ptrs; + int *M; + int *N; + int *K; + float **alpha_ptrs; + float **beta_ptrs; + + // Initialize from workspace buffer + static GroupedGemmSetupWorkspace from_buffers(char *setup_ws_ptr, size_t num_tensors, size_t alignment) { + GroupedGemmSetupWorkspace ws; + size_t offset = 0; + const size_t ptr_size = num_tensors * sizeof(void *); + const size_t int_size = num_tensors * sizeof(int); + + ws.A_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; + ws.B_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; + ws.C_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; + ws.D_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; + ws.M = reinterpret_cast(setup_ws_ptr + offset); offset += int_size; + ws.N = reinterpret_cast(setup_ws_ptr + offset); offset += int_size; + ws.K = reinterpret_cast(setup_ws_ptr + offset); offset += int_size; + ws.alpha_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; + ws.beta_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; + + offset = ((offset + alignment - 1) / alignment) * alignment; + + return ws; + } + + // Calculate required size for setup workspace (pointer arrays + M/N/K + alpha/beta ptrs) + static size_t required_setup_size(size_t num_tensors, size_t alignment) { + const size_t ptr_size = num_tensors * sizeof(void *); + const size_t int_size = num_tensors * sizeof(int); + size_t size = 4 * ptr_size + 3 * int_size + 2 * ptr_size; // M, N, K only (no LDA/LDB/LDC/LDD) + size = ((size + alignment - 1) / alignment) * alignment; + return size; + } +}; + +// ----------------------------------------------------------------------------- +// Helper routines to keep nvte_grouped_gemm readable +// ----------------------------------------------------------------------------- +inline void validate_grouped_gemm_inputs(const transformer_engine::GroupedTensor* inputA, + const transformer_engine::GroupedTensor* inputB, + const transformer_engine::GroupedTensor* inputC, + const transformer_engine::GroupedTensor* outputD) { + const size_t num_tensors = inputA->num_tensors; + NVTE_CHECK(num_tensors >= 1, "Grouped GEMM: num_tensors must be at least 1"); + NVTE_CHECK(inputB->num_tensors == num_tensors, + "Grouped GEMM: A and B must have the same num_tensors"); + NVTE_CHECK(inputC->num_tensors == num_tensors, + "Grouped GEMM: A and C must have the same num_tensors"); + NVTE_CHECK(outputD->num_tensors == num_tensors, + "Grouped GEMM: A and D must have the same num_tensors"); + + auto is_fp8_or_16bit = [](DType dtype) { + return dtype == DType::kFloat8E4M3 || dtype == DType::kFloat8E5M2 || + dtype == DType::kBFloat16 || dtype == DType::kFloat16; + }; + auto is_output_dtype = [](DType dtype) { + return dtype == DType::kBFloat16 || dtype == DType::kFloat16 || dtype == DType::kFloat32; + }; + NVTE_CHECK(is_fp8_or_16bit(inputA->dtype()) && is_fp8_or_16bit(inputB->dtype()), + "Grouped GEMM inputs must be FP8, BF16, or FP16."); + NVTE_CHECK(is_output_dtype(inputC->dtype()) && is_output_dtype(outputD->dtype()), + "Grouped GEMM outputs must be BF16, FP16, or FP32."); + NVTE_CHECK(inputA->has_data() || inputA->has_columnwise_data(), + "Grouped GEMM: A tensor is missing both row-wise and column-wise data"); + NVTE_CHECK(inputB->has_data() || inputB->has_columnwise_data(), + "Grouped GEMM: B tensor is missing both row-wise and column-wise data"); +} + +// Select row-wise vs column-wise storage and adjust transpose flag for grouped GEMM. +// Mirrors the non-grouped GEMM logic for FP8 layout handling (TN-only on Hopper) and +// fallback to column-wise data when row-wise is absent. +struct GroupedOperandSelection { + const char* base = nullptr; + transformer_engine::DType dtype = transformer_engine::DType::kNumTypes; + bool trans = false; + bool use_columnwise = false; +}; + +inline GroupedOperandSelection select_grouped_operand(const transformer_engine::GroupedTensor* t, + bool trans, bool is_A) { + using namespace transformer_engine; + const bool has_row = t->has_data(); + const bool has_col = t->has_columnwise_data(); + NVTE_CHECK(has_row || has_col, "Grouped GEMM operand is missing both row-wise and column-wise data"); + + // Not yet supported in grouped GEMM: block scaling, MXFP8, NVFP4 specialized layouts. + const auto sm = t->scaling_mode; + NVTE_CHECK(sm != NVTE_BLOCK_SCALING_1D && sm != NVTE_BLOCK_SCALING_2D && + !is_mxfp_scaling(sm) && !is_nvfp_scaling(sm), + "Grouped GEMM does not yet support NVFP4/MXFP8/block scaling operand selection"); + + const DType row_dtype = t->data.dtype; + const DType col_dtype = t->columnwise_data.dtype; + GroupedOperandSelection sel; + sel.trans = trans; + + const DType rep_dtype = has_row ? row_dtype : col_dtype; + const bool is_fp8 = is_fp8_dtype(rep_dtype); + const bool non_tn_fp8_ok = nvte_is_non_tn_fp8_gemm_supported(); + + // Hopper-style TN-only FP8: force TN by switching layout and flipping transpose when needed. + if (is_fp8 && !non_tn_fp8_ok) { + if (is_A) { + if (!sel.trans) { + NVTE_CHECK(has_col, "Grouped GEMM: A is missing column-wise data needed for FP8 TN layout"); + sel.base = static_cast(t->columnwise_data.dptr); + sel.dtype = col_dtype; + sel.trans = true; // using pre-transposed storage + sel.use_columnwise = true; + return sel; + } + } else { // B + if (sel.trans) { + NVTE_CHECK(has_col, "Grouped GEMM: B is missing column-wise data needed for FP8 TN layout"); + sel.base = static_cast(t->columnwise_data.dptr); + sel.dtype = col_dtype; + sel.trans = false; // using pre-transposed storage + sel.use_columnwise = true; + return sel; + } + } + } + + // If only column-wise data is available, mirror the transpose flag (pre-transposed storage). + if (!has_row && has_col) { + sel.base = static_cast(t->columnwise_data.dptr); + sel.dtype = col_dtype; + sel.trans = !sel.trans; + sel.use_columnwise = true; + return sel; + } + + // Default: use row-wise data (or column-wise if row-wise absent, covered above). + sel.base = static_cast(has_row ? t->data.dptr : t->columnwise_data.dptr); + sel.dtype = has_row ? row_dtype : col_dtype; + sel.use_columnwise = !has_row && has_col; + return sel; +} + +inline void* validate_and_get_workspace_ptr(transformer_engine::Tensor* ws, size_t required_size, + const char* workspace_name) { + NVTE_CHECK(ws != nullptr, workspace_name, " tensor is null."); + const size_t provided_size = get_buffer_size_bytes(ws->data.numel(), ws->data.dtype); + NVTE_CHECK(provided_size >= required_size, + "Grouped GEMM: Insufficient ", workspace_name, ". Required: ", required_size, + " bytes, Available: ", provided_size, " bytes."); + return ws->data.dptr; +} + +inline void init_matrix_layouts(cublasLtMatrixLayoutOpaque_t& descA, + cublasLtMatrixLayoutOpaque_t& descB, + cublasLtMatrixLayoutOpaque_t& descC, + cublasLtMatrixLayoutOpaque_t& descD, + const GroupedGemmWorkspace& ws, bool transa, bool transb, + bool a_columnwise, bool b_columnwise, + size_t num_tensors, cudaDataType_t A_type, cudaDataType_t B_type, + cudaDataType_t D_type) { + // For column-major layout: leading dimension is the number of rows in storage. + // If columnwise data was chosen, storage is already transposed. + const int* rowa = a_columnwise ? ws.M : (transa ? ws.K : ws.M); + const int* cola = a_columnwise ? ws.K : (transa ? ws.M : ws.K); + const int* lda = rowa; + const int* rowb = b_columnwise ? ws.N : (transb ? ws.N : ws.K); + const int* colb = b_columnwise ? ws.K : (transb ? ws.K : ws.N); + const int* ldb = rowb; + + NVTE_CHECK_CUBLAS( + cublasLtGroupedMatrixLayoutInit(&descA, A_type, num_tensors, (void*)rowa, (void*)cola, (void*)lda)); + NVTE_CHECK_CUBLAS( + cublasLtGroupedMatrixLayoutInit(&descB, B_type, num_tensors, (void*)rowb, (void*)colb, (void*)ldb)); + NVTE_CHECK_CUBLAS( + cublasLtGroupedMatrixLayoutInit(&descC, D_type, num_tensors, (void*)ws.M, (void*)ws.N, (void*)ws.M)); + NVTE_CHECK_CUBLAS( + cublasLtGroupedMatrixLayoutInit(&descD, D_type, num_tensors, (void*)ws.M, (void*)ws.N, (void*)ws.M)); +} + +inline void init_matmul_desc(cublasLtMatmulDescOpaque_t& matmulDesc, cublasOperation_t op_A, + cublasOperation_t op_B) { + NVTE_CHECK_CUBLAS(cublasLtMatmulDescInit(&matmulDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F)); + + NVTE_CHECK_CUBLAS( + cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_TRANSA, &op_A, sizeof(op_A))); + NVTE_CHECK_CUBLAS( + cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &op_B, sizeof(op_B))); + + cublasLtPointerMode_t pointer_mode = CUBLASLT_POINTER_MODE_DEVICE; + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, + &pointer_mode, sizeof(pointer_mode))); + + int64_t alphabeta_batch_stride = 1; + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_ALPHA_BATCH_STRIDE, + &alphabeta_batch_stride, sizeof(int64_t))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_BETA_BATCH_STRIDE, + &alphabeta_batch_stride, sizeof(int64_t))); +} + +inline cublasLtMatmulAlgo_t select_grouped_gemm_algo(cublasLtHandle_t handle, + cublasLtMatmulDescOpaque_t& matmulDesc, + cublasLtMatrixLayoutOpaque_t& descA, + cublasLtMatrixLayoutOpaque_t& descB, + cublasLtMatrixLayoutOpaque_t& descC, + cublasLtMatrixLayoutOpaque_t& descD, int64_t avg_m, + int64_t avg_n, int64_t avg_k) { + cublasLtMatmulPreferenceOpaque_t preference; + NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceInit(&preference)); + NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( + &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &kGroupedGemmCublasWorkspaceSize, + sizeof(size_t))); + NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( + &preference, CUBLASLT_MATMUL_PREF_GROUPED_DESC_D_AVERAGE_ROWS, &avg_m, sizeof(int64_t))); + NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( + &preference, CUBLASLT_MATMUL_PREF_GROUPED_DESC_D_AVERAGE_COLS, &avg_n, sizeof(int64_t))); + NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( + &preference, CUBLASLT_MATMUL_PREF_GROUPED_AVERAGE_REDUCTION_DIM, &avg_k, sizeof(int64_t))); + + cublasLtMatmulHeuristicResult_t heuristicResult; + int returnedResults = 0; + auto status = cublasLtMatmulAlgoGetHeuristic(handle, &matmulDesc, &descA, &descB, &descC, &descD, + &preference, 1, &heuristicResult, &returnedResults); + NVTE_CHECK(status != CUBLAS_STATUS_NOT_SUPPORTED, "Unable to find suitable cuBLAS grouped GEMM algorithm"); + NVTE_CHECK_CUBLAS(status); + NVTE_CHECK(returnedResults > 0, "No suitable algorithm found for grouped GEMM"); + return heuristicResult.algo; +} + +// Single kernel that sets up all GEMM parameters. +// Rationale: cuBLASLt grouped matmul API needs flat arrays of pointers and per-matrix M/N/K, +// but NVTEGroupedTensor stores a single contiguous buffer + optional per-tensor offsets/shapes. +// We bridge the mismatch on GPU by computing per-group pointers and dims in one kernel. +__global__ void setup_grouped_gemm_kernel( + // Output arrays + void **A_ptrs, void **B_ptrs, void **C_ptrs, void **D_ptrs, + int *M, int *N, int *K, + float **alpha_ptrs, float **beta_ptrs, + // Base pointers + const char *a_base, const char *b_base, const char *c_base, char *d_base, + // Dimension info (per tensor) + TensorShapeInfo A_meta, TensorShapeInfo B_meta, + TensorShapeInfo C_meta, TensorShapeInfo D_meta, + // Element sizes + size_t a_elem_size, size_t b_elem_size, size_t c_elem_size, size_t d_elem_size, + // Alpha/beta pointers (same for all groups) + float *alpha_ptr, float *beta_ptr, + // Transpose flags + bool transa, bool transb, + // Number of tensors + size_t num_tensors) { + + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= num_tensors) return; + + // Get dimensions for this tensor (from array or uniform value) + int64_t a_first = A_meta.first_dims ? A_meta.first_dims[idx] : A_meta.uniform_first; + int64_t a_last = A_meta.last_dims ? A_meta.last_dims[idx] : A_meta.uniform_last; + int64_t b_first = B_meta.first_dims ? B_meta.first_dims[idx] : B_meta.uniform_first; + int64_t b_last = B_meta.last_dims ? B_meta.last_dims[idx] : B_meta.uniform_last; + + // Compute offsets (from array or compute from uniform dims) + int64_t a_offset = A_meta.offsets ? A_meta.offsets[idx] : (idx * A_meta.uniform_first * A_meta.uniform_last); + int64_t b_offset = B_meta.offsets ? B_meta.offsets[idx] : (idx * B_meta.uniform_first * B_meta.uniform_last); + int64_t c_offset = C_meta.offsets ? C_meta.offsets[idx] : (idx * C_meta.uniform_first * C_meta.uniform_last); + int64_t d_offset = D_meta.offsets ? D_meta.offsets[idx] : (idx * D_meta.uniform_first * D_meta.uniform_last); + + // Compute data pointers + A_ptrs[idx] = const_cast(a_base) + a_offset * a_elem_size; + B_ptrs[idx] = const_cast(b_base) + b_offset * b_elem_size; + C_ptrs[idx] = const_cast(c_base) + c_offset * c_elem_size; + D_ptrs[idx] = d_base + d_offset * d_elem_size; + + // Compute M, N, K dimensions + M[idx] = static_cast(transa ? a_last : a_first); + K[idx] = static_cast(transa ? a_first : a_last); + N[idx] = static_cast(transb ? b_first : b_last); + + // Fill alpha/beta pointers (same for all groups) + alpha_ptrs[idx] = alpha_ptr; + beta_ptrs[idx] = beta_ptr; +} + +// Launch the setup kernel to populate workspace arrays +inline void launch_grouped_gemm_setup( + const GroupedGemmWorkspace &ws, + const transformer_engine::GroupedTensor *A, + const transformer_engine::GroupedTensor *B, + const transformer_engine::GroupedTensor *C, + const transformer_engine::GroupedTensor *D, + const transformer_engine::Tensor *alpha_tensor, + const transformer_engine::Tensor *beta_tensor, + const char *a_base, const char *b_base, + size_t a_elem_size, size_t b_elem_size, + bool transa, bool transb, + size_t num_tensors, cudaStream_t stream) { + + TensorShapeInfo A_meta = TensorShapeInfo::from_tensor(A); + TensorShapeInfo B_meta = TensorShapeInfo::from_tensor(B); + TensorShapeInfo C_meta = TensorShapeInfo::for_C(C, D); + TensorShapeInfo D_meta = TensorShapeInfo::from_tensor(D); + + const char *c_base = static_cast(C->data.dptr); + char *d_base = static_cast(D->data.dptr); + + const size_t c_elem_size = transformer_engine::typeToSize(C->dtype()); + const size_t d_elem_size = transformer_engine::typeToSize(D->dtype()); + + const int threads_per_block = 256; + const int num_blocks = (num_tensors + threads_per_block - 1) / threads_per_block; + + setup_grouped_gemm_kernel<<>>( + ws.A_ptrs, ws.B_ptrs, ws.C_ptrs, ws.D_ptrs, + ws.M, ws.N, ws.K, + ws.alpha_ptrs, ws.beta_ptrs, + a_base, b_base, c_base, d_base, + A_meta, B_meta, C_meta, D_meta, + a_elem_size, b_elem_size, c_elem_size, d_elem_size, + static_cast(alpha_tensor->data.dptr), + static_cast(beta_tensor->data.dptr), + transa, transb, num_tensors); + + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +// Constants for grouped GEMM workspace +static constexpr size_t kGroupedGemmAlignment = 256; +static constexpr size_t kGroupedGemmCublasWorkspaceSize = 32ull * 1024 * 1024; // 32 MiB + +inline size_t grouped_gemm_setup_workspace_size(size_t num_tensors) { + return GroupedGemmSetupWorkspace::required_setup_size(num_tensors, kGroupedGemmAlignment); +} + +void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, + const NVTEGroupedTensor A, const NVTEGroupedTensor B, + const NVTETensor beta, const NVTEGroupedTensor C, NVTEGroupedTensor D, + NVTETensor workspace_setup, NVTETensor workspace_cublas, + NVTEMatmulConfig config, cudaStream_t stream, + const int64_t* avg_m, const int64_t* avg_n, const int64_t* avg_k) { + NVTE_API_CALL(nvte_grouped_gemm); + using namespace transformer_engine; + + // Convert to internal types + const GroupedTensor *inputA = convertNVTEGroupedTensorCheck(A); + const GroupedTensor *inputB = convertNVTEGroupedTensorCheck(B); + const GroupedTensor *inputC = convertNVTEGroupedTensorCheck(C); + GroupedTensor *outputD = convertNVTEGroupedTensorCheck(D); + const Tensor *alpha_tensor = convertNVTETensorCheck(alpha); + const Tensor *beta_tensor = convertNVTETensorCheck(beta); + Tensor *wspace_setup = convertNVTETensor(workspace_setup); + Tensor *wspace_cublas = convertNVTETensor(workspace_cublas); + + // Validate inputs and num_tensors + validate_grouped_gemm_inputs(inputA, inputB, inputC, outputD); + const size_t num_tensors = inputA->num_tensors; + + // Select operand storage (row-wise vs column-wise) and adjust transpose flags to + // mirror the non-grouped GEMM logic for FP8 layout constraints. + bool transa_flag = static_cast(transa); + bool transb_flag = static_cast(transb); + const auto A_sel = select_grouped_operand(inputA, transa_flag, /*is_A=*/true); + const auto B_sel = select_grouped_operand(inputB, transb_flag, /*is_A=*/false); + transa_flag = A_sel.trans; + transb_flag = B_sel.trans; + const size_t a_elem_size = transformer_engine::typeToSize(A_sel.dtype); + const size_t b_elem_size = transformer_engine::typeToSize(B_sel.dtype); + + // Workspaces: setup (pointer arrays) and cuBLAS + const size_t setup_workspace_size = grouped_gemm_setup_workspace_size(num_tensors); + const size_t cublas_workspace_size = kGroupedGemmCublasWorkspaceSize; + + void* setup_workspace_ptr = + validate_and_get_workspace_ptr(wspace_setup, setup_workspace_size, "Grouped GEMM setup workspace"); + void* cublas_workspace_ptr = + validate_and_get_workspace_ptr(wspace_cublas, cublas_workspace_size, "Grouped GEMM cuBLAS workspace"); + + NVTE_CHECK(cublas_workspace_ptr != nullptr, "Grouped GEMM: cuBLAS workspace pointer is null"); + + auto setup_workspace = GroupedGemmSetupWorkspace::from_buffers( + static_cast(setup_workspace_ptr), num_tensors, kGroupedGemmAlignment); + launch_grouped_gemm_setup(setup_workspace, inputA, inputB, inputC, outputD, + alpha_tensor, beta_tensor, + A_sel.base, B_sel.base, a_elem_size, b_elem_size, + transa_flag, transb_flag, + num_tensors, stream); + + // Get cuBLAS handle + using cublasHandleManager = detail::HandleManager; + cublasLtHandle_t handle = cublasHandleManager::Instance().GetHandle(); + + // Get data types + const cudaDataType_t A_type = get_cuda_dtype(A_sel.dtype); + const cudaDataType_t B_type = get_cuda_dtype(B_sel.dtype); + const cudaDataType_t D_type = get_cuda_dtype(outputD->dtype()); + + // Setup cuBLAS operations + cublasOperation_t op_A = transa_flag ? CUBLAS_OP_T : CUBLAS_OP_N; + cublasOperation_t op_B = transb_flag ? CUBLAS_OP_T : CUBLAS_OP_N; + + // Create grouped matrix layouts + cublasLtMatrixLayoutOpaque_t descA, descB, descC, descD; + init_matrix_layouts(descA, descB, descC, descD, setup_workspace, + transa_flag, transb_flag, A_sel.use_columnwise, B_sel.use_columnwise, + num_tensors, A_type, B_type, D_type); + + // Create matmul descriptor + cublasLtMatmulDescOpaque_t matmulDesc; + init_matmul_desc(matmulDesc, op_A, op_B); + + // Compute average dimensions for heuristics + // K dimension: if transa, K is A's first dim; if not, K is A's last dim + int64_t avg_m_val = avg_m ? *avg_m : compute_avg_first_dim(outputD); + int64_t avg_n_val = avg_n ? *avg_n : compute_avg_last_dim(outputD); + int64_t avg_k_val = + avg_k ? *avg_k : (transa_flag ? compute_avg_first_dim(inputA) : compute_avg_last_dim(inputA)); + + // Heuristic selection + cublasLtMatmulAlgo_t algo = + select_grouped_gemm_algo(handle, matmulDesc, descA, descB, descC, descD, avg_m_val, avg_n_val, + avg_k_val); + + // Execute the grouped GEMM + NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, &matmulDesc, setup_workspace.alpha_ptrs, + setup_workspace.A_ptrs, &descA, setup_workspace.B_ptrs, &descB, + setup_workspace.beta_ptrs, setup_workspace.C_ptrs, + &descC, setup_workspace.D_ptrs, &descD, + &algo, cublas_workspace_ptr, + kGroupedGemmCublasWorkspaceSize, stream)); +} diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index 950014cc9be..51241aef6b3 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -228,6 +228,42 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor bool transa, bool transb, bool grad, NVTETensor *workspace, bool accumulate, bool use_split_accumulator, int math_sm_count, cudaStream_t stream); + +/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ +/*! \brief Grouped matrix multiplication: D = alpha * op(A) @ op(B) + beta * C + * + * Performs batched GEMM on a collection of matrices with potentially different shapes. + * All tensors in the group must have compatible dimensions for matrix multiplication. + * Uses NVTEGroupedTensor to efficiently handle collections of tensors with contiguous + * memory layout and shape metadata. + * + * \param[in] transa Whether to transpose A matrices. + * \param[in] transb Whether to transpose B matrices. + * \param[in] alpha Scale multiplier for A @ B (NVTETensor with num_tensors elements, + * or single element for uniform alpha). + * \param[in] A Input grouped tensor A. + * \param[in] B Input grouped tensor B. + * \param[in] beta Scale multiplier for C (NVTETensor with num_tensors elements, + * or single element for uniform beta). + * \param[in] C Input grouped tensor C (can be NULL for beta=0). + * \param[out] D Output grouped tensor D. + * \param[in] workspace Workspace tensor for intermediate computations. + * \param[in] config Matrix multiplication configuration. + * \param[in] stream CUDA stream for the operation. + * + * Requirements: + * - A, B, C (if provided), D must have the same num_tensors + * - For each i: D[i] = alpha[i] * op(A[i]) @ op(B[i]) + beta[i] * C[i] + * - Shape compatibility: if transa=false, transb=false: + * - A[i]: (M[i], K[i]), B[i]: (K[i], N[i]), D[i]: (M[i], N[i]) + */ +void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, + const NVTEGroupedTensor A, const NVTEGroupedTensor B, + const NVTETensor beta, const NVTEGroupedTensor C, NVTEGroupedTensor D, + NVTETensor workspace_setup, NVTETensor workspace_cublas, + NVTEMatmulConfig config, cudaStream_t stream, + const int64_t* avg_m, const int64_t* avg_n, const int64_t* avg_k); + #ifdef __cplusplus } // extern "C" #endif // __cplusplus From 76293d4dc9ebb8a7e1c7ba2ae47f866d56998d33 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 10 Dec 2025 14:32:15 +0000 Subject: [PATCH 02/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/cpp/operator/test_grouped_gemm.cu | 2 - .../common/gemm/cublaslt_gemm.cu | 279 +++++++++--------- .../common/include/transformer_engine/gemm.h | 11 +- 3 files changed, 141 insertions(+), 151 deletions(-) diff --git a/tests/cpp/operator/test_grouped_gemm.cu b/tests/cpp/operator/test_grouped_gemm.cu index 0e9c6c6a4d6..d346e068879 100644 --- a/tests/cpp/operator/test_grouped_gemm.cu +++ b/tests/cpp/operator/test_grouped_gemm.cu @@ -507,5 +507,3 @@ INSTANTIATE_TEST_SUITE_P(OperatorTest, MakeGroupedGemmTestName); } // namespace - - diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 53be59cc00c..2c8c2093c6f 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1105,46 +1105,42 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor } } - // Helper struct to pass per-tensor shape/offset info (pointer or uniform value) struct TensorShapeInfo { - const int64_t *first_dims; // nullptr if uniform - const int64_t *last_dims; // nullptr if uniform - const int64_t *offsets; // nullptr if need to compute - int64_t uniform_first; // used if first_dims == nullptr - int64_t uniform_last; // used if last_dims == nullptr + const int64_t *first_dims; // nullptr if uniform + const int64_t *last_dims; // nullptr if uniform + const int64_t *offsets; // nullptr if need to compute + int64_t uniform_first; // used if first_dims == nullptr + int64_t uniform_last; // used if last_dims == nullptr // Create from GroupedTensor static TensorShapeInfo from_tensor(const transformer_engine::GroupedTensor *t) { - return { - t->first_dims.has_data() ? static_cast(t->first_dims.dptr) : nullptr, - t->last_dims.has_data() ? static_cast(t->last_dims.dptr) : nullptr, - t->tensor_offsets.has_data() ? static_cast(t->tensor_offsets.dptr) : nullptr, - t->get_common_first_dim(), - t->get_common_last_dim()}; + return {t->first_dims.has_data() ? static_cast(t->first_dims.dptr) : nullptr, + t->last_dims.has_data() ? static_cast(t->last_dims.dptr) : nullptr, + t->tensor_offsets.has_data() ? static_cast(t->tensor_offsets.dptr) + : nullptr, + t->get_common_first_dim(), t->get_common_last_dim()}; } // Create for C tensor (uses D's dimensions, only has offsets) static TensorShapeInfo for_C(const transformer_engine::GroupedTensor *C, const transformer_engine::GroupedTensor *D) { - return { - nullptr, - nullptr, - C->tensor_offsets.has_data() ? static_cast(C->tensor_offsets.dptr) : nullptr, - D->get_common_first_dim(), - D->get_common_last_dim()}; + return {nullptr, nullptr, + C->tensor_offsets.has_data() ? static_cast(C->tensor_offsets.dptr) + : nullptr, + D->get_common_first_dim(), D->get_common_last_dim()}; } }; // Helper functions to compute average dimensions from logical_shape for heuristics // These are hints for cuBLASLt algorithm selection, don't need to be exact -inline int64_t compute_avg_first_dim(const transformer_engine::GroupedTensor* t) { +inline int64_t compute_avg_first_dim(const transformer_engine::GroupedTensor *t) { // logical_shape[0] is either num_tensors*M (uniform) or sum_of_M (varying first) // In both cases, dividing by num_tensors gives the average return static_cast(t->logical_shape.data[0]) / static_cast(t->num_tensors); } -inline int64_t compute_avg_last_dim(const transformer_engine::GroupedTensor* t) { +inline int64_t compute_avg_last_dim(const transformer_engine::GroupedTensor *t) { if (t->all_same_last_dim()) { // logical_shape[1] is the common N return static_cast(t->logical_shape.data[1]); @@ -1167,21 +1163,31 @@ struct GroupedGemmSetupWorkspace { float **beta_ptrs; // Initialize from workspace buffer - static GroupedGemmSetupWorkspace from_buffers(char *setup_ws_ptr, size_t num_tensors, size_t alignment) { + static GroupedGemmSetupWorkspace from_buffers(char *setup_ws_ptr, size_t num_tensors, + size_t alignment) { GroupedGemmSetupWorkspace ws; size_t offset = 0; const size_t ptr_size = num_tensors * sizeof(void *); const size_t int_size = num_tensors * sizeof(int); - ws.A_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; - ws.B_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; - ws.C_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; - ws.D_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; - ws.M = reinterpret_cast(setup_ws_ptr + offset); offset += int_size; - ws.N = reinterpret_cast(setup_ws_ptr + offset); offset += int_size; - ws.K = reinterpret_cast(setup_ws_ptr + offset); offset += int_size; - ws.alpha_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; - ws.beta_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; + ws.A_ptrs = reinterpret_cast(setup_ws_ptr + offset); + offset += ptr_size; + ws.B_ptrs = reinterpret_cast(setup_ws_ptr + offset); + offset += ptr_size; + ws.C_ptrs = reinterpret_cast(setup_ws_ptr + offset); + offset += ptr_size; + ws.D_ptrs = reinterpret_cast(setup_ws_ptr + offset); + offset += ptr_size; + ws.M = reinterpret_cast(setup_ws_ptr + offset); + offset += int_size; + ws.N = reinterpret_cast(setup_ws_ptr + offset); + offset += int_size; + ws.K = reinterpret_cast(setup_ws_ptr + offset); + offset += int_size; + ws.alpha_ptrs = reinterpret_cast(setup_ws_ptr + offset); + offset += ptr_size; + ws.beta_ptrs = reinterpret_cast(setup_ws_ptr + offset); + offset += ptr_size; offset = ((offset + alignment - 1) / alignment) * alignment; @@ -1201,10 +1207,10 @@ struct GroupedGemmSetupWorkspace { // ----------------------------------------------------------------------------- // Helper routines to keep nvte_grouped_gemm readable // ----------------------------------------------------------------------------- -inline void validate_grouped_gemm_inputs(const transformer_engine::GroupedTensor* inputA, - const transformer_engine::GroupedTensor* inputB, - const transformer_engine::GroupedTensor* inputC, - const transformer_engine::GroupedTensor* outputD) { +inline void validate_grouped_gemm_inputs(const transformer_engine::GroupedTensor *inputA, + const transformer_engine::GroupedTensor *inputB, + const transformer_engine::GroupedTensor *inputC, + const transformer_engine::GroupedTensor *outputD) { const size_t num_tensors = inputA->num_tensors; NVTE_CHECK(num_tensors >= 1, "Grouped GEMM: num_tensors must be at least 1"); NVTE_CHECK(inputB->num_tensors == num_tensors, @@ -1235,23 +1241,24 @@ inline void validate_grouped_gemm_inputs(const transformer_engine::GroupedTensor // Mirrors the non-grouped GEMM logic for FP8 layout handling (TN-only on Hopper) and // fallback to column-wise data when row-wise is absent. struct GroupedOperandSelection { - const char* base = nullptr; + const char *base = nullptr; transformer_engine::DType dtype = transformer_engine::DType::kNumTypes; bool trans = false; bool use_columnwise = false; }; -inline GroupedOperandSelection select_grouped_operand(const transformer_engine::GroupedTensor* t, +inline GroupedOperandSelection select_grouped_operand(const transformer_engine::GroupedTensor *t, bool trans, bool is_A) { using namespace transformer_engine; const bool has_row = t->has_data(); const bool has_col = t->has_columnwise_data(); - NVTE_CHECK(has_row || has_col, "Grouped GEMM operand is missing both row-wise and column-wise data"); + NVTE_CHECK(has_row || has_col, + "Grouped GEMM operand is missing both row-wise and column-wise data"); // Not yet supported in grouped GEMM: block scaling, MXFP8, NVFP4 specialized layouts. const auto sm = t->scaling_mode; - NVTE_CHECK(sm != NVTE_BLOCK_SCALING_1D && sm != NVTE_BLOCK_SCALING_2D && - !is_mxfp_scaling(sm) && !is_nvfp_scaling(sm), + NVTE_CHECK(sm != NVTE_BLOCK_SCALING_1D && sm != NVTE_BLOCK_SCALING_2D && !is_mxfp_scaling(sm) && + !is_nvfp_scaling(sm), "Grouped GEMM does not yet support NVFP4/MXFP8/block scaling operand selection"); const DType row_dtype = t->data.dtype; @@ -1268,7 +1275,7 @@ inline GroupedOperandSelection select_grouped_operand(const transformer_engine:: if (is_A) { if (!sel.trans) { NVTE_CHECK(has_col, "Grouped GEMM: A is missing column-wise data needed for FP8 TN layout"); - sel.base = static_cast(t->columnwise_data.dptr); + sel.base = static_cast(t->columnwise_data.dptr); sel.dtype = col_dtype; sel.trans = true; // using pre-transposed storage sel.use_columnwise = true; @@ -1277,7 +1284,7 @@ inline GroupedOperandSelection select_grouped_operand(const transformer_engine:: } else { // B if (sel.trans) { NVTE_CHECK(has_col, "Grouped GEMM: B is missing column-wise data needed for FP8 TN layout"); - sel.base = static_cast(t->columnwise_data.dptr); + sel.base = static_cast(t->columnwise_data.dptr); sel.dtype = col_dtype; sel.trans = false; // using pre-transposed storage sel.use_columnwise = true; @@ -1288,7 +1295,7 @@ inline GroupedOperandSelection select_grouped_operand(const transformer_engine:: // If only column-wise data is available, mirror the transpose flag (pre-transposed storage). if (!has_row && has_col) { - sel.base = static_cast(t->columnwise_data.dptr); + sel.base = static_cast(t->columnwise_data.dptr); sel.dtype = col_dtype; sel.trans = !sel.trans; sel.use_columnwise = true; @@ -1296,81 +1303,81 @@ inline GroupedOperandSelection select_grouped_operand(const transformer_engine:: } // Default: use row-wise data (or column-wise if row-wise absent, covered above). - sel.base = static_cast(has_row ? t->data.dptr : t->columnwise_data.dptr); + sel.base = static_cast(has_row ? t->data.dptr : t->columnwise_data.dptr); sel.dtype = has_row ? row_dtype : col_dtype; - sel.use_columnwise = !has_row && has_col; + sel.use_columnwise = !has_row && has_col; return sel; } -inline void* validate_and_get_workspace_ptr(transformer_engine::Tensor* ws, size_t required_size, - const char* workspace_name) { +inline void *validate_and_get_workspace_ptr(transformer_engine::Tensor *ws, size_t required_size, + const char *workspace_name) { NVTE_CHECK(ws != nullptr, workspace_name, " tensor is null."); const size_t provided_size = get_buffer_size_bytes(ws->data.numel(), ws->data.dtype); - NVTE_CHECK(provided_size >= required_size, - "Grouped GEMM: Insufficient ", workspace_name, ". Required: ", required_size, - " bytes, Available: ", provided_size, " bytes."); + NVTE_CHECK(provided_size >= required_size, "Grouped GEMM: Insufficient ", workspace_name, + ". Required: ", required_size, " bytes, Available: ", provided_size, " bytes."); return ws->data.dptr; } -inline void init_matrix_layouts(cublasLtMatrixLayoutOpaque_t& descA, - cublasLtMatrixLayoutOpaque_t& descB, - cublasLtMatrixLayoutOpaque_t& descC, - cublasLtMatrixLayoutOpaque_t& descD, - const GroupedGemmWorkspace& ws, bool transa, bool transb, - bool a_columnwise, bool b_columnwise, +inline void init_matrix_layouts(cublasLtMatrixLayoutOpaque_t &descA, + cublasLtMatrixLayoutOpaque_t &descB, + cublasLtMatrixLayoutOpaque_t &descC, + cublasLtMatrixLayoutOpaque_t &descD, const GroupedGemmWorkspace &ws, + bool transa, bool transb, bool a_columnwise, bool b_columnwise, size_t num_tensors, cudaDataType_t A_type, cudaDataType_t B_type, cudaDataType_t D_type) { // For column-major layout: leading dimension is the number of rows in storage. // If columnwise data was chosen, storage is already transposed. - const int* rowa = a_columnwise ? ws.M : (transa ? ws.K : ws.M); - const int* cola = a_columnwise ? ws.K : (transa ? ws.M : ws.K); - const int* lda = rowa; - const int* rowb = b_columnwise ? ws.N : (transb ? ws.N : ws.K); - const int* colb = b_columnwise ? ws.K : (transb ? ws.K : ws.N); - const int* ldb = rowb; - - NVTE_CHECK_CUBLAS( - cublasLtGroupedMatrixLayoutInit(&descA, A_type, num_tensors, (void*)rowa, (void*)cola, (void*)lda)); - NVTE_CHECK_CUBLAS( - cublasLtGroupedMatrixLayoutInit(&descB, B_type, num_tensors, (void*)rowb, (void*)colb, (void*)ldb)); - NVTE_CHECK_CUBLAS( - cublasLtGroupedMatrixLayoutInit(&descC, D_type, num_tensors, (void*)ws.M, (void*)ws.N, (void*)ws.M)); - NVTE_CHECK_CUBLAS( - cublasLtGroupedMatrixLayoutInit(&descD, D_type, num_tensors, (void*)ws.M, (void*)ws.N, (void*)ws.M)); + const int *rowa = a_columnwise ? ws.M : (transa ? ws.K : ws.M); + const int *cola = a_columnwise ? ws.K : (transa ? ws.M : ws.K); + const int *lda = rowa; + const int *rowb = b_columnwise ? ws.N : (transb ? ws.N : ws.K); + const int *colb = b_columnwise ? ws.K : (transb ? ws.K : ws.N); + const int *ldb = rowb; + + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descA, A_type, num_tensors, (void *)rowa, + (void *)cola, (void *)lda)); + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descB, B_type, num_tensors, (void *)rowb, + (void *)colb, (void *)ldb)); + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descC, D_type, num_tensors, (void *)ws.M, + (void *)ws.N, (void *)ws.M)); + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descD, D_type, num_tensors, (void *)ws.M, + (void *)ws.N, (void *)ws.M)); } -inline void init_matmul_desc(cublasLtMatmulDescOpaque_t& matmulDesc, cublasOperation_t op_A, +inline void init_matmul_desc(cublasLtMatmulDescOpaque_t &matmulDesc, cublasOperation_t op_A, cublasOperation_t op_B) { NVTE_CHECK_CUBLAS(cublasLtMatmulDescInit(&matmulDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F)); - NVTE_CHECK_CUBLAS( - cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_TRANSA, &op_A, sizeof(op_A))); - NVTE_CHECK_CUBLAS( - cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &op_B, sizeof(op_B))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_TRANSA, &op_A, + sizeof(op_A))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &op_B, + sizeof(op_B))); cublasLtPointerMode_t pointer_mode = CUBLASLT_POINTER_MODE_DEVICE; NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointer_mode, sizeof(pointer_mode))); int64_t alphabeta_batch_stride = 1; - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_ALPHA_BATCH_STRIDE, + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, + CUBLASLT_MATMUL_DESC_ALPHA_BATCH_STRIDE, &alphabeta_batch_stride, sizeof(int64_t))); - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_BETA_BATCH_STRIDE, + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, + CUBLASLT_MATMUL_DESC_BETA_BATCH_STRIDE, &alphabeta_batch_stride, sizeof(int64_t))); } inline cublasLtMatmulAlgo_t select_grouped_gemm_algo(cublasLtHandle_t handle, - cublasLtMatmulDescOpaque_t& matmulDesc, - cublasLtMatrixLayoutOpaque_t& descA, - cublasLtMatrixLayoutOpaque_t& descB, - cublasLtMatrixLayoutOpaque_t& descC, - cublasLtMatrixLayoutOpaque_t& descD, int64_t avg_m, - int64_t avg_n, int64_t avg_k) { + cublasLtMatmulDescOpaque_t &matmulDesc, + cublasLtMatrixLayoutOpaque_t &descA, + cublasLtMatrixLayoutOpaque_t &descB, + cublasLtMatrixLayoutOpaque_t &descC, + cublasLtMatrixLayoutOpaque_t &descD, + int64_t avg_m, int64_t avg_n, int64_t avg_k) { cublasLtMatmulPreferenceOpaque_t preference; NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceInit(&preference)); - NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( - &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &kGroupedGemmCublasWorkspaceSize, - sizeof(size_t))); + NVTE_CHECK_CUBLAS( + cublasLtMatmulPreferenceSetAttribute(&preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &kGroupedGemmCublasWorkspaceSize, sizeof(size_t))); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( &preference, CUBLASLT_MATMUL_PREF_GROUPED_DESC_D_AVERAGE_ROWS, &avg_m, sizeof(int64_t))); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( @@ -1382,7 +1389,8 @@ inline cublasLtMatmulAlgo_t select_grouped_gemm_algo(cublasLtHandle_t handle, int returnedResults = 0; auto status = cublasLtMatmulAlgoGetHeuristic(handle, &matmulDesc, &descA, &descB, &descC, &descD, &preference, 1, &heuristicResult, &returnedResults); - NVTE_CHECK(status != CUBLAS_STATUS_NOT_SUPPORTED, "Unable to find suitable cuBLAS grouped GEMM algorithm"); + NVTE_CHECK(status != CUBLAS_STATUS_NOT_SUPPORTED, + "Unable to find suitable cuBLAS grouped GEMM algorithm"); NVTE_CHECK_CUBLAS(status); NVTE_CHECK(returnedResults > 0, "No suitable algorithm found for grouped GEMM"); return heuristicResult.algo; @@ -1394,14 +1402,12 @@ inline cublasLtMatmulAlgo_t select_grouped_gemm_algo(cublasLtHandle_t handle, // We bridge the mismatch on GPU by computing per-group pointers and dims in one kernel. __global__ void setup_grouped_gemm_kernel( // Output arrays - void **A_ptrs, void **B_ptrs, void **C_ptrs, void **D_ptrs, - int *M, int *N, int *K, + void **A_ptrs, void **B_ptrs, void **C_ptrs, void **D_ptrs, int *M, int *N, int *K, float **alpha_ptrs, float **beta_ptrs, // Base pointers const char *a_base, const char *b_base, const char *c_base, char *d_base, // Dimension info (per tensor) - TensorShapeInfo A_meta, TensorShapeInfo B_meta, - TensorShapeInfo C_meta, TensorShapeInfo D_meta, + TensorShapeInfo A_meta, TensorShapeInfo B_meta, TensorShapeInfo C_meta, TensorShapeInfo D_meta, // Element sizes size_t a_elem_size, size_t b_elem_size, size_t c_elem_size, size_t d_elem_size, // Alpha/beta pointers (same for all groups) @@ -1410,7 +1416,6 @@ __global__ void setup_grouped_gemm_kernel( bool transa, bool transb, // Number of tensors size_t num_tensors) { - size_t idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx >= num_tensors) return; @@ -1421,10 +1426,14 @@ __global__ void setup_grouped_gemm_kernel( int64_t b_last = B_meta.last_dims ? B_meta.last_dims[idx] : B_meta.uniform_last; // Compute offsets (from array or compute from uniform dims) - int64_t a_offset = A_meta.offsets ? A_meta.offsets[idx] : (idx * A_meta.uniform_first * A_meta.uniform_last); - int64_t b_offset = B_meta.offsets ? B_meta.offsets[idx] : (idx * B_meta.uniform_first * B_meta.uniform_last); - int64_t c_offset = C_meta.offsets ? C_meta.offsets[idx] : (idx * C_meta.uniform_first * C_meta.uniform_last); - int64_t d_offset = D_meta.offsets ? D_meta.offsets[idx] : (idx * D_meta.uniform_first * D_meta.uniform_last); + int64_t a_offset = + A_meta.offsets ? A_meta.offsets[idx] : (idx * A_meta.uniform_first * A_meta.uniform_last); + int64_t b_offset = + B_meta.offsets ? B_meta.offsets[idx] : (idx * B_meta.uniform_first * B_meta.uniform_last); + int64_t c_offset = + C_meta.offsets ? C_meta.offsets[idx] : (idx * C_meta.uniform_first * C_meta.uniform_last); + int64_t d_offset = + D_meta.offsets ? D_meta.offsets[idx] : (idx * D_meta.uniform_first * D_meta.uniform_last); // Compute data pointers A_ptrs[idx] = const_cast(a_base) + a_offset * a_elem_size; @@ -1444,18 +1453,12 @@ __global__ void setup_grouped_gemm_kernel( // Launch the setup kernel to populate workspace arrays inline void launch_grouped_gemm_setup( - const GroupedGemmWorkspace &ws, - const transformer_engine::GroupedTensor *A, - const transformer_engine::GroupedTensor *B, - const transformer_engine::GroupedTensor *C, - const transformer_engine::GroupedTensor *D, - const transformer_engine::Tensor *alpha_tensor, - const transformer_engine::Tensor *beta_tensor, - const char *a_base, const char *b_base, - size_t a_elem_size, size_t b_elem_size, - bool transa, bool transb, - size_t num_tensors, cudaStream_t stream) { - + const GroupedGemmWorkspace &ws, const transformer_engine::GroupedTensor *A, + const transformer_engine::GroupedTensor *B, const transformer_engine::GroupedTensor *C, + const transformer_engine::GroupedTensor *D, const transformer_engine::Tensor *alpha_tensor, + const transformer_engine::Tensor *beta_tensor, const char *a_base, const char *b_base, + size_t a_elem_size, size_t b_elem_size, bool transa, bool transb, size_t num_tensors, + cudaStream_t stream) { TensorShapeInfo A_meta = TensorShapeInfo::from_tensor(A); TensorShapeInfo B_meta = TensorShapeInfo::from_tensor(B); TensorShapeInfo C_meta = TensorShapeInfo::for_C(C, D); @@ -1471,15 +1474,10 @@ inline void launch_grouped_gemm_setup( const int num_blocks = (num_tensors + threads_per_block - 1) / threads_per_block; setup_grouped_gemm_kernel<<>>( - ws.A_ptrs, ws.B_ptrs, ws.C_ptrs, ws.D_ptrs, - ws.M, ws.N, ws.K, - ws.alpha_ptrs, ws.beta_ptrs, - a_base, b_base, c_base, d_base, - A_meta, B_meta, C_meta, D_meta, - a_elem_size, b_elem_size, c_elem_size, d_elem_size, - static_cast(alpha_tensor->data.dptr), - static_cast(beta_tensor->data.dptr), - transa, transb, num_tensors); + ws.A_ptrs, ws.B_ptrs, ws.C_ptrs, ws.D_ptrs, ws.M, ws.N, ws.K, ws.alpha_ptrs, ws.beta_ptrs, + a_base, b_base, c_base, d_base, A_meta, B_meta, C_meta, D_meta, a_elem_size, b_elem_size, + c_elem_size, d_elem_size, static_cast(alpha_tensor->data.dptr), + static_cast(beta_tensor->data.dptr), transa, transb, num_tensors); NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -1492,12 +1490,11 @@ inline size_t grouped_gemm_setup_workspace_size(size_t num_tensors) { return GroupedGemmSetupWorkspace::required_setup_size(num_tensors, kGroupedGemmAlignment); } -void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, - const NVTEGroupedTensor A, const NVTEGroupedTensor B, - const NVTETensor beta, const NVTEGroupedTensor C, NVTEGroupedTensor D, - NVTETensor workspace_setup, NVTETensor workspace_cublas, - NVTEMatmulConfig config, cudaStream_t stream, - const int64_t* avg_m, const int64_t* avg_n, const int64_t* avg_k) { +void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVTEGroupedTensor A, + const NVTEGroupedTensor B, const NVTETensor beta, const NVTEGroupedTensor C, + NVTEGroupedTensor D, NVTETensor workspace_setup, NVTETensor workspace_cublas, + NVTEMatmulConfig config, cudaStream_t stream, const int64_t *avg_m, + const int64_t *avg_n, const int64_t *avg_k) { NVTE_API_CALL(nvte_grouped_gemm); using namespace transformer_engine; @@ -1530,20 +1527,18 @@ void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const size_t setup_workspace_size = grouped_gemm_setup_workspace_size(num_tensors); const size_t cublas_workspace_size = kGroupedGemmCublasWorkspaceSize; - void* setup_workspace_ptr = - validate_and_get_workspace_ptr(wspace_setup, setup_workspace_size, "Grouped GEMM setup workspace"); - void* cublas_workspace_ptr = - validate_and_get_workspace_ptr(wspace_cublas, cublas_workspace_size, "Grouped GEMM cuBLAS workspace"); + void *setup_workspace_ptr = validate_and_get_workspace_ptr(wspace_setup, setup_workspace_size, + "Grouped GEMM setup workspace"); + void *cublas_workspace_ptr = validate_and_get_workspace_ptr(wspace_cublas, cublas_workspace_size, + "Grouped GEMM cuBLAS workspace"); NVTE_CHECK(cublas_workspace_ptr != nullptr, "Grouped GEMM: cuBLAS workspace pointer is null"); auto setup_workspace = GroupedGemmSetupWorkspace::from_buffers( - static_cast(setup_workspace_ptr), num_tensors, kGroupedGemmAlignment); - launch_grouped_gemm_setup(setup_workspace, inputA, inputB, inputC, outputD, - alpha_tensor, beta_tensor, - A_sel.base, B_sel.base, a_elem_size, b_elem_size, - transa_flag, transb_flag, - num_tensors, stream); + static_cast(setup_workspace_ptr), num_tensors, kGroupedGemmAlignment); + launch_grouped_gemm_setup(setup_workspace, inputA, inputB, inputC, outputD, alpha_tensor, + beta_tensor, A_sel.base, B_sel.base, a_elem_size, b_elem_size, + transa_flag, transb_flag, num_tensors, stream); // Get cuBLAS handle using cublasHandleManager = detail::HandleManager; @@ -1560,9 +1555,9 @@ void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, // Create grouped matrix layouts cublasLtMatrixLayoutOpaque_t descA, descB, descC, descD; - init_matrix_layouts(descA, descB, descC, descD, setup_workspace, - transa_flag, transb_flag, A_sel.use_columnwise, B_sel.use_columnwise, - num_tensors, A_type, B_type, D_type); + init_matrix_layouts(descA, descB, descC, descD, setup_workspace, transa_flag, transb_flag, + A_sel.use_columnwise, B_sel.use_columnwise, num_tensors, A_type, B_type, + D_type); // Create matmul descriptor cublasLtMatmulDescOpaque_t matmulDesc; @@ -1576,15 +1571,13 @@ void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, avg_k ? *avg_k : (transa_flag ? compute_avg_first_dim(inputA) : compute_avg_last_dim(inputA)); // Heuristic selection - cublasLtMatmulAlgo_t algo = - select_grouped_gemm_algo(handle, matmulDesc, descA, descB, descC, descD, avg_m_val, avg_n_val, - avg_k_val); + cublasLtMatmulAlgo_t algo = select_grouped_gemm_algo(handle, matmulDesc, descA, descB, descC, + descD, avg_m_val, avg_n_val, avg_k_val); // Execute the grouped GEMM NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, &matmulDesc, setup_workspace.alpha_ptrs, - setup_workspace.A_ptrs, &descA, setup_workspace.B_ptrs, &descB, - setup_workspace.beta_ptrs, setup_workspace.C_ptrs, - &descC, setup_workspace.D_ptrs, &descD, - &algo, cublas_workspace_ptr, + setup_workspace.A_ptrs, &descA, setup_workspace.B_ptrs, &descB, + setup_workspace.beta_ptrs, setup_workspace.C_ptrs, &descC, + setup_workspace.D_ptrs, &descD, &algo, cublas_workspace_ptr, kGroupedGemmCublasWorkspaceSize, stream)); } diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index 51241aef6b3..948058295ee 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -257,12 +257,11 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor * - Shape compatibility: if transa=false, transb=false: * - A[i]: (M[i], K[i]), B[i]: (K[i], N[i]), D[i]: (M[i], N[i]) */ -void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, - const NVTEGroupedTensor A, const NVTEGroupedTensor B, - const NVTETensor beta, const NVTEGroupedTensor C, NVTEGroupedTensor D, - NVTETensor workspace_setup, NVTETensor workspace_cublas, - NVTEMatmulConfig config, cudaStream_t stream, - const int64_t* avg_m, const int64_t* avg_n, const int64_t* avg_k); +void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVTEGroupedTensor A, + const NVTEGroupedTensor B, const NVTETensor beta, const NVTEGroupedTensor C, + NVTEGroupedTensor D, NVTETensor workspace_setup, NVTETensor workspace_cublas, + NVTEMatmulConfig config, cudaStream_t stream, const int64_t *avg_m, + const int64_t *avg_n, const int64_t *avg_k); #ifdef __cplusplus } // extern "C" From 296d77362099c52fa8e19a299f4a4134dc184096 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 10 Dec 2025 18:25:39 +0100 Subject: [PATCH 03/17] Add FP8 scale support and fix alignment for grouped GEMM - Add FP8 scale_inv pointer handling in nvte_grouped_gemm for proper FP8 GEMM - Fix random padding in tests to ensure 16-byte alignment for all dtypes - Reorder GroupedGemmSetupWorkspace members for natural alignment - Remove debug prints Signed-off-by: Pawel Gadzinski --- tests/cpp/operator/test_grouped_gemm.cu | 55 +++++--- .../common/gemm/cublaslt_gemm.cu | 119 +++++++++++++----- .../common/include/transformer_engine/gemm.h | 2 + 3 files changed, 131 insertions(+), 45 deletions(-) diff --git a/tests/cpp/operator/test_grouped_gemm.cu b/tests/cpp/operator/test_grouped_gemm.cu index d346e068879..bff175f405a 100644 --- a/tests/cpp/operator/test_grouped_gemm.cu +++ b/tests/cpp/operator/test_grouped_gemm.cu @@ -1,8 +1,8 @@ -/*********************************************************************** - * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. - **********************************************************************/ + ************************************************************************/ #include #include @@ -16,6 +16,8 @@ #include #include +#include +#include #include #include "../test_common.h" @@ -136,7 +138,7 @@ GroupedBuffers build_grouped_tensor(const std::vector& tensors, const NVTEShape shape = tensors[0]->rowwise_shape(); const DType dtype = tensors[0]->dtype(); const size_t num_tensors = tensors.size(); - const size_t elem_size = typeToSize(dtype); + const size_t elem_size = typeToNumBits(dtype) / 8; GroupedBuffers grouped; grouped.elem_size = elem_size; grouped.num_tensors = num_tensors; @@ -162,9 +164,13 @@ GroupedBuffers build_grouped_tensor(const std::vector& tensors, std::vector offsets(num_tensors, 0); auto random_padding = [&]() -> int64_t { + // Random padding ensuring 16-byte alignment regardless of element size + // cuBLAS requires aligned pointers for vectorized loads static std::mt19937 gen(12345); std::uniform_int_distribution dist(0, 3); - return dist(gen); + // Calculate elements needed for 16-byte alignment + const size_t align_elements = (16 * 8) / typeToNumBits(dtype); // 16 bytes / element_size + return dist(gen) * static_cast(align_elements); }; auto numel = [&](size_t idx) -> int64_t { @@ -301,7 +307,12 @@ Tensor make_fp8_operand(const std::string& name, const std::vector& shap Tensor make_bf16_operand(const std::string& name, const std::vector& shape) { Tensor t(name, shape, DType::kBFloat16); - fillUniform(&t); + // Fill with ones for easier debugging + //fillUniform(&t); + const size_t numel = shape[0] * shape[1]; + std::vector<__nv_bfloat16> ones(numel, __float2bfloat16(1.0f)); + NVTE_CHECK_CUDA(cudaMemcpy(t.rowwise_dptr(), ones.data(), + numel * sizeof(__nv_bfloat16), cudaMemcpyHostToDevice)); return t; } @@ -312,17 +323,21 @@ struct TestParams { ShapeCase shape_case; }; +// Returns a vector of (M, N, K) tuples for each GEMM in the group. +// M - number of rows in output D +// N - number of columns in output D +// K - reduction dimension shared between A and B std::vector> make_shapes(ShapeCase scase) { switch (scase) { case ShapeCase::kAllSame: return {{64, 64, 32}, {64, 64, 32}, {64, 64, 32}}; - case ShapeCase::kSameFirst: // M wspólne, N/K zróżnicowane - return {{64, 64, 32}, {64, 96, 32}, {64, 80, 48}}; - case ShapeCase::kSameLast: // N wspólne, M/K zróżnicowane - return {{48, 80, 32}, {96, 80, 48}, {72, 80, 40}}; + case ShapeCase::kSameFirst: + return {{64, 80, 32}, {64, 80, 48}, {64, 80, 64}}; + case ShapeCase::kSameLast: + return {{64, 80, 32}, {64, 80, 48}, {64, 80, 64}}; case ShapeCase::kAllDifferent: default: - return {{48, 80, 32}, {96, 64, 48}, {40, 72, 24}}; + return {{64, 96, 32}, {64, 96, 48}, {64, 96, 64}}; } } @@ -345,10 +360,10 @@ void run_grouped_gemm_case(const TestParams& params) { for (size_t i = 0; i < num_gemms; ++i) { const auto [M, N, K] = shapes[i]; - const std::vector a_shape = params.transa ? std::vector{K, M} - : std::vector{M, K}; - const std::vector b_shape = params.transb ? std::vector{N, K} - : std::vector{K, N}; + const std::vector a_shape = params.transa ? std::vector{M, K} + : std::vector{K, M}; + const std::vector b_shape = params.transb ? std::vector{K, N} + : std::vector{N, K}; switch (params.input_case) { case InputCase::kFP8Current: { A_tensors.emplace_back(make_fp8_operand("A" + std::to_string(i), a_shape)); @@ -373,6 +388,10 @@ void run_grouped_gemm_case(const TestParams& params) { std::vector gelu_ptrs(num_gemms, nullptr); std::vector workspaces(num_gemms); std::vector workspace_ptrs(num_gemms, nullptr); + std::vector A_views; + std::vector B_views; + A_views.reserve(num_gemms); + B_views.reserve(num_gemms); const size_t cublas_ws_bytes = 32ull * 1024 * 1024; @@ -382,6 +401,8 @@ void run_grouped_gemm_case(const TestParams& params) { D_ptrs[i] = D_multi[i].data(); workspaces[i] = Tensor("workspace" + std::to_string(i), std::vector{cublas_ws_bytes}, DType::kByte); workspace_ptrs[i] = workspaces[i].data(); + A_views.push_back(&A_tensors[i]); + B_views.push_back(&B_tensors[i]); } nvte_multi_tensor_gemm(A_ptrs.data(), @@ -399,8 +420,8 @@ void run_grouped_gemm_case(const TestParams& params) { 0, 0); - GroupedBuffers grouped_A = build_grouped_tensor(A_tensors, A_tensors[0].scaling_mode()); - GroupedBuffers grouped_B = build_grouped_tensor(B_tensors, B_tensors[0].scaling_mode()); + GroupedBuffers grouped_A = build_grouped_tensor(A_views, A_tensors[0].scaling_mode()); + GroupedBuffers grouped_B = build_grouped_tensor(B_views, B_tensors[0].scaling_mode()); std::vector C_tensors; std::vector D_group_tensors; diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 2c8c2093c6f..bb29d58de47 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1115,20 +1115,50 @@ struct TensorShapeInfo { // Create from GroupedTensor static TensorShapeInfo from_tensor(const transformer_engine::GroupedTensor *t) { - return {t->first_dims.has_data() ? static_cast(t->first_dims.dptr) : nullptr, - t->last_dims.has_data() ? static_cast(t->last_dims.dptr) : nullptr, + const bool has_first = t->first_dims.has_data(); + const bool has_last = t->last_dims.has_data(); + // When per-tensor dims are not provided, we must be in the uniform-shape case. + NVTE_CHECK(has_first || t->all_same_first_dim(), + "GroupedTensor is missing first_dims for varying shapes"); + NVTE_CHECK(has_last || t->all_same_last_dim(), + "GroupedTensor is missing last_dims for varying shapes"); + + const int64_t *first_ptr = has_first ? static_cast(t->first_dims.dptr) : nullptr; + const int64_t *last_ptr = has_last ? static_cast(t->last_dims.dptr) : nullptr; + + const int64_t uniform_first = has_first ? 0 : static_cast(t->get_common_first_dim()); + const int64_t uniform_last = has_last ? 0 : static_cast(t->get_common_last_dim()); + + return {first_ptr, + last_ptr, t->tensor_offsets.has_data() ? static_cast(t->tensor_offsets.dptr) : nullptr, - t->get_common_first_dim(), t->get_common_last_dim()}; + uniform_first, + uniform_last}; } // Create for C tensor (uses D's dimensions, only has offsets) static TensorShapeInfo for_C(const transformer_engine::GroupedTensor *C, const transformer_engine::GroupedTensor *D) { - return {nullptr, nullptr, + const bool has_first = D->first_dims.has_data(); + const bool has_last = D->last_dims.has_data(); + NVTE_CHECK(has_first || D->all_same_first_dim(), + "GroupedTensor D is missing first_dims for varying shapes"); + NVTE_CHECK(has_last || D->all_same_last_dim(), + "GroupedTensor D is missing last_dims for varying shapes"); + + const int64_t *first_ptr = + has_first ? static_cast(D->first_dims.dptr) : nullptr; + const int64_t *last_ptr = has_last ? static_cast(D->last_dims.dptr) : nullptr; + const int64_t uniform_first = has_first ? 0 : static_cast(D->get_common_first_dim()); + const int64_t uniform_last = has_last ? 0 : static_cast(D->get_common_last_dim()); + + return {first_ptr, + last_ptr, C->tensor_offsets.has_data() ? static_cast(C->tensor_offsets.dptr) : nullptr, - D->get_common_first_dim(), D->get_common_last_dim()}; + uniform_first, + uniform_last}; } }; @@ -1144,10 +1174,9 @@ inline int64_t compute_avg_last_dim(const transformer_engine::GroupedTensor *t) if (t->all_same_last_dim()) { // logical_shape[1] is the common N return static_cast(t->logical_shape.data[1]); - } else { - // logical_shape[1] is sum_of_N, divide by num_tensors - return static_cast(t->logical_shape.data[1]) / static_cast(t->num_tensors); } + // When varying, logical_shape[1] should be sum of last dims if provided; otherwise fallback to avg via division. + return static_cast(t->logical_shape.data[1]) / static_cast(t->num_tensors); } // Workspace layout for grouped GEMM @@ -1163,6 +1192,7 @@ struct GroupedGemmSetupWorkspace { float **beta_ptrs; // Initialize from workspace buffer + // Layout: all pointer arrays first (8-byte aligned), then int arrays (4-byte aligned) static GroupedGemmSetupWorkspace from_buffers(char *setup_ws_ptr, size_t num_tensors, size_t alignment) { GroupedGemmSetupWorkspace ws; @@ -1170,6 +1200,7 @@ struct GroupedGemmSetupWorkspace { const size_t ptr_size = num_tensors * sizeof(void *); const size_t int_size = num_tensors * sizeof(int); + // Pointer arrays first (all 8-byte aligned) ws.A_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; ws.B_ptrs = reinterpret_cast(setup_ws_ptr + offset); @@ -1178,27 +1209,30 @@ struct GroupedGemmSetupWorkspace { offset += ptr_size; ws.D_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; + ws.alpha_ptrs = reinterpret_cast(setup_ws_ptr + offset); + offset += ptr_size; + ws.beta_ptrs = reinterpret_cast(setup_ws_ptr + offset); + offset += ptr_size; + + // Int arrays last (4-byte aligned, always satisfied after pointer arrays) ws.M = reinterpret_cast(setup_ws_ptr + offset); offset += int_size; ws.N = reinterpret_cast(setup_ws_ptr + offset); offset += int_size; ws.K = reinterpret_cast(setup_ws_ptr + offset); offset += int_size; - ws.alpha_ptrs = reinterpret_cast(setup_ws_ptr + offset); - offset += ptr_size; - ws.beta_ptrs = reinterpret_cast(setup_ws_ptr + offset); - offset += ptr_size; offset = ((offset + alignment - 1) / alignment) * alignment; return ws; } - // Calculate required size for setup workspace (pointer arrays + M/N/K + alpha/beta ptrs) + // Calculate required size for setup workspace (pointer arrays + M/N/K) static size_t required_setup_size(size_t num_tensors, size_t alignment) { const size_t ptr_size = num_tensors * sizeof(void *); const size_t int_size = num_tensors * sizeof(int); - size_t size = 4 * ptr_size + 3 * int_size + 2 * ptr_size; // M, N, K only (no LDA/LDB/LDC/LDD) + // Layout: 6 ptr arrays, then 3 int arrays (no padding needed) + size_t size = 6 * ptr_size + 3 * int_size; size = ((size + alignment - 1) / alignment) * alignment; return size; } @@ -1220,12 +1254,16 @@ inline void validate_grouped_gemm_inputs(const transformer_engine::GroupedTensor NVTE_CHECK(outputD->num_tensors == num_tensors, "Grouped GEMM: A and D must have the same num_tensors"); - auto is_fp8_or_16bit = [](DType dtype) { - return dtype == DType::kFloat8E4M3 || dtype == DType::kFloat8E5M2 || - dtype == DType::kBFloat16 || dtype == DType::kFloat16; + auto is_fp8_or_16bit = [](transformer_engine::DType dtype) { + return dtype == transformer_engine::DType::kFloat8E4M3 || + dtype == transformer_engine::DType::kFloat8E5M2 || + dtype == transformer_engine::DType::kBFloat16 || + dtype == transformer_engine::DType::kFloat16; }; - auto is_output_dtype = [](DType dtype) { - return dtype == DType::kBFloat16 || dtype == DType::kFloat16 || dtype == DType::kFloat32; + auto is_output_dtype = [](transformer_engine::DType dtype) { + return dtype == transformer_engine::DType::kBFloat16 || + dtype == transformer_engine::DType::kFloat16 || + dtype == transformer_engine::DType::kFloat32; }; NVTE_CHECK(is_fp8_or_16bit(inputA->dtype()) && is_fp8_or_16bit(inputB->dtype()), "Grouped GEMM inputs must be FP8, BF16, or FP16."); @@ -1321,7 +1359,8 @@ inline void *validate_and_get_workspace_ptr(transformer_engine::Tensor *ws, size inline void init_matrix_layouts(cublasLtMatrixLayoutOpaque_t &descA, cublasLtMatrixLayoutOpaque_t &descB, cublasLtMatrixLayoutOpaque_t &descC, - cublasLtMatrixLayoutOpaque_t &descD, const GroupedGemmWorkspace &ws, + cublasLtMatrixLayoutOpaque_t &descD, + const GroupedGemmSetupWorkspace &ws, bool transa, bool transb, bool a_columnwise, bool b_columnwise, size_t num_tensors, cudaDataType_t A_type, cudaDataType_t B_type, cudaDataType_t D_type) { @@ -1366,6 +1405,10 @@ inline void init_matmul_desc(cublasLtMatmulDescOpaque_t &matmulDesc, cublasOpera &alphabeta_batch_stride, sizeof(int64_t))); } +// Constants for grouped GEMM workspace (declared early for use in heuristics) +static constexpr size_t kGroupedGemmAlignment = 256; +static constexpr size_t kGroupedGemmCublasWorkspaceSize = 32ull * 1024 * 1024; // 32 MiB + inline cublasLtMatmulAlgo_t select_grouped_gemm_algo(cublasLtHandle_t handle, cublasLtMatmulDescOpaque_t &matmulDesc, cublasLtMatrixLayoutOpaque_t &descA, @@ -1442,9 +1485,11 @@ __global__ void setup_grouped_gemm_kernel( D_ptrs[idx] = d_base + d_offset * d_elem_size; // Compute M, N, K dimensions - M[idx] = static_cast(transa ? a_last : a_first); - K[idx] = static_cast(transa ? a_first : a_last); - N[idx] = static_cast(transb ? b_first : b_last); + // Test stores A as {K,M} when !transa, {M,K} when transa + // Test stores B as {N,K} when !transb, {K,N} when transb + M[idx] = static_cast(transa ? a_first : a_last); + K[idx] = static_cast(transa ? a_last : a_first); + N[idx] = static_cast(transb ? b_last : b_first); // Fill alpha/beta pointers (same for all groups) alpha_ptrs[idx] = alpha_ptr; @@ -1453,7 +1498,7 @@ __global__ void setup_grouped_gemm_kernel( // Launch the setup kernel to populate workspace arrays inline void launch_grouped_gemm_setup( - const GroupedGemmWorkspace &ws, const transformer_engine::GroupedTensor *A, + const GroupedGemmSetupWorkspace &ws, const transformer_engine::GroupedTensor *A, const transformer_engine::GroupedTensor *B, const transformer_engine::GroupedTensor *C, const transformer_engine::GroupedTensor *D, const transformer_engine::Tensor *alpha_tensor, const transformer_engine::Tensor *beta_tensor, const char *a_base, const char *b_base, @@ -1482,10 +1527,6 @@ inline void launch_grouped_gemm_setup( NVTE_CHECK_CUDA(cudaGetLastError()); } -// Constants for grouped GEMM workspace -static constexpr size_t kGroupedGemmAlignment = 256; -static constexpr size_t kGroupedGemmCublasWorkspaceSize = 32ull * 1024 * 1024; // 32 MiB - inline size_t grouped_gemm_setup_workspace_size(size_t num_tensors) { return GroupedGemmSetupWorkspace::required_setup_size(num_tensors, kGroupedGemmAlignment); } @@ -1563,6 +1604,28 @@ void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVT cublasLtMatmulDescOpaque_t matmulDesc; init_matmul_desc(matmulDesc, op_A, op_B); + // Set FP8 scale pointers if needed + const bool is_fp8_a = is_fp8_dtype(A_sel.dtype); + const bool is_fp8_b = is_fp8_dtype(B_sel.dtype); + if (is_fp8_a || is_fp8_b) { + // For FP8 grouped GEMM, we need to pass scale_inv pointers + // The scale_inv arrays contain one float per tensor in the group + if (is_fp8_a) { + void *a_scale_inv = A_sel.use_columnwise ? inputA->columnwise_scale_inv.dptr + : inputA->scale_inv.dptr; + NVTE_CHECK(a_scale_inv != nullptr, "FP8 grouped GEMM: A scale_inv is required"); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + &matmulDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &a_scale_inv, sizeof(a_scale_inv))); + } + if (is_fp8_b) { + void *b_scale_inv = B_sel.use_columnwise ? inputB->columnwise_scale_inv.dptr + : inputB->scale_inv.dptr; + NVTE_CHECK(b_scale_inv != nullptr, "FP8 grouped GEMM: B scale_inv is required"); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + &matmulDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &b_scale_inv, sizeof(b_scale_inv))); + } + } + // Compute average dimensions for heuristics // K dimension: if transa, K is A's first dim; if not, K is A's last dim int64_t avg_m_val = avg_m ? *avg_m : compute_avg_first_dim(outputD); diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index 948058295ee..246fb5fefd9 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -11,6 +11,8 @@ #ifndef TRANSFORMER_ENGINE_GEMM_H_ #define TRANSFORMER_ENGINE_GEMM_H_ +#include + #include "transformer_engine.h" #ifdef __cplusplus From 785df3440a443b72340dfdf33db7391280e3a968 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 10 Dec 2025 17:26:49 +0000 Subject: [PATCH 04/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/gemm/cublaslt_gemm.cu | 29 +++++++++---------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index bb29d58de47..55f52a1c4d9 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1123,18 +1123,17 @@ struct TensorShapeInfo { NVTE_CHECK(has_last || t->all_same_last_dim(), "GroupedTensor is missing last_dims for varying shapes"); - const int64_t *first_ptr = has_first ? static_cast(t->first_dims.dptr) : nullptr; + const int64_t *first_ptr = + has_first ? static_cast(t->first_dims.dptr) : nullptr; const int64_t *last_ptr = has_last ? static_cast(t->last_dims.dptr) : nullptr; const int64_t uniform_first = has_first ? 0 : static_cast(t->get_common_first_dim()); const int64_t uniform_last = has_last ? 0 : static_cast(t->get_common_last_dim()); - return {first_ptr, - last_ptr, + return {first_ptr, last_ptr, t->tensor_offsets.has_data() ? static_cast(t->tensor_offsets.dptr) : nullptr, - uniform_first, - uniform_last}; + uniform_first, uniform_last}; } // Create for C tensor (uses D's dimensions, only has offsets) @@ -1153,12 +1152,10 @@ struct TensorShapeInfo { const int64_t uniform_first = has_first ? 0 : static_cast(D->get_common_first_dim()); const int64_t uniform_last = has_last ? 0 : static_cast(D->get_common_last_dim()); - return {first_ptr, - last_ptr, + return {first_ptr, last_ptr, C->tensor_offsets.has_data() ? static_cast(C->tensor_offsets.dptr) : nullptr, - uniform_first, - uniform_last}; + uniform_first, uniform_last}; } }; @@ -1360,9 +1357,9 @@ inline void init_matrix_layouts(cublasLtMatrixLayoutOpaque_t &descA, cublasLtMatrixLayoutOpaque_t &descB, cublasLtMatrixLayoutOpaque_t &descC, cublasLtMatrixLayoutOpaque_t &descD, - const GroupedGemmSetupWorkspace &ws, - bool transa, bool transb, bool a_columnwise, bool b_columnwise, - size_t num_tensors, cudaDataType_t A_type, cudaDataType_t B_type, + const GroupedGemmSetupWorkspace &ws, bool transa, bool transb, + bool a_columnwise, bool b_columnwise, size_t num_tensors, + cudaDataType_t A_type, cudaDataType_t B_type, cudaDataType_t D_type) { // For column-major layout: leading dimension is the number of rows in storage. // If columnwise data was chosen, storage is already transposed. @@ -1611,15 +1608,15 @@ void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVT // For FP8 grouped GEMM, we need to pass scale_inv pointers // The scale_inv arrays contain one float per tensor in the group if (is_fp8_a) { - void *a_scale_inv = A_sel.use_columnwise ? inputA->columnwise_scale_inv.dptr - : inputA->scale_inv.dptr; + void *a_scale_inv = + A_sel.use_columnwise ? inputA->columnwise_scale_inv.dptr : inputA->scale_inv.dptr; NVTE_CHECK(a_scale_inv != nullptr, "FP8 grouped GEMM: A scale_inv is required"); NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( &matmulDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &a_scale_inv, sizeof(a_scale_inv))); } if (is_fp8_b) { - void *b_scale_inv = B_sel.use_columnwise ? inputB->columnwise_scale_inv.dptr - : inputB->scale_inv.dptr; + void *b_scale_inv = + B_sel.use_columnwise ? inputB->columnwise_scale_inv.dptr : inputB->scale_inv.dptr; NVTE_CHECK(b_scale_inv != nullptr, "FP8 grouped GEMM: B scale_inv is required"); NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( &matmulDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &b_scale_inv, sizeof(b_scale_inv))); From 1329b3746abfe3f9d845e90da7945bede6e3893c Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 10 Dec 2025 22:34:16 +0100 Subject: [PATCH 05/17] fix Signed-off-by: Pawel Gadzinski --- .../common/gemm/cublaslt_gemm.cu | 33 +++++++++---------- 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 55f52a1c4d9..3662247b51f 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1217,9 +1217,6 @@ struct GroupedGemmSetupWorkspace { ws.N = reinterpret_cast(setup_ws_ptr + offset); offset += int_size; ws.K = reinterpret_cast(setup_ws_ptr + offset); - offset += int_size; - - offset = ((offset + alignment - 1) / alignment) * alignment; return ws; } @@ -1363,21 +1360,21 @@ inline void init_matrix_layouts(cublasLtMatrixLayoutOpaque_t &descA, cudaDataType_t D_type) { // For column-major layout: leading dimension is the number of rows in storage. // If columnwise data was chosen, storage is already transposed. - const int *rowa = a_columnwise ? ws.M : (transa ? ws.K : ws.M); - const int *cola = a_columnwise ? ws.K : (transa ? ws.M : ws.K); - const int *lda = rowa; - const int *rowb = b_columnwise ? ws.N : (transb ? ws.N : ws.K); - const int *colb = b_columnwise ? ws.K : (transb ? ws.K : ws.N); - const int *ldb = rowb; - - NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descA, A_type, num_tensors, (void *)rowa, - (void *)cola, (void *)lda)); - NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descB, B_type, num_tensors, (void *)rowb, - (void *)colb, (void *)ldb)); - NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descC, D_type, num_tensors, (void *)ws.M, - (void *)ws.N, (void *)ws.M)); - NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descD, D_type, num_tensors, (void *)ws.M, - (void *)ws.N, (void *)ws.M)); + int *rowa = a_columnwise ? ws.M : (transa ? ws.K : ws.M); + int *cola = a_columnwise ? ws.K : (transa ? ws.M : ws.K); + int *lda = rowa; + int *rowb = b_columnwise ? ws.N : (transb ? ws.N : ws.K); + int *colb = b_columnwise ? ws.K : (transb ? ws.K : ws.N); + int *ldb = rowb; + + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descA, A_type, num_tensors, + rowa, cola, lda)); + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descB, B_type, num_tensors, + rowb, colb, ldb)); + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descC, D_type, num_tensors, + ws.M, ws.N, ws.M)); + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descD, D_type, num_tensors, + ws.M, ws.N, ws.M)); } inline void init_matmul_desc(cublasLtMatmulDescOpaque_t &matmulDesc, cublasOperation_t op_A, From 47c58be8ce0ee14fc26a90a2f8b3ad8035283b4c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 10 Dec 2025 21:35:06 +0000 Subject: [PATCH 06/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/gemm/cublaslt_gemm.cu | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 3662247b51f..91405bd42f5 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1367,14 +1367,10 @@ inline void init_matrix_layouts(cublasLtMatrixLayoutOpaque_t &descA, int *colb = b_columnwise ? ws.K : (transb ? ws.K : ws.N); int *ldb = rowb; - NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descA, A_type, num_tensors, - rowa, cola, lda)); - NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descB, B_type, num_tensors, - rowb, colb, ldb)); - NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descC, D_type, num_tensors, - ws.M, ws.N, ws.M)); - NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descD, D_type, num_tensors, - ws.M, ws.N, ws.M)); + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descA, A_type, num_tensors, rowa, cola, lda)); + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descB, B_type, num_tensors, rowb, colb, ldb)); + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descC, D_type, num_tensors, ws.M, ws.N, ws.M)); + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descD, D_type, num_tensors, ws.M, ws.N, ws.M)); } inline void init_matmul_desc(cublasLtMatmulDescOpaque_t &matmulDesc, cublasOperation_t op_A, From a155a8a3dd17663c82882f64b30a5a118ba3695b Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 11 Dec 2025 11:55:44 +0100 Subject: [PATCH 07/17] Grouped GEMM: code cleanup and NULL C support - Remove unused alignment parameter from GroupedGemmSetupWorkspace::from_buffers - Simplify select_grouped_operand by removing dead code branches - Add GroupedOperandSelection.tensor field to avoid passing tensor separately - Extract set_fp8_scale_pointers and init_matrix_layouts helpers - Add safety check for FP8 on Hopper column-wise fallback - Support NULL C tensor when beta=0 (uses D as placeholder) - Remove unused get_scale_inv() from test - Add use_null_c test parameter and test case - Fix documentation: alpha/beta are single element tensors only Signed-off-by: Piotr Gadzinski Signed-off-by: Pawel Gadzinski --- tests/cpp/operator/test_grouped_gemm.cu | 210 ++++++++---------- .../common/gemm/cublaslt_gemm.cu | 163 +++++++------- .../common/include/transformer_engine/gemm.h | 34 +-- 3 files changed, 203 insertions(+), 204 deletions(-) diff --git a/tests/cpp/operator/test_grouped_gemm.cu b/tests/cpp/operator/test_grouped_gemm.cu index bff175f405a..5e5144fa4c8 100644 --- a/tests/cpp/operator/test_grouped_gemm.cu +++ b/tests/cpp/operator/test_grouped_gemm.cu @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -28,7 +29,6 @@ using namespace test; namespace { enum class InputCase { - kFP8Delayed, kFP8Current, kBF16, }; @@ -40,17 +40,37 @@ enum class ShapeCase { kAllDifferent, }; +// Custom deleters for RAII +struct CudaDeleter { + void operator()(void* p) const { if (p) cudaFree(p); } +}; +struct GroupedTensorDeleter { + void operator()(NVTEGroupedTensor h) const { if (h) nvte_destroy_grouped_tensor(h); } +}; + +template +using CudaPtr = std::unique_ptr; +using GroupedTensorHandle = std::unique_ptr, GroupedTensorDeleter>; + +// Helper to allocate CUDA memory into a CudaPtr +template +CudaPtr cuda_alloc(size_t bytes) { + void* ptr = nullptr; + NVTE_CHECK_CUDA(cudaMalloc(&ptr, bytes)); + return CudaPtr(static_cast(ptr)); +} + // Helper owning GPU buffers that back NVTEGroupedTensor. // NVTEGroupedTensor does not own memory; data/offsets/scales // must be allocated and freed by the test. struct GroupedBuffers { - NVTEGroupedTensor handle{nullptr}; - void* data{nullptr}; - void* scale_inv{nullptr}; - int64_t* first_dims_dev{nullptr}; - int64_t* last_dims_dev{nullptr}; - int64_t* offsets_dev{nullptr}; - void* columnwise_data{nullptr}; + GroupedTensorHandle handle; + CudaPtr<> data; + CudaPtr<> scale_inv; + CudaPtr first_dims_dev; + CudaPtr last_dims_dev; + CudaPtr offsets_dev; + CudaPtr<> columnwise_data; NVTEShape logical_shape{}; std::vector offsets_host; std::vector tensor_bytes; @@ -62,65 +82,13 @@ struct GroupedBuffers { GroupedBuffers() = default; GroupedBuffers(const GroupedBuffers&) = delete; GroupedBuffers& operator=(const GroupedBuffers&) = delete; - GroupedBuffers(GroupedBuffers&& other) noexcept { - *this = std::move(other); - } - GroupedBuffers& operator=(GroupedBuffers&& other) noexcept { - if (this == &other) return *this; - handle = other.handle; - data = other.data; - scale_inv = other.scale_inv; - first_dims_dev = other.first_dims_dev; - last_dims_dev = other.last_dims_dev; - offsets_dev = other.offsets_dev; - logical_shape = other.logical_shape; - offsets_host = std::move(other.offsets_host); - tensor_bytes = std::move(other.tensor_bytes); - num_tensors = other.num_tensors; - elem_size = other.elem_size; - dtype = other.dtype; - scaling_mode = other.scaling_mode; - - other.handle = nullptr; - other.data = nullptr; - other.scale_inv = nullptr; - other.first_dims_dev = nullptr; - other.last_dims_dev = nullptr; - other.offsets_dev = nullptr; - other.num_tensors = 0; - return *this; - } + GroupedBuffers(GroupedBuffers&&) = default; + GroupedBuffers& operator=(GroupedBuffers&&) = default; + ~GroupedBuffers() = default; - ~GroupedBuffers() { - if (data) { - cudaFree(data); - data = nullptr; - } - if (scale_inv) { - cudaFree(scale_inv); - scale_inv = nullptr; - } - if (columnwise_data) { - cudaFree(columnwise_data); - columnwise_data = nullptr; - } - if (first_dims_dev) { - cudaFree(first_dims_dev); - first_dims_dev = nullptr; - } - if (last_dims_dev) { - cudaFree(last_dims_dev); - last_dims_dev = nullptr; - } - if (offsets_dev) { - cudaFree(offsets_dev); - offsets_dev = nullptr; - } - if (handle) { - nvte_destroy_grouped_tensor(handle); - handle = nullptr; - } - } + // Convenience accessors for raw pointers + NVTEGroupedTensor get_handle() const { return handle.get(); } + void* get_data() const { return data.get(); } }; size_t grouped_setup_workspace_size(const size_t num_tensors) { @@ -211,7 +179,7 @@ GroupedBuffers build_grouped_tensor(const std::vector& tensors, size_t logical_data[2] = {static_cast(logical_first), static_cast(logical_last)}; grouped.logical_shape = nvte_make_shape(logical_data, 2); - grouped.handle = nvte_create_grouped_tensor(scaling_mode, num_tensors, grouped.logical_shape); + grouped.handle.reset(nvte_create_grouped_tensor(scaling_mode, num_tensors, grouped.logical_shape)); const int64_t last_idx = static_cast(num_tensors - 1); const int64_t total_elems = need_offsets @@ -219,59 +187,60 @@ GroupedBuffers build_grouped_tensor(const std::vector& tensors, : (logical_first * logical_last); const size_t total_bytes = static_cast(total_elems) * elem_size; - NVTE_CHECK_CUDA(cudaMalloc(&grouped.data, total_bytes)); + grouped.data = cuda_alloc(total_bytes); for (size_t i = 0; i < num_tensors; ++i) { const size_t offset_bytes = static_cast(offsets[i]) * elem_size; - NVTE_CHECK_CUDA(cudaMemcpy(static_cast(grouped.data) + offset_bytes, + NVTE_CHECK_CUDA(cudaMemcpy(static_cast(grouped.data.get()) + offset_bytes, tensors[i]->rowwise_dptr(), grouped.tensor_bytes[i], cudaMemcpyDeviceToDevice)); } - NVTEBasicTensor data_tensor{grouped.data, static_cast(dtype), grouped.logical_shape}; - nvte_set_grouped_tensor_param(&grouped.handle, kNVTEGroupedRowwiseData, &data_tensor); + NVTEBasicTensor data_tensor{grouped.data.get(), static_cast(dtype), grouped.logical_shape}; + NVTEGroupedTensor h = grouped.handle.get(); + nvte_set_grouped_tensor_param(&h, kNVTEGroupedRowwiseData, &data_tensor); const bool include_columnwise = isFp8Type(dtype) || isFp4Type(dtype); if (include_columnwise) { - NVTE_CHECK_CUDA(cudaMalloc(&grouped.columnwise_data, total_bytes)); + grouped.columnwise_data = cuda_alloc(total_bytes); for (size_t i = 0; i < num_tensors; ++i) { const size_t offset_bytes = static_cast(offsets[i]) * elem_size; - NVTE_CHECK_CUDA(cudaMemcpy(static_cast(grouped.columnwise_data) + offset_bytes, + NVTE_CHECK_CUDA(cudaMemcpy(static_cast(grouped.columnwise_data.get()) + offset_bytes, tensors[i]->columnwise_dptr(), grouped.tensor_bytes[i], cudaMemcpyDeviceToDevice)); } - NVTEBasicTensor col_tensor{grouped.columnwise_data, + NVTEBasicTensor col_tensor{grouped.columnwise_data.get(), static_cast(dtype), grouped.logical_shape}; - nvte_set_grouped_tensor_param(&grouped.handle, kNVTEGroupedColumnwiseData, &col_tensor); + nvte_set_grouped_tensor_param(&h, kNVTEGroupedColumnwiseData, &col_tensor); } if (!same_first) { - NVTE_CHECK_CUDA(cudaMalloc(&grouped.first_dims_dev, num_tensors * sizeof(int64_t))); - NVTE_CHECK_CUDA(cudaMemcpy(grouped.first_dims_dev, first_dims.data(), + grouped.first_dims_dev = cuda_alloc(num_tensors * sizeof(int64_t)); + NVTE_CHECK_CUDA(cudaMemcpy(grouped.first_dims_dev.get(), first_dims.data(), num_tensors * sizeof(int64_t), cudaMemcpyHostToDevice)); NVTEShape fd_shape = nvte_make_shape(&num_tensors, 1); - NVTEBasicTensor fd_tensor{grouped.first_dims_dev, kNVTEInt64, fd_shape}; - nvte_set_grouped_tensor_param(&grouped.handle, kNVTEGroupedFirstDims, &fd_tensor); + NVTEBasicTensor fd_tensor{grouped.first_dims_dev.get(), kNVTEInt64, fd_shape}; + nvte_set_grouped_tensor_param(&h, kNVTEGroupedFirstDims, &fd_tensor); } if (!same_last) { - NVTE_CHECK_CUDA(cudaMalloc(&grouped.last_dims_dev, num_tensors * sizeof(int64_t))); - NVTE_CHECK_CUDA(cudaMemcpy(grouped.last_dims_dev, last_dims.data(), + grouped.last_dims_dev = cuda_alloc(num_tensors * sizeof(int64_t)); + NVTE_CHECK_CUDA(cudaMemcpy(grouped.last_dims_dev.get(), last_dims.data(), num_tensors * sizeof(int64_t), cudaMemcpyHostToDevice)); NVTEShape ld_shape = nvte_make_shape(&num_tensors, 1); - NVTEBasicTensor ld_tensor{grouped.last_dims_dev, kNVTEInt64, ld_shape}; - nvte_set_grouped_tensor_param(&grouped.handle, kNVTEGroupedLastDims, &ld_tensor); + NVTEBasicTensor ld_tensor{grouped.last_dims_dev.get(), kNVTEInt64, ld_shape}; + nvte_set_grouped_tensor_param(&h, kNVTEGroupedLastDims, &ld_tensor); } if (!same_first || !same_last) { - NVTE_CHECK_CUDA(cudaMalloc(&grouped.offsets_dev, num_tensors * sizeof(int64_t))); - NVTE_CHECK_CUDA(cudaMemcpy(grouped.offsets_dev, offsets.data(), + grouped.offsets_dev = cuda_alloc(num_tensors * sizeof(int64_t)); + NVTE_CHECK_CUDA(cudaMemcpy(grouped.offsets_dev.get(), offsets.data(), num_tensors * sizeof(int64_t), cudaMemcpyHostToDevice)); NVTEShape off_shape = nvte_make_shape(&num_tensors, 1); - NVTEBasicTensor off_tensor{grouped.offsets_dev, kNVTEInt64, off_shape}; - nvte_set_grouped_tensor_param(&grouped.handle, kNVTEGroupedTensorOffsets, &off_tensor); + NVTEBasicTensor off_tensor{grouped.offsets_dev.get(), kNVTEInt64, off_shape}; + nvte_set_grouped_tensor_param(&h, kNVTEGroupedTensorOffsets, &off_tensor); } if (isFp8Type(dtype)) { @@ -280,13 +249,13 @@ GroupedBuffers build_grouped_tensor(const std::vector& tensors, tensors[i]->to_cpu(); scale_inv_cpu[i] = tensors[i]->rowwise_cpu_scale_inv_ptr()[0]; } - NVTE_CHECK_CUDA(cudaMalloc(&grouped.scale_inv, sizeof(float) * num_tensors)); - NVTE_CHECK_CUDA(cudaMemcpy(grouped.scale_inv, scale_inv_cpu.data(), + grouped.scale_inv = cuda_alloc(sizeof(float) * num_tensors); + NVTE_CHECK_CUDA(cudaMemcpy(grouped.scale_inv.get(), scale_inv_cpu.data(), sizeof(float) * num_tensors, cudaMemcpyHostToDevice)); NVTEShape scale_shape = nvte_make_shape(&num_tensors, 1); - NVTEBasicTensor scale_tensor{grouped.scale_inv, kNVTEFloat32, scale_shape}; - nvte_set_grouped_tensor_param(&grouped.handle, kNVTEGroupedRowwiseScaleInv, &scale_tensor); - nvte_set_grouped_tensor_param(&grouped.handle, kNVTEGroupedColumnwiseScaleInv, &scale_tensor); + NVTEBasicTensor scale_tensor{grouped.scale_inv.get(), kNVTEFloat32, scale_shape}; + nvte_set_grouped_tensor_param(&h, kNVTEGroupedRowwiseScaleInv, &scale_tensor); + nvte_set_grouped_tensor_param(&h, kNVTEGroupedColumnwiseScaleInv, &scale_tensor); } return grouped; @@ -321,6 +290,7 @@ struct TestParams { bool transa; bool transb; ShapeCase shape_case; + bool use_null_c = false; // When true, pass nullptr for C (valid when beta=0) }; // Returns a vector of (M, N, K) tuples for each GEMM in the group. @@ -332,12 +302,14 @@ std::vector> make_shapes(ShapeCase scase) { case ShapeCase::kAllSame: return {{64, 64, 32}, {64, 64, 32}, {64, 64, 32}}; case ShapeCase::kSameFirst: - return {{64, 80, 32}, {64, 80, 48}, {64, 80, 64}}; + // Same M (first dim), varying N and K + return {{64, 80, 32}, {64, 96, 48}, {64, 112, 64}}; case ShapeCase::kSameLast: - return {{64, 80, 32}, {64, 80, 48}, {64, 80, 64}}; + // Same N (last dim), varying M and K + return {{64, 80, 32}, {80, 80, 48}, {96, 80, 64}}; case ShapeCase::kAllDifferent: default: - return {{64, 96, 32}, {64, 96, 48}, {64, 96, 64}}; + return {{64, 96, 32}, {80, 112, 48}, {96, 128, 64}}; } } @@ -430,9 +402,11 @@ void run_grouped_gemm_case(const TestParams& params) { for (size_t i = 0; i < num_gemms; ++i) { const auto [M, N, K] = shapes[i]; (void)K; - C_tensors.emplace_back(Tensor("C" + std::to_string(i), - std::vector{static_cast(M), static_cast(N)}, - DType::kBFloat16)); + if (!params.use_null_c) { + C_tensors.emplace_back(Tensor("C" + std::to_string(i), + std::vector{static_cast(M), static_cast(N)}, + DType::kBFloat16)); + } D_group_tensors.emplace_back(Tensor("D_group" + std::to_string(i), std::vector{static_cast(M), static_cast(N)}, DType::kBFloat16)); @@ -441,11 +415,16 @@ void run_grouped_gemm_case(const TestParams& params) { std::vector C_views, D_views; for (size_t i = 0; i < num_gemms; ++i) { - C_views.push_back(&C_tensors[i]); + if (!params.use_null_c) { + C_views.push_back(&C_tensors[i]); + } D_views.push_back(&D_group_tensors[i]); } - GroupedBuffers grouped_C = build_grouped_tensor(C_views, NVTE_DELAYED_TENSOR_SCALING); + std::optional grouped_C; + if (!params.use_null_c) { + grouped_C = build_grouped_tensor(C_views, NVTE_DELAYED_TENSOR_SCALING); + } GroupedBuffers grouped_D = build_grouped_tensor(D_views, NVTE_DELAYED_TENSOR_SCALING); Tensor alpha_tensor("alpha", std::vector{1}, DType::kFloat32); @@ -462,11 +441,11 @@ void run_grouped_gemm_case(const TestParams& params) { nvte_grouped_gemm(params.transa, params.transb, alpha_tensor.data(), - grouped_A.handle, - grouped_B.handle, + grouped_A.get_handle(), + grouped_B.get_handle(), beta_tensor.data(), - grouped_C.handle, - grouped_D.handle, + params.use_null_c ? nullptr : grouped_C->get_handle(), + grouped_D.get_handle(), setup_ws.data(), cublas_ws.data(), nullptr, @@ -482,7 +461,7 @@ void run_grouped_gemm_case(const TestParams& params) { D_multi[i].dtype()); const size_t offset_bytes = static_cast(grouped_D.offsets_host[i]) * grouped_D.elem_size; NVTE_CHECK_CUDA(cudaMemcpy(grouped_split.rowwise_dptr(), - static_cast(grouped_D.data) + offset_bytes, + static_cast(grouped_D.get_data()) + offset_bytes, grouped_D.tensor_bytes[i], cudaMemcpyDeviceToDevice)); grouped_split.to_cpu(); @@ -504,22 +483,25 @@ TEST_P(GroupedGemmTest, CompareWithMultiTensorGemm) { } std::string MakeGroupedGemmTestName(const testing::TestParamInfo& info) { - constexpr const char* kInputNames[] = {"FP8Delayed", "FP8Current", "BF16"}; + constexpr const char* kInputNames[] = {"FP8Current", "BF16"}; constexpr const char* kShapeNames[] = {"AllSame", "SameM", "SameN", "AllDiff"}; const std::string layout = std::string("ta") + (info.param.transa ? "T" : "N") + "tb" + (info.param.transb ? "T" : "N"); + const std::string null_c = info.param.use_null_c ? "_NullC" : ""; return std::string(kInputNames[static_cast(info.param.input_case)]) + "_" + - kShapeNames[static_cast(info.param.shape_case)] + "_" + layout; + kShapeNames[static_cast(info.param.shape_case)] + "_" + layout + null_c; } const std::vector kTestParams = { - {InputCase::kFP8Current, true, false, ShapeCase::kAllDifferent}, - {InputCase::kFP8Current, false, true, ShapeCase::kAllDifferent}, - {InputCase::kFP8Current, false, false, ShapeCase::kAllSame}, - {InputCase::kBF16, true, false, ShapeCase::kSameFirst}, - {InputCase::kBF16, false, true, ShapeCase::kSameLast}, - {InputCase::kBF16, false, false, ShapeCase::kAllSame}, - {InputCase::kBF16, true, true, ShapeCase::kAllDifferent}, + {InputCase::kFP8Current, true, false, ShapeCase::kAllDifferent, false}, + {InputCase::kFP8Current, false, true, ShapeCase::kAllDifferent, false}, + {InputCase::kFP8Current, false, false, ShapeCase::kAllSame, false}, + {InputCase::kBF16, true, false, ShapeCase::kSameFirst, false}, + {InputCase::kBF16, false, true, ShapeCase::kSameLast, false}, + {InputCase::kBF16, false, false, ShapeCase::kAllSame, false}, + {InputCase::kBF16, true, true, ShapeCase::kAllDifferent, false}, + // Test NULL C (valid when beta=0) + {InputCase::kBF16, false, false, ShapeCase::kAllSame, true}, }; INSTANTIATE_TEST_SUITE_P(OperatorTest, diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 91405bd42f5..9d9a5097d4c 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1190,8 +1190,7 @@ struct GroupedGemmSetupWorkspace { // Initialize from workspace buffer // Layout: all pointer arrays first (8-byte aligned), then int arrays (4-byte aligned) - static GroupedGemmSetupWorkspace from_buffers(char *setup_ws_ptr, size_t num_tensors, - size_t alignment) { + static GroupedGemmSetupWorkspace from_buffers(char *setup_ws_ptr, size_t num_tensors) { GroupedGemmSetupWorkspace ws; size_t offset = 0; const size_t ptr_size = num_tensors * sizeof(void *); @@ -1243,8 +1242,11 @@ inline void validate_grouped_gemm_inputs(const transformer_engine::GroupedTensor NVTE_CHECK(num_tensors >= 1, "Grouped GEMM: num_tensors must be at least 1"); NVTE_CHECK(inputB->num_tensors == num_tensors, "Grouped GEMM: A and B must have the same num_tensors"); - NVTE_CHECK(inputC->num_tensors == num_tensors, - "Grouped GEMM: A and C must have the same num_tensors"); + // C can be NULL (will use D as C when beta=0) + if (inputC != nullptr) { + NVTE_CHECK(inputC->num_tensors == num_tensors, + "Grouped GEMM: A and C must have the same num_tensors"); + } NVTE_CHECK(outputD->num_tensors == num_tensors, "Grouped GEMM: A and D must have the same num_tensors"); @@ -1261,8 +1263,13 @@ inline void validate_grouped_gemm_inputs(const transformer_engine::GroupedTensor }; NVTE_CHECK(is_fp8_or_16bit(inputA->dtype()) && is_fp8_or_16bit(inputB->dtype()), "Grouped GEMM inputs must be FP8, BF16, or FP16."); - NVTE_CHECK(is_output_dtype(inputC->dtype()) && is_output_dtype(outputD->dtype()), - "Grouped GEMM outputs must be BF16, FP16, or FP32."); + // Only check C dtype if C is provided + if (inputC != nullptr) { + NVTE_CHECK(is_output_dtype(inputC->dtype()), + "Grouped GEMM: C must be BF16, FP16, or FP32."); + } + NVTE_CHECK(is_output_dtype(outputD->dtype()), + "Grouped GEMM: D must be BF16, FP16, or FP32."); NVTE_CHECK(inputA->has_data() || inputA->has_columnwise_data(), "Grouped GEMM: A tensor is missing both row-wise and column-wise data"); NVTE_CHECK(inputB->has_data() || inputB->has_columnwise_data(), @@ -1273,6 +1280,7 @@ inline void validate_grouped_gemm_inputs(const transformer_engine::GroupedTensor // Mirrors the non-grouped GEMM logic for FP8 layout handling (TN-only on Hopper) and // fallback to column-wise data when row-wise is absent. struct GroupedOperandSelection { + const transformer_engine::GroupedTensor *tensor = nullptr; const char *base = nullptr; transformer_engine::DType dtype = transformer_engine::DType::kNumTypes; bool trans = false; @@ -1296,6 +1304,7 @@ inline GroupedOperandSelection select_grouped_operand(const transformer_engine:: const DType row_dtype = t->data.dtype; const DType col_dtype = t->columnwise_data.dtype; GroupedOperandSelection sel; + sel.tensor = t; sel.trans = trans; const DType rep_dtype = has_row ? row_dtype : col_dtype; @@ -1327,6 +1336,9 @@ inline GroupedOperandSelection select_grouped_operand(const transformer_engine:: // If only column-wise data is available, mirror the transpose flag (pre-transposed storage). if (!has_row && has_col) { + // On Hopper FP8, this would break TN requirement - should have been handled above + NVTE_CHECK(!is_fp8 || non_tn_fp8_ok, + "Grouped GEMM: FP8 on Hopper requires row-wise data for this transpose configuration"); sel.base = static_cast(t->columnwise_data.dptr); sel.dtype = col_dtype; sel.trans = !sel.trans; @@ -1334,10 +1346,10 @@ inline GroupedOperandSelection select_grouped_operand(const transformer_engine:: return sel; } - // Default: use row-wise data (or column-wise if row-wise absent, covered above). - sel.base = static_cast(has_row ? t->data.dptr : t->columnwise_data.dptr); - sel.dtype = has_row ? row_dtype : col_dtype; - sel.use_columnwise = !has_row && has_col; + // Default: use row-wise data (column-wise case already handled above) + sel.base = static_cast(t->data.dptr); + sel.dtype = row_dtype; + sel.use_columnwise = false; return sel; } @@ -1354,17 +1366,22 @@ inline void init_matrix_layouts(cublasLtMatrixLayoutOpaque_t &descA, cublasLtMatrixLayoutOpaque_t &descB, cublasLtMatrixLayoutOpaque_t &descC, cublasLtMatrixLayoutOpaque_t &descD, - const GroupedGemmSetupWorkspace &ws, bool transa, bool transb, - bool a_columnwise, bool b_columnwise, size_t num_tensors, - cudaDataType_t A_type, cudaDataType_t B_type, - cudaDataType_t D_type) { + const GroupedGemmSetupWorkspace &ws, + const GroupedOperandSelection &A_sel, + const GroupedOperandSelection &B_sel, + const transformer_engine::GroupedTensor *D, + size_t num_tensors) { + const cudaDataType_t A_type = get_cuda_dtype(A_sel.dtype); + const cudaDataType_t B_type = get_cuda_dtype(B_sel.dtype); + const cudaDataType_t D_type = get_cuda_dtype(D->dtype()); + // For column-major layout: leading dimension is the number of rows in storage. // If columnwise data was chosen, storage is already transposed. - int *rowa = a_columnwise ? ws.M : (transa ? ws.K : ws.M); - int *cola = a_columnwise ? ws.K : (transa ? ws.M : ws.K); + int *rowa = A_sel.use_columnwise ? ws.M : (A_sel.trans ? ws.K : ws.M); + int *cola = A_sel.use_columnwise ? ws.K : (A_sel.trans ? ws.M : ws.K); int *lda = rowa; - int *rowb = b_columnwise ? ws.N : (transb ? ws.N : ws.K); - int *colb = b_columnwise ? ws.K : (transb ? ws.K : ws.N); + int *rowb = B_sel.use_columnwise ? ws.N : (B_sel.trans ? ws.N : ws.K); + int *colb = B_sel.use_columnwise ? ws.K : (B_sel.trans ? ws.K : ws.N); int *ldb = rowb; NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descA, A_type, num_tensors, rowa, cola, lda)); @@ -1395,6 +1412,31 @@ inline void init_matmul_desc(cublasLtMatmulDescOpaque_t &matmulDesc, cublasOpera &alphabeta_batch_stride, sizeof(int64_t))); } +inline void set_fp8_scale_pointers(cublasLtMatmulDescOpaque_t &matmulDesc, + const GroupedOperandSelection &A_sel, + const GroupedOperandSelection &B_sel) { + const bool is_fp8_a = is_fp8_dtype(A_sel.dtype); + const bool is_fp8_b = is_fp8_dtype(B_sel.dtype); + if (!is_fp8_a && !is_fp8_b) return; + + if (is_fp8_a) { + void *a_scale_inv = A_sel.use_columnwise + ? A_sel.tensor->columnwise_scale_inv.dptr + : A_sel.tensor->scale_inv.dptr; + NVTE_CHECK(a_scale_inv != nullptr, "FP8 grouped GEMM: A scale_inv is required"); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + &matmulDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &a_scale_inv, sizeof(a_scale_inv))); + } + if (is_fp8_b) { + void *b_scale_inv = B_sel.use_columnwise + ? B_sel.tensor->columnwise_scale_inv.dptr + : B_sel.tensor->scale_inv.dptr; + NVTE_CHECK(b_scale_inv != nullptr, "FP8 grouped GEMM: B scale_inv is required"); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + &matmulDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &b_scale_inv, sizeof(b_scale_inv))); + } +} + // Constants for grouped GEMM workspace (declared early for use in heuristics) static constexpr size_t kGroupedGemmAlignment = 256; static constexpr size_t kGroupedGemmCublasWorkspaceSize = 32ull * 1024 * 1024; // 32 MiB @@ -1488,20 +1530,20 @@ __global__ void setup_grouped_gemm_kernel( // Launch the setup kernel to populate workspace arrays inline void launch_grouped_gemm_setup( - const GroupedGemmSetupWorkspace &ws, const transformer_engine::GroupedTensor *A, - const transformer_engine::GroupedTensor *B, const transformer_engine::GroupedTensor *C, + const GroupedGemmSetupWorkspace &ws, const GroupedOperandSelection &A_sel, + const GroupedOperandSelection &B_sel, const transformer_engine::GroupedTensor *C, const transformer_engine::GroupedTensor *D, const transformer_engine::Tensor *alpha_tensor, - const transformer_engine::Tensor *beta_tensor, const char *a_base, const char *b_base, - size_t a_elem_size, size_t b_elem_size, bool transa, bool transb, size_t num_tensors, - cudaStream_t stream) { - TensorShapeInfo A_meta = TensorShapeInfo::from_tensor(A); - TensorShapeInfo B_meta = TensorShapeInfo::from_tensor(B); + const transformer_engine::Tensor *beta_tensor, size_t num_tensors, cudaStream_t stream) { + TensorShapeInfo A_meta = TensorShapeInfo::from_tensor(A_sel.tensor); + TensorShapeInfo B_meta = TensorShapeInfo::from_tensor(B_sel.tensor); TensorShapeInfo C_meta = TensorShapeInfo::for_C(C, D); TensorShapeInfo D_meta = TensorShapeInfo::from_tensor(D); const char *c_base = static_cast(C->data.dptr); char *d_base = static_cast(D->data.dptr); + const size_t a_elem_size = transformer_engine::typeToSize(A_sel.dtype); + const size_t b_elem_size = transformer_engine::typeToSize(B_sel.dtype); const size_t c_elem_size = transformer_engine::typeToSize(C->dtype()); const size_t d_elem_size = transformer_engine::typeToSize(D->dtype()); @@ -1510,9 +1552,9 @@ inline void launch_grouped_gemm_setup( setup_grouped_gemm_kernel<<>>( ws.A_ptrs, ws.B_ptrs, ws.C_ptrs, ws.D_ptrs, ws.M, ws.N, ws.K, ws.alpha_ptrs, ws.beta_ptrs, - a_base, b_base, c_base, d_base, A_meta, B_meta, C_meta, D_meta, a_elem_size, b_elem_size, - c_elem_size, d_elem_size, static_cast(alpha_tensor->data.dptr), - static_cast(beta_tensor->data.dptr), transa, transb, num_tensors); + A_sel.base, B_sel.base, c_base, d_base, A_meta, B_meta, C_meta, D_meta, a_elem_size, + b_elem_size, c_elem_size, d_elem_size, static_cast(alpha_tensor->data.dptr), + static_cast(beta_tensor->data.dptr), A_sel.trans, B_sel.trans, num_tensors); NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -1532,7 +1574,7 @@ void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVT // Convert to internal types const GroupedTensor *inputA = convertNVTEGroupedTensorCheck(A); const GroupedTensor *inputB = convertNVTEGroupedTensorCheck(B); - const GroupedTensor *inputC = convertNVTEGroupedTensorCheck(C); + const GroupedTensor *inputC_raw = convertNVTEGroupedTensor(C); // Can be NULL GroupedTensor *outputD = convertNVTEGroupedTensorCheck(D); const Tensor *alpha_tensor = convertNVTETensorCheck(alpha); const Tensor *beta_tensor = convertNVTETensorCheck(beta); @@ -1540,19 +1582,16 @@ void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVT Tensor *wspace_cublas = convertNVTETensor(workspace_cublas); // Validate inputs and num_tensors - validate_grouped_gemm_inputs(inputA, inputB, inputC, outputD); + validate_grouped_gemm_inputs(inputA, inputB, inputC_raw, outputD); + + // If C is NULL, use D as C (valid when beta=0, cuBLAS won't read C data) + const GroupedTensor *inputC = (inputC_raw != nullptr) ? inputC_raw : outputD; const size_t num_tensors = inputA->num_tensors; // Select operand storage (row-wise vs column-wise) and adjust transpose flags to // mirror the non-grouped GEMM logic for FP8 layout constraints. - bool transa_flag = static_cast(transa); - bool transb_flag = static_cast(transb); - const auto A_sel = select_grouped_operand(inputA, transa_flag, /*is_A=*/true); - const auto B_sel = select_grouped_operand(inputB, transb_flag, /*is_A=*/false); - transa_flag = A_sel.trans; - transb_flag = B_sel.trans; - const size_t a_elem_size = transformer_engine::typeToSize(A_sel.dtype); - const size_t b_elem_size = transformer_engine::typeToSize(B_sel.dtype); + const auto A_sel = select_grouped_operand(inputA, static_cast(transa), /*is_A=*/true); + const auto B_sel = select_grouped_operand(inputB, static_cast(transb), /*is_A=*/false); // Workspaces: setup (pointer arrays) and cuBLAS const size_t setup_workspace_size = grouped_gemm_setup_workspace_size(num_tensors); @@ -1563,65 +1602,35 @@ void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVT void *cublas_workspace_ptr = validate_and_get_workspace_ptr(wspace_cublas, cublas_workspace_size, "Grouped GEMM cuBLAS workspace"); - NVTE_CHECK(cublas_workspace_ptr != nullptr, "Grouped GEMM: cuBLAS workspace pointer is null"); - auto setup_workspace = GroupedGemmSetupWorkspace::from_buffers( - static_cast(setup_workspace_ptr), num_tensors, kGroupedGemmAlignment); - launch_grouped_gemm_setup(setup_workspace, inputA, inputB, inputC, outputD, alpha_tensor, - beta_tensor, A_sel.base, B_sel.base, a_elem_size, b_elem_size, - transa_flag, transb_flag, num_tensors, stream); + static_cast(setup_workspace_ptr), num_tensors); + launch_grouped_gemm_setup(setup_workspace, A_sel, B_sel, inputC, outputD, + alpha_tensor, beta_tensor, num_tensors, stream); // Get cuBLAS handle using cublasHandleManager = detail::HandleManager; cublasLtHandle_t handle = cublasHandleManager::Instance().GetHandle(); - // Get data types - const cudaDataType_t A_type = get_cuda_dtype(A_sel.dtype); - const cudaDataType_t B_type = get_cuda_dtype(B_sel.dtype); - const cudaDataType_t D_type = get_cuda_dtype(outputD->dtype()); - // Setup cuBLAS operations - cublasOperation_t op_A = transa_flag ? CUBLAS_OP_T : CUBLAS_OP_N; - cublasOperation_t op_B = transb_flag ? CUBLAS_OP_T : CUBLAS_OP_N; + cublasOperation_t op_A = A_sel.trans ? CUBLAS_OP_T : CUBLAS_OP_N; + cublasOperation_t op_B = B_sel.trans ? CUBLAS_OP_T : CUBLAS_OP_N; // Create grouped matrix layouts cublasLtMatrixLayoutOpaque_t descA, descB, descC, descD; - init_matrix_layouts(descA, descB, descC, descD, setup_workspace, transa_flag, transb_flag, - A_sel.use_columnwise, B_sel.use_columnwise, num_tensors, A_type, B_type, - D_type); + init_matrix_layouts(descA, descB, descC, descD, setup_workspace, A_sel, B_sel, outputD, + num_tensors); // Create matmul descriptor cublasLtMatmulDescOpaque_t matmulDesc; init_matmul_desc(matmulDesc, op_A, op_B); - - // Set FP8 scale pointers if needed - const bool is_fp8_a = is_fp8_dtype(A_sel.dtype); - const bool is_fp8_b = is_fp8_dtype(B_sel.dtype); - if (is_fp8_a || is_fp8_b) { - // For FP8 grouped GEMM, we need to pass scale_inv pointers - // The scale_inv arrays contain one float per tensor in the group - if (is_fp8_a) { - void *a_scale_inv = - A_sel.use_columnwise ? inputA->columnwise_scale_inv.dptr : inputA->scale_inv.dptr; - NVTE_CHECK(a_scale_inv != nullptr, "FP8 grouped GEMM: A scale_inv is required"); - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( - &matmulDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &a_scale_inv, sizeof(a_scale_inv))); - } - if (is_fp8_b) { - void *b_scale_inv = - B_sel.use_columnwise ? inputB->columnwise_scale_inv.dptr : inputB->scale_inv.dptr; - NVTE_CHECK(b_scale_inv != nullptr, "FP8 grouped GEMM: B scale_inv is required"); - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( - &matmulDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &b_scale_inv, sizeof(b_scale_inv))); - } - } + set_fp8_scale_pointers(matmulDesc, A_sel, B_sel); // Compute average dimensions for heuristics // K dimension: if transa, K is A's first dim; if not, K is A's last dim int64_t avg_m_val = avg_m ? *avg_m : compute_avg_first_dim(outputD); int64_t avg_n_val = avg_n ? *avg_n : compute_avg_last_dim(outputD); int64_t avg_k_val = - avg_k ? *avg_k : (transa_flag ? compute_avg_first_dim(inputA) : compute_avg_last_dim(inputA)); + avg_k ? *avg_k : (A_sel.trans ? compute_avg_first_dim(A_sel.tensor) : compute_avg_last_dim(A_sel.tensor)); // Heuristic selection cublasLtMatmulAlgo_t algo = select_grouped_gemm_algo(handle, matmulDesc, descA, descB, descC, diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index 246fb5fefd9..02cf01853df 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -239,19 +239,27 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor * Uses NVTEGroupedTensor to efficiently handle collections of tensors with contiguous * memory layout and shape metadata. * - * \param[in] transa Whether to transpose A matrices. - * \param[in] transb Whether to transpose B matrices. - * \param[in] alpha Scale multiplier for A @ B (NVTETensor with num_tensors elements, - * or single element for uniform alpha). - * \param[in] A Input grouped tensor A. - * \param[in] B Input grouped tensor B. - * \param[in] beta Scale multiplier for C (NVTETensor with num_tensors elements, - * or single element for uniform beta). - * \param[in] C Input grouped tensor C (can be NULL for beta=0). - * \param[out] D Output grouped tensor D. - * \param[in] workspace Workspace tensor for intermediate computations. - * \param[in] config Matrix multiplication configuration. - * \param[in] stream CUDA stream for the operation. + * \param[in] transa Whether to transpose A matrices. + * \param[in] transb Whether to transpose B matrices. + * \param[in] alpha Scale multiplier for A @ B (single element NVTETensor). + * \param[in] A Input grouped tensor A. + * \param[in] B Input grouped tensor B. + * \param[in] beta Scale multiplier for C (single element NVTETensor). + * \param[in] C Input grouped tensor C (can be NULL for beta=0). + * \param[out] D Output grouped tensor D. + * \param[in] workspace_setup Workspace tensor for pointer array setup. + * \param[in] workspace_cublas Workspace tensor for cuBLAS operations. + * \param[in] config Matrix multiplication configuration. + * \param[in] stream CUDA stream for the operation. + * \param[in] avg_m Optional hint for average M dimension across all matrices in the + * group. Used by cuBLASLt for algorithm selection heuristics. + * If NULL, computed automatically from D's logical shape. + * \param[in] avg_n Optional hint for average N dimension across all matrices in the + * group. Used by cuBLASLt for algorithm selection heuristics. + * If NULL, computed automatically from D's logical shape. + * \param[in] avg_k Optional hint for average K (reduction) dimension across all + * matrices in the group. Used by cuBLASLt for algorithm selection + * heuristics. If NULL, computed automatically from A's logical shape. * * Requirements: * - A, B, C (if provided), D must have the same num_tensors From 3b2fcdf3137cec31b83dc6dc0f64e2e367aa6f9b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 11 Dec 2025 10:57:26 +0000 Subject: [PATCH 08/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/gemm/cublaslt_gemm.cu | 33 +++++++++---------- 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 9d9a5097d4c..7f2635943b8 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1265,11 +1265,9 @@ inline void validate_grouped_gemm_inputs(const transformer_engine::GroupedTensor "Grouped GEMM inputs must be FP8, BF16, or FP16."); // Only check C dtype if C is provided if (inputC != nullptr) { - NVTE_CHECK(is_output_dtype(inputC->dtype()), - "Grouped GEMM: C must be BF16, FP16, or FP32."); + NVTE_CHECK(is_output_dtype(inputC->dtype()), "Grouped GEMM: C must be BF16, FP16, or FP32."); } - NVTE_CHECK(is_output_dtype(outputD->dtype()), - "Grouped GEMM: D must be BF16, FP16, or FP32."); + NVTE_CHECK(is_output_dtype(outputD->dtype()), "Grouped GEMM: D must be BF16, FP16, or FP32."); NVTE_CHECK(inputA->has_data() || inputA->has_columnwise_data(), "Grouped GEMM: A tensor is missing both row-wise and column-wise data"); NVTE_CHECK(inputB->has_data() || inputB->has_columnwise_data(), @@ -1337,8 +1335,9 @@ inline GroupedOperandSelection select_grouped_operand(const transformer_engine:: // If only column-wise data is available, mirror the transpose flag (pre-transposed storage). if (!has_row && has_col) { // On Hopper FP8, this would break TN requirement - should have been handled above - NVTE_CHECK(!is_fp8 || non_tn_fp8_ok, - "Grouped GEMM: FP8 on Hopper requires row-wise data for this transpose configuration"); + NVTE_CHECK( + !is_fp8 || non_tn_fp8_ok, + "Grouped GEMM: FP8 on Hopper requires row-wise data for this transpose configuration"); sel.base = static_cast(t->columnwise_data.dptr); sel.dtype = col_dtype; sel.trans = !sel.trans; @@ -1369,8 +1368,7 @@ inline void init_matrix_layouts(cublasLtMatrixLayoutOpaque_t &descA, const GroupedGemmSetupWorkspace &ws, const GroupedOperandSelection &A_sel, const GroupedOperandSelection &B_sel, - const transformer_engine::GroupedTensor *D, - size_t num_tensors) { + const transformer_engine::GroupedTensor *D, size_t num_tensors) { const cudaDataType_t A_type = get_cuda_dtype(A_sel.dtype); const cudaDataType_t B_type = get_cuda_dtype(B_sel.dtype); const cudaDataType_t D_type = get_cuda_dtype(D->dtype()); @@ -1420,17 +1418,15 @@ inline void set_fp8_scale_pointers(cublasLtMatmulDescOpaque_t &matmulDesc, if (!is_fp8_a && !is_fp8_b) return; if (is_fp8_a) { - void *a_scale_inv = A_sel.use_columnwise - ? A_sel.tensor->columnwise_scale_inv.dptr - : A_sel.tensor->scale_inv.dptr; + void *a_scale_inv = A_sel.use_columnwise ? A_sel.tensor->columnwise_scale_inv.dptr + : A_sel.tensor->scale_inv.dptr; NVTE_CHECK(a_scale_inv != nullptr, "FP8 grouped GEMM: A scale_inv is required"); NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( &matmulDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &a_scale_inv, sizeof(a_scale_inv))); } if (is_fp8_b) { - void *b_scale_inv = B_sel.use_columnwise - ? B_sel.tensor->columnwise_scale_inv.dptr - : B_sel.tensor->scale_inv.dptr; + void *b_scale_inv = B_sel.use_columnwise ? B_sel.tensor->columnwise_scale_inv.dptr + : B_sel.tensor->scale_inv.dptr; NVTE_CHECK(b_scale_inv != nullptr, "FP8 grouped GEMM: B scale_inv is required"); NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( &matmulDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &b_scale_inv, sizeof(b_scale_inv))); @@ -1604,8 +1600,8 @@ void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVT auto setup_workspace = GroupedGemmSetupWorkspace::from_buffers( static_cast(setup_workspace_ptr), num_tensors); - launch_grouped_gemm_setup(setup_workspace, A_sel, B_sel, inputC, outputD, - alpha_tensor, beta_tensor, num_tensors, stream); + launch_grouped_gemm_setup(setup_workspace, A_sel, B_sel, inputC, outputD, alpha_tensor, + beta_tensor, num_tensors, stream); // Get cuBLAS handle using cublasHandleManager = detail::HandleManager; @@ -1629,8 +1625,9 @@ void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVT // K dimension: if transa, K is A's first dim; if not, K is A's last dim int64_t avg_m_val = avg_m ? *avg_m : compute_avg_first_dim(outputD); int64_t avg_n_val = avg_n ? *avg_n : compute_avg_last_dim(outputD); - int64_t avg_k_val = - avg_k ? *avg_k : (A_sel.trans ? compute_avg_first_dim(A_sel.tensor) : compute_avg_last_dim(A_sel.tensor)); + int64_t avg_k_val = avg_k ? *avg_k + : (A_sel.trans ? compute_avg_first_dim(A_sel.tensor) + : compute_avg_last_dim(A_sel.tensor)); // Heuristic selection cublasLtMatmulAlgo_t algo = select_grouped_gemm_algo(handle, matmulDesc, descA, descB, descC, From 5b0582bbf0fd05773242df67836ec263014d52dd Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 11 Dec 2025 12:15:12 +0100 Subject: [PATCH 09/17] Grouped GEMM: per-matrix alpha/beta support - Change alpha/beta from single values to per-matrix arrays - Validate alpha/beta have exactly num_tensors elements - Update kernel to index alpha_ptr[idx] and beta_ptr[idx] - Move alpha/beta validation to validate_grouped_gemm_inputs - Update tests to use per-matrix alpha/beta arrays - Update documentation Signed-off-by: Piotr Gadzinski Signed-off-by: Pawel Gadzinski --- tests/cpp/operator/test_grouped_gemm.cu | 15 +++++++----- .../common/gemm/cublaslt_gemm.cu | 24 ++++++++++++++----- .../common/include/transformer_engine/gemm.h | 4 ++-- 3 files changed, 29 insertions(+), 14 deletions(-) diff --git a/tests/cpp/operator/test_grouped_gemm.cu b/tests/cpp/operator/test_grouped_gemm.cu index 5e5144fa4c8..82b5bd3803e 100644 --- a/tests/cpp/operator/test_grouped_gemm.cu +++ b/tests/cpp/operator/test_grouped_gemm.cu @@ -427,12 +427,15 @@ void run_grouped_gemm_case(const TestParams& params) { } GroupedBuffers grouped_D = build_grouped_tensor(D_views, NVTE_DELAYED_TENSOR_SCALING); - Tensor alpha_tensor("alpha", std::vector{1}, DType::kFloat32); - Tensor beta_tensor("beta", std::vector{1}, DType::kFloat32); - const float alpha_val = 1.f; - const float beta_val = 0.f; - NVTE_CHECK_CUDA(cudaMemcpy(alpha_tensor.rowwise_dptr(), &alpha_val, sizeof(float), cudaMemcpyHostToDevice)); - NVTE_CHECK_CUDA(cudaMemcpy(beta_tensor.rowwise_dptr(), &beta_val, sizeof(float), cudaMemcpyHostToDevice)); + // Per-matrix alpha/beta (all 1.0 and 0.0 respectively) + Tensor alpha_tensor("alpha", std::vector{num_gemms}, DType::kFloat32); + Tensor beta_tensor("beta", std::vector{num_gemms}, DType::kFloat32); + std::vector alpha_vals(num_gemms, 1.f); + std::vector beta_vals(num_gemms, 0.f); + NVTE_CHECK_CUDA(cudaMemcpy(alpha_tensor.rowwise_dptr(), alpha_vals.data(), + num_gemms * sizeof(float), cudaMemcpyHostToDevice)); + NVTE_CHECK_CUDA(cudaMemcpy(beta_tensor.rowwise_dptr(), beta_vals.data(), + num_gemms * sizeof(float), cudaMemcpyHostToDevice)); const size_t setup_ws_bytes = grouped_setup_workspace_size(num_gemms); Tensor setup_ws("setup_ws", std::vector{setup_ws_bytes}, DType::kByte); diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 7f2635943b8..caa394d5492 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1237,7 +1237,9 @@ struct GroupedGemmSetupWorkspace { inline void validate_grouped_gemm_inputs(const transformer_engine::GroupedTensor *inputA, const transformer_engine::GroupedTensor *inputB, const transformer_engine::GroupedTensor *inputC, - const transformer_engine::GroupedTensor *outputD) { + const transformer_engine::GroupedTensor *outputD, + const transformer_engine::Tensor *alpha_tensor, + const transformer_engine::Tensor *beta_tensor) { const size_t num_tensors = inputA->num_tensors; NVTE_CHECK(num_tensors >= 1, "Grouped GEMM: num_tensors must be at least 1"); NVTE_CHECK(inputB->num_tensors == num_tensors, @@ -1250,6 +1252,16 @@ inline void validate_grouped_gemm_inputs(const transformer_engine::GroupedTensor NVTE_CHECK(outputD->num_tensors == num_tensors, "Grouped GEMM: A and D must have the same num_tensors"); + // Validate alpha/beta have per-matrix values + const size_t alpha_numel = alpha_tensor->data.shape.numel(); + const size_t beta_numel = beta_tensor->data.shape.numel(); + NVTE_CHECK(alpha_numel == num_tensors, + "Grouped GEMM: alpha must have num_tensors (", num_tensors, ") elements, got ", + alpha_numel); + NVTE_CHECK(beta_numel == num_tensors, + "Grouped GEMM: beta must have num_tensors (", num_tensors, ") elements, got ", + beta_numel); + auto is_fp8_or_16bit = [](transformer_engine::DType dtype) { return dtype == transformer_engine::DType::kFloat8E4M3 || dtype == transformer_engine::DType::kFloat8E5M2 || @@ -1481,7 +1493,7 @@ __global__ void setup_grouped_gemm_kernel( TensorShapeInfo A_meta, TensorShapeInfo B_meta, TensorShapeInfo C_meta, TensorShapeInfo D_meta, // Element sizes size_t a_elem_size, size_t b_elem_size, size_t c_elem_size, size_t d_elem_size, - // Alpha/beta pointers (same for all groups) + // Alpha/beta pointers (per-matrix arrays) float *alpha_ptr, float *beta_ptr, // Transpose flags bool transa, bool transb, @@ -1519,9 +1531,9 @@ __global__ void setup_grouped_gemm_kernel( K[idx] = static_cast(transa ? a_last : a_first); N[idx] = static_cast(transb ? b_last : b_first); - // Fill alpha/beta pointers (same for all groups) - alpha_ptrs[idx] = alpha_ptr; - beta_ptrs[idx] = beta_ptr; + // Fill alpha/beta pointers (per-matrix) + alpha_ptrs[idx] = alpha_ptr + idx; + beta_ptrs[idx] = beta_ptr + idx; } // Launch the setup kernel to populate workspace arrays @@ -1578,7 +1590,7 @@ void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVT Tensor *wspace_cublas = convertNVTETensor(workspace_cublas); // Validate inputs and num_tensors - validate_grouped_gemm_inputs(inputA, inputB, inputC_raw, outputD); + validate_grouped_gemm_inputs(inputA, inputB, inputC_raw, outputD, alpha_tensor, beta_tensor); // If C is NULL, use D as C (valid when beta=0, cuBLAS won't read C data) const GroupedTensor *inputC = (inputC_raw != nullptr) ? inputC_raw : outputD; diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index 02cf01853df..9dfa009115e 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -241,10 +241,10 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor * * \param[in] transa Whether to transpose A matrices. * \param[in] transb Whether to transpose B matrices. - * \param[in] alpha Scale multiplier for A @ B (single element NVTETensor). + * \param[in] alpha Scale multipliers for A @ B (NVTETensor with num_tensors elements). * \param[in] A Input grouped tensor A. * \param[in] B Input grouped tensor B. - * \param[in] beta Scale multiplier for C (single element NVTETensor). + * \param[in] beta Scale multipliers for C (NVTETensor with num_tensors elements). * \param[in] C Input grouped tensor C (can be NULL for beta=0). * \param[out] D Output grouped tensor D. * \param[in] workspace_setup Workspace tensor for pointer array setup. From 101766bcb15e9cd6a9df01eaa6e5b5b9d9989f40 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 11 Dec 2025 11:17:48 +0000 Subject: [PATCH 10/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/gemm/cublaslt_gemm.cu | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index caa394d5492..1d63cf65cf6 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1255,12 +1255,10 @@ inline void validate_grouped_gemm_inputs(const transformer_engine::GroupedTensor // Validate alpha/beta have per-matrix values const size_t alpha_numel = alpha_tensor->data.shape.numel(); const size_t beta_numel = beta_tensor->data.shape.numel(); - NVTE_CHECK(alpha_numel == num_tensors, - "Grouped GEMM: alpha must have num_tensors (", num_tensors, ") elements, got ", - alpha_numel); - NVTE_CHECK(beta_numel == num_tensors, - "Grouped GEMM: beta must have num_tensors (", num_tensors, ") elements, got ", - beta_numel); + NVTE_CHECK(alpha_numel == num_tensors, "Grouped GEMM: alpha must have num_tensors (", num_tensors, + ") elements, got ", alpha_numel); + NVTE_CHECK(beta_numel == num_tensors, "Grouped GEMM: beta must have num_tensors (", num_tensors, + ") elements, got ", beta_numel); auto is_fp8_or_16bit = [](transformer_engine::DType dtype) { return dtype == transformer_engine::DType::kFloat8E4M3 || From 1167f7539fb91a7d8cb7de2ea252e89415967073 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 11 Dec 2025 12:25:28 +0100 Subject: [PATCH 11/17] Fix alpha/beta numel - use SimpleTensor::numel() Signed-off-by: Piotr Gadzinski Signed-off-by: Pawel Gadzinski --- transformer_engine/common/gemm/cublaslt_gemm.cu | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 1d63cf65cf6..b8aa2a8ba3c 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1253,12 +1253,14 @@ inline void validate_grouped_gemm_inputs(const transformer_engine::GroupedTensor "Grouped GEMM: A and D must have the same num_tensors"); // Validate alpha/beta have per-matrix values - const size_t alpha_numel = alpha_tensor->data.shape.numel(); - const size_t beta_numel = beta_tensor->data.shape.numel(); - NVTE_CHECK(alpha_numel == num_tensors, "Grouped GEMM: alpha must have num_tensors (", num_tensors, - ") elements, got ", alpha_numel); - NVTE_CHECK(beta_numel == num_tensors, "Grouped GEMM: beta must have num_tensors (", num_tensors, - ") elements, got ", beta_numel); + const size_t alpha_numel = alpha_tensor->data.numel(); + const size_t beta_numel = beta_tensor->data.numel(); + NVTE_CHECK(alpha_numel == num_tensors, + "Grouped GEMM: alpha must have num_tensors (", num_tensors, ") elements, got ", + alpha_numel); + NVTE_CHECK(beta_numel == num_tensors, + "Grouped GEMM: beta must have num_tensors (", num_tensors, ") elements, got ", + beta_numel); auto is_fp8_or_16bit = [](transformer_engine::DType dtype) { return dtype == transformer_engine::DType::kFloat8E4M3 || From e4a80a3522b8d1b29199d807a4770ebc815ca487 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Fri, 19 Dec 2025 09:57:33 +0100 Subject: [PATCH 12/17] Refactor: move grouped GEMM to separate file and cleanup API Signed-off-by: Pawel Gadzinski --- tests/cpp/operator/test_grouped_gemm.cu | 12 +- .../common/gemm/cublaslt_gemm.cu | 549 +--------------- .../common/gemm/cublaslt_grouped_gemm.cu | 599 ++++++++++++++++++ .../common/gemm/cublaslt_grouped_gemm.cuh | 18 + .../common/include/transformer_engine/gemm.h | 12 +- 5 files changed, 635 insertions(+), 555 deletions(-) create mode 100644 transformer_engine/common/gemm/cublaslt_grouped_gemm.cu create mode 100644 transformer_engine/common/gemm/cublaslt_grouped_gemm.cuh diff --git a/tests/cpp/operator/test_grouped_gemm.cu b/tests/cpp/operator/test_grouped_gemm.cu index 82b5bd3803e..0ea76946bc2 100644 --- a/tests/cpp/operator/test_grouped_gemm.cu +++ b/tests/cpp/operator/test_grouped_gemm.cu @@ -4,6 +4,7 @@ * See LICENSE for license information. ************************************************************************/ +#include #include #include #include @@ -314,9 +315,12 @@ std::vector> make_shapes(ShapeCase scase) { } void run_grouped_gemm_case(const TestParams& params) { - if (params.input_case != InputCase::kBF16 && - getDeviceComputeCapability() < hopperComputeCapability) { - GTEST_SKIP() << "FP8 grouped GEMM requires Hopper or newer."; +#if CUBLAS_VERSION < 130200 + GTEST_SKIP() << "Grouped GEMM requires cuBLAS 13.2+, but compile-time cuBLAS version is " + << CUBLAS_VERSION << "."; +#else + if (getDeviceComputeCapability() < hopperComputeCapability) { + GTEST_SKIP() << "Grouped GEMM requires Hopper (SM90) or newer."; } const std::vector> shapes = make_shapes(params.shape_case); @@ -451,7 +455,6 @@ void run_grouped_gemm_case(const TestParams& params) { grouped_D.get_handle(), setup_ws.data(), cublas_ws.data(), - nullptr, 0, nullptr, nullptr, @@ -477,6 +480,7 @@ void run_grouped_gemm_case(const TestParams& params) { atol, rtol); } +#endif // CUBLAS_VERSION >= 130200 } class GroupedGemmTest : public ::testing::TestWithParam {}; diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index b8aa2a8ba3c..86f517af7dc 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -23,6 +23,7 @@ #include "../util/logging.h" #include "../util/multi_stream.h" #include "./config.h" +#include "./cublaslt_grouped_gemm.cuh" #include "./cutlass_grouped_gemm.cuh" namespace { @@ -1104,551 +1105,3 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor cublas_path(); } } - -// Helper struct to pass per-tensor shape/offset info (pointer or uniform value) -struct TensorShapeInfo { - const int64_t *first_dims; // nullptr if uniform - const int64_t *last_dims; // nullptr if uniform - const int64_t *offsets; // nullptr if need to compute - int64_t uniform_first; // used if first_dims == nullptr - int64_t uniform_last; // used if last_dims == nullptr - - // Create from GroupedTensor - static TensorShapeInfo from_tensor(const transformer_engine::GroupedTensor *t) { - const bool has_first = t->first_dims.has_data(); - const bool has_last = t->last_dims.has_data(); - // When per-tensor dims are not provided, we must be in the uniform-shape case. - NVTE_CHECK(has_first || t->all_same_first_dim(), - "GroupedTensor is missing first_dims for varying shapes"); - NVTE_CHECK(has_last || t->all_same_last_dim(), - "GroupedTensor is missing last_dims for varying shapes"); - - const int64_t *first_ptr = - has_first ? static_cast(t->first_dims.dptr) : nullptr; - const int64_t *last_ptr = has_last ? static_cast(t->last_dims.dptr) : nullptr; - - const int64_t uniform_first = has_first ? 0 : static_cast(t->get_common_first_dim()); - const int64_t uniform_last = has_last ? 0 : static_cast(t->get_common_last_dim()); - - return {first_ptr, last_ptr, - t->tensor_offsets.has_data() ? static_cast(t->tensor_offsets.dptr) - : nullptr, - uniform_first, uniform_last}; - } - - // Create for C tensor (uses D's dimensions, only has offsets) - static TensorShapeInfo for_C(const transformer_engine::GroupedTensor *C, - const transformer_engine::GroupedTensor *D) { - const bool has_first = D->first_dims.has_data(); - const bool has_last = D->last_dims.has_data(); - NVTE_CHECK(has_first || D->all_same_first_dim(), - "GroupedTensor D is missing first_dims for varying shapes"); - NVTE_CHECK(has_last || D->all_same_last_dim(), - "GroupedTensor D is missing last_dims for varying shapes"); - - const int64_t *first_ptr = - has_first ? static_cast(D->first_dims.dptr) : nullptr; - const int64_t *last_ptr = has_last ? static_cast(D->last_dims.dptr) : nullptr; - const int64_t uniform_first = has_first ? 0 : static_cast(D->get_common_first_dim()); - const int64_t uniform_last = has_last ? 0 : static_cast(D->get_common_last_dim()); - - return {first_ptr, last_ptr, - C->tensor_offsets.has_data() ? static_cast(C->tensor_offsets.dptr) - : nullptr, - uniform_first, uniform_last}; - } -}; - -// Helper functions to compute average dimensions from logical_shape for heuristics -// These are hints for cuBLASLt algorithm selection, don't need to be exact -inline int64_t compute_avg_first_dim(const transformer_engine::GroupedTensor *t) { - // logical_shape[0] is either num_tensors*M (uniform) or sum_of_M (varying first) - // In both cases, dividing by num_tensors gives the average - return static_cast(t->logical_shape.data[0]) / static_cast(t->num_tensors); -} - -inline int64_t compute_avg_last_dim(const transformer_engine::GroupedTensor *t) { - if (t->all_same_last_dim()) { - // logical_shape[1] is the common N - return static_cast(t->logical_shape.data[1]); - } - // When varying, logical_shape[1] should be sum of last dims if provided; otherwise fallback to avg via division. - return static_cast(t->logical_shape.data[1]) / static_cast(t->num_tensors); -} - -// Workspace layout for grouped GEMM -struct GroupedGemmSetupWorkspace { - void **A_ptrs; - void **B_ptrs; - void **C_ptrs; - void **D_ptrs; - int *M; - int *N; - int *K; - float **alpha_ptrs; - float **beta_ptrs; - - // Initialize from workspace buffer - // Layout: all pointer arrays first (8-byte aligned), then int arrays (4-byte aligned) - static GroupedGemmSetupWorkspace from_buffers(char *setup_ws_ptr, size_t num_tensors) { - GroupedGemmSetupWorkspace ws; - size_t offset = 0; - const size_t ptr_size = num_tensors * sizeof(void *); - const size_t int_size = num_tensors * sizeof(int); - - // Pointer arrays first (all 8-byte aligned) - ws.A_ptrs = reinterpret_cast(setup_ws_ptr + offset); - offset += ptr_size; - ws.B_ptrs = reinterpret_cast(setup_ws_ptr + offset); - offset += ptr_size; - ws.C_ptrs = reinterpret_cast(setup_ws_ptr + offset); - offset += ptr_size; - ws.D_ptrs = reinterpret_cast(setup_ws_ptr + offset); - offset += ptr_size; - ws.alpha_ptrs = reinterpret_cast(setup_ws_ptr + offset); - offset += ptr_size; - ws.beta_ptrs = reinterpret_cast(setup_ws_ptr + offset); - offset += ptr_size; - - // Int arrays last (4-byte aligned, always satisfied after pointer arrays) - ws.M = reinterpret_cast(setup_ws_ptr + offset); - offset += int_size; - ws.N = reinterpret_cast(setup_ws_ptr + offset); - offset += int_size; - ws.K = reinterpret_cast(setup_ws_ptr + offset); - - return ws; - } - - // Calculate required size for setup workspace (pointer arrays + M/N/K) - static size_t required_setup_size(size_t num_tensors, size_t alignment) { - const size_t ptr_size = num_tensors * sizeof(void *); - const size_t int_size = num_tensors * sizeof(int); - // Layout: 6 ptr arrays, then 3 int arrays (no padding needed) - size_t size = 6 * ptr_size + 3 * int_size; - size = ((size + alignment - 1) / alignment) * alignment; - return size; - } -}; - -// ----------------------------------------------------------------------------- -// Helper routines to keep nvte_grouped_gemm readable -// ----------------------------------------------------------------------------- -inline void validate_grouped_gemm_inputs(const transformer_engine::GroupedTensor *inputA, - const transformer_engine::GroupedTensor *inputB, - const transformer_engine::GroupedTensor *inputC, - const transformer_engine::GroupedTensor *outputD, - const transformer_engine::Tensor *alpha_tensor, - const transformer_engine::Tensor *beta_tensor) { - const size_t num_tensors = inputA->num_tensors; - NVTE_CHECK(num_tensors >= 1, "Grouped GEMM: num_tensors must be at least 1"); - NVTE_CHECK(inputB->num_tensors == num_tensors, - "Grouped GEMM: A and B must have the same num_tensors"); - // C can be NULL (will use D as C when beta=0) - if (inputC != nullptr) { - NVTE_CHECK(inputC->num_tensors == num_tensors, - "Grouped GEMM: A and C must have the same num_tensors"); - } - NVTE_CHECK(outputD->num_tensors == num_tensors, - "Grouped GEMM: A and D must have the same num_tensors"); - - // Validate alpha/beta have per-matrix values - const size_t alpha_numel = alpha_tensor->data.numel(); - const size_t beta_numel = beta_tensor->data.numel(); - NVTE_CHECK(alpha_numel == num_tensors, - "Grouped GEMM: alpha must have num_tensors (", num_tensors, ") elements, got ", - alpha_numel); - NVTE_CHECK(beta_numel == num_tensors, - "Grouped GEMM: beta must have num_tensors (", num_tensors, ") elements, got ", - beta_numel); - - auto is_fp8_or_16bit = [](transformer_engine::DType dtype) { - return dtype == transformer_engine::DType::kFloat8E4M3 || - dtype == transformer_engine::DType::kFloat8E5M2 || - dtype == transformer_engine::DType::kBFloat16 || - dtype == transformer_engine::DType::kFloat16; - }; - auto is_output_dtype = [](transformer_engine::DType dtype) { - return dtype == transformer_engine::DType::kBFloat16 || - dtype == transformer_engine::DType::kFloat16 || - dtype == transformer_engine::DType::kFloat32; - }; - NVTE_CHECK(is_fp8_or_16bit(inputA->dtype()) && is_fp8_or_16bit(inputB->dtype()), - "Grouped GEMM inputs must be FP8, BF16, or FP16."); - // Only check C dtype if C is provided - if (inputC != nullptr) { - NVTE_CHECK(is_output_dtype(inputC->dtype()), "Grouped GEMM: C must be BF16, FP16, or FP32."); - } - NVTE_CHECK(is_output_dtype(outputD->dtype()), "Grouped GEMM: D must be BF16, FP16, or FP32."); - NVTE_CHECK(inputA->has_data() || inputA->has_columnwise_data(), - "Grouped GEMM: A tensor is missing both row-wise and column-wise data"); - NVTE_CHECK(inputB->has_data() || inputB->has_columnwise_data(), - "Grouped GEMM: B tensor is missing both row-wise and column-wise data"); -} - -// Select row-wise vs column-wise storage and adjust transpose flag for grouped GEMM. -// Mirrors the non-grouped GEMM logic for FP8 layout handling (TN-only on Hopper) and -// fallback to column-wise data when row-wise is absent. -struct GroupedOperandSelection { - const transformer_engine::GroupedTensor *tensor = nullptr; - const char *base = nullptr; - transformer_engine::DType dtype = transformer_engine::DType::kNumTypes; - bool trans = false; - bool use_columnwise = false; -}; - -inline GroupedOperandSelection select_grouped_operand(const transformer_engine::GroupedTensor *t, - bool trans, bool is_A) { - using namespace transformer_engine; - const bool has_row = t->has_data(); - const bool has_col = t->has_columnwise_data(); - NVTE_CHECK(has_row || has_col, - "Grouped GEMM operand is missing both row-wise and column-wise data"); - - // Not yet supported in grouped GEMM: block scaling, MXFP8, NVFP4 specialized layouts. - const auto sm = t->scaling_mode; - NVTE_CHECK(sm != NVTE_BLOCK_SCALING_1D && sm != NVTE_BLOCK_SCALING_2D && !is_mxfp_scaling(sm) && - !is_nvfp_scaling(sm), - "Grouped GEMM does not yet support NVFP4/MXFP8/block scaling operand selection"); - - const DType row_dtype = t->data.dtype; - const DType col_dtype = t->columnwise_data.dtype; - GroupedOperandSelection sel; - sel.tensor = t; - sel.trans = trans; - - const DType rep_dtype = has_row ? row_dtype : col_dtype; - const bool is_fp8 = is_fp8_dtype(rep_dtype); - const bool non_tn_fp8_ok = nvte_is_non_tn_fp8_gemm_supported(); - - // Hopper-style TN-only FP8: force TN by switching layout and flipping transpose when needed. - if (is_fp8 && !non_tn_fp8_ok) { - if (is_A) { - if (!sel.trans) { - NVTE_CHECK(has_col, "Grouped GEMM: A is missing column-wise data needed for FP8 TN layout"); - sel.base = static_cast(t->columnwise_data.dptr); - sel.dtype = col_dtype; - sel.trans = true; // using pre-transposed storage - sel.use_columnwise = true; - return sel; - } - } else { // B - if (sel.trans) { - NVTE_CHECK(has_col, "Grouped GEMM: B is missing column-wise data needed for FP8 TN layout"); - sel.base = static_cast(t->columnwise_data.dptr); - sel.dtype = col_dtype; - sel.trans = false; // using pre-transposed storage - sel.use_columnwise = true; - return sel; - } - } - } - - // If only column-wise data is available, mirror the transpose flag (pre-transposed storage). - if (!has_row && has_col) { - // On Hopper FP8, this would break TN requirement - should have been handled above - NVTE_CHECK( - !is_fp8 || non_tn_fp8_ok, - "Grouped GEMM: FP8 on Hopper requires row-wise data for this transpose configuration"); - sel.base = static_cast(t->columnwise_data.dptr); - sel.dtype = col_dtype; - sel.trans = !sel.trans; - sel.use_columnwise = true; - return sel; - } - - // Default: use row-wise data (column-wise case already handled above) - sel.base = static_cast(t->data.dptr); - sel.dtype = row_dtype; - sel.use_columnwise = false; - return sel; -} - -inline void *validate_and_get_workspace_ptr(transformer_engine::Tensor *ws, size_t required_size, - const char *workspace_name) { - NVTE_CHECK(ws != nullptr, workspace_name, " tensor is null."); - const size_t provided_size = get_buffer_size_bytes(ws->data.numel(), ws->data.dtype); - NVTE_CHECK(provided_size >= required_size, "Grouped GEMM: Insufficient ", workspace_name, - ". Required: ", required_size, " bytes, Available: ", provided_size, " bytes."); - return ws->data.dptr; -} - -inline void init_matrix_layouts(cublasLtMatrixLayoutOpaque_t &descA, - cublasLtMatrixLayoutOpaque_t &descB, - cublasLtMatrixLayoutOpaque_t &descC, - cublasLtMatrixLayoutOpaque_t &descD, - const GroupedGemmSetupWorkspace &ws, - const GroupedOperandSelection &A_sel, - const GroupedOperandSelection &B_sel, - const transformer_engine::GroupedTensor *D, size_t num_tensors) { - const cudaDataType_t A_type = get_cuda_dtype(A_sel.dtype); - const cudaDataType_t B_type = get_cuda_dtype(B_sel.dtype); - const cudaDataType_t D_type = get_cuda_dtype(D->dtype()); - - // For column-major layout: leading dimension is the number of rows in storage. - // If columnwise data was chosen, storage is already transposed. - int *rowa = A_sel.use_columnwise ? ws.M : (A_sel.trans ? ws.K : ws.M); - int *cola = A_sel.use_columnwise ? ws.K : (A_sel.trans ? ws.M : ws.K); - int *lda = rowa; - int *rowb = B_sel.use_columnwise ? ws.N : (B_sel.trans ? ws.N : ws.K); - int *colb = B_sel.use_columnwise ? ws.K : (B_sel.trans ? ws.K : ws.N); - int *ldb = rowb; - - NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descA, A_type, num_tensors, rowa, cola, lda)); - NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descB, B_type, num_tensors, rowb, colb, ldb)); - NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descC, D_type, num_tensors, ws.M, ws.N, ws.M)); - NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descD, D_type, num_tensors, ws.M, ws.N, ws.M)); -} - -inline void init_matmul_desc(cublasLtMatmulDescOpaque_t &matmulDesc, cublasOperation_t op_A, - cublasOperation_t op_B) { - NVTE_CHECK_CUBLAS(cublasLtMatmulDescInit(&matmulDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F)); - - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_TRANSA, &op_A, - sizeof(op_A))); - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &op_B, - sizeof(op_B))); - - cublasLtPointerMode_t pointer_mode = CUBLASLT_POINTER_MODE_DEVICE; - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, - &pointer_mode, sizeof(pointer_mode))); - - int64_t alphabeta_batch_stride = 1; - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, - CUBLASLT_MATMUL_DESC_ALPHA_BATCH_STRIDE, - &alphabeta_batch_stride, sizeof(int64_t))); - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, - CUBLASLT_MATMUL_DESC_BETA_BATCH_STRIDE, - &alphabeta_batch_stride, sizeof(int64_t))); -} - -inline void set_fp8_scale_pointers(cublasLtMatmulDescOpaque_t &matmulDesc, - const GroupedOperandSelection &A_sel, - const GroupedOperandSelection &B_sel) { - const bool is_fp8_a = is_fp8_dtype(A_sel.dtype); - const bool is_fp8_b = is_fp8_dtype(B_sel.dtype); - if (!is_fp8_a && !is_fp8_b) return; - - if (is_fp8_a) { - void *a_scale_inv = A_sel.use_columnwise ? A_sel.tensor->columnwise_scale_inv.dptr - : A_sel.tensor->scale_inv.dptr; - NVTE_CHECK(a_scale_inv != nullptr, "FP8 grouped GEMM: A scale_inv is required"); - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( - &matmulDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &a_scale_inv, sizeof(a_scale_inv))); - } - if (is_fp8_b) { - void *b_scale_inv = B_sel.use_columnwise ? B_sel.tensor->columnwise_scale_inv.dptr - : B_sel.tensor->scale_inv.dptr; - NVTE_CHECK(b_scale_inv != nullptr, "FP8 grouped GEMM: B scale_inv is required"); - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( - &matmulDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &b_scale_inv, sizeof(b_scale_inv))); - } -} - -// Constants for grouped GEMM workspace (declared early for use in heuristics) -static constexpr size_t kGroupedGemmAlignment = 256; -static constexpr size_t kGroupedGemmCublasWorkspaceSize = 32ull * 1024 * 1024; // 32 MiB - -inline cublasLtMatmulAlgo_t select_grouped_gemm_algo(cublasLtHandle_t handle, - cublasLtMatmulDescOpaque_t &matmulDesc, - cublasLtMatrixLayoutOpaque_t &descA, - cublasLtMatrixLayoutOpaque_t &descB, - cublasLtMatrixLayoutOpaque_t &descC, - cublasLtMatrixLayoutOpaque_t &descD, - int64_t avg_m, int64_t avg_n, int64_t avg_k) { - cublasLtMatmulPreferenceOpaque_t preference; - NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceInit(&preference)); - NVTE_CHECK_CUBLAS( - cublasLtMatmulPreferenceSetAttribute(&preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, - &kGroupedGemmCublasWorkspaceSize, sizeof(size_t))); - NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( - &preference, CUBLASLT_MATMUL_PREF_GROUPED_DESC_D_AVERAGE_ROWS, &avg_m, sizeof(int64_t))); - NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( - &preference, CUBLASLT_MATMUL_PREF_GROUPED_DESC_D_AVERAGE_COLS, &avg_n, sizeof(int64_t))); - NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( - &preference, CUBLASLT_MATMUL_PREF_GROUPED_AVERAGE_REDUCTION_DIM, &avg_k, sizeof(int64_t))); - - cublasLtMatmulHeuristicResult_t heuristicResult; - int returnedResults = 0; - auto status = cublasLtMatmulAlgoGetHeuristic(handle, &matmulDesc, &descA, &descB, &descC, &descD, - &preference, 1, &heuristicResult, &returnedResults); - NVTE_CHECK(status != CUBLAS_STATUS_NOT_SUPPORTED, - "Unable to find suitable cuBLAS grouped GEMM algorithm"); - NVTE_CHECK_CUBLAS(status); - NVTE_CHECK(returnedResults > 0, "No suitable algorithm found for grouped GEMM"); - return heuristicResult.algo; -} - -// Single kernel that sets up all GEMM parameters. -// Rationale: cuBLASLt grouped matmul API needs flat arrays of pointers and per-matrix M/N/K, -// but NVTEGroupedTensor stores a single contiguous buffer + optional per-tensor offsets/shapes. -// We bridge the mismatch on GPU by computing per-group pointers and dims in one kernel. -__global__ void setup_grouped_gemm_kernel( - // Output arrays - void **A_ptrs, void **B_ptrs, void **C_ptrs, void **D_ptrs, int *M, int *N, int *K, - float **alpha_ptrs, float **beta_ptrs, - // Base pointers - const char *a_base, const char *b_base, const char *c_base, char *d_base, - // Dimension info (per tensor) - TensorShapeInfo A_meta, TensorShapeInfo B_meta, TensorShapeInfo C_meta, TensorShapeInfo D_meta, - // Element sizes - size_t a_elem_size, size_t b_elem_size, size_t c_elem_size, size_t d_elem_size, - // Alpha/beta pointers (per-matrix arrays) - float *alpha_ptr, float *beta_ptr, - // Transpose flags - bool transa, bool transb, - // Number of tensors - size_t num_tensors) { - size_t idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= num_tensors) return; - - // Get dimensions for this tensor (from array or uniform value) - int64_t a_first = A_meta.first_dims ? A_meta.first_dims[idx] : A_meta.uniform_first; - int64_t a_last = A_meta.last_dims ? A_meta.last_dims[idx] : A_meta.uniform_last; - int64_t b_first = B_meta.first_dims ? B_meta.first_dims[idx] : B_meta.uniform_first; - int64_t b_last = B_meta.last_dims ? B_meta.last_dims[idx] : B_meta.uniform_last; - - // Compute offsets (from array or compute from uniform dims) - int64_t a_offset = - A_meta.offsets ? A_meta.offsets[idx] : (idx * A_meta.uniform_first * A_meta.uniform_last); - int64_t b_offset = - B_meta.offsets ? B_meta.offsets[idx] : (idx * B_meta.uniform_first * B_meta.uniform_last); - int64_t c_offset = - C_meta.offsets ? C_meta.offsets[idx] : (idx * C_meta.uniform_first * C_meta.uniform_last); - int64_t d_offset = - D_meta.offsets ? D_meta.offsets[idx] : (idx * D_meta.uniform_first * D_meta.uniform_last); - - // Compute data pointers - A_ptrs[idx] = const_cast(a_base) + a_offset * a_elem_size; - B_ptrs[idx] = const_cast(b_base) + b_offset * b_elem_size; - C_ptrs[idx] = const_cast(c_base) + c_offset * c_elem_size; - D_ptrs[idx] = d_base + d_offset * d_elem_size; - - // Compute M, N, K dimensions - // Test stores A as {K,M} when !transa, {M,K} when transa - // Test stores B as {N,K} when !transb, {K,N} when transb - M[idx] = static_cast(transa ? a_first : a_last); - K[idx] = static_cast(transa ? a_last : a_first); - N[idx] = static_cast(transb ? b_last : b_first); - - // Fill alpha/beta pointers (per-matrix) - alpha_ptrs[idx] = alpha_ptr + idx; - beta_ptrs[idx] = beta_ptr + idx; -} - -// Launch the setup kernel to populate workspace arrays -inline void launch_grouped_gemm_setup( - const GroupedGemmSetupWorkspace &ws, const GroupedOperandSelection &A_sel, - const GroupedOperandSelection &B_sel, const transformer_engine::GroupedTensor *C, - const transformer_engine::GroupedTensor *D, const transformer_engine::Tensor *alpha_tensor, - const transformer_engine::Tensor *beta_tensor, size_t num_tensors, cudaStream_t stream) { - TensorShapeInfo A_meta = TensorShapeInfo::from_tensor(A_sel.tensor); - TensorShapeInfo B_meta = TensorShapeInfo::from_tensor(B_sel.tensor); - TensorShapeInfo C_meta = TensorShapeInfo::for_C(C, D); - TensorShapeInfo D_meta = TensorShapeInfo::from_tensor(D); - - const char *c_base = static_cast(C->data.dptr); - char *d_base = static_cast(D->data.dptr); - - const size_t a_elem_size = transformer_engine::typeToSize(A_sel.dtype); - const size_t b_elem_size = transformer_engine::typeToSize(B_sel.dtype); - const size_t c_elem_size = transformer_engine::typeToSize(C->dtype()); - const size_t d_elem_size = transformer_engine::typeToSize(D->dtype()); - - const int threads_per_block = 256; - const int num_blocks = (num_tensors + threads_per_block - 1) / threads_per_block; - - setup_grouped_gemm_kernel<<>>( - ws.A_ptrs, ws.B_ptrs, ws.C_ptrs, ws.D_ptrs, ws.M, ws.N, ws.K, ws.alpha_ptrs, ws.beta_ptrs, - A_sel.base, B_sel.base, c_base, d_base, A_meta, B_meta, C_meta, D_meta, a_elem_size, - b_elem_size, c_elem_size, d_elem_size, static_cast(alpha_tensor->data.dptr), - static_cast(beta_tensor->data.dptr), A_sel.trans, B_sel.trans, num_tensors); - - NVTE_CHECK_CUDA(cudaGetLastError()); -} - -inline size_t grouped_gemm_setup_workspace_size(size_t num_tensors) { - return GroupedGemmSetupWorkspace::required_setup_size(num_tensors, kGroupedGemmAlignment); -} - -void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVTEGroupedTensor A, - const NVTEGroupedTensor B, const NVTETensor beta, const NVTEGroupedTensor C, - NVTEGroupedTensor D, NVTETensor workspace_setup, NVTETensor workspace_cublas, - NVTEMatmulConfig config, cudaStream_t stream, const int64_t *avg_m, - const int64_t *avg_n, const int64_t *avg_k) { - NVTE_API_CALL(nvte_grouped_gemm); - using namespace transformer_engine; - - // Convert to internal types - const GroupedTensor *inputA = convertNVTEGroupedTensorCheck(A); - const GroupedTensor *inputB = convertNVTEGroupedTensorCheck(B); - const GroupedTensor *inputC_raw = convertNVTEGroupedTensor(C); // Can be NULL - GroupedTensor *outputD = convertNVTEGroupedTensorCheck(D); - const Tensor *alpha_tensor = convertNVTETensorCheck(alpha); - const Tensor *beta_tensor = convertNVTETensorCheck(beta); - Tensor *wspace_setup = convertNVTETensor(workspace_setup); - Tensor *wspace_cublas = convertNVTETensor(workspace_cublas); - - // Validate inputs and num_tensors - validate_grouped_gemm_inputs(inputA, inputB, inputC_raw, outputD, alpha_tensor, beta_tensor); - - // If C is NULL, use D as C (valid when beta=0, cuBLAS won't read C data) - const GroupedTensor *inputC = (inputC_raw != nullptr) ? inputC_raw : outputD; - const size_t num_tensors = inputA->num_tensors; - - // Select operand storage (row-wise vs column-wise) and adjust transpose flags to - // mirror the non-grouped GEMM logic for FP8 layout constraints. - const auto A_sel = select_grouped_operand(inputA, static_cast(transa), /*is_A=*/true); - const auto B_sel = select_grouped_operand(inputB, static_cast(transb), /*is_A=*/false); - - // Workspaces: setup (pointer arrays) and cuBLAS - const size_t setup_workspace_size = grouped_gemm_setup_workspace_size(num_tensors); - const size_t cublas_workspace_size = kGroupedGemmCublasWorkspaceSize; - - void *setup_workspace_ptr = validate_and_get_workspace_ptr(wspace_setup, setup_workspace_size, - "Grouped GEMM setup workspace"); - void *cublas_workspace_ptr = validate_and_get_workspace_ptr(wspace_cublas, cublas_workspace_size, - "Grouped GEMM cuBLAS workspace"); - - auto setup_workspace = GroupedGemmSetupWorkspace::from_buffers( - static_cast(setup_workspace_ptr), num_tensors); - launch_grouped_gemm_setup(setup_workspace, A_sel, B_sel, inputC, outputD, alpha_tensor, - beta_tensor, num_tensors, stream); - - // Get cuBLAS handle - using cublasHandleManager = detail::HandleManager; - cublasLtHandle_t handle = cublasHandleManager::Instance().GetHandle(); - - // Setup cuBLAS operations - cublasOperation_t op_A = A_sel.trans ? CUBLAS_OP_T : CUBLAS_OP_N; - cublasOperation_t op_B = B_sel.trans ? CUBLAS_OP_T : CUBLAS_OP_N; - - // Create grouped matrix layouts - cublasLtMatrixLayoutOpaque_t descA, descB, descC, descD; - init_matrix_layouts(descA, descB, descC, descD, setup_workspace, A_sel, B_sel, outputD, - num_tensors); - - // Create matmul descriptor - cublasLtMatmulDescOpaque_t matmulDesc; - init_matmul_desc(matmulDesc, op_A, op_B); - set_fp8_scale_pointers(matmulDesc, A_sel, B_sel); - - // Compute average dimensions for heuristics - // K dimension: if transa, K is A's first dim; if not, K is A's last dim - int64_t avg_m_val = avg_m ? *avg_m : compute_avg_first_dim(outputD); - int64_t avg_n_val = avg_n ? *avg_n : compute_avg_last_dim(outputD); - int64_t avg_k_val = avg_k ? *avg_k - : (A_sel.trans ? compute_avg_first_dim(A_sel.tensor) - : compute_avg_last_dim(A_sel.tensor)); - - // Heuristic selection - cublasLtMatmulAlgo_t algo = select_grouped_gemm_algo(handle, matmulDesc, descA, descB, descC, - descD, avg_m_val, avg_n_val, avg_k_val); - - // Execute the grouped GEMM - NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, &matmulDesc, setup_workspace.alpha_ptrs, - setup_workspace.A_ptrs, &descA, setup_workspace.B_ptrs, &descB, - setup_workspace.beta_ptrs, setup_workspace.C_ptrs, &descC, - setup_workspace.D_ptrs, &descD, &algo, cublas_workspace_ptr, - kGroupedGemmCublasWorkspaceSize, stream)); -} diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu new file mode 100644 index 00000000000..4125bd82bff --- /dev/null +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -0,0 +1,599 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include + +#include + +#include "../common.h" +#include "../util/cuda_runtime.h" +#include "../util/handle_manager.h" +#include "../util/logging.h" +#include "./cublaslt_grouped_gemm.cuh" + +namespace { + +inline void CreateCublasHandle(cublasLtHandle_t *handle) { + NVTE_CHECK_CUBLAS(cublasLtCreate(handle)); +} + +} // namespace + +#if CUBLAS_VERSION >= 130100 + +namespace { + +// Helper struct to pass per-tensor shape/offset info (pointer or uniform value) +struct TensorShapeInfo { + const int64_t *first_dims; // nullptr if uniform + const int64_t *last_dims; // nullptr if uniform + const int64_t *offsets; // nullptr if need to compute + int64_t uniform_first; // used if first_dims == nullptr + int64_t uniform_last; // used if last_dims == nullptr + + // Create from GroupedTensor + static TensorShapeInfo from_tensor(const transformer_engine::GroupedTensor *t) { + const bool has_first = t->first_dims.has_data(); + const bool has_last = t->last_dims.has_data(); + // When per-tensor dims are not provided, we must be in the uniform-shape case. + NVTE_CHECK(has_first || t->all_same_first_dim(), + "GroupedTensor is missing first_dims for varying shapes"); + NVTE_CHECK(has_last || t->all_same_last_dim(), + "GroupedTensor is missing last_dims for varying shapes"); + + const int64_t *first_ptr = + has_first ? static_cast(t->first_dims.dptr) : nullptr; + const int64_t *last_ptr = has_last ? static_cast(t->last_dims.dptr) : nullptr; + + const int64_t uniform_first = has_first ? 0 : static_cast(t->get_common_first_dim()); + const int64_t uniform_last = has_last ? 0 : static_cast(t->get_common_last_dim()); + + return {first_ptr, last_ptr, + t->tensor_offsets.has_data() ? static_cast(t->tensor_offsets.dptr) + : nullptr, + uniform_first, uniform_last}; + } + + // Create for C tensor (uses D's dimensions, only has offsets) + static TensorShapeInfo for_C(const transformer_engine::GroupedTensor *C, + const transformer_engine::GroupedTensor *D) { + const bool has_first = D->first_dims.has_data(); + const bool has_last = D->last_dims.has_data(); + NVTE_CHECK(has_first || D->all_same_first_dim(), + "GroupedTensor D is missing first_dims for varying shapes"); + NVTE_CHECK(has_last || D->all_same_last_dim(), + "GroupedTensor D is missing last_dims for varying shapes"); + + const int64_t *first_ptr = + has_first ? static_cast(D->first_dims.dptr) : nullptr; + const int64_t *last_ptr = has_last ? static_cast(D->last_dims.dptr) : nullptr; + const int64_t uniform_first = has_first ? 0 : static_cast(D->get_common_first_dim()); + const int64_t uniform_last = has_last ? 0 : static_cast(D->get_common_last_dim()); + + return {first_ptr, last_ptr, + C->tensor_offsets.has_data() ? static_cast(C->tensor_offsets.dptr) + : nullptr, + uniform_first, uniform_last}; + } +}; + +// Helper functions to compute average dimensions from logical_shape for heuristics +// These are hints for cuBLASLt algorithm selection, don't need to be exact +inline int64_t compute_avg_first_dim(const transformer_engine::GroupedTensor *t) { + // logical_shape[0] is either num_tensors*M (uniform) or sum_of_M (varying first) + // In both cases, dividing by num_tensors gives the average + return static_cast(t->logical_shape.data[0]) / static_cast(t->num_tensors); +} + +inline int64_t compute_avg_last_dim(const transformer_engine::GroupedTensor *t) { + if (t->all_same_last_dim()) { + // logical_shape[1] is the common N + return static_cast(t->logical_shape.data[1]); + } + // When varying, logical_shape[1] should be sum of last dims if provided; otherwise fallback to avg via division. + return static_cast(t->logical_shape.data[1]) / static_cast(t->num_tensors); +} + +// Workspace layout for grouped GEMM +struct GroupedGemmSetupWorkspace { + void **A_ptrs; + void **B_ptrs; + void **C_ptrs; + void **D_ptrs; + int *M; + int *N; + int *K; + float **alpha_ptrs; + float **beta_ptrs; + + // Initialize from workspace buffer + // Layout: all pointer arrays first (8-byte aligned), then int arrays (4-byte aligned) + static GroupedGemmSetupWorkspace from_buffers(char *setup_ws_ptr, size_t num_tensors) { + GroupedGemmSetupWorkspace ws; + size_t offset = 0; + const size_t ptr_size = num_tensors * sizeof(void *); + const size_t int_size = num_tensors * sizeof(int); + + // Pointer arrays first (all 8-byte aligned) + ws.A_ptrs = reinterpret_cast(setup_ws_ptr + offset); + offset += ptr_size; + ws.B_ptrs = reinterpret_cast(setup_ws_ptr + offset); + offset += ptr_size; + ws.C_ptrs = reinterpret_cast(setup_ws_ptr + offset); + offset += ptr_size; + ws.D_ptrs = reinterpret_cast(setup_ws_ptr + offset); + offset += ptr_size; + ws.alpha_ptrs = reinterpret_cast(setup_ws_ptr + offset); + offset += ptr_size; + ws.beta_ptrs = reinterpret_cast(setup_ws_ptr + offset); + offset += ptr_size; + + // Int arrays last (4-byte aligned, always satisfied after pointer arrays) + ws.M = reinterpret_cast(setup_ws_ptr + offset); + offset += int_size; + ws.N = reinterpret_cast(setup_ws_ptr + offset); + offset += int_size; + ws.K = reinterpret_cast(setup_ws_ptr + offset); + + return ws; + } + + // Calculate required size for setup workspace (pointer arrays + M/N/K) + static size_t required_setup_size(size_t num_tensors, size_t alignment) { + const size_t ptr_size = num_tensors * sizeof(void *); + const size_t int_size = num_tensors * sizeof(int); + // Layout: 6 ptr arrays, then 3 int arrays (no padding needed) + size_t size = 6 * ptr_size + 3 * int_size; + size = ((size + alignment - 1) / alignment) * alignment; + return size; + } +}; + +// ----------------------------------------------------------------------------- +// Helper routines to keep nvte_grouped_gemm readable +// ----------------------------------------------------------------------------- +inline void validate_grouped_gemm_inputs(const transformer_engine::GroupedTensor *inputA, + const transformer_engine::GroupedTensor *inputB, + const transformer_engine::GroupedTensor *inputC, + const transformer_engine::GroupedTensor *outputD, + const transformer_engine::Tensor *alpha_tensor, + const transformer_engine::Tensor *beta_tensor) { + const size_t num_tensors = inputA->num_tensors; + NVTE_CHECK(num_tensors >= 1, "Grouped GEMM: num_tensors must be at least 1"); + NVTE_CHECK(inputB->num_tensors == num_tensors, + "Grouped GEMM: A and B must have the same num_tensors"); + // C can be NULL (will use D as C when beta=0) + if (inputC != nullptr) { + NVTE_CHECK(inputC->num_tensors == num_tensors, + "Grouped GEMM: A and C must have the same num_tensors"); + } + NVTE_CHECK(outputD->num_tensors == num_tensors, + "Grouped GEMM: A and D must have the same num_tensors"); + + // Validate alpha/beta have per-matrix values + const size_t alpha_numel = alpha_tensor->data.numel(); + const size_t beta_numel = beta_tensor->data.numel(); + NVTE_CHECK(alpha_numel == num_tensors, + "Grouped GEMM: alpha must have num_tensors (", num_tensors, ") elements, got ", + alpha_numel); + NVTE_CHECK(beta_numel == num_tensors, + "Grouped GEMM: beta must have num_tensors (", num_tensors, ") elements, got ", + beta_numel); + + auto is_fp8_or_16bit = [](transformer_engine::DType dtype) { + return dtype == transformer_engine::DType::kFloat8E4M3 || + dtype == transformer_engine::DType::kFloat8E5M2 || + dtype == transformer_engine::DType::kBFloat16 || + dtype == transformer_engine::DType::kFloat16; + }; + auto is_output_dtype = [](transformer_engine::DType dtype) { + return dtype == transformer_engine::DType::kBFloat16 || + dtype == transformer_engine::DType::kFloat16 || + dtype == transformer_engine::DType::kFloat32; + }; + NVTE_CHECK(is_fp8_or_16bit(inputA->dtype()) && is_fp8_or_16bit(inputB->dtype()), + "Grouped GEMM inputs must be FP8, BF16, or FP16."); + // Only check C dtype if C is provided + if (inputC != nullptr) { + NVTE_CHECK(is_output_dtype(inputC->dtype()), "Grouped GEMM: C must be BF16, FP16, or FP32."); + } + NVTE_CHECK(is_output_dtype(outputD->dtype()), "Grouped GEMM: D must be BF16, FP16, or FP32."); + NVTE_CHECK(inputA->has_data() || inputA->has_columnwise_data(), + "Grouped GEMM: A tensor is missing both row-wise and column-wise data"); + NVTE_CHECK(inputB->has_data() || inputB->has_columnwise_data(), + "Grouped GEMM: B tensor is missing both row-wise and column-wise data"); +} + +// Select row-wise vs column-wise storage and adjust transpose flag for grouped GEMM. +// Mirrors the non-grouped GEMM logic for FP8 layout handling (TN-only on Hopper) and +// fallback to column-wise data when row-wise is absent. +struct GroupedOperandSelection { + const transformer_engine::GroupedTensor *tensor = nullptr; + const char *dptr = nullptr; + transformer_engine::DType dtype = transformer_engine::DType::kNumTypes; + bool trans = false; + bool use_columnwise = false; +}; + +inline GroupedOperandSelection select_grouped_operand(const transformer_engine::GroupedTensor *t, + bool trans, bool is_A) { + using namespace transformer_engine; + const bool has_row = t->has_data(); + const bool has_col = t->has_columnwise_data(); + NVTE_CHECK(has_row || has_col, + "Grouped GEMM operand is missing both row-wise and column-wise data"); + + // Currently only unquantized data and tensor-scaled FP8 are supported. + const auto sm = t->scaling_mode; + NVTE_CHECK(sm == NVTE_DELAYED_TENSOR_SCALING, + "Grouped GEMM is only supported with unquantized data and tensor-scaled FP8 data"); + + const DType row_dtype = t->data.dtype; + const DType col_dtype = t->columnwise_data.dtype; + GroupedOperandSelection sel; + sel.tensor = t; + sel.trans = trans; + + const DType rep_dtype = has_row ? row_dtype : col_dtype; + const bool is_fp8 = is_fp8_dtype(rep_dtype); + const bool non_tn_fp8_ok = nvte_is_non_tn_fp8_gemm_supported(); + + // Hopper-style TN-only FP8: force TN by switching layout and flipping transpose when needed. + if (is_fp8 && !non_tn_fp8_ok) { + if (is_A) { + if (!sel.trans) { + NVTE_CHECK(has_col, "Grouped GEMM: A is missing column-wise data needed for FP8 TN layout"); + sel.dptr = static_cast(t->columnwise_data.dptr); + sel.dtype = col_dtype; + sel.trans = true; // using pre-transposed storage + sel.use_columnwise = true; + return sel; + } + } else { // B + if (sel.trans) { + NVTE_CHECK(has_col, "Grouped GEMM: B is missing column-wise data needed for FP8 TN layout"); + sel.dptr = static_cast(t->columnwise_data.dptr); + sel.dtype = col_dtype; + sel.trans = false; // using pre-transposed storage + sel.use_columnwise = true; + return sel; + } + } + } + + // If only column-wise data is available, mirror the transpose flag (pre-transposed storage). + if (!has_row && has_col) { + // On Hopper FP8, this would break TN requirement - should have been handled above + NVTE_CHECK( + !is_fp8 || non_tn_fp8_ok, + "Grouped GEMM: FP8 on Hopper requires row-wise data for this transpose configuration"); + sel.dptr = static_cast(t->columnwise_data.dptr); + sel.dtype = col_dtype; + sel.trans = !sel.trans; + sel.use_columnwise = true; + return sel; + } + + // Default: use row-wise data (column-wise case already handled above) + sel.dptr = static_cast(t->data.dptr); + sel.dtype = row_dtype; + sel.use_columnwise = false; + return sel; +} + +inline void *validate_and_get_workspace_ptr(transformer_engine::Tensor *ws, size_t required_size, + const char *workspace_name) { + NVTE_CHECK(ws != nullptr, workspace_name, " tensor is null."); + const size_t provided_size = get_buffer_size_bytes(ws->data.numel(), ws->data.dtype); + NVTE_CHECK(provided_size >= required_size, "Grouped GEMM: Insufficient ", workspace_name, + ". Required: ", required_size, " bytes, Available: ", provided_size, " bytes."); + return ws->data.dptr; +} + +inline void init_matrix_layouts(cublasLtMatrixLayoutOpaque_t &descA, + cublasLtMatrixLayoutOpaque_t &descB, + cublasLtMatrixLayoutOpaque_t &descC, + cublasLtMatrixLayoutOpaque_t &descD, + const GroupedGemmSetupWorkspace &ws, + const GroupedOperandSelection &A_sel, + const GroupedOperandSelection &B_sel, + const transformer_engine::GroupedTensor *D, size_t num_tensors) { + const cudaDataType_t A_type = get_cuda_dtype(A_sel.dtype); + const cudaDataType_t B_type = get_cuda_dtype(B_sel.dtype); + const cudaDataType_t D_type = get_cuda_dtype(D->dtype()); + + // For column-major layout: leading dimension is the number of rows in storage. + // If columnwise data was chosen, storage is already transposed. + int *rowa = A_sel.use_columnwise ? ws.M : (A_sel.trans ? ws.K : ws.M); + int *cola = A_sel.use_columnwise ? ws.K : (A_sel.trans ? ws.M : ws.K); + int *lda = rowa; + int *rowb = B_sel.use_columnwise ? ws.N : (B_sel.trans ? ws.N : ws.K); + int *colb = B_sel.use_columnwise ? ws.K : (B_sel.trans ? ws.K : ws.N); + int *ldb = rowb; + + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descA, A_type, num_tensors, rowa, cola, lda)); + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descB, B_type, num_tensors, rowb, colb, ldb)); + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descC, D_type, num_tensors, ws.M, ws.N, ws.M)); + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descD, D_type, num_tensors, ws.M, ws.N, ws.M)); +} + +inline void init_matmul_desc(cublasLtMatmulDescOpaque_t &matmulDesc, cublasOperation_t op_A, + cublasOperation_t op_B) { + NVTE_CHECK_CUBLAS(cublasLtMatmulDescInit(&matmulDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F)); + + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_TRANSA, &op_A, + sizeof(op_A))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &op_B, + sizeof(op_B))); + + cublasLtPointerMode_t pointer_mode = CUBLASLT_POINTER_MODE_DEVICE; + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, + &pointer_mode, sizeof(pointer_mode))); + + int64_t alphabeta_batch_stride = 1; + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, + CUBLASLT_MATMUL_DESC_ALPHA_BATCH_STRIDE, + &alphabeta_batch_stride, sizeof(int64_t))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, + CUBLASLT_MATMUL_DESC_BETA_BATCH_STRIDE, + &alphabeta_batch_stride, sizeof(int64_t))); +} + +inline void set_fp8_scale_pointers(cublasLtMatmulDescOpaque_t &matmulDesc, + const GroupedOperandSelection &A_sel, + const GroupedOperandSelection &B_sel) { + const bool is_fp8_a = is_fp8_dtype(A_sel.dtype); + const bool is_fp8_b = is_fp8_dtype(B_sel.dtype); + if (!is_fp8_a && !is_fp8_b) return; + + if (is_fp8_a) { + void *a_scale_inv = A_sel.use_columnwise ? A_sel.tensor->columnwise_scale_inv.dptr + : A_sel.tensor->scale_inv.dptr; + NVTE_CHECK(a_scale_inv != nullptr, "FP8 grouped GEMM: A scale_inv is required"); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + &matmulDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &a_scale_inv, sizeof(a_scale_inv))); + } + if (is_fp8_b) { + void *b_scale_inv = B_sel.use_columnwise ? B_sel.tensor->columnwise_scale_inv.dptr + : B_sel.tensor->scale_inv.dptr; + NVTE_CHECK(b_scale_inv != nullptr, "FP8 grouped GEMM: B scale_inv is required"); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + &matmulDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &b_scale_inv, sizeof(b_scale_inv))); + } +} + +// Constants for grouped GEMM workspace (declared early for use in heuristics) +static constexpr size_t kGroupedGemmAlignment = 256; +static constexpr size_t kGroupedGemmCublasWorkspaceSize = 32ull * 1024 * 1024; // 32 MiB + +inline cublasLtMatmulAlgo_t select_grouped_gemm_algo(cublasLtHandle_t handle, + cublasLtMatmulDescOpaque_t &matmulDesc, + cublasLtMatrixLayoutOpaque_t &descA, + cublasLtMatrixLayoutOpaque_t &descB, + cublasLtMatrixLayoutOpaque_t &descC, + cublasLtMatrixLayoutOpaque_t &descD, + int64_t avg_m, int64_t avg_n, int64_t avg_k) { + cublasLtMatmulPreferenceOpaque_t preference; + NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceInit(&preference)); + NVTE_CHECK_CUBLAS( + cublasLtMatmulPreferenceSetAttribute(&preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &kGroupedGemmCublasWorkspaceSize, sizeof(size_t))); + NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( + &preference, CUBLASLT_MATMUL_PREF_GROUPED_DESC_D_AVERAGE_ROWS, &avg_m, sizeof(int64_t))); + NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( + &preference, CUBLASLT_MATMUL_PREF_GROUPED_DESC_D_AVERAGE_COLS, &avg_n, sizeof(int64_t))); + NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( + &preference, CUBLASLT_MATMUL_PREF_GROUPED_AVERAGE_REDUCTION_DIM, &avg_k, sizeof(int64_t))); + + cublasLtMatmulHeuristicResult_t heuristicResult; + int returnedResults = 0; + auto status = cublasLtMatmulAlgoGetHeuristic(handle, &matmulDesc, &descA, &descB, &descC, &descD, + &preference, 1, &heuristicResult, &returnedResults); + NVTE_CHECK(status != CUBLAS_STATUS_NOT_SUPPORTED, + "Unable to find suitable cuBLAS grouped GEMM algorithm"); + NVTE_CHECK_CUBLAS(status); + NVTE_CHECK(returnedResults > 0, "No suitable algorithm found for grouped GEMM"); + return heuristicResult.algo; +} + +// Single kernel that sets up all GEMM parameters. +// Rationale: cuBLASLt grouped matmul API needs flat arrays of pointers and per-matrix M/N/K, +// but NVTEGroupedTensor stores a single contiguous buffer + optional per-tensor offsets/shapes. +// We bridge the mismatch on GPU by computing per-group pointers and dims in one kernel. +__global__ void setup_grouped_gemm_kernel( + // Output arrays + void **A_ptrs, void **B_ptrs, void **C_ptrs, void **D_ptrs, int *M, int *N, int *K, + float **alpha_ptrs, float **beta_ptrs, + // Base pointers + const char *a_base, const char *b_base, const char *c_base, char *d_base, + // Dimension info (per tensor) + TensorShapeInfo A_meta, TensorShapeInfo B_meta, TensorShapeInfo C_meta, TensorShapeInfo D_meta, + // Element sizes + size_t a_elem_size, size_t b_elem_size, size_t c_elem_size, size_t d_elem_size, + // Alpha/beta pointers (per-matrix arrays) + float *alpha_ptr, float *beta_ptr, + // Transpose flags + bool transa, bool transb, + // Number of tensors + size_t num_tensors) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= num_tensors) return; + + // Get dimensions for this tensor (from array or uniform value) + int64_t a_first = A_meta.first_dims ? A_meta.first_dims[idx] : A_meta.uniform_first; + int64_t a_last = A_meta.last_dims ? A_meta.last_dims[idx] : A_meta.uniform_last; + int64_t b_first = B_meta.first_dims ? B_meta.first_dims[idx] : B_meta.uniform_first; + int64_t b_last = B_meta.last_dims ? B_meta.last_dims[idx] : B_meta.uniform_last; + + // Compute offsets (from array or compute from uniform dims) + int64_t a_offset = + A_meta.offsets ? A_meta.offsets[idx] : (idx * A_meta.uniform_first * A_meta.uniform_last); + int64_t b_offset = + B_meta.offsets ? B_meta.offsets[idx] : (idx * B_meta.uniform_first * B_meta.uniform_last); + int64_t c_offset = + C_meta.offsets ? C_meta.offsets[idx] : (idx * C_meta.uniform_first * C_meta.uniform_last); + int64_t d_offset = + D_meta.offsets ? D_meta.offsets[idx] : (idx * D_meta.uniform_first * D_meta.uniform_last); + + // Compute data pointers + A_ptrs[idx] = const_cast(a_base) + a_offset * a_elem_size; + B_ptrs[idx] = const_cast(b_base) + b_offset * b_elem_size; + C_ptrs[idx] = const_cast(c_base) + c_offset * c_elem_size; + D_ptrs[idx] = d_base + d_offset * d_elem_size; + + // Compute M, N, K dimensions + // Test stores A as {K,M} when !transa, {M,K} when transa + // Test stores B as {N,K} when !transb, {K,N} when transb + M[idx] = static_cast(transa ? a_first : a_last); + K[idx] = static_cast(transa ? a_last : a_first); + N[idx] = static_cast(transb ? b_last : b_first); + + // Fill alpha/beta pointers (per-matrix) + alpha_ptrs[idx] = alpha_ptr + idx; + beta_ptrs[idx] = beta_ptr + idx; +} + +// Launch the setup kernel to populate workspace arrays +inline void launch_grouped_gemm_setup( + const GroupedGemmSetupWorkspace &ws, const GroupedOperandSelection &A_sel, + const GroupedOperandSelection &B_sel, const transformer_engine::GroupedTensor *C, + const transformer_engine::GroupedTensor *D, const transformer_engine::Tensor *alpha_tensor, + const transformer_engine::Tensor *beta_tensor, size_t num_tensors, cudaStream_t stream) { + TensorShapeInfo A_meta = TensorShapeInfo::from_tensor(A_sel.tensor); + TensorShapeInfo B_meta = TensorShapeInfo::from_tensor(B_sel.tensor); + TensorShapeInfo C_meta = TensorShapeInfo::for_C(C, D); + TensorShapeInfo D_meta = TensorShapeInfo::from_tensor(D); + + const char *c_base = static_cast(C->data.dptr); + char *d_base = static_cast(D->data.dptr); + + const size_t a_elem_size = transformer_engine::typeToSize(A_sel.dtype); + const size_t b_elem_size = transformer_engine::typeToSize(B_sel.dtype); + const size_t c_elem_size = transformer_engine::typeToSize(C->dtype()); + const size_t d_elem_size = transformer_engine::typeToSize(D->dtype()); + + const int threads_per_block = 256; + const int num_blocks = (num_tensors + threads_per_block - 1) / threads_per_block; + + setup_grouped_gemm_kernel<<>>( + ws.A_ptrs, ws.B_ptrs, ws.C_ptrs, ws.D_ptrs, ws.M, ws.N, ws.K, ws.alpha_ptrs, ws.beta_ptrs, + A_sel.dptr, B_sel.dptr, c_base, d_base, A_meta, B_meta, C_meta, D_meta, a_elem_size, + b_elem_size, c_elem_size, d_elem_size, static_cast(alpha_tensor->data.dptr), + static_cast(beta_tensor->data.dptr), A_sel.trans, B_sel.trans, num_tensors); + + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +inline size_t grouped_gemm_setup_workspace_size(size_t num_tensors) { + return GroupedGemmSetupWorkspace::required_setup_size(num_tensors, kGroupedGemmAlignment); +} + +} // namespace + +void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVTEGroupedTensor A, + const NVTEGroupedTensor B, const NVTETensor beta, const NVTEGroupedTensor C, + NVTEGroupedTensor D, NVTETensor workspace_setup, NVTETensor workspace_cublas, + cudaStream_t stream, const int64_t *avg_m, const int64_t *avg_n, + const int64_t *avg_k) { + NVTE_API_CALL(nvte_grouped_gemm); + using namespace transformer_engine; + + // Grouped GEMM requires Hopper (SM90) or newer + const int current_device = cuda::current_device(); + NVTE_CHECK(cuda::sm_arch(current_device) >= 90, + "nvte_grouped_gemm requires Hopper (SM90) or newer architecture."); + + // Convert to internal types + const GroupedTensor *inputA = convertNVTEGroupedTensorCheck(A); + const GroupedTensor *inputB = convertNVTEGroupedTensorCheck(B); + const GroupedTensor *inputC_raw = convertNVTEGroupedTensor(C); // Can be NULL + GroupedTensor *outputD = convertNVTEGroupedTensorCheck(D); + const Tensor *alpha_tensor = convertNVTETensorCheck(alpha); + const Tensor *beta_tensor = convertNVTETensorCheck(beta); + Tensor *wspace_setup = convertNVTETensor(workspace_setup); + Tensor *wspace_cublas = convertNVTETensor(workspace_cublas); + + // Validate inputs and num_tensors + validate_grouped_gemm_inputs(inputA, inputB, inputC_raw, outputD, alpha_tensor, beta_tensor); + + // If C is NULL, use D as C (valid when beta=0, cuBLAS won't read C data) + const GroupedTensor *inputC = (inputC_raw != nullptr) ? inputC_raw : outputD; + const size_t num_tensors = inputA->num_tensors; + + // Select operand storage (row-wise vs column-wise) and adjust transpose flags to + // mirror the non-grouped GEMM logic for FP8 layout constraints. + const auto A_sel = select_grouped_operand(inputA, static_cast(transa), /*is_A=*/true); + const auto B_sel = select_grouped_operand(inputB, static_cast(transb), /*is_A=*/false); + + // Workspaces: setup (pointer arrays) and cuBLAS + const size_t setup_workspace_size = grouped_gemm_setup_workspace_size(num_tensors); + const size_t cublas_workspace_size = kGroupedGemmCublasWorkspaceSize; + + void *setup_workspace_ptr = validate_and_get_workspace_ptr(wspace_setup, setup_workspace_size, + "Grouped GEMM setup workspace"); + void *cublas_workspace_ptr = validate_and_get_workspace_ptr(wspace_cublas, cublas_workspace_size, + "Grouped GEMM cuBLAS workspace"); + + auto setup_workspace = GroupedGemmSetupWorkspace::from_buffers( + static_cast(setup_workspace_ptr), num_tensors); + launch_grouped_gemm_setup(setup_workspace, A_sel, B_sel, inputC, outputD, alpha_tensor, + beta_tensor, num_tensors, stream); + + // Get cuBLAS handle + using cublasHandleManager = detail::HandleManager; + cublasLtHandle_t handle = cublasHandleManager::Instance().GetHandle(); + + // Setup cuBLAS operations + cublasOperation_t op_A = A_sel.trans ? CUBLAS_OP_T : CUBLAS_OP_N; + cublasOperation_t op_B = B_sel.trans ? CUBLAS_OP_T : CUBLAS_OP_N; + + // Create grouped matrix layouts + cublasLtMatrixLayoutOpaque_t descA, descB, descC, descD; + init_matrix_layouts(descA, descB, descC, descD, setup_workspace, A_sel, B_sel, outputD, + num_tensors); + + // Create matmul descriptor + cublasLtMatmulDescOpaque_t matmulDesc; + init_matmul_desc(matmulDesc, op_A, op_B); + set_fp8_scale_pointers(matmulDesc, A_sel, B_sel); + + // Compute average dimensions for heuristics + // K dimension: if transa, K is A's first dim; if not, K is A's last dim + int64_t avg_m_val = avg_m ? *avg_m : compute_avg_first_dim(outputD); + int64_t avg_n_val = avg_n ? *avg_n : compute_avg_last_dim(outputD); + int64_t avg_k_val = avg_k ? *avg_k + : (A_sel.trans ? compute_avg_first_dim(A_sel.tensor) + : compute_avg_last_dim(A_sel.tensor)); + + // Heuristic selection + cublasLtMatmulAlgo_t algo = select_grouped_gemm_algo(handle, matmulDesc, descA, descB, descC, + descD, avg_m_val, avg_n_val, avg_k_val); + + // Execute the grouped GEMM + NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, &matmulDesc, setup_workspace.alpha_ptrs, + setup_workspace.A_ptrs, &descA, setup_workspace.B_ptrs, &descB, + setup_workspace.beta_ptrs, setup_workspace.C_ptrs, &descC, + setup_workspace.D_ptrs, &descD, &algo, cublas_workspace_ptr, + kGroupedGemmCublasWorkspaceSize, stream)); +} + +#else // CUBLAS_VERSION < 130100 + +void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVTEGroupedTensor A, + const NVTEGroupedTensor B, const NVTETensor beta, const NVTEGroupedTensor C, + NVTEGroupedTensor D, NVTETensor workspace_setup, NVTETensor workspace_cublas, + cudaStream_t stream, const int64_t *avg_m, const int64_t *avg_n, + const int64_t *avg_k) { + NVTE_ERROR("nvte_grouped_gemm requires cuBLAS 13.2+, but compile-time cuBLAS version is ", + CUBLAS_VERSION, ". Please upgrade to CUDA 13.2 or newer."); +} + +#endif // CUBLAS_VERSION >= 130100 + diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cuh b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cuh new file mode 100644 index 00000000000..6514ba2f974 --- /dev/null +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cuh @@ -0,0 +1,18 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_COMMON_GEMM_CUBLASLT_GROUPED_GEMM_CUH_ +#define TRANSFORMER_ENGINE_COMMON_GEMM_CUBLASLT_GROUPED_GEMM_CUH_ + +#include +#include +#include + +// nvte_grouped_gemm is declared in transformer_engine/gemm.h +// This header is for internal use only. + +#endif // TRANSFORMER_ENGINE_COMMON_GEMM_CUBLASLT_GROUPED_GEMM_CUH_ + diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index 9dfa009115e..b2e42bd66fe 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -11,7 +11,7 @@ #ifndef TRANSFORMER_ENGINE_GEMM_H_ #define TRANSFORMER_ENGINE_GEMM_H_ -#include +#include #include "transformer_engine.h" @@ -233,6 +233,10 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor /* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ /*! \brief Grouped matrix multiplication: D = alpha * op(A) @ op(B) + beta * C + * + * \note Requires cuBLAS 13.2+ (CUDA 13.2+) and Hopper (SM90) or newer GPU architecture. + * Will error at runtime if compiled with an older cuBLAS version or run on + * a pre-Hopper GPU. * * Performs batched GEMM on a collection of matrices with potentially different shapes. * All tensors in the group must have compatible dimensions for matrix multiplication. @@ -262,6 +266,8 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor * heuristics. If NULL, computed automatically from A's logical shape. * * Requirements: + * - cuBLAS 13.2+ (CUDA 13.2+) + * - Hopper (SM90) or newer GPU architecture * - A, B, C (if provided), D must have the same num_tensors * - For each i: D[i] = alpha[i] * op(A[i]) @ op(B[i]) + beta[i] * C[i] * - Shape compatibility: if transa=false, transb=false: @@ -270,8 +276,8 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVTEGroupedTensor A, const NVTEGroupedTensor B, const NVTETensor beta, const NVTEGroupedTensor C, NVTEGroupedTensor D, NVTETensor workspace_setup, NVTETensor workspace_cublas, - NVTEMatmulConfig config, cudaStream_t stream, const int64_t *avg_m, - const int64_t *avg_n, const int64_t *avg_k); + cudaStream_t stream, const int64_t *avg_m, const int64_t *avg_n, + const int64_t *avg_k); #ifdef __cplusplus } // extern "C" From 047a9f93bd5252241883077e0a904b2c7f1c6e57 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Fri, 19 Dec 2025 12:29:12 +0100 Subject: [PATCH 13/17] fix Signed-off-by: Pawel Gadzinski --- tests/cpp/operator/test_grouped_gemm.cu | 5 +++-- transformer_engine/common/CMakeLists.txt | 1 + transformer_engine/common/include/transformer_engine/gemm.h | 3 +-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/cpp/operator/test_grouped_gemm.cu b/tests/cpp/operator/test_grouped_gemm.cu index 0ea76946bc2..3336dbc6d56 100644 --- a/tests/cpp/operator/test_grouped_gemm.cu +++ b/tests/cpp/operator/test_grouped_gemm.cu @@ -137,8 +137,9 @@ GroupedBuffers build_grouped_tensor(const std::vector& tensors, // cuBLAS requires aligned pointers for vectorized loads static std::mt19937 gen(12345); std::uniform_int_distribution dist(0, 3); - // Calculate elements needed for 16-byte alignment - const size_t align_elements = (16 * 8) / typeToNumBits(dtype); // 16 bytes / element_size + // Calculate elements needed for 16-byte alignment in bytes, rounded up + const size_t align_elements = + std::max(1, (16 + elem_size - 1) / elem_size); // 16 bytes / element_size return dist(gen) * static_cast(align_elements); }; diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 264f7f9a78d..e25bf024397 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -144,6 +144,7 @@ list(APPEND transformer_engine_cuda_sources fused_attn/fused_attn_fp8.cu fused_attn/utils.cu gemm/cublaslt_gemm.cu + gemm/cublaslt_grouped_gemm.cu normalization/layernorm/ln_bwd_semi_cuda_kernel.cu normalization/layernorm/ln_fwd_cuda_kernel.cu normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index b2e42bd66fe..f1e27761582 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -234,7 +234,7 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor /* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ /*! \brief Grouped matrix multiplication: D = alpha * op(A) @ op(B) + beta * C * - * \note Requires cuBLAS 13.2+ (CUDA 13.2+) and Hopper (SM90) or newer GPU architecture. + * \note Requires cuBLAS 13.1+ (CUDA 13.1+) and Hopper (SM90) or newer GPU architecture. * Will error at runtime if compiled with an older cuBLAS version or run on * a pre-Hopper GPU. * @@ -253,7 +253,6 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor * \param[out] D Output grouped tensor D. * \param[in] workspace_setup Workspace tensor for pointer array setup. * \param[in] workspace_cublas Workspace tensor for cuBLAS operations. - * \param[in] config Matrix multiplication configuration. * \param[in] stream CUDA stream for the operation. * \param[in] avg_m Optional hint for average M dimension across all matrices in the * group. Used by cuBLASLt for algorithm selection heuristics. From c490e06ab71f9919d69bfc2c67eb6b7cf6bc20ff Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 19 Dec 2025 11:32:34 +0000 Subject: [PATCH 14/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/gemm/cublaslt_grouped_gemm.cu | 11 ++++------- .../common/gemm/cublaslt_grouped_gemm.cuh | 1 - .../common/include/transformer_engine/gemm.h | 2 +- 3 files changed, 5 insertions(+), 9 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu index 4125bd82bff..3647a4c39ed 100644 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -180,12 +180,10 @@ inline void validate_grouped_gemm_inputs(const transformer_engine::GroupedTensor // Validate alpha/beta have per-matrix values const size_t alpha_numel = alpha_tensor->data.numel(); const size_t beta_numel = beta_tensor->data.numel(); - NVTE_CHECK(alpha_numel == num_tensors, - "Grouped GEMM: alpha must have num_tensors (", num_tensors, ") elements, got ", - alpha_numel); - NVTE_CHECK(beta_numel == num_tensors, - "Grouped GEMM: beta must have num_tensors (", num_tensors, ") elements, got ", - beta_numel); + NVTE_CHECK(alpha_numel == num_tensors, "Grouped GEMM: alpha must have num_tensors (", num_tensors, + ") elements, got ", alpha_numel); + NVTE_CHECK(beta_numel == num_tensors, "Grouped GEMM: beta must have num_tensors (", num_tensors, + ") elements, got ", beta_numel); auto is_fp8_or_16bit = [](transformer_engine::DType dtype) { return dtype == transformer_engine::DType::kFloat8E4M3 || @@ -596,4 +594,3 @@ void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVT } #endif // CUBLAS_VERSION >= 130100 - diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cuh b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cuh index 6514ba2f974..a032e594d57 100644 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cuh +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cuh @@ -15,4 +15,3 @@ // This header is for internal use only. #endif // TRANSFORMER_ENGINE_COMMON_GEMM_CUBLASLT_GROUPED_GEMM_CUH_ - diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index f1e27761582..0c8d601d509 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -11,7 +11,7 @@ #ifndef TRANSFORMER_ENGINE_GEMM_H_ #define TRANSFORMER_ENGINE_GEMM_H_ -#include +#include #include "transformer_engine.h" From 59145cc2a7d4e4cb92addbd39c374541cbed5eb9 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Mon, 22 Dec 2025 10:21:19 +0100 Subject: [PATCH 15/17] fix Signed-off-by: Pawel Gadzinski --- tests/cpp/operator/test_grouped_gemm.cu | 7 ++++--- .../common/gemm/cublaslt_grouped_gemm.cu | 10 +++++----- .../common/include/transformer_engine/gemm.h | 6 +++--- 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/tests/cpp/operator/test_grouped_gemm.cu b/tests/cpp/operator/test_grouped_gemm.cu index 3336dbc6d56..bdcfa68a4f7 100644 --- a/tests/cpp/operator/test_grouped_gemm.cu +++ b/tests/cpp/operator/test_grouped_gemm.cu @@ -95,7 +95,8 @@ struct GroupedBuffers { size_t grouped_setup_workspace_size(const size_t num_tensors) { const size_t ptr_bytes = num_tensors * sizeof(void*); const size_t int_bytes = num_tensors * sizeof(int); - size_t size = 4 * ptr_bytes + 3 * int_bytes + 2 * ptr_bytes; + // Layout: 6 pointer arrays (A, B, C, D, alpha, beta) + 3 int arrays (M, N, K) + size_t size = 6 * ptr_bytes + 3 * int_bytes; const size_t alignment = 256; size = ((size + alignment - 1) / alignment) * alignment; return size; @@ -320,8 +321,8 @@ void run_grouped_gemm_case(const TestParams& params) { GTEST_SKIP() << "Grouped GEMM requires cuBLAS 13.2+, but compile-time cuBLAS version is " << CUBLAS_VERSION << "."; #else - if (getDeviceComputeCapability() < hopperComputeCapability) { - GTEST_SKIP() << "Grouped GEMM requires Hopper (SM90) or newer."; + if (getDeviceComputeCapability() < blackwellComputeCapability) { + GTEST_SKIP() << "Grouped GEMM requires Blackwell (SM100) or newer."; } const std::vector> shapes = make_shapes(params.shape_case); diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu index 3647a4c39ed..40180fe7607 100644 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -503,10 +503,10 @@ void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVT NVTE_API_CALL(nvte_grouped_gemm); using namespace transformer_engine; - // Grouped GEMM requires Hopper (SM90) or newer + // Grouped GEMM requires Blackwell (SM100) or newer const int current_device = cuda::current_device(); - NVTE_CHECK(cuda::sm_arch(current_device) >= 90, - "nvte_grouped_gemm requires Hopper (SM90) or newer architecture."); + NVTE_CHECK(cuda::sm_arch(current_device) >= 100, + "nvte_grouped_gemm requires Blackwell (SM100) or newer architecture."); // Convert to internal types const GroupedTensor *inputA = convertNVTEGroupedTensorCheck(A); @@ -589,8 +589,8 @@ void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVT NVTEGroupedTensor D, NVTETensor workspace_setup, NVTETensor workspace_cublas, cudaStream_t stream, const int64_t *avg_m, const int64_t *avg_n, const int64_t *avg_k) { - NVTE_ERROR("nvte_grouped_gemm requires cuBLAS 13.2+, but compile-time cuBLAS version is ", - CUBLAS_VERSION, ". Please upgrade to CUDA 13.2 or newer."); + NVTE_ERROR("nvte_grouped_gemm requires cuBLAS 13.1+, but compile-time cuBLAS version is ", + CUBLAS_VERSION, ". Please upgrade to CUDA 13.1 or newer."); } #endif // CUBLAS_VERSION >= 130100 diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index 0c8d601d509..168141224c6 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -234,9 +234,9 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor /* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ /*! \brief Grouped matrix multiplication: D = alpha * op(A) @ op(B) + beta * C * - * \note Requires cuBLAS 13.1+ (CUDA 13.1+) and Hopper (SM90) or newer GPU architecture. + * \note Requires cuBLAS 13.1+ (CUDA 13.1+) and Blackwell (SM100) or newer GPU architecture. * Will error at runtime if compiled with an older cuBLAS version or run on - * a pre-Hopper GPU. + * a pre-Blackwell GPU. * * Performs batched GEMM on a collection of matrices with potentially different shapes. * All tensors in the group must have compatible dimensions for matrix multiplication. @@ -266,7 +266,7 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor * * Requirements: * - cuBLAS 13.2+ (CUDA 13.2+) - * - Hopper (SM90) or newer GPU architecture + * - Blackwell (SM100) or newer GPU architecture * - A, B, C (if provided), D must have the same num_tensors * - For each i: D[i] = alpha[i] * op(A[i]) @ op(B[i]) + beta[i] * C[i] * - Shape compatibility: if transa=false, transb=false: From 77b422ac8d6e33bb5d56651a2e956629c17a5db8 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Mon, 22 Dec 2025 10:47:19 +0100 Subject: [PATCH 16/17] Require Blackwell (SM100) and cuBLAS 13.1+ for grouped GEMM Signed-off-by: Pawel Gadzinski --- 3rdparty/cudnn-frontend | 2 +- tests/cpp/operator/test_grouped_gemm.cu | 4 ++-- transformer_engine/common/include/transformer_engine/gemm.h | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 0258951d4d5..be6c079be8a 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 0258951d4d512f4714eb1574496f4d57669b1b93 +Subproject commit be6c079be8aaffa0fc079fcf039887e637c289c7 diff --git a/tests/cpp/operator/test_grouped_gemm.cu b/tests/cpp/operator/test_grouped_gemm.cu index bdcfa68a4f7..2514f11ab39 100644 --- a/tests/cpp/operator/test_grouped_gemm.cu +++ b/tests/cpp/operator/test_grouped_gemm.cu @@ -317,8 +317,8 @@ std::vector> make_shapes(ShapeCase scase) { } void run_grouped_gemm_case(const TestParams& params) { -#if CUBLAS_VERSION < 130200 - GTEST_SKIP() << "Grouped GEMM requires cuBLAS 13.2+, but compile-time cuBLAS version is " +#if CUBLAS_VERSION < 130100 + GTEST_SKIP() << "Grouped GEMM requires cuBLAS 13.1+, but compile-time cuBLAS version is " << CUBLAS_VERSION << "."; #else if (getDeviceComputeCapability() < blackwellComputeCapability) { diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index 168141224c6..f4c60ca3fe9 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -265,7 +265,7 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor * heuristics. If NULL, computed automatically from A's logical shape. * * Requirements: - * - cuBLAS 13.2+ (CUDA 13.2+) + * - cuBLAS 13.1+ (CUDA 13.1+) * - Blackwell (SM100) or newer GPU architecture * - A, B, C (if provided), D must have the same num_tensors * - For each i: D[i] = alpha[i] * op(A[i]) @ op(B[i]) + beta[i] * C[i] From 9c8158ee86a30699710c0dc1cb17c5d9b9aa4ced Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Mon, 22 Dec 2025 11:28:47 +0100 Subject: [PATCH 17/17] fix Signed-off-by: Pawel Gadzinski --- tests/cpp/operator/test_grouped_gemm.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cpp/operator/test_grouped_gemm.cu b/tests/cpp/operator/test_grouped_gemm.cu index 2514f11ab39..ada69808589 100644 --- a/tests/cpp/operator/test_grouped_gemm.cu +++ b/tests/cpp/operator/test_grouped_gemm.cu @@ -482,7 +482,7 @@ void run_grouped_gemm_case(const TestParams& params) { atol, rtol); } -#endif // CUBLAS_VERSION >= 130200 +#endif // CUBLAS_VERSION >= 130100 } class GroupedGemmTest : public ::testing::TestWithParam {};