From e951f60b62649c166bbd056c765ca41a74a5e2ef Mon Sep 17 00:00:00 2001 From: Alex Magro Date: Fri, 16 Jan 2026 15:04:00 -0600 Subject: [PATCH] Remove IS_NORM template parameter --- transformer_engine/common/normalization/common.h | 2 +- transformer_engine/common/util/rocm_cast_kernels.cuh | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/transformer_engine/common/normalization/common.h b/transformer_engine/common/normalization/common.h index 31d2c0b74..c4a0a6f8f 100644 --- a/transformer_engine/common/normalization/common.h +++ b/transformer_engine/common/normalization/common.h @@ -464,7 +464,7 @@ void rocm_norm_mxfp8_quantize(LaunchParams &launch_params) TRANSFORMER_ENGINE_SWITCH_CONDITION( !(cols % (32 * sizeof(compute_t))), IS_ALIGNED, cast_mxfp8_2D_kernel<<>>( + SCALE_DIM_Y, scale_dim_X_rowwise, IS_ALIGNED><<>>( reinterpret_cast(launch_params.params.z), nullptr, reinterpret_cast(launch_params.z_tensor->data.dptr), diff --git a/transformer_engine/common/util/rocm_cast_kernels.cuh b/transformer_engine/common/util/rocm_cast_kernels.cuh index d62350e0a..ac0ce2174 100644 --- a/transformer_engine/common/util/rocm_cast_kernels.cuh +++ b/transformer_engine/common/util/rocm_cast_kernels.cuh @@ -45,7 +45,7 @@ constexpr size_t MXFP8_ITERATIONS = MXFP8_CHUNK_DIM_Y / MXFP8_BUFFER_DIM_Y; // template + size_t SCALE_DIM_X, bool IS_ALIGNED> __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) cast_mxfp8_2D_kernel(const IType *input_ptr, const IType *act_input_ptr, @@ -221,7 +221,7 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) const float subwarp_amax = subwarp_reduce_max_broadcast(thread_amax); const e8m0_t biased_exponent = - float_to_e8m0(subwarp_amax * Quantized_Limits::max_norm_rcp) + (IS_NORM ? 1 : 0); // Normalization requires a +1 scale to avoid saturation + float_to_e8m0(subwarp_amax * Quantized_Limits::max_norm_rcp); // Only single thread writes the computed scaling factor if (tid_rowwise_X % THREADS_PER_SCALE_X_ROWWISE == 0) { @@ -278,7 +278,7 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) __builtin_assume(amax >= 0); block_amax = fmaxf(block_amax, amax); - const e8m0_t biased_exponent = float_to_e8m0(amax * Quantized_Limits::max_norm_rcp) + (IS_NORM ? 1 : 0); // Normalization requires a +1 scale to avoid saturation + const e8m0_t biased_exponent = float_to_e8m0(amax * Quantized_Limits::max_norm_rcp); const int global_scales_offset_Y = scales_colwise_chunk_offset_Y + iter; const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_colwise_X;