diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index b3f24e9337e..bde67e56d13 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -807,17 +807,28 @@ def _quantize_dbias_impl( use_rht = True rht_matrix = get_rht_matrix() - new_amax, post_rht_amax = calculate_post_rht_amax( - x.data, - amax_scope=amax_scope, - transpose_batch_sequence=transpose_batch_sequence, - produce_regular_amax=amax is None, - flatten_axis=flatten_axis, - ) - if amax is None: - # If amax is already calculated in a previous layer, we skip calculating it in the TE kernel - # So here we only calculate and update amax when it is not provided from a previous layer (amax is None) - amax = new_amax + use_approx_amax = True + if use_approx_amax: + if amax is None: + amax = calculate_amax( + x.data, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + ) + amax_scale = 2.0 + post_rht_amax = amax * amax_scale + else: + new_amax, post_rht_amax = calculate_post_rht_amax( + x.data, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + produce_regular_amax=amax is None, + flatten_axis=flatten_axis, + ) + if amax is None: + # If amax is already calculated in a previous layer, we skip calculating it in the TE kernel + # So here we only calculate and update amax when it is not provided from a previous layer (amax is None) + amax = new_amax if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: if amax is None: