Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion transformer_engine/common/normalization/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ void rocm_norm_mxfp8_quantize(LaunchParams<ForwardKernelParams> &launch_params)
TRANSFORMER_ENGINE_SWITCH_CONDITION(
!(cols % (32 * sizeof(compute_t))), IS_ALIGNED,
cast_mxfp8_2D_kernel<false, false, false, Empty, {}, compute_t, OType,
SCALE_DIM_Y, scale_dim_X_rowwise, IS_ALIGNED, true><<<grid, block, 0, launch_params.stream>>>(
SCALE_DIM_Y, scale_dim_X_rowwise, IS_ALIGNED><<<grid, block, 0, launch_params.stream>>>(
reinterpret_cast<const compute_t*>(launch_params.params.z),
nullptr,
reinterpret_cast<OType *>(launch_params.z_tensor->data.dptr),
Expand Down
6 changes: 3 additions & 3 deletions transformer_engine/common/util/rocm_cast_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ constexpr size_t MXFP8_ITERATIONS = MXFP8_CHUNK_DIM_Y / MXFP8_BUFFER_DIM_Y; //

template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ParamOP,
float (*OP)(float, const ParamOP &), typename IType, typename OType, size_t SCALE_DIM_Y,
size_t SCALE_DIM_X, bool IS_ALIGNED, bool IS_NORM = false>
size_t SCALE_DIM_X, bool IS_ALIGNED>
__global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK)
cast_mxfp8_2D_kernel(const IType *input_ptr,
const IType *act_input_ptr,
Expand Down Expand Up @@ -221,7 +221,7 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK)

const float subwarp_amax = subwarp_reduce_max_broadcast<SUBWARP_WIDTH>(thread_amax);
const e8m0_t biased_exponent =
float_to_e8m0(subwarp_amax * Quantized_Limits<OType>::max_norm_rcp) + (IS_NORM ? 1 : 0); // Normalization requires a +1 scale to avoid saturation
float_to_e8m0(subwarp_amax * Quantized_Limits<OType>::max_norm_rcp);

// Only single thread writes the computed scaling factor
if (tid_rowwise_X % THREADS_PER_SCALE_X_ROWWISE == 0) {
Expand Down Expand Up @@ -278,7 +278,7 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK)
__builtin_assume(amax >= 0);
block_amax = fmaxf(block_amax, amax);

const e8m0_t biased_exponent = float_to_e8m0(amax * Quantized_Limits<OType>::max_norm_rcp) + (IS_NORM ? 1 : 0); // Normalization requires a +1 scale to avoid saturation
const e8m0_t biased_exponent = float_to_e8m0(amax * Quantized_Limits<OType>::max_norm_rcp);

const int global_scales_offset_Y = scales_colwise_chunk_offset_Y + iter;
const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_colwise_X;
Expand Down