diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 38b437b994d..36a9cc384b0 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -394,6 +394,10 @@ struct QuantizationConfig { NVTETensor rng_state = nullptr; bool nvfp4_2d_quantization = false; bool stochastic_rounding = false; + // Scale factor for estimating post-RHT amax from pre-RHT amax. + // When <= 0.0f, true post-RHT amax is used (default behavior). + // When > 0.0f, post-RHT amax is estimated as: pre_rht_amax * amax_estimation_scale + float amax_estimation_scale = 0.0f; static constexpr size_t attr_sizes[] = { sizeof(bool), // force_pow_2_scales @@ -402,7 +406,8 @@ struct QuantizationConfig { sizeof(Float8BlockScaleTensorFormat), // float8_block_scale_tensor_format sizeof(NVTETensor), // rng_seed and offset sizeof(bool), // nvfp4_2d_quantization - sizeof(bool) // stochastic_rounding + sizeof(bool), // stochastic_rounding + sizeof(float) // amax_estimation_scale }; }; diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu index 12f02dba6b6..6d51f41f344 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu @@ -139,7 +139,8 @@ rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile, TSFC * SFC, TiledMMA mma, float const* global_amax, - const size_t* rng_state) + const size_t* rng_state, + float amax_scale) { using namespace cute; using X = Underscore; @@ -407,7 +408,8 @@ rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile, accumulator_pipeline.producer_tail(accumulator_pipe_producer_state); tmem_allocator.free(tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns); } else if (is_epilogue_warp) { - const float global_amax_val = *global_amax; + // Apply amax estimation scale if provided (amax_scale > 0 means estimation is enabled) + const float global_amax_val = (*global_amax) * amax_scale; static constexpr int FragmentSize = 256 / sizeof_bits_v; tmem_allocation_result_barrier.arrive_and_wait(); @@ -543,7 +545,8 @@ rht_gemm_ntt_w_sfc(int m, int n, const size_t* rng_state, uint32_t sm_count, cudaStream_t stream, - int k_tile_size = 2048) + int k_tile_size = 2048, + float amax_scale = 1.0f) { using namespace cute; @@ -662,7 +665,8 @@ rht_gemm_ntt_w_sfc(int m, int n, C, dC, sC, SFC, mma, global_amax, - rng_state); + rng_state, + amax_scale); } // this function is used to wrap the rht_gemm_ntt_w_sfc function @@ -678,7 +682,8 @@ rht_gemm_ttt_wrapper(int m, int n, const size_t* rng_state, uint32_t sm_count, cudaStream_t stream, - int k_tile_size = 1024) + int k_tile_size = 1024, + float amax_scale = 1.0f) { // in addition to transpose the input tensor A // we also need to reshape m, n to at best @@ -696,7 +701,8 @@ rht_gemm_ttt_wrapper(int m, int n, SFC, global_amax, rng_state, sm_count, stream, - k_tile_size); + k_tile_size, + amax_scale); } } // namespace @@ -734,6 +740,11 @@ void hadamard_transform_cast_fusion_columnwise(const Tensor &input_, Tensor &out rng_state = reinterpret_cast(rng_state_tensor.data.dptr); } + // Amax estimation scale: when > 0, amax is scaled by this factor + // This allows estimating post-RHT amax from pre-RHT amax + const float amax_scale = + (quant_config.amax_estimation_scale > 0.0f) ? quant_config.amax_estimation_scale : 1.0f; + // Template arguments using TA = cute::bfloat16_t; using TB = cute::bfloat16_t; @@ -813,7 +824,8 @@ void hadamard_transform_cast_fusion_columnwise(const Tensor &input_, Tensor &out /*rng_state=*/rng_state, /*sm_count=*/sm_count, /*stream=*/stream, - /*k_tile_size=*/k_tile_size);); + /*k_tile_size=*/k_tile_size, + /*amax_scale=*/amax_scale);); } } // namespace transformer_engine diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index b2e04ba69f5..1f3f10f26e5 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -337,6 +337,11 @@ enum NVTEQuantizationConfigAttribute { kNVTEQuantizationConfigNVFP42DQuantization = 5, /*! Whether to enable stochastic rounding */ kNVTEQuantizationConfigStochasticRounding = 6, + /*! Scale factor for estimating post-RHT amax from pre-RHT amax. + * When <= 0.0f, true post-RHT amax is used (default behavior). + * When > 0.0f, post-RHT amax is estimated as: pre_rht_amax * amax_estimation_scale + */ + kNVTEQuantizationConfigAmaxEstimationScale = 7, kNVTEQuantizationConfigNumAttributes }; @@ -997,6 +1002,16 @@ class QuantizationConfigWrapper { &stochastic_rounding, sizeof(bool)); } + /*! \brief Set amax estimation scale for post-RHT amax estimation + * + * When <= 0.0f, true post-RHT amax is used (default behavior). + * When > 0.0f, post-RHT amax is estimated as: pre_rht_amax * amax_estimation_scale + */ + void set_amax_estimation_scale(float amax_estimation_scale) { + nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigAmaxEstimationScale, + &amax_estimation_scale, sizeof(float)); + } + private: /*! \brief Wrapped NVTEQuantizationConfig. */ NVTEQuantizationConfig config_ = nullptr; diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 98e2a29df85..4d9c043f268 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -65,6 +65,9 @@ class QParams: amax_epsilon: optional minimum value of abs max random_hadamard_transform: whether to use random hadamard transform stochastic_rounding: whether to use stocastic rounding + amax_estimation_scale: scale factor for estimating post-RHT amax from pre-RHT amax. + When None, true post-RHT amax is computed (default behavior). + When set to a float, post-RHT amax is estimated as: pre_rht_amax * amax_estimation_scale """ power_2_scale: bool = False @@ -72,6 +75,7 @@ class QParams: random_hadamard_transform: bool = False stochastic_rounding: bool = False fp4_2d_quantization: bool = False + amax_estimation_scale: float | None = None def __repr__(self) -> str: return ( @@ -79,7 +83,8 @@ def __repr__(self) -> str: f"amax_epsilon={self.amax_epsilon},\n" f"random_hadamard_transform={self.random_hadamard_transform},\n" f"stochastic_rounding={self.stochastic_rounding},\n" - f"fp4_2d_quantization={self.fp4_2d_quantization}\n)" + f"fp4_2d_quantization={self.fp4_2d_quantization},\n" + f"amax_estimation_scale={self.amax_estimation_scale}\n)" ) @@ -428,6 +433,16 @@ class NVFP4BlockScaling(Recipe): If set to `True`, stochastic rounding is disabled during quantization for all tensors. disable_2d_quantization : bool, default = False If set to `True`, 1D block scaling with block size 16 is used for all tensors. + use_post_rht_amax_estimation : bool, default = False + **EXPERIMENTAL**: If set to `True`, post-RHT amax is estimated from pre-RHT amax + instead of being computed by a separate RHT+amax kernel. This can reduce the + number of kernel launches but may affect numerical accuracy. + post_rht_amax_estimation_scale_fwd_inp : float, default = 2.0 + Scale factor for estimating post-RHT amax for forward input activations. + Only used when `use_post_rht_amax_estimation=True`. + post_rht_amax_estimation_scale_bwd_grad : float, default = 1.0 + Scale factor for estimating post-RHT amax for backward gradients. + Only used when `use_post_rht_amax_estimation=True`. """ # Configuration envvars @@ -444,10 +459,33 @@ class NVFP4BlockScaling(Recipe): fp8_dpa: bool = False fp8_mha: bool = False + # Experimental: Post-RHT amax estimation + use_post_rht_amax_estimation: bool = ( + os.getenv("NVTE_NVFP4_POST_RHT_AMAX_ESTIMATION", "0") == "1" + ) + post_rht_amax_estimation_scale_fwd_inp = float( + os.getenv("NVTE_NVFP4_POST_RHT_AMAX_ESTIMATION_X_SCALE", "2.0") + ) + post_rht_amax_estimation_scale_bwd_grad = float( + os.getenv("NVTE_NVFP4_POST_RHT_AMAX_ESTIMATION_G_SCALE", "1.0") + ) + def __post_init__(self) -> None: assert self.fp4_format == Format.E2M1, "Only E2M1 is supported for NVFP4 scaling" assert self.fp8_format == Format.E4M3, "Only E4M3 is supported for NVFP4 scaling" + # Determine amax estimation scales (None = use true post-RHT amax) + amax_scale_fwd_inp = ( + self.post_rht_amax_estimation_scale_fwd_inp + if self.use_post_rht_amax_estimation + else None + ) + amax_scale_bwd_grad = ( + self.post_rht_amax_estimation_scale_bwd_grad + if self.use_post_rht_amax_estimation + else None + ) + # Quantization params # Note: RHT is currently only applied to column-wise usage so that # it can be used for wgrad GEMM. @@ -455,6 +493,7 @@ def __post_init__(self) -> None: random_hadamard_transform=not self.disable_rht, stochastic_rounding=False, fp4_2d_quantization=False, + amax_estimation_scale=amax_scale_fwd_inp, ) self.fp4_quant_fwd_weight = QParams( random_hadamard_transform=False, @@ -465,6 +504,7 @@ def __post_init__(self) -> None: random_hadamard_transform=not self.disable_rht, stochastic_rounding=not self.disable_stochastic_rounding, fp4_2d_quantization=False, + amax_estimation_scale=amax_scale_bwd_grad, ) def __repr__(self) -> str: @@ -474,6 +514,7 @@ def __repr__(self) -> str: f"fp8_format={str(self.fp8_format).split('.')[1]}, " f"fp8_dpa={self.fp8_dpa}, " f"fp8_mha={self.fp8_mha}, " + f"use_post_rht_amax_estimation={self.use_post_rht_amax_estimation}, " f"fp4_quant_fwd_inp={self.fp4_quant_fwd_inp}, " f"fp4_quant_fwd_weight={self.fp4_quant_fwd_weight}, " f"fp4_quant_bwd_grad={self.fp4_quant_bwd_grad}, " diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 8d9563b789f..787a8dc13a3 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -933,6 +933,9 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config, case kNVTEQuantizationConfigStochasticRounding: std::memcpy(&config_.stochastic_rounding, buf, attr_size); break; + case kNVTEQuantizationConfigAmaxEstimationScale: + std::memcpy(&config_.amax_estimation_scale, buf, attr_size); + break; default: NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); } diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 978bee52dc1..af0399fcaa9 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -296,6 +296,10 @@ class NVFP4Quantizer : public Quantizer { // 2D block scaling bool with_2d_quantization; bool stochastic_rounding; + // Scale factor for estimating post-RHT amax from pre-RHT amax. + // When <= 0.0f, true post-RHT amax is used (default behavior). + // When > 0.0f, post-RHT amax is estimated as: pre_rht_amax * amax_estimation_scale + float amax_estimation_scale; int rht_matrix_random_sign_mask_t; at::Tensor rht_matrix; diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index 14cc084c0c7..d972e97059f 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -42,10 +42,15 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); - if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) { - // Post-RHT amax is handled within NVFP4 quantizer + // Check if amax estimation is enabled (scale > 0 means we can use pre-RHT amax) + const bool use_amax_estimation = nvfp4_quantizer_cpp->amax_estimation_scale > 0.0f; + if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax && + !use_amax_estimation) { + // Post-RHT amax is handled within NVFP4 quantizer (need true post-RHT amax) impl = Impl::UNFUSED; } else { + // When use_amax_estimation is true, activation kernel computes pre-RHT amax, + // and the quantizer will scale it to estimate post-RHT amax. impl = Impl::FUSED_ACTIVATION_AMAX_NVFP4; } } @@ -154,10 +159,15 @@ py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& i } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); - if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) { - // Post-RHT amax is handled within NVFP4 quantizer + // Check if amax estimation is enabled (scale > 0 means we can use pre-RHT amax) + const bool use_amax_estimation = nvfp4_quantizer_cpp->amax_estimation_scale > 0.0f; + if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax && + !use_amax_estimation) { + // Post-RHT amax is handled within NVFP4 quantizer (need true post-RHT amax) impl = Impl::UNFUSED; } else { + // When use_amax_estimation is true, activation kernel computes pre-RHT amax, + // and the quantizer will scale it to estimate post-RHT amax. impl = Impl::FUSED_ACTIVATION_AMAX_NVFP4; } } diff --git a/transformer_engine/pytorch/csrc/extensions/bias.cpp b/transformer_engine/pytorch/csrc/extensions/bias.cpp index b0435d27230..531996515f4 100644 --- a/transformer_engine/pytorch/csrc/extensions/bias.cpp +++ b/transformer_engine/pytorch/csrc/extensions/bias.cpp @@ -152,10 +152,15 @@ std::vector dact_dbias( } else if (detail::IsNVFP4Quantizers(quantizer_py.ptr())) { auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); - if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) { - // Post-RHT amax is handled within NVFP4 quantizer + // Check if amax estimation is enabled (scale > 0 means we can use pre-RHT amax) + const bool use_amax_estimation = nvfp4_quantizer_cpp->amax_estimation_scale > 0.0f; + if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax && + !use_amax_estimation) { + // Post-RHT amax is handled within NVFP4 quantizer (need true post-RHT amax) impl = Impl::UNFUSED; } else { + // When use_amax_estimation is true, dact kernel computes pre-RHT amax, + // and the quantizer will scale it to estimate post-RHT amax. impl = Impl::FUSED_DACT_AMAX_NVFP4; } } diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index 3c5c17fc6f2..a540e348701 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -126,11 +126,16 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); - if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) { - // Post-RHT amax is handled within NVFP4 quantizer + // Check if amax estimation is enabled (scale > 0 means we can use pre-RHT amax) + const bool use_amax_estimation = nvfp4_quantizer_cpp->amax_estimation_scale > 0.0f; + if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax && + !use_amax_estimation) { + // Post-RHT amax is handled within NVFP4 quantizer (need true post-RHT amax) impl = Impl::UNFUSED; } else if (!transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { - // TE kernel supports amax output + // TE kernel supports amax output. + // When use_amax_estimation is true, LayerNorm computes pre-RHT amax, + // and the quantizer will scale it to estimate post-RHT amax. impl = Impl::FUSED_NORM_AMAX_NVFP4; } } @@ -355,11 +360,16 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); - if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) { - // Post-RHT amax is handled within NVFP4 quantizer + // Check if amax estimation is enabled (scale > 0 means we can use pre-RHT amax) + const bool use_amax_estimation = nvfp4_quantizer_cpp->amax_estimation_scale > 0.0f; + if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax && + !use_amax_estimation) { + // Post-RHT amax is handled within NVFP4 quantizer (need true post-RHT amax) impl = Impl::UNFUSED; } else if (!transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { - // TE kernel supports amax output + // TE kernel supports amax output. + // When use_amax_estimation is true, LayerNorm computes pre-RHT amax, + // and the quantizer will scale it to estimate post-RHT amax. impl = Impl::FUSED_NORM_AMAX_NVFP4; } } diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index d7e8912ac74..7264c942f75 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1140,6 +1140,15 @@ NVFP4Quantizer::NVFP4Quantizer(const py::handle& quantizer) : Quantizer(quantize this->with_2d_quantization = quantizer.attr("with_2d_quantization").cast(); this->stochastic_rounding = quantizer.attr("stochastic_rounding").cast(); + // Amax estimation scale: when > 0, post-RHT amax is estimated from pre-RHT amax + // Default to 0.0 (disabled) if the attribute doesn't exist + if (py::hasattr(quantizer, "amax_estimation_scale")) { + auto scale = quantizer.attr("amax_estimation_scale"); + this->amax_estimation_scale = scale.is_none() ? 0.0f : scale.cast(); + } else { + this->amax_estimation_scale = 0.0f; + } + // Get amax reduction group if needed for NVFP4 AG const bool with_amax_reduction = quantizer.attr("with_amax_reduction").cast(); c10::intrusive_ptr amax_reduction_group; @@ -1459,6 +1468,7 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou } quant_config.set_nvfp4_2d_quantization(this->with_2d_quantization); quant_config.set_stochastic_rounding(this->stochastic_rounding); + quant_config.set_amax_estimation_scale(this->amax_estimation_scale); // We only need RHT for columnwise usage. // flat first dim and last dim for multi dimensional input @@ -1486,11 +1496,16 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou input.dtype() == DType::kBFloat16 && rows % 64 == 0 && cols % 128 == 0; // Compute amax. + // When amax_estimation_scale > 0, we estimate post-RHT amax from pre-RHT amax, + // so we skip the RHT+amax kernel and compute simple pre-RHT amax instead. + const bool use_amax_estimation = this->amax_estimation_scale > 0.0f; + if (this->with_rht) { if (input.dtype() != DType::kBFloat16) { NVTE_CHECK(false, "RHT is only supported for bfloat16 input"); } - if (this->with_post_rht_amax) { + if (this->with_post_rht_amax && !use_amax_estimation) { + // Compute true post-RHT amax using RHT+amax kernel // We need: // 1. Rowwise amax = amax for input // 2. Columnwise amax = amax for RHT(input.t) @@ -1498,8 +1513,33 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou nvte_hadamard_transform_amax(input.data(), out.data(), 0, this->rht_matrix_random_sign_mask_t, stream); }); + } else if (use_amax_estimation) { + // EXPERIMENTAL: Skip RHT+amax kernel, compute pre-RHT amax instead if required. + // The kernel will scale this by amax_estimation_scale to estimate post-RHT amax. + if (compute_amax) { + auto rowwise_amax_ptr = out.get_amax().data_ptr; + auto columnwise_amax_ptr = out.get_columnwise_amax().data_ptr; + void* amax_ptr = rowwise_amax_ptr != nullptr ? rowwise_amax_ptr : columnwise_amax_ptr; + NVTE_CHECK(amax_ptr != nullptr, "Could not find amax pointer for estimation"); + + // Compute pre-RHT amax of input tensor + out.set_amax(amax_ptr, DType::kFloat32, std::vector{1}); + NVTE_SCOPED_GIL_RELEASE( + { nvte_compute_amax_with_config(input.data(), out.data(), quant_config, stream); }); + out.set_amax(rowwise_amax_ptr, DType::kFloat32, std::vector{1}); + + // Make sure row-wise and column-wise amaxes match + if (rowwise_amax_ptr != amax_ptr && rowwise_amax_ptr != nullptr) { + NVTE_CHECK_CUDA(cudaMemcpyAsync(rowwise_amax_ptr, amax_ptr, sizeof(float), + cudaMemcpyDeviceToDevice, stream)); + } + if (columnwise_amax_ptr != amax_ptr && columnwise_amax_ptr != nullptr) { + NVTE_CHECK_CUDA(cudaMemcpyAsync(columnwise_amax_ptr, amax_ptr, sizeof(float), + cudaMemcpyDeviceToDevice, stream)); + } + } } else { - // raise error since it's not supported yet + // with_rht but not with_post_rht_amax and not using estimation NVTE_CHECK(false, "Pre-RHT amax is not supported yet"); } } else { // Without RHT diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index fbe2ee6d1cf..70800d9b2cb 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -1332,20 +1332,23 @@ def _make_quantizer(idx: int) -> NVFP4Quantizer: with_post_rht_amax=qparams.random_hadamard_transform, with_2d_quantization=qparams.fp4_2d_quantization, stochastic_rounding=qparams.stochastic_rounding, + amax_estimation_scale=qparams.amax_estimation_scale, ) return [_make_quantizer(idx) for idx in range(self.num_quantizers)] if self.mode == "backward": + qparams = self.recipe.fp4_quant_bwd_grad return [ NVFP4Quantizer( fp4_dtype=self.dtype, rowwise=True, columnwise=True, - with_rht=self.recipe.fp4_quant_bwd_grad.random_hadamard_transform, - with_post_rht_amax=self.recipe.fp4_quant_bwd_grad.random_hadamard_transform, - with_2d_quantization=self.recipe.fp4_quant_bwd_grad.fp4_2d_quantization, - stochastic_rounding=self.recipe.fp4_quant_bwd_grad.stochastic_rounding, + with_rht=qparams.random_hadamard_transform, + with_post_rht_amax=qparams.random_hadamard_transform, + with_2d_quantization=qparams.fp4_2d_quantization, + stochastic_rounding=qparams.stochastic_rounding, + amax_estimation_scale=qparams.amax_estimation_scale, ) for _ in range(self.num_quantizers) ] diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 0c244628d65..ed6fa698821 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -127,6 +127,12 @@ class NVFP4Quantizer(Quantizer): """Stochastic rounding, only applicable for gradients.""" stochastic_rounding: bool + """Scale factor for estimating post-RHT amax from pre-RHT amax. + When None, true post-RHT amax is computed (default behavior). + When set to a float, post-RHT amax is estimated as: pre_rht_amax * amax_estimation_scale + """ + amax_estimation_scale: Optional[float] + """RHT matrix random sign mask""" rht_matrix_random_sign_mask_t: int rht_matrix: torch.Tensor @@ -143,6 +149,7 @@ def __init__( with_2d_quantization: bool = False, stochastic_rounding: bool = False, with_random_sign_mask: bool = True, + amax_estimation_scale: Optional[float] = None, ) -> None: super().__init__(rowwise=rowwise, columnwise=columnwise) self.dtype = fp4_dtype @@ -152,6 +159,7 @@ def __init__( self.amax_reduction_group = amax_reduction_group self.with_2d_quantization = with_2d_quantization self.stochastic_rounding = stochastic_rounding + self.amax_estimation_scale = amax_estimation_scale self.rht_matrix_random_sign_mask_t = get_random_sign_mask_for_rht( with_random_sign_mask, torch.cuda.current_device() ) @@ -191,6 +199,7 @@ def copy(self) -> NVFP4Quantizer: with_post_rht_amax=self.with_post_rht_amax, with_2d_quantization=self.with_2d_quantization, stochastic_rounding=self.stochastic_rounding, + amax_estimation_scale=self.amax_estimation_scale, ) quantizer.internal = self.internal quantizer.rht_matrix = self.rht_matrix