diff --git a/tests/cpp/operator/test_cublaslt_gemm.cu b/tests/cpp/operator/test_cublaslt_gemm.cu index 61ca86a1e..21e4d4be6 100644 --- a/tests/cpp/operator/test_cublaslt_gemm.cu +++ b/tests/cpp/operator/test_cublaslt_gemm.cu @@ -51,107 +51,299 @@ using TShape = std::vector; } // namespace -float ref_gelu(float x){ +__device__ __host__ __forceinline__ float ref_gelu(float x){ float cdf = 0.5f * (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x)))); return x * cdf; } -template -void compute_ref( - const A_Type* a_data, - const B_Type* b_data, - const float a_scale_inv, - const float b_scale_inv, - const Bias_Type* bias_data, //bias is of dim m - const float d_scale, +template +__global__ void compute_ref_kernel( + const A_Type* __restrict__ a_data, + const B_Type* __restrict__ b_data, + float a_scale_inv_scalar, // used when mxfp8 == false + float b_scale_inv_scalar, + const fp8e8m0* __restrict__ a_scale_inv_mxfp8, // used when mxfp8 == true + const fp8e8m0* __restrict__ b_scale_inv_mxfp8, + const Bias_Type* __restrict__ bias_data, + float d_scale, size_t m, size_t k, size_t n, - D_Type* ref_d_data, - float* ref_d_amax_ptr, - Gelu_Type* ref_gelu_data, + D_Type* __restrict__ d_data, + float* __restrict__ d_amax, + Gelu_Type* __restrict__ gelu_data, bool transa, - bool transb){ + bool transb, + bool is_fp8_output, + bool a_is_colwise, + bool b_is_colwise, + bool use_mxfp8) +{ + const size_t k_chunks = k / 32; + const size_t jj = blockIdx.x * blockDim.x + threadIdx.x; + const size_t ii = blockIdx.y * blockDim.y + threadIdx.y; + + const bool in_range = (ii < m) && (jj < n); + + float val = 0.0f; + + if (in_range) { + for (size_t kk = 0; kk < k; ++kk) { + size_t a_idx = 0; + size_t b_idx = 0; + + if (use_mxfp8) { + a_idx = transa ? (ii * k + kk) : (kk * m + ii); + b_idx = transb ? (kk * n + jj) : (jj * k + kk); + } else { + // Non-MXFP8 FP8 path may use explicit transpose buffers (cpu_rowwise_to_columnwise), + // so indexing depends on which backing buffer is passed in. + a_idx = a_is_colwise ? (ii * k + kk) + : (transa ? (ii * k + kk) : (kk * m + ii)); + + b_idx = b_is_colwise ? (jj * k + kk) + : (transb ? (kk * n + jj) : (jj * k + kk)); + } - float ref_d_amax = 0; + float a_scale_inv_val = a_scale_inv_scalar; + float b_scale_inv_val = b_scale_inv_scalar; - #pragma omp parallel for schedule(static) collapse(2) reduction(max: ref_d_amax) proc_bind(spread) - for(size_t ii = 0; ii < m; ii++){ - for(size_t jj = 0; jj < n; jj++){ - float val = 0; - for(size_t kk = 0; kk < k; kk++){ - float a_val = transa ? a_data[kk + ii*k] : a_data[ii + kk*m]; - float b_val = transb ? b_data[jj + kk*n] : b_data[kk + jj*k]; - val += a_scale_inv*a_val*b_scale_inv*b_val; - } - if(bias_data){ - val += (float)bias_data[ii]; - } - if(ref_gelu_data){ - ref_gelu_data[ii + jj*m] = (Gelu_Type)(val); - val = ref_gelu(val); - } - ref_d_data[ii+jj*m] = (D_Type)(val*d_scale); - // update ref_d_amax if in fp8 - DType dtype = TypeInfo::dtype; - if(isFp8Type(dtype)){ - ref_d_amax = std::max(ref_d_amax, std::fabs(val)); + if (a_scale_inv_mxfp8) { + const size_t kc = kk / 32; + + const size_t a_scale_idx = ii * k_chunks + kc; + const size_t b_scale_idx = jj * k_chunks + kc; + + a_scale_inv_val = exp2f(a_scale_inv_mxfp8[a_scale_idx] - 127.0f); + b_scale_inv_val = exp2f(b_scale_inv_mxfp8[b_scale_idx] - 127.0f); } + + const float a_val = static_cast(a_data[a_idx]); + const float b_val = static_cast(b_data[b_idx]); + + val += a_scale_inv_val * a_val * b_scale_inv_val * b_val; + } + + if (bias_data) { + val += static_cast(bias_data[ii]); + } + + if (gelu_data) { + gelu_data[ii + jj * m] = static_cast(val); + val = ref_gelu(val); } + + const float scaled = val * d_scale; + d_data[ii + jj * m] = static_cast(scaled); } - if (ref_d_amax_ptr) - { - *ref_d_amax_ptr = ref_d_amax; + + // Blockwise reduction for amax + if (is_fp8_output && d_amax) { + const int tid = threadIdx.y * blockDim.x + threadIdx.x; + const int nthreads = blockDim.x * blockDim.y; + + extern __shared__ float s_amax[]; + + // Out-of-range threads contribute 0 + s_amax[tid] = in_range ? fabsf(val) : 0.0f; + __syncthreads(); + + for (int offset = nthreads / 2; offset > 0; offset /= 2) { + if (tid < offset) { + s_amax[tid] = fmaxf(s_amax[tid], s_amax[tid + offset]); + } + __syncthreads(); + } + + if (tid == 0) { + const float block_max = s_amax[0]; + atomicMax(d_amax, block_max); + } } } -template -void compute_mxfp8_ref( - const A_Type* a_data, - const B_Type* b_data, - const fp8e8m0* a_scale_inv_data, - const fp8e8m0* b_scale_inv_data, - const Bias_Type* bias_data, //bias is of dim m - const float d_scale, - size_t m, size_t k, size_t n, - D_Type* ref_d_data, - float* ref_d_amax_ptr, - Gelu_Type* ref_gelu_data, - bool transa, - bool transb){ - float ref_d_amax = 0; +struct TestParams { + size_t m; + size_t k; + size_t n; + bool use_bias; + bool use_gelu; + bool transa; + bool transb; + NVTEScalingMode scaling_mode; +}; - #pragma omp parallel for schedule(static) collapse(2) reduction(max: ref_d_amax) proc_bind(spread) - for(size_t ii = 0; ii < m; ii++){ - for(size_t jj = 0; jj < n; jj++){ - float val = 0; - for(size_t kk = 0; kk < k; kk++){ - size_t a_idx = transa ? (ii*k + kk) : (kk*m + ii); - size_t b_idx = transb ? (kk*n + jj) : (jj*k + kk); - float a_scale_inv_val = std::exp2f(a_scale_inv_data[transa ? a_idx/32 : (kk/32 * m + ii)] - 127); - float b_scale_inv_val = std::exp2f(b_scale_inv_data[transb ? (kk/32 * n + jj) : b_idx/32] - 127); - val += a_scale_inv_val * (float)a_data[a_idx] * b_scale_inv_val * (float)b_data[b_idx]; - } - if(bias_data){ - val += (float)bias_data[ii]; - } - if(ref_gelu_data){ - ref_gelu_data[ii + jj*m] = (Gelu_Type)(val); - val = ref_gelu(val); + +template +static void run_reference( + const TestParams& params, + const Tensor& A, + const Tensor& B, + const Tensor* Bias, // nullable + float d_scale, + std::unique_ptr& ref_D, // m*n + float* ref_amax_d, + std::unique_ptr& ref_pre_gelu_out) // nullable +{ + const bool use_mxfp8 = (params.scaling_mode == NVTE_MXFP8_1D_SCALING); + + const size_t k_chunks = params.k / 32; + + Gelu_Type* ref_gelu_host = (params.use_gelu ? ref_pre_gelu_out.get() : nullptr); + + const bool is_fp8_output = test::isFp8Type(test::TypeInfo::dtype); + + const bool a_use_colwise = (!params.transa) && A.columnwise(); + const bool b_use_colwise = ( params.transb) && B.columnwise(); + + const A_Type* a_dev = static_cast( + a_use_colwise ? A.columnwise_dptr() : A.rowwise_dptr()); + + const B_Type* b_dev = static_cast( + b_use_colwise ? B.columnwise_dptr() : B.rowwise_dptr()); + + // scaling inputs + float a_scale_inv_scalar = 1.0f; + float b_scale_inv_scalar = 1.0f; + + const fp8e8m0* a_scale_dev = nullptr; + const fp8e8m0* b_scale_dev = nullptr; + + // If MXFP8, pack scale_inv into tight [row][kc] buffers on host, then transfer to device + std::vector a_scale_packed; + std::vector b_scale_packed; + fp8e8m0* d_a_scale_packed = nullptr; + fp8e8m0* d_b_scale_packed = nullptr; + + if (use_mxfp8) { + const fp8e8m0* a_scale_cpu = params.transa + ? A.rowwise_cpu_scale_inv_ptr() + : A.columnwise_cpu_scale_inv_ptr(); + const fp8e8m0* b_scale_cpu = params.transb + ? B.columnwise_cpu_scale_inv_ptr() + : B.rowwise_cpu_scale_inv_ptr(); + + // Pack into row-major [row][kc]: + // A_packed[ii, kc] and B_packed[jj, kc] + a_scale_packed.resize(params.m * k_chunks); + b_scale_packed.resize(params.n * k_chunks); + + for (size_t ii = 0; ii < params.m; ++ii) { + for (size_t kc = 0; kc < k_chunks; ++kc) { + const size_t src_idx = params.transa ? (ii * k_chunks + kc) : (kc * params.m + ii); + a_scale_packed[ii * k_chunks + kc] = a_scale_cpu[src_idx]; } - ref_d_data[ii+jj*m] = (D_Type)(val*d_scale); - // update ref_d_amax if in fp8 - DType dtype = TypeInfo::dtype; - if(isFp8Type(dtype)){ - ref_d_amax = std::max(ref_d_amax, std::fabs(val)); + } + + for (size_t jj = 0; jj < params.n; ++jj) { + for (size_t kc = 0; kc < k_chunks; ++kc) { + const size_t src_idx = params.transb ? (kc * params.n + jj) : (jj * k_chunks + kc); + b_scale_packed[jj * k_chunks + kc] = b_scale_cpu[src_idx]; } } + + NVTE_CHECK_CUDA(cudaMalloc(&d_a_scale_packed, a_scale_packed.size() * sizeof(fp8e8m0))); + NVTE_CHECK_CUDA(cudaMalloc(&d_b_scale_packed, b_scale_packed.size() * sizeof(fp8e8m0))); + + NVTE_CHECK_CUDA(cudaMemcpy(d_a_scale_packed, a_scale_packed.data(), + a_scale_packed.size() * sizeof(fp8e8m0), + cudaMemcpyHostToDevice)); + NVTE_CHECK_CUDA(cudaMemcpy(d_b_scale_packed, b_scale_packed.data(), + b_scale_packed.size() * sizeof(fp8e8m0), + cudaMemcpyHostToDevice)); + + a_scale_dev = d_a_scale_packed; + b_scale_dev = d_b_scale_packed; + } else { + a_scale_inv_scalar = A.rowwise_scale_inv(); + b_scale_inv_scalar = B.rowwise_scale_inv(); } - if (ref_d_amax_ptr) - { - *ref_d_amax_ptr = ref_d_amax; + + // optional bias device pointer + const Bias_Type* bias_dev = nullptr; + if (Bias) { + bias_dev = static_cast(Bias->rowwise_dptr()); + } + + // allocate device outputs + const size_t lenD = params.m * params.n; + const size_t bytesD = lenD * sizeof(D_Type); + + D_Type* d_refD = nullptr; + Gelu_Type* d_refGelu = nullptr; + float* d_refAmax = nullptr; + + NVTE_CHECK_CUDA(cudaMalloc(&d_refD, bytesD)); + if (ref_gelu_host) { + NVTE_CHECK_CUDA(cudaMalloc(&d_refGelu, lenD * sizeof(Gelu_Type))); + } + if (is_fp8_output && ref_amax_d) { + NVTE_CHECK_CUDA(cudaMalloc(&d_refAmax, sizeof(float))); + NVTE_CHECK_CUDA(cudaMemset(d_refAmax, 0, sizeof(float))); + } + + // Kernel launch + dim3 block(16, 16); + dim3 grid((unsigned)((params.n + block.x - 1) / block.x), + (unsigned)((params.m + block.y - 1) / block.y)); + + const size_t shmem_bytes = size_t(block.x) * size_t(block.y) * sizeof(float); + + compute_ref_kernel + <<>>( + a_dev, + b_dev, + a_scale_inv_scalar, + b_scale_inv_scalar, + a_scale_dev, + b_scale_dev, + bias_dev, + d_scale, + params.m, params.k, params.n, + d_refD, + d_refAmax, + d_refGelu, + params.transa, + params.transb, + is_fp8_output, + a_use_colwise, + b_use_colwise, + use_mxfp8); + + NVTE_CHECK_CUDA(cudaGetLastError()); + NVTE_CHECK_CUDA(cudaDeviceSynchronize()); + + // copy outputs back + NVTE_CHECK_CUDA(cudaMemcpy(ref_D.get(), d_refD, bytesD, cudaMemcpyDeviceToHost)); + + if (ref_gelu_host) { + NVTE_CHECK_CUDA(cudaMemcpy(ref_gelu_host, d_refGelu, lenD * sizeof(Gelu_Type), + cudaMemcpyDeviceToHost)); + } + + if (ref_amax_d) { + if (is_fp8_output) { + NVTE_CHECK_CUDA(cudaMemcpy(ref_amax_d, d_refAmax, sizeof(float), + cudaMemcpyDeviceToHost)); + } else { + *ref_amax_d = 0.0f; + } } + + // cleanup + NVTE_CHECK_CUDA(cudaFree(d_refD)); + if (d_refGelu) + NVTE_CHECK_CUDA(cudaFree(d_refGelu)); + if (d_refAmax) + NVTE_CHECK_CUDA(cudaFree(d_refAmax)); + if (d_a_scale_packed) + NVTE_CHECK_CUDA(cudaFree(d_a_scale_packed)); + if (d_b_scale_packed) + NVTE_CHECK_CUDA(cudaFree(d_b_scale_packed)); } + template void cpu_rowwise_to_columnwise( size_t m, size_t n, @@ -191,16 +383,6 @@ std::pair getTestTolerances(const DType type, bool use_fp8, bool return {atol, rtol}; } -struct TestParams { - size_t m; - size_t k; - size_t n; - bool use_bias; - bool use_gelu; - bool transa; - bool transb; - NVTEScalingMode scaling_mode; -}; template void performTest(const TestParams& params) { @@ -383,7 +565,7 @@ void performTest(const TestParams& params) { pre_gelu_out.to_cpu(); } - //perform the gemm in CPU + //perform the reference gemm on GPU std::unique_ptr ref_D = std::make_unique(params.m*params.n); std::unique_ptr ref_pre_gelu_out; if(params.use_gelu){ @@ -391,40 +573,17 @@ void performTest(const TestParams& params) { } float ref_amax_d; - if (use_mxfp8) { - const A_Type *a_data; - const B_Type *b_data; - const fp8e8m0 *a_scale_inv_data, *b_scale_inv_data; - if (params.transa) { - a_data = A.rowwise_cpu_dptr(); - a_scale_inv_data = A.rowwise_cpu_scale_inv_ptr(); - } else { - a_data = A.columnwise_cpu_dptr(); - a_scale_inv_data = A.columnwise_cpu_scale_inv_ptr(); - } - if (params.transb) { - b_data = B.columnwise_cpu_dptr(); - b_scale_inv_data = B.columnwise_cpu_scale_inv_ptr(); - } else { - b_data = B.rowwise_cpu_dptr(); - b_scale_inv_data = B.rowwise_cpu_scale_inv_ptr(); - } - compute_mxfp8_ref( - a_data, b_data, a_scale_inv_data, b_scale_inv_data, - params.use_bias ? bias.rowwise_cpu_dptr() : nullptr, - D.scale(), params.m, params.k, params.n, ref_D.get(), &ref_amax_d, - params.use_gelu ? ref_pre_gelu_out.get() : nullptr, - params.transa, params.transb); - } else { - compute_ref( - A.rowwise_cpu_dptr(), B.rowwise_cpu_dptr(), - A.rowwise_scale_inv(), B.rowwise_scale_inv(), - params.use_bias ? bias.rowwise_cpu_dptr() : nullptr, - D.scale(), params.m, params.k, params.n, ref_D.get(), &ref_amax_d, - params.use_gelu ? ref_pre_gelu_out.get() : nullptr, - params.transa, params.transb); - } + run_reference( + params, + A, + B, + params.use_bias ? &bias : nullptr, + D.scale(), + ref_D, + &ref_amax_d, + ref_pre_gelu_out); + // check if error message happens in running (void)cudaDeviceSynchronize(); auto err = cudaGetLastError(); diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index bfb46f8a0..07b4cd9bf 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -244,7 +244,7 @@ class Tensor { } template - T *rowwise_cpu_scale_inv_ptr(){ + T *rowwise_cpu_scale_inv_ptr() const { if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING){ NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); } else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING_1D || tensor_.scaling_mode() == NVTE_BLOCK_SCALING_2D) { @@ -257,7 +257,7 @@ class Tensor { } template - T *columnwise_cpu_scale_inv_ptr(){ + T *columnwise_cpu_scale_inv_ptr() const { if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING){ NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); } else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING_1D || tensor_.scaling_mode() == NVTE_BLOCK_SCALING_2D) { @@ -269,7 +269,7 @@ class Tensor { return reinterpret_cast(columnwise_scale_inv_cpu_data_.get()); } - float rowwise_scale_inv(){ + float rowwise_scale_inv() const { if(rowwise_scale_inv_cpu_data_) { float scale_inv = rowwise_cpu_scale_inv_ptr()[0]; return scale_inv;