From ad748dadd66f6e0e9620d95dfa5b172ed67f28b0 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 9 Dec 2025 15:58:03 -0600 Subject: [PATCH 1/8] GEMM reference HIP implementation --- tests/cpp/operator/test_cublaslt_gemm.cu | 309 ++++++++++++++++++----- 1 file changed, 245 insertions(+), 64 deletions(-) diff --git a/tests/cpp/operator/test_cublaslt_gemm.cu b/tests/cpp/operator/test_cublaslt_gemm.cu index 071470bdf..e1e0b9316 100644 --- a/tests/cpp/operator/test_cublaslt_gemm.cu +++ b/tests/cpp/operator/test_cublaslt_gemm.cu @@ -51,11 +51,224 @@ using TShape = std::vector; } // namespace -float ref_gelu(float x){ +__device__ __host__ __forceinline__ float ref_gelu(float x){ float cdf = 0.5f * (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x)))); return x * cdf; } +template +__global__ void compute_ref_kernel( + const A_Type* __restrict__ a_data, + const B_Type* __restrict__ b_data, + float a_scale_inv_scalar, // used when mxfp8 == false + float b_scale_inv_scalar, + const fp8e8m0* __restrict__ a_scale_inv_mxfp8, // used when mxfp8 == true + const fp8e8m0* __restrict__ b_scale_inv_mxfp8, + const Bias_Type* __restrict__ bias_data, + float d_scale, + size_t m, size_t k, size_t n, + D_Type* __restrict__ d_data, + float* __restrict__ d_amax, + Gelu_Type* __restrict__ gelu_data, + bool transa, + bool transb, + bool is_fp8_output) +{ + const size_t jj = blockIdx.x * blockDim.x + threadIdx.x; + const size_t ii = blockIdx.y * blockDim.y + threadIdx.y; + + if (ii >= m || jj >= n) + return; + + float val = 0.0f; + + for (size_t kk = 0; kk < k; ++kk) { + const size_t a_idx = transa ? (ii * k + kk) : (kk * m + ii); + const size_t b_idx = transb ? (kk * n + jj) : (jj * k + kk); + + float a_scale_inv_val = a_scale_inv_scalar; + float b_scale_inv_val = b_scale_inv_scalar; + + if (a_scale_inv_mxfp8) { + const size_t a_scale_idx = + transa ? (a_idx / 32) : ((kk / 32) * m + ii); + const size_t b_scale_idx = + transb ? ((kk / 32) * n + jj) : (b_idx / 32); + + const float a_byte = static_cast(a_scale_inv_mxfp8[a_scale_idx]); + const float b_byte = static_cast(b_scale_inv_mxfp8[b_scale_idx]); + + a_scale_inv_val = exp2f(a_byte - 127.0f); + b_scale_inv_val = exp2f(b_byte - 127.0f); + } + + const float a_val = a_data[a_idx]; + const float b_val = b_data[b_idx]; + + val += a_scale_inv_val * a_val * b_scale_inv_val * b_val; + } + + if (bias_data) { + val += (float)bias_data[ii]; + } + + if (gelu_data) { + gelu_data[ii + jj * m] = val; + val = ref_gelu(val); + } + + const float scaled = val * d_scale; + d_data[ii + jj * m] = scaled; + + if (is_fp8_output && d_amax) { + atomicMax(d_amax, fabsf(val)); + } +} + +// Common implementation used by both tensor-wise and MXFP8 frontends +template +static void compute_ref_impl( + const A_Type* a_data, + const B_Type* b_data, + float a_scale_inv_scalar, // used when mxfp8 == false + float b_scale_inv_scalar, + const fp8e8m0* a_scale_inv_mxfp8, // used when mxfp8 == true + const fp8e8m0* b_scale_inv_mxfp8, + const Bias_Type* bias_data, + float d_scale, + size_t m, size_t k, size_t n, + D_Type* d_data, + float* d_amax_host, + Gelu_Type* gelu_data, + bool transa, + bool transb) +{ + using transformer_engine::DType; + using ::TypeInfo; + using ::isFp8Type; + + const bool use_mxfp8 = (a_scale_inv_mxfp8 != nullptr); + + const DType dtype = TypeInfo::dtype; + const bool is_fp8_output = isFp8Type(dtype); + + const size_t lenA = m * k; + const size_t lenB = k * n; + const size_t lenD = m * n; + const size_t lenBias = m; + const size_t lenGelu = m * n; + + const size_t lenA_scale = use_mxfp8 ? (lenA + 31) / 32 : 0; + const size_t lenB_scale = use_mxfp8 ? (lenB + 31) / 32 : 0; + + A_Type* dA = nullptr; + B_Type* dB = nullptr; + Bias_Type* dBias = nullptr; + D_Type* dD = nullptr; + Gelu_Type* dGelu = nullptr; + float* dAmax = nullptr; + fp8e8m0* dA_scale = nullptr; + fp8e8m0* dB_scale = nullptr; + + // Allocations and H2D transfers + NVTE_CHECK_CUDA(cudaMalloc(&dA, lenA * sizeof(A_Type))); + NVTE_CHECK_CUDA(cudaMalloc(&dB, lenB * sizeof(B_Type))); + NVTE_CHECK_CUDA(cudaMalloc(&dD, lenD * sizeof(D_Type))); + + NVTE_CHECK_CUDA(cudaMemcpy( + dA, a_data, lenA * sizeof(A_Type), cudaMemcpyHostToDevice)); + NVTE_CHECK_CUDA(cudaMemcpy( + dB, b_data, lenB * sizeof(B_Type), cudaMemcpyHostToDevice)); + + if (bias_data) { + NVTE_CHECK_CUDA(cudaMalloc(&dBias, lenBias * sizeof(Bias_Type))); + NVTE_CHECK_CUDA(cudaMemcpy( + dBias, bias_data, lenBias * sizeof(Bias_Type), + cudaMemcpyHostToDevice)); + } + + if (gelu_data) { + NVTE_CHECK_CUDA(cudaMalloc(&dGelu, lenGelu * sizeof(Gelu_Type))); + NVTE_CHECK_CUDA(cudaMemset(dGelu, 0, lenGelu * sizeof(Gelu_Type))); + } + + if (use_mxfp8) { + NVTE_CHECK_CUDA(cudaMalloc(&dA_scale, lenA_scale * sizeof(fp8e8m0))); + NVTE_CHECK_CUDA(cudaMalloc(&dB_scale, lenB_scale * sizeof(fp8e8m0))); + NVTE_CHECK_CUDA(cudaMemcpy( + dA_scale, a_scale_inv_mxfp8, lenA_scale * sizeof(fp8e8m0), + cudaMemcpyHostToDevice)); + NVTE_CHECK_CUDA(cudaMemcpy( + dB_scale, b_scale_inv_mxfp8, lenB_scale * sizeof(fp8e8m0), + cudaMemcpyHostToDevice)); + } + + if (is_fp8_output && d_amax_host) { + NVTE_CHECK_CUDA(cudaMalloc(&dAmax, sizeof(float))); + NVTE_CHECK_CUDA(cudaMemset(dAmax, 0, sizeof(float))); + } + + // Kernel launch + dim3 block(16, 16); + dim3 grid((n + block.x - 1) / block.x, (m + block.y - 1) / block.y); + + compute_ref_kernel + <<>>( + dA, + dB, + a_scale_inv_scalar, + b_scale_inv_scalar, + dA_scale, + dB_scale, + dBias, + d_scale, + m, k, n, + dD, + dAmax, + dGelu, + transa, + transb, + is_fp8_output); + + NVTE_CHECK_CUDA(cudaGetLastError()); + NVTE_CHECK_CUDA(cudaDeviceSynchronize()); + + // D2H copies + NVTE_CHECK_CUDA(cudaMemcpy( + d_data, dD, lenD * sizeof(D_Type), cudaMemcpyDeviceToHost)); + + if (gelu_data) { + NVTE_CHECK_CUDA(cudaMemcpy( + gelu_data, dGelu, lenGelu * sizeof(Gelu_Type), + cudaMemcpyDeviceToHost)); + } + + if (is_fp8_output && d_amax_host) { + NVTE_CHECK_CUDA(cudaMemcpy( + d_amax_host, dAmax, sizeof(float), cudaMemcpyDeviceToHost)); + } else if (d_amax_host) { + *d_amax_host = 0.0f; + } + + // cleanup + NVTE_CHECK_CUDA(cudaFree(dA)); + NVTE_CHECK_CUDA(cudaFree(dB)); + NVTE_CHECK_CUDA(cudaFree(dD)); + if (dBias) + NVTE_CHECK_CUDA(cudaFree(dBias)); + if (dGelu) + NVTE_CHECK_CUDA(cudaFree(dGelu)); + if (dAmax) + NVTE_CHECK_CUDA(cudaFree(dAmax)); + if (dA_scale) + NVTE_CHECK_CUDA(cudaFree(dA_scale)); + if (dB_scale) + NVTE_CHECK_CUDA(cudaFree(dB_scale)); +} + + template void compute_ref( const A_Type* a_data, @@ -71,36 +284,21 @@ void compute_ref( bool transa, bool transb){ - float ref_d_amax = 0; - - #pragma omp parallel for schedule(static) collapse(2) reduction(max: ref_d_amax) proc_bind(spread) - for(size_t ii = 0; ii < m; ii++){ - for(size_t jj = 0; jj < n; jj++){ - float val = 0; - for(size_t kk = 0; kk < k; kk++){ - float a_val = transa ? a_data[kk + ii*k] : a_data[ii + kk*m]; - float b_val = transb ? b_data[jj + kk*n] : b_data[kk + jj*k]; - val += a_scale_inv*a_val*b_scale_inv*b_val; - } - if(bias_data){ - val += (float)bias_data[ii]; - } - if(ref_gelu_data){ - ref_gelu_data[ii + jj*m] = (Gelu_Type)(val); - val = ref_gelu(val); - } - ref_d_data[ii+jj*m] = (D_Type)(val*d_scale); - // update ref_d_amax if in fp8 - DType dtype = TypeInfo::dtype; - if(isFp8Type(dtype)){ - ref_d_amax = std::max(ref_d_amax, std::fabs(val)); - } - } - } - if (ref_d_amax_ptr) - { - *ref_d_amax_ptr = ref_d_amax; - } + compute_ref_impl( + a_data, + b_data, + /*a_scale_inv_scalar=*/a_scale_inv, + /*b_scale_inv_scalar=*/b_scale_inv, + /*a_scale_inv_mxfp8=*/nullptr, + /*b_scale_inv_mxfp8=*/nullptr, + bias_data, + d_scale, + m, k, n, + ref_d_data, + ref_d_amax_ptr, + ref_gelu_data, + transa, + transb); } template @@ -118,38 +316,21 @@ void compute_mxfp8_ref( bool transa, bool transb){ - float ref_d_amax = 0; - - #pragma omp parallel for schedule(static) collapse(2) reduction(max: ref_d_amax) proc_bind(spread) - for(size_t ii = 0; ii < m; ii++){ - for(size_t jj = 0; jj < n; jj++){ - float val = 0; - for(size_t kk = 0; kk < k; kk++){ - size_t a_idx = transa ? (ii*k + kk) : (kk*m + ii); - size_t b_idx = transb ? (kk*n + jj) : (jj*k + kk); - float a_scale_inv_val = std::exp2f(a_scale_inv_data[transa ? a_idx/32 : (kk/32 * m + ii)] - 127); - float b_scale_inv_val = std::exp2f(b_scale_inv_data[transb ? (kk/32 * n + jj) : b_idx/32] - 127); - val += a_scale_inv_val * (float)a_data[a_idx] * b_scale_inv_val * (float)b_data[b_idx]; - } - if(bias_data){ - val += (float)bias_data[ii]; - } - if(ref_gelu_data){ - ref_gelu_data[ii + jj*m] = (Gelu_Type)(val); - val = ref_gelu(val); - } - ref_d_data[ii+jj*m] = (D_Type)(val*d_scale); - // update ref_d_amax if in fp8 - DType dtype = TypeInfo::dtype; - if(isFp8Type(dtype)){ - ref_d_amax = std::max(ref_d_amax, std::fabs(val)); - } - } - } - if (ref_d_amax_ptr) - { - *ref_d_amax_ptr = ref_d_amax; - } + compute_ref_impl( + a_data, + b_data, + /*a_scale_inv_scalar=*/1.0f, + /*b_scale_inv_scalar=*/1.0f, + /*a_scale_inv_mxfp8=*/a_scale_inv_data, + /*b_scale_inv_mxfp8=*/b_scale_inv_data, + bias_data, + d_scale, + m, k, n, + ref_d_data, + ref_d_amax_ptr, + ref_gelu_data, + transa, + transb); } template @@ -371,7 +552,7 @@ void performTest(const TestParams& params) { pre_gelu_out.to_cpu(); } - //perform the gemm in CPU + //perform the reference gemm on GPU std::unique_ptr ref_D = std::make_unique(params.m*params.n); std::unique_ptr ref_pre_gelu_out; if(params.use_gelu){ From 11e090b9e34f0fc792122e232af4e2b863122ef6 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 11 Dec 2025 15:14:53 -0600 Subject: [PATCH 2/8] blockwise amax --- tests/cpp/operator/test_cublaslt_gemm.cu | 86 +++++++++++++++--------- 1 file changed, 55 insertions(+), 31 deletions(-) diff --git a/tests/cpp/operator/test_cublaslt_gemm.cu b/tests/cpp/operator/test_cublaslt_gemm.cu index e1e0b9316..0c5f9a759 100644 --- a/tests/cpp/operator/test_cublaslt_gemm.cu +++ b/tests/cpp/operator/test_cublaslt_gemm.cu @@ -78,51 +78,72 @@ __global__ void compute_ref_kernel( const size_t jj = blockIdx.x * blockDim.x + threadIdx.x; const size_t ii = blockIdx.y * blockDim.y + threadIdx.y; - if (ii >= m || jj >= n) - return; + const bool in_range = (ii < m) && (jj < n); float val = 0.0f; - for (size_t kk = 0; kk < k; ++kk) { - const size_t a_idx = transa ? (ii * k + kk) : (kk * m + ii); - const size_t b_idx = transb ? (kk * n + jj) : (jj * k + kk); + if (in_range) { + for (size_t kk = 0; kk < k; ++kk) { + const size_t a_idx = transa ? (ii * k + kk) : (kk * m + ii); + const size_t b_idx = transb ? (kk * n + jj) : (jj * k + kk); - float a_scale_inv_val = a_scale_inv_scalar; - float b_scale_inv_val = b_scale_inv_scalar; + float a_scale_inv_val = a_scale_inv_scalar; + float b_scale_inv_val = b_scale_inv_scalar; - if (a_scale_inv_mxfp8) { - const size_t a_scale_idx = - transa ? (a_idx / 32) : ((kk / 32) * m + ii); - const size_t b_scale_idx = - transb ? ((kk / 32) * n + jj) : (b_idx / 32); + if (a_scale_inv_mxfp8) { + const size_t a_scale_idx = + transa ? (a_idx / 32) : ((kk / 32) * m + ii); + const size_t b_scale_idx = + transb ? ((kk / 32) * n + jj) : (b_idx / 32); - const float a_byte = static_cast(a_scale_inv_mxfp8[a_scale_idx]); - const float b_byte = static_cast(b_scale_inv_mxfp8[b_scale_idx]); + const float a_byte = static_cast(a_scale_inv_mxfp8[a_scale_idx]); + const float b_byte = static_cast(b_scale_inv_mxfp8[b_scale_idx]); - a_scale_inv_val = exp2f(a_byte - 127.0f); - b_scale_inv_val = exp2f(b_byte - 127.0f); + a_scale_inv_val = exp2f(a_byte - 127.0f); + b_scale_inv_val = exp2f(b_byte - 127.0f); + } + + const float a_val = static_cast(a_data[a_idx]); + const float b_val = static_cast(b_data[b_idx]); + + val += a_scale_inv_val * a_val * b_scale_inv_val * b_val; } - const float a_val = a_data[a_idx]; - const float b_val = b_data[b_idx]; + if (bias_data) { + val += static_cast(bias_data[ii]); + } - val += a_scale_inv_val * a_val * b_scale_inv_val * b_val; - } + if (gelu_data) { + gelu_data[ii + jj * m] = static_cast(val); + val = ref_gelu(val); + } - if (bias_data) { - val += (float)bias_data[ii]; + const float scaled = val * d_scale; + d_data[ii + jj * m] = static_cast(scaled); } - if (gelu_data) { - gelu_data[ii + jj * m] = val; - val = ref_gelu(val); - } + // Blockwise reduction for amax + if (is_fp8_output && d_amax) { + const int tid = threadIdx.y * blockDim.x + threadIdx.x; + const int nthreads = blockDim.x * blockDim.y; - const float scaled = val * d_scale; - d_data[ii + jj * m] = scaled; + extern __shared__ float s_amax[]; - if (is_fp8_output && d_amax) { - atomicMax(d_amax, fabsf(val)); + // Out-of-range threads contribute 0 + s_amax[tid] = in_range ? fabsf(val) : 0.0f; + __syncthreads(); + + for (int offset = nthreads / 2; offset > 0; offset /= 2) { + if (tid < offset) { + s_amax[tid] = fmaxf(s_amax[tid], s_amax[tid + offset]); + } + __syncthreads(); + } + + if (tid == 0) { + const float block_max = s_amax[0]; + atomicMax(d_amax, block_max); + } } } @@ -214,8 +235,11 @@ static void compute_ref_impl( dim3 block(16, 16); dim3 grid((n + block.x - 1) / block.x, (m + block.y - 1) / block.y); + const int nthreads = block.x * block.y; + size_t shmem_bytes = nthreads * sizeof(float); + compute_ref_kernel - <<>>( + <<>>( dA, dB, a_scale_inv_scalar, From 3ecea7fb11748bc6c99250e3de46ebec68dfc778 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 13 Jan 2026 17:13:48 -0600 Subject: [PATCH 3/8] Change to use Tensor arguments, combine mxfp8/non-mxfp8 paths --- tests/cpp/operator/test_cublaslt_gemm.cu | 343 +++++++++-------------- tests/cpp/test_common.h | 14 +- 2 files changed, 137 insertions(+), 220 deletions(-) diff --git a/tests/cpp/operator/test_cublaslt_gemm.cu b/tests/cpp/operator/test_cublaslt_gemm.cu index 3f5249a6a..631c06c51 100644 --- a/tests/cpp/operator/test_cublaslt_gemm.cu +++ b/tests/cpp/operator/test_cublaslt_gemm.cu @@ -73,7 +73,9 @@ __global__ void compute_ref_kernel( Gelu_Type* __restrict__ gelu_data, bool transa, bool transb, - bool is_fp8_output) + bool is_fp8_output, + bool a_is_colwise, + bool b_is_colwise) { const size_t jj = blockIdx.x * blockDim.x + threadIdx.x; const size_t ii = blockIdx.y * blockDim.y + threadIdx.y; @@ -84,17 +86,26 @@ __global__ void compute_ref_kernel( if (in_range) { for (size_t kk = 0; kk < k; ++kk) { - const size_t a_idx = transa ? (ii * k + kk) : (kk * m + ii); - const size_t b_idx = transb ? (kk * n + jj) : (jj * k + kk); + // Indexing depends on which backing buffer we passed in + const size_t a_idx = + a_is_colwise ? (ii * k + kk) + : (transa ? (ii * k + kk) : (kk * m + ii)); + + const size_t b_idx = + b_is_colwise ? (jj * k + kk) + : (transb ? (kk * n + jj) : (jj * k + kk)); float a_scale_inv_val = a_scale_inv_scalar; float b_scale_inv_val = b_scale_inv_scalar; if (a_scale_inv_mxfp8) { const size_t a_scale_idx = - transa ? (a_idx / 32) : ((kk / 32) * m + ii); + a_is_colwise ? (a_idx / 32) + : (transa ? (a_idx / 32) : ((kk / 32) * m + ii)); + const size_t b_scale_idx = - transb ? ((kk / 32) * n + jj) : (b_idx / 32); + b_is_colwise ? (b_idx / 32) + : (transb ? ((kk / 32) * n + jj) : (b_idx / 32)); const float a_byte = static_cast(a_scale_inv_mxfp8[a_scale_idx]); const float b_byte = static_cast(b_scale_inv_mxfp8[b_scale_idx]); @@ -147,216 +158,145 @@ __global__ void compute_ref_kernel( } } -// Common implementation used by both tensor-wise and MXFP8 frontends + +struct TestParams { + size_t m; + size_t k; + size_t n; + bool use_bias; + bool use_gelu; + bool transa; + bool transb; + NVTEScalingMode scaling_mode; +}; + + template -static void compute_ref_impl( - const A_Type* a_data, - const B_Type* b_data, - float a_scale_inv_scalar, // used when mxfp8 == false - float b_scale_inv_scalar, - const fp8e8m0* a_scale_inv_mxfp8, // used when mxfp8 == true - const fp8e8m0* b_scale_inv_mxfp8, - const Bias_Type* bias_data, - float d_scale, - size_t m, size_t k, size_t n, - D_Type* d_data, - float* d_amax_host, - Gelu_Type* gelu_data, - bool transa, - bool transb) +static void run_reference( + const TestParams& params, + const Tensor& A, + const Tensor& B, + const Tensor* Bias, // nullable + float d_scale, + std::unique_ptr& ref_D, // m*n + float* ref_amax_d, + std::unique_ptr& ref_pre_gelu_out) // nullable { - using transformer_engine::DType; - using ::TypeInfo; - using ::isFp8Type; + const bool use_mxfp8 = (params.scaling_mode == NVTE_MXFP8_1D_SCALING); - const bool use_mxfp8 = (a_scale_inv_mxfp8 != nullptr); + Gelu_Type* ref_gelu_host = (params.use_gelu ? ref_pre_gelu_out.get() : nullptr); - const DType dtype = TypeInfo::dtype; - const bool is_fp8_output = isFp8Type(dtype); + const bool is_fp8_output = test::isFp8Type(test::TypeInfo::dtype); - const size_t lenA = m * k; - const size_t lenB = k * n; - const size_t lenD = m * n; - const size_t lenBias = m; - const size_t lenGelu = m * n; + const bool a_use_colwise = (!params.transa) && A.columnwise(); + const bool b_use_colwise = ( params.transb) && B.columnwise(); - const size_t lenA_scale = use_mxfp8 ? (lenA + 31) / 32 : 0; - const size_t lenB_scale = use_mxfp8 ? (lenB + 31) / 32 : 0; + const A_Type* a_dev = static_cast( + a_use_colwise ? A.columnwise_dptr() : A.rowwise_dptr()); - A_Type* dA = nullptr; - B_Type* dB = nullptr; - Bias_Type* dBias = nullptr; - D_Type* dD = nullptr; - Gelu_Type* dGelu = nullptr; - float* dAmax = nullptr; - fp8e8m0* dA_scale = nullptr; - fp8e8m0* dB_scale = nullptr; + const B_Type* b_dev = static_cast( + b_use_colwise ? B.columnwise_dptr() : B.rowwise_dptr()); - // Allocations and H2D transfers - NVTE_CHECK_CUDA(cudaMalloc(&dA, lenA * sizeof(A_Type))); - NVTE_CHECK_CUDA(cudaMalloc(&dB, lenB * sizeof(B_Type))); - NVTE_CHECK_CUDA(cudaMalloc(&dD, lenD * sizeof(D_Type))); + // scaling inputs + float a_scale_inv_scalar = 1.0f; + float b_scale_inv_scalar = 1.0f; - NVTE_CHECK_CUDA(cudaMemcpy( - dA, a_data, lenA * sizeof(A_Type), cudaMemcpyHostToDevice)); - NVTE_CHECK_CUDA(cudaMemcpy( - dB, b_data, lenB * sizeof(B_Type), cudaMemcpyHostToDevice)); + const fp8e8m0* a_scale_dev = nullptr; + const fp8e8m0* b_scale_dev = nullptr; - if (bias_data) { - NVTE_CHECK_CUDA(cudaMalloc(&dBias, lenBias * sizeof(Bias_Type))); - NVTE_CHECK_CUDA(cudaMemcpy( - dBias, bias_data, lenBias * sizeof(Bias_Type), - cudaMemcpyHostToDevice)); - } + if (use_mxfp8) { + a_scale_dev = params.transa + ? (const fp8e8m0*) A.rowwise_scale_inv_dptr() + : (const fp8e8m0*) A.columnwise_scale_inv_dptr(); - if (gelu_data) { - NVTE_CHECK_CUDA(cudaMalloc(&dGelu, lenGelu * sizeof(Gelu_Type))); - NVTE_CHECK_CUDA(cudaMemset(dGelu, 0, lenGelu * sizeof(Gelu_Type))); + b_scale_dev = params.transb + ? (const fp8e8m0*) B.columnwise_scale_inv_dptr() + : (const fp8e8m0*) B.rowwise_scale_inv_dptr(); + } else { + a_scale_inv_scalar = A.rowwise_scale_inv(); + b_scale_inv_scalar = B.rowwise_scale_inv(); } - if (use_mxfp8) { - NVTE_CHECK_CUDA(cudaMalloc(&dA_scale, lenA_scale * sizeof(fp8e8m0))); - NVTE_CHECK_CUDA(cudaMalloc(&dB_scale, lenB_scale * sizeof(fp8e8m0))); - NVTE_CHECK_CUDA(cudaMemcpy( - dA_scale, a_scale_inv_mxfp8, lenA_scale * sizeof(fp8e8m0), - cudaMemcpyHostToDevice)); - NVTE_CHECK_CUDA(cudaMemcpy( - dB_scale, b_scale_inv_mxfp8, lenB_scale * sizeof(fp8e8m0), - cudaMemcpyHostToDevice)); + // optional bias device pointer + const Bias_Type* bias_dev = nullptr; + if (Bias) { + bias_dev = static_cast(Bias->rowwise_dptr()); } - if (is_fp8_output && d_amax_host) { - NVTE_CHECK_CUDA(cudaMalloc(&dAmax, sizeof(float))); - NVTE_CHECK_CUDA(cudaMemset(dAmax, 0, sizeof(float))); + // allocate device outputs + const size_t lenD = params.m * params.n; + const size_t bytesD = lenD * sizeof(D_Type); + + D_Type* d_refD = nullptr; + Gelu_Type* d_refGelu = nullptr; + float* d_refAmax = nullptr; + + NVTE_CHECK_CUDA(cudaMalloc(&d_refD, bytesD)); + if (ref_gelu_host) { + NVTE_CHECK_CUDA(cudaMalloc(&d_refGelu, lenD * sizeof(Gelu_Type))); + } + if (is_fp8_output && ref_amax_d) { + NVTE_CHECK_CUDA(cudaMalloc(&d_refAmax, sizeof(float))); + NVTE_CHECK_CUDA(cudaMemset(d_refAmax, 0, sizeof(float))); } // Kernel launch dim3 block(16, 16); - dim3 grid((n + block.x - 1) / block.x, (m + block.y - 1) / block.y); + dim3 grid((unsigned)((params.n + block.x - 1) / block.x), + (unsigned)((params.m + block.y - 1) / block.y)); - const int nthreads = block.x * block.y; - size_t shmem_bytes = nthreads * sizeof(float); + const size_t shmem_bytes = size_t(block.x) * size_t(block.y) * sizeof(float); compute_ref_kernel <<>>( - dA, - dB, + a_dev, + b_dev, a_scale_inv_scalar, b_scale_inv_scalar, - dA_scale, - dB_scale, - dBias, + a_scale_dev, + b_scale_dev, + bias_dev, d_scale, - m, k, n, - dD, - dAmax, - dGelu, - transa, - transb, - is_fp8_output); + params.m, params.k, params.n, + d_refD, + d_refAmax, + d_refGelu, + params.transa, + params.transb, + is_fp8_output, + a_use_colwise, + b_use_colwise); NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK_CUDA(cudaDeviceSynchronize()); - // D2H copies - NVTE_CHECK_CUDA(cudaMemcpy( - d_data, dD, lenD * sizeof(D_Type), cudaMemcpyDeviceToHost)); + // copy outputs back + NVTE_CHECK_CUDA(cudaMemcpy(ref_D.get(), d_refD, bytesD, cudaMemcpyDeviceToHost)); - if (gelu_data) { - NVTE_CHECK_CUDA(cudaMemcpy( - gelu_data, dGelu, lenGelu * sizeof(Gelu_Type), - cudaMemcpyDeviceToHost)); + if (ref_gelu_host) { + NVTE_CHECK_CUDA(cudaMemcpy(ref_gelu_host, d_refGelu, lenD * sizeof(Gelu_Type), + cudaMemcpyDeviceToHost)); } - if (is_fp8_output && d_amax_host) { - NVTE_CHECK_CUDA(cudaMemcpy( - d_amax_host, dAmax, sizeof(float), cudaMemcpyDeviceToHost)); - } else if (d_amax_host) { - *d_amax_host = 0.0f; + if (ref_amax_d) { + if (is_fp8_output) { + NVTE_CHECK_CUDA(cudaMemcpy(ref_amax_d, d_refAmax, sizeof(float), + cudaMemcpyDeviceToHost)); + } else { + *ref_amax_d = 0.0f; + } } // cleanup - NVTE_CHECK_CUDA(cudaFree(dA)); - NVTE_CHECK_CUDA(cudaFree(dB)); - NVTE_CHECK_CUDA(cudaFree(dD)); - if (dBias) - NVTE_CHECK_CUDA(cudaFree(dBias)); - if (dGelu) - NVTE_CHECK_CUDA(cudaFree(dGelu)); - if (dAmax) - NVTE_CHECK_CUDA(cudaFree(dAmax)); - if (dA_scale) - NVTE_CHECK_CUDA(cudaFree(dA_scale)); - if (dB_scale) - NVTE_CHECK_CUDA(cudaFree(dB_scale)); + NVTE_CHECK_CUDA(cudaFree(d_refD)); + if (d_refGelu) + NVTE_CHECK_CUDA(cudaFree(d_refGelu)); + if (d_refAmax) + NVTE_CHECK_CUDA(cudaFree(d_refAmax)); } -template -void compute_ref( - const A_Type* a_data, - const B_Type* b_data, - const float a_scale_inv, - const float b_scale_inv, - const Bias_Type* bias_data, //bias is of dim m - const float d_scale, - size_t m, size_t k, size_t n, - D_Type* ref_d_data, - float* ref_d_amax_ptr, - Gelu_Type* ref_gelu_data, - bool transa, - bool transb){ - - compute_ref_impl( - a_data, - b_data, - /*a_scale_inv_scalar=*/a_scale_inv, - /*b_scale_inv_scalar=*/b_scale_inv, - /*a_scale_inv_mxfp8=*/nullptr, - /*b_scale_inv_mxfp8=*/nullptr, - bias_data, - d_scale, - m, k, n, - ref_d_data, - ref_d_amax_ptr, - ref_gelu_data, - transa, - transb); -} - -template -void compute_mxfp8_ref( - const A_Type* a_data, - const B_Type* b_data, - const fp8e8m0* a_scale_inv_data, - const fp8e8m0* b_scale_inv_data, - const Bias_Type* bias_data, //bias is of dim m - const float d_scale, - size_t m, size_t k, size_t n, - D_Type* ref_d_data, - float* ref_d_amax_ptr, - Gelu_Type* ref_gelu_data, - bool transa, - bool transb){ - - compute_ref_impl( - a_data, - b_data, - /*a_scale_inv_scalar=*/1.0f, - /*b_scale_inv_scalar=*/1.0f, - /*a_scale_inv_mxfp8=*/a_scale_inv_data, - /*b_scale_inv_mxfp8=*/b_scale_inv_data, - bias_data, - d_scale, - m, k, n, - ref_d_data, - ref_d_amax_ptr, - ref_gelu_data, - transa, - transb); -} - template void cpu_rowwise_to_columnwise( size_t m, size_t n, @@ -396,16 +336,6 @@ std::pair getTestTolerances(const DType type, bool use_fp8, bool return {atol, rtol}; } -struct TestParams { - size_t m; - size_t k; - size_t n; - bool use_bias; - bool use_gelu; - bool transa; - bool transb; - NVTEScalingMode scaling_mode; -}; template void performTest(const TestParams& params) { @@ -588,40 +518,17 @@ void performTest(const TestParams& params) { } float ref_amax_d; - if (use_mxfp8) { - const A_Type *a_data; - const B_Type *b_data; - const fp8e8m0 *a_scale_inv_data, *b_scale_inv_data; - if (params.transa) { - a_data = A.rowwise_cpu_dptr(); - a_scale_inv_data = A.rowwise_cpu_scale_inv_ptr(); - } else { - a_data = A.columnwise_cpu_dptr(); - a_scale_inv_data = A.columnwise_cpu_scale_inv_ptr(); - } - if (params.transb) { - b_data = B.columnwise_cpu_dptr(); - b_scale_inv_data = B.columnwise_cpu_scale_inv_ptr(); - } else { - b_data = B.rowwise_cpu_dptr(); - b_scale_inv_data = B.rowwise_cpu_scale_inv_ptr(); - } - compute_mxfp8_ref( - a_data, b_data, a_scale_inv_data, b_scale_inv_data, - params.use_bias ? bias.rowwise_cpu_dptr() : nullptr, - D.scale(), params.m, params.k, params.n, ref_D.get(), &ref_amax_d, - params.use_gelu ? ref_pre_gelu_out.get() : nullptr, - params.transa, params.transb); - } else { - compute_ref( - A.rowwise_cpu_dptr(), B.rowwise_cpu_dptr(), - A.rowwise_scale_inv(), B.rowwise_scale_inv(), - params.use_bias ? bias.rowwise_cpu_dptr() : nullptr, - D.scale(), params.m, params.k, params.n, ref_D.get(), &ref_amax_d, - params.use_gelu ? ref_pre_gelu_out.get() : nullptr, - params.transa, params.transb); - } + run_reference( + params, + A, + B, + params.use_bias ? &bias : nullptr, + D.scale(), + ref_D, + &ref_amax_d, + ref_pre_gelu_out); + // check if error message happens in running (void)cudaDeviceSynchronize(); auto err = cudaGetLastError(); diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index bfb46f8a0..8892ff097 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -224,6 +224,16 @@ class Tensor { return reinterpret_cast(cpu_data_columnwise_.get()); } + void *rowwise_scale_inv_dptr() const { + NVTE_CHECK(rowwise_, "Tensor does not have rowwise data!"); + return tensor_.get_rowwise_scale_inv().data_ptr; + } + + void *columnwise_scale_inv_dptr() const { + NVTE_CHECK(columnwise_, "Tensor does not have columnwise data!"); + return tensor_.get_columnwise_scale_inv().data_ptr; + } + float amax() const { if(amax_cpu_data_) { to_cpu(); @@ -244,7 +254,7 @@ class Tensor { } template - T *rowwise_cpu_scale_inv_ptr(){ + T *rowwise_cpu_scale_inv_ptr() const { if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING){ NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); } else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING_1D || tensor_.scaling_mode() == NVTE_BLOCK_SCALING_2D) { @@ -269,7 +279,7 @@ class Tensor { return reinterpret_cast(columnwise_scale_inv_cpu_data_.get()); } - float rowwise_scale_inv(){ + float rowwise_scale_inv() const { if(rowwise_scale_inv_cpu_data_) { float scale_inv = rowwise_cpu_scale_inv_ptr()[0]; return scale_inv; From 86fbbac87113f00341062e4a9b150a855207acd6 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 14 Jan 2026 14:17:24 -0600 Subject: [PATCH 4/8] skip on SwizzleScale limitation on gfx950 --- tests/cpp/operator/test_cublaslt_gemm.cu | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/cpp/operator/test_cublaslt_gemm.cu b/tests/cpp/operator/test_cublaslt_gemm.cu index 560218575..da59a8dee 100644 --- a/tests/cpp/operator/test_cublaslt_gemm.cu +++ b/tests/cpp/operator/test_cublaslt_gemm.cu @@ -427,6 +427,11 @@ void performTest(const TestParams& params) { GTEST_SKIP() << "FP8 GEMM with bias is not supported in current config"; } } + + if (use_mxfp8 && (isFp8Type(atype) || isFp8Type(btype)) && (params.transa != true || params.transb != false)) { + GTEST_SKIP() << "On gfx950, MXFP8 FP8/BF8 GEMM currently requires TN (SwizzleScale limitation)."; + } + } if (prop.major == 9 && prop.minor == 4) //gfx942 specific hipblasLt limitations { From 54de3dbd3891e0a0d0f0962fe3ccc4a9eaac759f Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 14 Jan 2026 21:44:57 +0000 Subject: [PATCH 5/8] Revert "skip on SwizzleScale limitation on gfx950" This reverts commit 86fbbac87113f00341062e4a9b150a855207acd6. --- tests/cpp/operator/test_cublaslt_gemm.cu | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/cpp/operator/test_cublaslt_gemm.cu b/tests/cpp/operator/test_cublaslt_gemm.cu index da59a8dee..560218575 100644 --- a/tests/cpp/operator/test_cublaslt_gemm.cu +++ b/tests/cpp/operator/test_cublaslt_gemm.cu @@ -427,11 +427,6 @@ void performTest(const TestParams& params) { GTEST_SKIP() << "FP8 GEMM with bias is not supported in current config"; } } - - if (use_mxfp8 && (isFp8Type(atype) || isFp8Type(btype)) && (params.transa != true || params.transb != false)) { - GTEST_SKIP() << "On gfx950, MXFP8 FP8/BF8 GEMM currently requires TN (SwizzleScale limitation)."; - } - } if (prop.major == 9 && prop.minor == 4) //gfx942 specific hipblasLt limitations { From 311ddfe66bbe738ab550b74dccaf5fb8d885438d Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 14 Jan 2026 17:21:58 -0600 Subject: [PATCH 6/8] MXFP8 fix --- tests/cpp/operator/test_cublaslt_gemm.cu | 8 +++----- tests/cpp/test_common.h | 2 +- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/cpp/operator/test_cublaslt_gemm.cu b/tests/cpp/operator/test_cublaslt_gemm.cu index 560218575..3d15ac3d4 100644 --- a/tests/cpp/operator/test_cublaslt_gemm.cu +++ b/tests/cpp/operator/test_cublaslt_gemm.cu @@ -107,11 +107,9 @@ __global__ void compute_ref_kernel( b_is_colwise ? (b_idx / 32) : (transb ? ((kk / 32) * n + jj) : (b_idx / 32)); - const float a_byte = static_cast(a_scale_inv_mxfp8[a_scale_idx]); - const float b_byte = static_cast(b_scale_inv_mxfp8[b_scale_idx]); - - a_scale_inv_val = exp2f(a_byte - 127.0f); - b_scale_inv_val = exp2f(b_byte - 127.0f); + // scale_inv is stored as an e8m0 biased exponent; convert to 2^(127-exp) + a_scale_inv_val = exp2f_rcp(a_scale_inv_mxfp8[a_scale_idx]); + b_scale_inv_val = exp2f_rcp(b_scale_inv_mxfp8[b_scale_idx]); } const float a_val = static_cast(a_data[a_idx]); diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 8892ff097..2114feacc 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -446,7 +446,7 @@ inline fp8e8m0 float_to_e8m0(float val) { return exponent; } -inline float exp2f_rcp(fp8e8m0 biased_exp) { +__device__ __host__ __forceinline__ float exp2f_rcp(fp8e8m0 biased_exp) { return (biased_exp == 0) ? 1 : exp2f(FP32_EXPONENT_BIAS - static_cast(biased_exp)); } From 445e64fbce9060bfe5d0f23dedf5de209bcf353f Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 15 Jan 2026 14:14:53 -0600 Subject: [PATCH 7/8] =?UTF-8?q?correct=20scale=5Finv=20packing=20and=20exp?= =?UTF-8?q?2(biased=E2=88=92127)=20conversion?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/cpp/operator/test_cublaslt_gemm.cu | 99 ++++++++++++++++++------ tests/cpp/test_common.h | 2 +- 2 files changed, 75 insertions(+), 26 deletions(-) diff --git a/tests/cpp/operator/test_cublaslt_gemm.cu b/tests/cpp/operator/test_cublaslt_gemm.cu index 3d15ac3d4..376a5fc26 100644 --- a/tests/cpp/operator/test_cublaslt_gemm.cu +++ b/tests/cpp/operator/test_cublaslt_gemm.cu @@ -75,8 +75,10 @@ __global__ void compute_ref_kernel( bool transb, bool is_fp8_output, bool a_is_colwise, - bool b_is_colwise) + bool b_is_colwise, + bool use_mxfp8) { + const size_t k_chunks = k / 32; const size_t jj = blockIdx.x * blockDim.x + threadIdx.x; const size_t ii = blockIdx.y * blockDim.y + threadIdx.y; @@ -86,30 +88,33 @@ __global__ void compute_ref_kernel( if (in_range) { for (size_t kk = 0; kk < k; ++kk) { - // Indexing depends on which backing buffer we passed in - const size_t a_idx = - a_is_colwise ? (ii * k + kk) - : (transa ? (ii * k + kk) : (kk * m + ii)); - - const size_t b_idx = - b_is_colwise ? (jj * k + kk) - : (transb ? (kk * n + jj) : (jj * k + kk)); + size_t a_idx = 0; + size_t b_idx = 0; + + if (use_mxfp8) { + a_idx = transa ? (ii * k + kk) : (kk * m + ii); + b_idx = transb ? (kk * n + jj) : (jj * k + kk); + } else { + // Non-MXFP8 FP8 path may use explicit transpose buffers (cpu_rowwise_to_columnwise), + // so indexing depends on which backing buffer is passed in. + a_idx = a_is_colwise ? (ii * k + kk) + : (transa ? (ii * k + kk) : (kk * m + ii)); + + b_idx = b_is_colwise ? (jj * k + kk) + : (transb ? (kk * n + jj) : (jj * k + kk)); + } float a_scale_inv_val = a_scale_inv_scalar; float b_scale_inv_val = b_scale_inv_scalar; if (a_scale_inv_mxfp8) { - const size_t a_scale_idx = - a_is_colwise ? (a_idx / 32) - : (transa ? (a_idx / 32) : ((kk / 32) * m + ii)); + const size_t kc = kk / 32; - const size_t b_scale_idx = - b_is_colwise ? (b_idx / 32) - : (transb ? ((kk / 32) * n + jj) : (b_idx / 32)); + const size_t a_scale_idx = ii * k_chunks + kc; + const size_t b_scale_idx = jj * k_chunks + kc; - // scale_inv is stored as an e8m0 biased exponent; convert to 2^(127-exp) - a_scale_inv_val = exp2f_rcp(a_scale_inv_mxfp8[a_scale_idx]); - b_scale_inv_val = exp2f_rcp(b_scale_inv_mxfp8[b_scale_idx]); + a_scale_inv_val = exp2f(a_scale_inv_mxfp8[a_scale_idx] - 127.0f); + b_scale_inv_val = exp2f(b_scale_inv_mxfp8[b_scale_idx] - 127.0f); } const float a_val = static_cast(a_data[a_idx]); @@ -183,6 +188,8 @@ static void run_reference( { const bool use_mxfp8 = (params.scaling_mode == NVTE_MXFP8_1D_SCALING); + const size_t k_chunks = params.k / 32; + Gelu_Type* ref_gelu_host = (params.use_gelu ? ref_pre_gelu_out.get() : nullptr); const bool is_fp8_output = test::isFp8Type(test::TypeInfo::dtype); @@ -203,14 +210,51 @@ static void run_reference( const fp8e8m0* a_scale_dev = nullptr; const fp8e8m0* b_scale_dev = nullptr; + // If MXFP8, pack scale_inv into tight [row][kc] buffers on host, then transfer to device + std::vector a_scale_packed; + std::vector b_scale_packed; + fp8e8m0* d_a_scale_packed = nullptr; + fp8e8m0* d_b_scale_packed = nullptr; + if (use_mxfp8) { - a_scale_dev = params.transa - ? (const fp8e8m0*) A.rowwise_scale_inv_dptr() - : (const fp8e8m0*) A.columnwise_scale_inv_dptr(); + const fp8e8m0* a_scale_cpu = params.transa + ? A.rowwise_cpu_scale_inv_ptr() + : A.columnwise_cpu_scale_inv_ptr(); + const fp8e8m0* b_scale_cpu = params.transb + ? B.columnwise_cpu_scale_inv_ptr() + : B.rowwise_cpu_scale_inv_ptr(); + + // Pack into row-major [row][kc]: + // A_packed[ii, kc] and B_packed[jj, kc] + a_scale_packed.resize(params.m * k_chunks); + b_scale_packed.resize(params.n * k_chunks); + + for (size_t ii = 0; ii < params.m; ++ii) { + for (size_t kc = 0; kc < k_chunks; ++kc) { + const size_t src_idx = params.transa ? (ii * k_chunks + kc) : (kc * params.m + ii); + a_scale_packed[ii * k_chunks + kc] = a_scale_cpu[src_idx]; + } + } + + for (size_t jj = 0; jj < params.n; ++jj) { + for (size_t kc = 0; kc < k_chunks; ++kc) { + const size_t src_idx = params.transb ? (kc * params.n + jj) : (jj * k_chunks + kc); + b_scale_packed[jj * k_chunks + kc] = b_scale_cpu[src_idx]; + } + } + + NVTE_CHECK_CUDA(cudaMalloc(&d_a_scale_packed, a_scale_packed.size() * sizeof(fp8e8m0))); + NVTE_CHECK_CUDA(cudaMalloc(&d_b_scale_packed, b_scale_packed.size() * sizeof(fp8e8m0))); + + NVTE_CHECK_CUDA(cudaMemcpy(d_a_scale_packed, a_scale_packed.data(), + a_scale_packed.size() * sizeof(fp8e8m0), + cudaMemcpyHostToDevice)); + NVTE_CHECK_CUDA(cudaMemcpy(d_b_scale_packed, b_scale_packed.data(), + b_scale_packed.size() * sizeof(fp8e8m0), + cudaMemcpyHostToDevice)); - b_scale_dev = params.transb - ? (const fp8e8m0*) B.columnwise_scale_inv_dptr() - : (const fp8e8m0*) B.rowwise_scale_inv_dptr(); + a_scale_dev = d_a_scale_packed; + b_scale_dev = d_b_scale_packed; } else { a_scale_inv_scalar = A.rowwise_scale_inv(); b_scale_inv_scalar = B.rowwise_scale_inv(); @@ -264,7 +308,8 @@ static void run_reference( params.transb, is_fp8_output, a_use_colwise, - b_use_colwise); + b_use_colwise, + use_mxfp8); NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK_CUDA(cudaDeviceSynchronize()); @@ -292,6 +337,10 @@ static void run_reference( NVTE_CHECK_CUDA(cudaFree(d_refGelu)); if (d_refAmax) NVTE_CHECK_CUDA(cudaFree(d_refAmax)); + if (d_a_scale_packed) + NVTE_CHECK_CUDA(cudaFree(d_a_scale_packed)); + if (d_b_scale_packed) + NVTE_CHECK_CUDA(cudaFree(d_b_scale_packed)); } diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 2114feacc..7596bcf06 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -267,7 +267,7 @@ class Tensor { } template - T *columnwise_cpu_scale_inv_ptr(){ + T *columnwise_cpu_scale_inv_ptr() const { if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING){ NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); } else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING_1D || tensor_.scaling_mode() == NVTE_BLOCK_SCALING_2D) { From 462945fc299deca92a99e783fb1f71f4ae034252 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 15 Jan 2026 15:27:42 -0600 Subject: [PATCH 8/8] cleanups --- tests/cpp/operator/test_cublaslt_gemm.cu | 2 +- tests/cpp/test_common.h | 12 +----------- 2 files changed, 2 insertions(+), 12 deletions(-) diff --git a/tests/cpp/operator/test_cublaslt_gemm.cu b/tests/cpp/operator/test_cublaslt_gemm.cu index 376a5fc26..21e4d4be6 100644 --- a/tests/cpp/operator/test_cublaslt_gemm.cu +++ b/tests/cpp/operator/test_cublaslt_gemm.cu @@ -203,7 +203,7 @@ static void run_reference( const B_Type* b_dev = static_cast( b_use_colwise ? B.columnwise_dptr() : B.rowwise_dptr()); - // scaling inputs + // scaling inputs float a_scale_inv_scalar = 1.0f; float b_scale_inv_scalar = 1.0f; diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 7596bcf06..07b4cd9bf 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -224,16 +224,6 @@ class Tensor { return reinterpret_cast(cpu_data_columnwise_.get()); } - void *rowwise_scale_inv_dptr() const { - NVTE_CHECK(rowwise_, "Tensor does not have rowwise data!"); - return tensor_.get_rowwise_scale_inv().data_ptr; - } - - void *columnwise_scale_inv_dptr() const { - NVTE_CHECK(columnwise_, "Tensor does not have columnwise data!"); - return tensor_.get_columnwise_scale_inv().data_ptr; - } - float amax() const { if(amax_cpu_data_) { to_cpu(); @@ -446,7 +436,7 @@ inline fp8e8m0 float_to_e8m0(float val) { return exponent; } -__device__ __host__ __forceinline__ float exp2f_rcp(fp8e8m0 biased_exp) { +inline float exp2f_rcp(fp8e8m0 biased_exp) { return (biased_exp == 0) ? 1 : exp2f(FP32_EXPONENT_BIAS - static_cast(biased_exp)); }