diff --git a/qa/L0_pytorch_debug_unittest/test.sh b/qa/L0_pytorch_debug_unittest/test.sh index a176d21b15b..78be6ab5346 100644 --- a/qa/L0_pytorch_debug_unittest/test.sh +++ b/qa/L0_pytorch_debug_unittest/test.sh @@ -28,16 +28,16 @@ mkdir -p "$XML_LOG_DIR" pip install pytest==8.2.1 || error_exit "Failed to install pytest" -pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity.xml $TE_PATH/tests/pytorch/debug/test_sanity.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "test_sanity.py" -pytest -v -s --junitxml=$XML_LOG_DIR/test_config.xml $TE_PATH/tests/pytorch/debug/test_config.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "test_config.py" -pytest -v -s --junitxml=$XML_LOG_DIR/test_numerics.xml $TE_PATH/tests/pytorch/debug/test_numerics.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "test_numerics.py" -pytest -v -s --junitxml=$XML_LOG_DIR/test_log.xml $TE_PATH/tests/pytorch/debug/test_log.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_log.py" -NVTE_TORCH_COMPILE=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_api_features.xml $TE_PATH/tests/pytorch/debug/test_api_features.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_api_features.py" -pytest -v -s --junitxml=$XML_LOG_DIR/test_perf.xml $TE_PATH/tests/pytorch/debug/test_perf.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_perf.py" +pytest -s --junitxml=$XML_LOG_DIR/test_sanity.xml $TE_PATH/tests/pytorch/debug/test_sanity.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "test_sanity.py" +pytest -s --junitxml=$XML_LOG_DIR/test_config.xml $TE_PATH/tests/pytorch/debug/test_config.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "test_config.py" +pytest -s --junitxml=$XML_LOG_DIR/test_numerics.xml $TE_PATH/tests/pytorch/debug/test_numerics.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "test_numerics.py" +pytest -s --junitxml=$XML_LOG_DIR/test_log.xml $TE_PATH/tests/pytorch/debug/test_log.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_log.py" +NVTE_TORCH_COMPILE=0 pytest -s --junitxml=$XML_LOG_DIR/test_api_features.xml $TE_PATH/tests/pytorch/debug/test_api_features.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_api_features.py" +pytest -s --junitxml=$XML_LOG_DIR/test_perf.xml $TE_PATH/tests/pytorch/debug/test_perf.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_perf.py" # standard sanity and numerics tests with initialized debug -NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity_2.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "debug test_sanity.py" -NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_numerics_2.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "debug test_numerics.py" +NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -s --junitxml=$XML_LOG_DIR/test_sanity_2.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "debug test_sanity.py" +NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 pytest -s --junitxml=$XML_LOG_DIR/test_numerics_2.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "debug test_numerics.py" if [ "$RET" -ne 0 ]; then echo "Error in the following test cases:$FAILED_CASES" diff --git a/tests/cpp/operator/test_swizzle.cu b/tests/cpp/operator/test_swizzle.cu index f6e0da057ae..97c0d45b642 100644 --- a/tests/cpp/operator/test_swizzle.cu +++ b/tests/cpp/operator/test_swizzle.cu @@ -85,6 +85,7 @@ void performTestSwizzle1D(const int num_tiles_M, const int num_tiles_K, bool row std::vector scaling_mode = {SF_MODE_X, SF_MODE_Y, 0}; Tensor input("input", data_shape, dtype, rowwise, columnwise, NVTE_MXFP8_1D_SCALING); Tensor output("output", data_shape, dtype, rowwise, columnwise, NVTE_MXFP8_1D_SCALING); + output.set_with_gemm_swizzled_scales(true); fillUniform(&input); diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index b8993dfb620..771a4c2b5ca 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -286,6 +286,10 @@ class Tensor { tensor_.set_amax(nullptr, DType::kFloat32, tensor_.defaultShape); } + void set_with_gemm_swizzled_scales(bool with_gemm_swizzled_scales){ + tensor_.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); + } + void to_cpu() const; void from_cpu() const; void set_scale(float scale); diff --git a/tests/pytorch/test_float8_blockwise_scaling_exact.py b/tests/pytorch/test_float8_blockwise_scaling_exact.py index 153f0b7e042..ec7b4f02e6a 100644 --- a/tests/pytorch/test_float8_blockwise_scaling_exact.py +++ b/tests/pytorch/test_float8_blockwise_scaling_exact.py @@ -87,126 +87,6 @@ def initialize_for_many_scales( return result -@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) -@pytest.mark.parametrize( - "M, N", - [ - # full tile cases - (128, 128), - (256, 256), - (256, 1024), - (1024, 256), - # Padding required cases - (256, 272), - (303, 300), - (305, 256), - # Some larger tiles. - (2000, 2000), - (2048, 2000), - (2000, 1024), - (2048, 1024), - ], -) -@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) -@pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) -@pytest.mark.parametrize("eps", [0], ids=["eps_0"]) -@pytest.mark.parametrize("pow_2_scales", [True], ids=["pow2scales"]) -def test_quantization_1D_block_tiling_with_compact_data_and_scales( - x_dtype: torch.dtype, - M: int, - N: int, - quant_dtype: torch.dtype, - eps: float, - pow_2_scales: bool, -) -> None: - te_dtype = TE_DType[quant_dtype] - tile_size = (1, 128) - # This test runs a comparison of the ref class versus the class using - # CUDA kernels to quantize. They should quantize identically for pixels - # that are not DC values in the scale factor shape. - ref_quantizer = BlockwiseQuantizerReference() - sut_quantizer = Float8BlockQuantizer( - fp8_dtype=te_dtype, - rowwise=True, - columnwise=True, - amax_epsilon=eps, - force_pow_2_scales=pow_2_scales, - block_scaling_dim=1, - all_gather_usage=True, - ) - - # Setup device and random seed - device = "cuda" - seed = 0 - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - - # Input - x = initialize_for_many_scales((M, N), tile_size, dtype=x_dtype, device=device) - - x_fp8_sut = sut_quantizer.make_empty((M, N), dtype=x_dtype, device=device, requires_grad=False) - x_fp8_sut = sut_quantizer.update_quantized(x, x_fp8_sut) - x_fp8_sut_cpp_alloc = sut_quantizer(x) - - assert x_fp8_sut._rowwise_data is not None - qx: torch.Tensor = x_fp8_sut._rowwise_data.view(dtype=quant_dtype) - assert x_fp8_sut._rowwise_scale_inv is not None - sx: torch.Tensor = x_fp8_sut._rowwise_scale_inv - qx_t = x_fp8_sut._columnwise_data - sx_t = x_fp8_sut._columnwise_scale_inv - - qresult_ref = ref_quantizer.quantize( - x, - quant_dtype=quant_dtype, - return_transpose=True, - eps=eps, - pow_2_scales=pow_2_scales, - quant_tile_shape=tile_size, - munge_scale_shapes=False, - ) - qx_ref, sx_ref, qx_t_ref, sx_t_ref = ( - qresult_ref.data, - qresult_ref.scale, - qresult_ref.data_t, - qresult_ref.scale_t, - ) - - # match the reference quantize transpose output with the columnwise non-transpose method - qx_t_ref = qx_t_ref.transpose(-1, -2).contiguous() - sx_t_ref = sx_t_ref.transpose(-1, -2).contiguous() - - # Check - torch.testing.assert_close(qx.float(), qx_ref.float(), atol=0.0, rtol=0.0) - torch.testing.assert_close(sx, sx_ref, atol=0.0, rtol=0.0) - assert qx_t is not None - qx_t = qx_t.view(dtype=quant_dtype) - assert qx_t_ref is not None - assert sx_t is not None - assert sx_t_ref is not None - torch.testing.assert_close(qx_t.float(), qx_t_ref.float(), atol=0.0, rtol=0.0) - torch.testing.assert_close(sx_t, sx_t_ref, atol=0.0, rtol=0.0) - - # check that the C++ and Python allocators are equivalent - torch.testing.assert_close( - x_fp8_sut._rowwise_data, x_fp8_sut_cpp_alloc._rowwise_data, atol=0.0, rtol=0.0 - ) - torch.testing.assert_close( - x_fp8_sut._rowwise_scale_inv, x_fp8_sut_cpp_alloc._rowwise_scale_inv, atol=0.0, rtol=0.0 - ) - torch.testing.assert_close( - x_fp8_sut._columnwise_data, x_fp8_sut_cpp_alloc._columnwise_data, atol=0.0, rtol=0.0 - ) - torch.testing.assert_close( - x_fp8_sut._columnwise_scale_inv, - x_fp8_sut_cpp_alloc._columnwise_scale_inv, - atol=0.0, - rtol=0.0, - ) - - # check if the fp8 output between C++ and Python are the same - assert x_fp8_sut._data_format == x_fp8_sut_cpp_alloc._data_format - - def check_quantization_block_tiling_versus_reference( x_dtype: torch.dtype, M: int, diff --git a/tests/pytorch/test_float8blockwisetensor.py b/tests/pytorch/test_float8blockwisetensor.py index c59f8d8c6ac..7f18d72dd56 100644 --- a/tests/pytorch/test_float8blockwisetensor.py +++ b/tests/pytorch/test_float8blockwisetensor.py @@ -175,16 +175,12 @@ def test_quantize_dequantize_columnwise_only( ) @pytest.mark.parametrize("block_scaling_dim", [1, 2]) @pytest.mark.parametrize("dq_columnwise", [True, False]) - @pytest.mark.parametrize("all_gather_usage", [True, False]) def test_quantize_dequantize_dims( self, dims: DimsType, block_scaling_dim: int, dq_columnwise: bool, - all_gather_usage: bool, ) -> None: - if all_gather_usage and block_scaling_dim != 1: - pytest.skip("all_gather_usage only implemented for 1D block quantization.") atol = _tols[tex.DType.kFloat8E4M3]["atol"] rtol = _tols[tex.DType.kFloat8E4M3]["rtol"] quantizer = Float8BlockQuantizer( @@ -192,7 +188,6 @@ def test_quantize_dequantize_dims( rowwise=True, columnwise=dq_columnwise, block_scaling_dim=block_scaling_dim, - all_gather_usage=all_gather_usage, ) self._test_quantize_dequantize( quantizer=quantizer, @@ -218,7 +213,6 @@ def test_quantize_dequantize_compact_format( rowwise=True, columnwise=dq_columnwise, block_scaling_dim=block_scaling_dim, - all_gather_usage=(block_scaling_dim == 1), ) self._test_quantize_dequantize( quantizer=quantizer, @@ -283,13 +277,8 @@ def test_data_accessors(self, dims: DimsType, block_scaling_dim: int) -> None: @pytest.mark.parametrize("dims", [[256, 512], [250, 500]]) @pytest.mark.parametrize("block_scaling_dim", [1, 2]) - @pytest.mark.parametrize("all_gather_usage", [True, False]) - def test_serialization( - self, dims: DimsType, block_scaling_dim: int, all_gather_usage: bool - ) -> None: + def test_serialization(self, dims: DimsType, block_scaling_dim: int) -> None: """Test serialization of Float8BlockwiseQTensor""" - if all_gather_usage and block_scaling_dim != 1: - pytest.skip("all_gather_usage only implemented for 1D block quantization.") device = "cuda" dtype = torch.bfloat16 x_hp = torch.rand(_to_list(dims), dtype=dtype, device=device) @@ -298,7 +287,6 @@ def test_serialization( rowwise=True, columnwise=True, block_scaling_dim=block_scaling_dim, - all_gather_usage=all_gather_usage, ) # Create FP8 tensor @@ -322,7 +310,6 @@ def test_serialization( assert x_fp8_loaded._is_2D_scaled == x_fp8._is_2D_scaled assert x_fp8_loaded.dtype == x_fp8.dtype assert x_fp8_loaded._fp8_dtype == x_fp8._fp8_dtype - assert x_fp8_loaded._data_format == x_fp8._data_format # Test that dequantized values match x_fp8_dequant = x_fp8.dequantize() diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 735cc9b953f..8a0c9fa929d 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -2737,7 +2737,11 @@ def test_linear( # Check that original and loaded model match exactly tols = {"rtol": 0, "atol": 0} for param_load, param_save in zip(model_load.parameters(), model_save.parameters()): - torch.testing.assert_close(param_load, param_save, **tols) + torch.testing.assert_close( # Force dequantization by casting to FP64 + param_load.to(dtype=torch.float64, device="cpu"), + param_save.to(dtype=torch.float64, device="cpu"), + **tols, + ) torch.testing.assert_close(param_load.grad, param_save.grad, **tols) for y_load, y_save in zip(ys_load, ys_save): torch.testing.assert_close(y_load, y_save, **tols) @@ -2754,7 +2758,6 @@ def setup_class(cls) -> None: @pytest.mark.parametrize("requires_grad", (False, True)) @pytest.mark.parametrize("bias", (False, True)) - @pytest.mark.parametrize("normalization", ("LayerNorm", "RMSNorm")) @pytest.mark.parametrize("quantized_compute", (False, True)) @pytest.mark.parametrize("quantized_weight", (False, True)) @pytest.mark.parametrize("dtype", _dtypes) @@ -2764,25 +2767,18 @@ def test_layernorm_mlp( *, requires_grad: bool, bias: bool, - normalization: str, quantized_compute: bool, quantized_weight: bool, dtype: torch.dtype, quantization: Optional[str], device: torch.device = "cuda", - hidden_size: int = 32, - sequence_length: int = 512, + hidden_size: int = 256, + sequence_length: int = 48, batch_size: int = 4, - ffn_hidden_size: int = 64, + ffn_hidden_size: int = 384, layernorm_epsilon: float = 1e-5, ) -> None: - """ - LayerNorm/RMSNorm + Linear + GELU + Linear - - Note that this test checks only if the module runs - as when chaining multiple modules it is hard to validate - numerical accuracy. - """ + """LayerNorm/RMSNorm + Linear + SwiGLU + Linear""" # Make input shape in_shape = (sequence_length, batch_size, hidden_size) @@ -2798,38 +2794,90 @@ def test_layernorm_mlp( pytest.skip("Quantization scheme is not used") # Random data - _, x_test = make_reference_and_test_tensors( + x_ref, x_test = make_reference_and_test_tensors( in_shape, quantization=quantization, test_dtype=dtype, test_device=device, requires_grad=requires_grad, ) - _, dy_test = make_reference_and_test_tensors( + norm_w_ref, norm_w_test = make_reference_and_test_tensors( + hidden_size, + test_dtype=dtype, + test_device=device, + ) + norm_b_ref, norm_b_test = make_reference_and_test_tensors( + hidden_size, + test_dtype=dtype, + test_device=device, + ) + w1_ref, w1_test = make_reference_and_test_tensors( + (ffn_hidden_size, hidden_size), + quantization=quantization, + test_dtype=dtype, + test_device=device, + ) + w2_ref, w2_test = make_reference_and_test_tensors( + (hidden_size, ffn_hidden_size // 2), + quantization=quantization, + test_dtype=dtype, + test_device=device, + ) + b1_ref, b1_test, b2_ref, b2_test = None, None, None, None + if bias: + b1_ref, b1_test = make_reference_and_test_tensors( + ffn_hidden_size, + test_dtype=dtype, + test_device=device, + ) + b2_ref, b2_test = make_reference_and_test_tensors( + hidden_size, + test_dtype=dtype, + test_device=device, + ) + dy_ref, dy_test = make_reference_and_test_tensors( in_shape, quantization=quantization, test_dtype=dtype, test_device=device, requires_grad=False, ) + with torch.no_grad(): + for t in (norm_w_ref, norm_w_test, norm_b_ref, norm_b_test): + t -= 0.5 + for t in (w1_ref, w1_test, w2_ref, w2_test): + t *= 1 / 64 + if bias: + for t in (b1_ref, b1_test, b2_ref, b2_test): + t -= 0.5 + for t in (dy_ref, dy_test): + t -= 0.5 + + # Reference implementation + x = x_ref + x = torch.nn.functional.layer_norm( + x, + (hidden_size,), + weight=norm_w_ref, + bias=norm_b_ref, + eps=layernorm_epsilon, + ) + x = torch.nn.functional.linear(x, w1_ref, bias=b1_ref) + x1, x2 = x.chunk(2, dim=-1) + x = torch.nn.functional.silu(x1) * x2 + x = torch.nn.functional.linear(x, w2_ref, bias=b2_ref) + y_ref = x + y_ref.backward(dy_ref) - # Implementation with fusible operations + # Construct operations recipe = make_recipe(quantization) with te.quantized_model_init(enabled=quantized_weight, recipe=recipe): - if normalization == "LayerNorm": - norm = te_ops.LayerNorm( - hidden_size, - eps=layernorm_epsilon, - device=device, - dtype=dtype, - ) - else: - norm = te_ops.RMSNorm( - hidden_size, - eps=layernorm_epsilon, - device=device, - dtype=dtype, - ) + norm = te_ops.LayerNorm( + hidden_size, + eps=layernorm_epsilon, + device=device, + dtype=dtype, + ) ffn1 = te_ops.Linear( hidden_size, ffn_hidden_size, @@ -2837,15 +2885,48 @@ def test_layernorm_mlp( device=device, dtype=dtype, ) - act = te_ops.GELU() + act = te_ops.SwiGLU() ffn2 = te_ops.Linear( - ffn_hidden_size, + ffn_hidden_size // 2, hidden_size, bias=bias, device=device, dtype=dtype, ) + + # Copy weights + with torch.no_grad(): + norm.weight.copy_(norm_w_test) + norm.bias.copy_(norm_b_test) + ffn1.weight.copy_(w1_test) + ffn2.weight.copy_(w2_test) + if bias: + ffn1.bias.copy_(b1_test) + ffn2.bias.copy_(b2_test) + del norm_w_test, norm_b_test, w1_test, b1_test, w2_test, b2_test + + # Fuse ops and perform forward and backward pass forward = te_ops.Sequential(norm, ffn1, act, ffn2) with te.autocast(enabled=quantized_compute, recipe=recipe): y_test = forward(x_test) y_test.backward(dy_test) + + def to_cpu(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]: + """Convert to FP64 CPU tensor""" + if tensor is None: + return None + out = tensor.detach().to(dtype=torch.float64, device="cpu") + out = out.requires_grad_(requires_grad=tensor.requires_grad) + return out + + # Check values + tols = {"rtol": 0.25, "atol": 0.5} # Loose tols for sanity checking + torch.testing.assert_close(to_cpu(y_test), y_ref, **tols) + torch.testing.assert_close(to_cpu(x_test.grad), x_ref.grad, **tols) + torch.testing.assert_close(to_cpu(norm.weight.grad), norm_w_ref.grad, **tols) + torch.testing.assert_close(to_cpu(norm.bias.grad), norm_b_ref.grad, **tols) + torch.testing.assert_close(to_cpu(ffn2.weight.grad), w2_ref.grad, **tols) + torch.testing.assert_close(to_cpu(ffn1.weight.grad), w1_ref.grad, **tols) + if bias: + torch.testing.assert_close(to_cpu(ffn1.bias.grad), b1_ref.grad, **tols) + torch.testing.assert_close(to_cpu(ffn2.bias.grad), b2_ref.grad, **tols) diff --git a/transformer_engine/common/cast/dispatch/gated.cuh b/transformer_engine/common/cast/dispatch/gated.cuh index f08b09317e0..3494fe55b57 100644 --- a/transformer_engine/common/cast/dispatch/gated.cuh +++ b/transformer_engine/common/cast/dispatch/gated.cuh @@ -14,6 +14,7 @@ #include #include "../../common.h" +#include "../../transpose/transpose.h" #include "../../utils.cuh" #include "../fp8/gated_fp8.cuh" #include "../mxfp8/gated_mxfp8.cuh" @@ -53,6 +54,20 @@ void quantize_gated_fwd_helper(const NVTETensor nvte_input, NVTETensor nvte_outp } else { fp8::cast_gated_fwd(input, output, p, stream); } + if (is_fp8_dtype(output->dtype()) && output->has_columnwise_data()) { + // FP8 kernel only populates row-wise data, so perform + // transpose separately if needed + Tensor transpose_in, transpose_out, dummy; + transpose_in.scaling_mode = NVTE_DELAYED_TENSOR_SCALING; + transpose_in.data.dptr = output->data.dptr; + transpose_in.data.shape = {output->flat_first_dim(), output->flat_last_dim()}; + transpose_in.data.dtype = output->data.dtype; + transpose_out.scaling_mode = NVTE_DELAYED_TENSOR_SCALING; + transpose_out.data.dptr = output->columnwise_data.dptr; + transpose_out.data.shape = {output->flat_last_dim(), output->flat_first_dim()}; + transpose_out.data.dtype = output->data.dtype; + detail::transpose(transpose_in, /*noop=*/dummy, &transpose_out, stream); + } break; } case NVTE_MXFP8_1D_SCALING: { @@ -129,6 +144,20 @@ void quantize_gated_bwd_helper(const NVTETensor nvte_grad, const NVTETensor nvte } else { fp8::cast_gated_bwd(gated_input, grad, output, p, stream); } + if (is_fp8_dtype(output->dtype()) && output->has_columnwise_data()) { + // FP8 kernel only populates row-wise data, so perform + // transpose separately if needed + Tensor transpose_in, transpose_out, dummy; + transpose_in.scaling_mode = NVTE_DELAYED_TENSOR_SCALING; + transpose_in.data.dptr = output->data.dptr; + transpose_in.data.shape = {output->flat_first_dim(), output->flat_last_dim()}; + transpose_in.data.dtype = output->data.dtype; + transpose_out.scaling_mode = NVTE_DELAYED_TENSOR_SCALING; + transpose_out.data.dptr = output->columnwise_data.dptr; + transpose_out.data.shape = {output->flat_last_dim(), output->flat_first_dim()}; + transpose_out.data.dtype = output->data.dtype; + detail::transpose(transpose_in, /*noop=*/dummy, &transpose_out, stream); + } break; } case NVTE_MXFP8_1D_SCALING: { diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index 6d4454402c5..3efd10be9ab 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -150,17 +150,10 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE; FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE; if (output_tensor->has_data()) { - bool rowwise_compact = (quant_config_cpp.float8_block_scale_tensor_format == - Float8BlockScaleTensorFormat::COMPACT); - rowwise_option = rowwise_compact ? FP8BlockwiseRowwiseOption::ROWWISE_COMPACT - : FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY; + rowwise_option = FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY; } if (output_tensor->has_columnwise_data()) { - bool columnwise_compact = (quant_config_cpp.float8_block_scale_tensor_format == - Float8BlockScaleTensorFormat::COMPACT); - columnwise_option = columnwise_compact - ? FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT - : FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; + columnwise_option = FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; } quantize_transpose_vector_blockwise( input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, @@ -298,14 +291,12 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE; FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE; if (output_tensor->has_data()) { - bool rowwise_compact = (quant_config_cpp.float8_block_scale_tensor_format == - Float8BlockScaleTensorFormat::COMPACT); + const bool rowwise_compact = !output_tensor->with_gemm_swizzled_scales; rowwise_option = rowwise_compact ? FP8BlockwiseRowwiseOption::ROWWISE_COMPACT : FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY; } if (output_tensor->has_columnwise_data()) { - bool columnwise_compact = (quant_config_cpp.float8_block_scale_tensor_format == - Float8BlockScaleTensorFormat::COMPACT); + const bool columnwise_compact = !output_tensor->with_gemm_swizzled_scales; columnwise_option = columnwise_compact ? FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT : FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; diff --git a/transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh index c56ebe172ca..1cc17e9f9af 100644 --- a/transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh @@ -239,6 +239,7 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) NVTE_CHECK(is_fp8_dtype(input.columnwise_data.dtype), "Input must have FP8 type."); } + NVTE_CHECK(!input.with_gemm_swizzled_scales, "Input must have scales in compact format."); NVTE_CHECK(!is_fp8_dtype(output->data.dtype), "Output must be in higher precision."); NVTE_CHECK(output->shape() == input.shape(), "Input and output shapes need to match."); diff --git a/transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh index 4f0e1b80f70..3f4235eb1d1 100644 --- a/transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh @@ -49,9 +49,24 @@ constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4) / 1; // 128 // Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 4 = 128 / 32 +// Convert compact scale indices into GEMM swizzled scale index +__device__ __forceinline__ size_t gemm_swizzled_scale_idx(size_t i, size_t j, size_t num_tiles_X) { + constexpr size_t TILE_DIM_X = 4; // Tile dim in scale buffer + constexpr size_t TILE_DIM_Y = 128; + constexpr size_t TILE_SIZE = TILE_DIM_X * TILE_DIM_Y; + const size_t tile_idx_X = j / TILE_DIM_X; + const size_t tile_idx_Y = i / TILE_DIM_Y; + const size_t idx_in_tile_X = j % TILE_DIM_X; + const size_t idx_in_tile_Y = i % TILE_DIM_Y; + size_t idx = (tile_idx_Y * num_tiles_X + tile_idx_X) * TILE_SIZE; + idx += (idx_in_tile_Y % 32) * 16 + (idx_in_tile_Y / 32) * 4 + idx_in_tile_X; + return idx; +} + template + bool ROWWISE_SCALING, bool COLWISE_SCALING, bool WITH_GEMM_SWIZZLED_SCALES, + size_t THREADS_PER_CHUNK> __global__ void __launch_bounds__(THREADS_PER_CHUNK) quantize_gated_mxfp8_kernel(const __grid_constant__ CUtensorMap tensor_map_grad, const __grid_constant__ CUtensorMap tensor_map_input_act, @@ -355,14 +370,17 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // 2. Compute E8M0 scaling factor const e8m0_t biased_exponent_act = ptx::float_to_e8m0(thread_amax_act * Quantized_Limits::max_norm_rcp); - const size_t global_scales_offset_Y = scales_offset_Y_colwise + stage; const size_t global_scales_offset_X = scales_offset_X_colwise; - const size_t scale_idx = - global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + size_t scale_idx; + if constexpr (WITH_GEMM_SWIZZLED_SCALES) { + scale_idx = gemm_swizzled_scale_idx(global_scales_offset_X, global_scales_offset_Y, + DIVUP(rows, static_cast(128))); + } else { + scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + } const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y) >= rows; const bool out_of_bounds_colwise = row_out_of_bounds_colwise || col_out_of_bounds_colwise; - if (tid_Y_colwise == 0 && (!out_of_bounds_colwise)) { scales_colwise[scale_idx] = biased_exponent_act; } @@ -374,8 +392,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const e8m0_t biased_exponent_gate = ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits::max_norm_rcp); - // const size_t scale_idx_gate = scale_idx + scale_stride_colwise / 2; - const size_t scale_idx_gate = scale_idx + gate_scale_idx_offset_colwise; + size_t scale_idx_gate; + if constexpr (WITH_GEMM_SWIZZLED_SCALES) { + scale_idx_gate = gemm_swizzled_scale_idx( + global_scales_offset_X + gate_scale_idx_offset_colwise, global_scales_offset_Y, + DIVUP(rows, static_cast(128))); + } else { + scale_idx_gate = scale_idx + gate_scale_idx_offset_colwise; + } if (tid_Y_colwise == 0 && (!out_of_bounds_colwise)) { scales_colwise[scale_idx_gate] = biased_exponent_gate; } @@ -557,7 +581,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ptx::float_to_e8m0(thread_amax_act * Quantized_Limits::max_norm_rcp); const size_t stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; const size_t stage_scales_offset_X = scales_offset_X_rowwise; - const size_t scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; + size_t scale_idx; + if constexpr (WITH_GEMM_SWIZZLED_SCALES) { + const size_t output_cols = (IS_BWD ? 2 : 1) * cols; + scale_idx = gemm_swizzled_scale_idx(stage_scales_offset_Y, stage_scales_offset_X, + DIVUP(output_cols, static_cast(128))); + } else { + scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; + } const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y) >= rows; const bool out_of_bounds_rowwise = row_out_of_bounds_rowwise || col_out_of_bounds_rowwise; if (!out_of_bounds_rowwise) { @@ -573,7 +604,16 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) if constexpr (IS_BWD) { const e8m0_t biased_exponent_gate = ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits::max_norm_rcp); - const size_t scale_idx_gate = scale_idx + gate_scale_idx_offset_rowwise; + + size_t scale_idx_gate; + if constexpr (WITH_GEMM_SWIZZLED_SCALES) { + const size_t output_cols = (IS_BWD ? 2 : 1) * cols; + scale_idx_gate = gemm_swizzled_scale_idx( + stage_scales_offset_Y, stage_scales_offset_X + gate_scale_idx_offset_rowwise, + DIVUP(output_cols, static_cast(128))); + } else { + scale_idx_gate = scale_idx + gate_scale_idx_offset_rowwise; + } if (!out_of_bounds_rowwise) { scales_rowwise[scale_idx_gate] = biased_exponent_gate; } @@ -667,7 +707,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) parity ^= 1; destroy_barriers(mbar, is_master_thread); #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) -} +} // NOLINT(readability/fn_size) + } // namespace gated_kernel template has_data(); const bool USE_COLWISE_SCALING = output->has_columnwise_data(); + const bool with_gemm_swizzled_scales = output->with_gemm_swizzled_scales; if (USE_ROWWISE_SCALING) { NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); @@ -722,113 +764,140 @@ void quantize_gated(const Tensor &gated_input, const Tensor &grad, Tensor *outpu gated_input.dtype(), IType, TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( output->dtype(), OType, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + with_gemm_swizzled_scales, WITH_GEMM_SWIZZLED_SCALES, - alignas(64) CUtensorMap tensor_map_grad{}; - alignas(64) CUtensorMap tensor_map_input_act{}; - alignas(64) CUtensorMap tensor_map_input_gate{}; - alignas(64) CUtensorMap tensor_map_output_act_rowwise{}; - alignas(64) CUtensorMap tensor_map_output_gate_rowwise{}; - alignas(64) CUtensorMap tensor_map_output_act_colwise{}; - alignas(64) CUtensorMap tensor_map_output_gate_colwise{}; + alignas(64) CUtensorMap tensor_map_grad{}; + alignas(64) CUtensorMap tensor_map_input_act{}; + alignas(64) CUtensorMap tensor_map_input_gate{}; + alignas(64) CUtensorMap tensor_map_output_act_rowwise{}; + alignas(64) CUtensorMap tensor_map_output_gate_rowwise{}; + alignas(64) CUtensorMap tensor_map_output_act_colwise{}; + alignas(64) CUtensorMap tensor_map_output_gate_colwise{}; - constexpr size_t input_type_bit_size = TypeInfo::size; - constexpr size_t output_type_bit_size = TypeInfo::size; + constexpr size_t input_type_bit_size = TypeInfo::size; + constexpr size_t output_type_bit_size = TypeInfo::size; - if constexpr (IS_BWD) { - create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, - cols, 0, input_type_bit_size); - } + if constexpr (IS_BWD) { + create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, + cols, 0, input_type_bit_size); + } - const uint32_t tensor_stride_elems = output_cols; - create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, BUFF_DIM_Y, - BUFF_DIM_X, cols * 2, 0, input_type_bit_size); - create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, BUFF_DIM_Y, - BUFF_DIM_X, cols * 2, cols, input_type_bit_size); - - if (USE_ROWWISE_SCALING) { - create_2D_tensor_map(tensor_map_output_act_rowwise, output->data, rows, cols, - BUFF_DIM_Y, BUFF_DIM_X, tensor_stride_elems, 0, - output_type_bit_size); - create_2D_tensor_map(tensor_map_output_gate_rowwise, output->data, rows, cols, - BUFF_DIM_Y, BUFF_DIM_X, tensor_stride_elems, cols, - output_type_bit_size); - } + const uint32_t tensor_stride_elems = output_cols; + create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, BUFF_DIM_Y, + BUFF_DIM_X, cols * 2, 0, input_type_bit_size); + create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, BUFF_DIM_Y, + BUFF_DIM_X, cols * 2, cols, input_type_bit_size); + + if (USE_ROWWISE_SCALING) { + create_2D_tensor_map(tensor_map_output_act_rowwise, output->data, rows, cols, + BUFF_DIM_Y, BUFF_DIM_X, tensor_stride_elems, 0, + output_type_bit_size); + create_2D_tensor_map(tensor_map_output_gate_rowwise, output->data, rows, cols, + BUFF_DIM_Y, BUFF_DIM_X, tensor_stride_elems, cols, + output_type_bit_size); + } - if (USE_COLWISE_SCALING) { - create_2D_tensor_map(tensor_map_output_act_colwise, output->columnwise_data, rows, cols, - BUFF_DIM_Y, BUFF_DIM_X, tensor_stride_elems, 0, - output_type_bit_size); - create_2D_tensor_map(tensor_map_output_gate_colwise, output->columnwise_data, rows, - cols, BUFF_DIM_Y, BUFF_DIM_X, tensor_stride_elems, cols, - output_type_bit_size); - } + if (USE_COLWISE_SCALING) { + create_2D_tensor_map(tensor_map_output_act_colwise, output->columnwise_data, rows, + cols, BUFF_DIM_Y, BUFF_DIM_X, tensor_stride_elems, 0, + output_type_bit_size); + create_2D_tensor_map(tensor_map_output_gate_colwise, output->columnwise_data, rows, + cols, BUFF_DIM_Y, BUFF_DIM_X, tensor_stride_elems, cols, + output_type_bit_size); + } - const size_t buff_elems_total = BUFFS_NUM * BUFF_DIM_Y * BUFF_DIM_X; - const size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8; - const size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8; - const size_t buff_size_aligned_in = - DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT); - const size_t buff_size_aligned_out = - DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT); - - const size_t grad_mem = (IS_BWD ? buff_size_aligned_in : 0); - const size_t in_act_mem = buff_size_aligned_in; - const size_t in_gate_mem = buff_size_aligned_in; - const size_t in_mem = grad_mem + in_act_mem + in_gate_mem; - - const size_t out_act_mem = buff_size_aligned_out; - const size_t out_gate_mem = (IS_BWD ? buff_size_aligned_out : 0); - size_t out_mem = out_act_mem + out_gate_mem; - - if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { out_mem *= 2; } - - const size_t shmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; - - switch (scaling_type) { - case ScalingType::ROWWISE: { - auto kernel = - quantize_gated_mxfp8_kernel; - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); - - kernel<<>>( - tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, - tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, - tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr, - scales_colwise_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise, p); - break; - } - case ScalingType::COLWISE: { - auto kernel = - quantize_gated_mxfp8_kernel; - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); - - kernel<<>>( - tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, - tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, - tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr, - scales_colwise_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise, p); - break; - } - case ScalingType::BIDIMENSIONAL: { - auto kernel = - quantize_gated_mxfp8_kernel; - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); - - kernel<<>>( - tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, - tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, - tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr, - scales_colwise_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise, p); - break; - } - } NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*) - ); // NOLINT(*) + const size_t buff_elems_total = BUFFS_NUM * BUFF_DIM_Y * BUFF_DIM_X; + const size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8; + const size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8; + const size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT); + const size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT); + + const size_t grad_mem = (IS_BWD ? buff_size_aligned_in : 0); + const size_t in_act_mem = buff_size_aligned_in; + const size_t in_gate_mem = buff_size_aligned_in; + const size_t in_mem = grad_mem + in_act_mem + in_gate_mem; + + const size_t out_act_mem = buff_size_aligned_out; + const size_t out_gate_mem = (IS_BWD ? buff_size_aligned_out : 0); + size_t out_mem = out_act_mem + out_gate_mem; + + if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { out_mem *= 2; } + + const size_t shmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; + + // Zero out swizzled scales if padding is needed + /// TODO (tmoon) Handle this within the cast kernel + if (with_gemm_swizzled_scales) { + constexpr size_t TILE_DIM_X = 128; // Tile dim in data buffer + constexpr size_t TILE_DIM_Y = 128; + if (cols % TILE_DIM_X != 0 || rows % TILE_DIM_Y != 0) { + if (USE_ROWWISE_SCALING) { + NVTE_CHECK_CUDA(cudaMemsetAsync(output->scale_inv.dptr, 0, + output->scale_inv.buffer_size_bytes(), stream)); + } + if (USE_COLWISE_SCALING) { + NVTE_CHECK_CUDA( + cudaMemsetAsync(output->columnwise_scale_inv.dptr, 0, + output->columnwise_scale_inv.buffer_size_bytes(), stream)); + } + } + } + + switch (scaling_type) { + case ScalingType::ROWWISE: { + auto kernel = + quantize_gated_mxfp8_kernel; + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); + + kernel<<>>( + tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, + tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, + tensor_map_output_act_colwise, tensor_map_output_gate_colwise, + scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise, p); + break; + } + case ScalingType::COLWISE: { + auto kernel = + quantize_gated_mxfp8_kernel; + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); + + kernel<<>>( + tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, + tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, + tensor_map_output_act_colwise, tensor_map_output_gate_colwise, + scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise, p); + break; + } + case ScalingType::BIDIMENSIONAL: { + auto kernel = + quantize_gated_mxfp8_kernel; + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); + + kernel<<>>( + tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, + tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, + tensor_map_output_act_colwise, tensor_map_output_gate_colwise, + scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise, p); + break; + } + } NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*) + ); // NOLINT(*) + ); // NOLINT(*) } } // namespace mxfp8 diff --git a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh index cbb46f3f28d..f7e3f74d501 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh @@ -41,9 +41,24 @@ constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4) / 1; // 128 // Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 4 = 128 / 32 +// Convert compact scale indices into GEMM swizzled scale index +__device__ __forceinline__ size_t gemm_swizzled_scale_idx(size_t i, size_t j, size_t num_tiles_X) { + constexpr size_t TILE_DIM_X = 4; // Tile dim in scale buffer + constexpr size_t TILE_DIM_Y = 128; + constexpr size_t TILE_SIZE = TILE_DIM_X * TILE_DIM_Y; + const size_t tile_idx_X = j / TILE_DIM_X; + const size_t tile_idx_Y = i / TILE_DIM_Y; + const size_t idx_in_tile_X = j % TILE_DIM_X; + const size_t idx_in_tile_Y = i % TILE_DIM_Y; + size_t idx = (tile_idx_Y * num_tiles_X + tile_idx_X) * TILE_SIZE; + idx += (idx_in_tile_Y % 32) * 16 + (idx_in_tile_Y / 32) * 4 + idx_in_tile_X; + return idx; +} + template + bool COLWISE_SCALING, bool WITH_GEMM_SWIZZLED_SCALES, size_t CHUNK_DIM_Y, + size_t CHUNK_DIM_X, size_t THREADS_PER_CHUNK> __global__ void __launch_bounds__(THREADS_PER_CHUNK) quantize_mxfp8_kernel(const __grid_constant__ CUtensorMap tensor_map_input, const __grid_constant__ CUtensorMap tensor_map_act_input, @@ -106,7 +121,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const size_t scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; const size_t scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; - const bool rowwise_scale_is_within_bounds = scales_offset_X_rowwise < cols; + const bool rowwise_scale_is_within_bounds = SCALE_DIM_X * scales_offset_X_rowwise < cols; // helps resolving bank conflicts in shmem const int thread_lane = threadIdx.x % THREADS_PER_WARP; @@ -263,11 +278,15 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // 2. Compute E8M0 scaling factor const e8m0_t biased_exponent = ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); - const size_t global_scales_offset_Y = scales_offset_Y_colwise + stage; const size_t global_scales_offset_X = scales_offset_X_colwise; - const size_t scale_idx = - global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + size_t scale_idx; + if constexpr (WITH_GEMM_SWIZZLED_SCALES) { + scale_idx = gemm_swizzled_scale_idx(global_scales_offset_X, global_scales_offset_Y, + DIVUP(rows, static_cast(128))); + } else { + scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + } scales_colwise[scale_idx] = biased_exponent; const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); @@ -411,7 +430,13 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; const int stage_scales_offset_X = scales_offset_X_rowwise; - const int scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; + size_t scale_idx; + if constexpr (WITH_GEMM_SWIZZLED_SCALES) { + scale_idx = gemm_swizzled_scale_idx(stage_scales_offset_Y, stage_scales_offset_X, + DIVUP(cols, static_cast(128))); + } else { + scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; + } if (rowwise_scale_is_within_bounds) { scales_rowwise[scale_idx] = biased_exponent; } @@ -550,7 +575,6 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, bool use_colwise_scaling = output->has_columnwise_data(); NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data."); NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); - if (use_rowwise_scaling) { NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); } @@ -560,17 +584,21 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, } CheckNoopTensor(*noop, "cast_noop"); + constexpr bool CAST_DBIAS_ONLY = IS_DBIAS && (!IS_DACT) && (!IS_ACT); + + // Tensor dimensions const size_t rows = input.flat_first_dim(); const size_t cols = input.flat_last_dim(); - constexpr bool CAST_DBIAS_ONLY = IS_DBIAS && (!IS_DACT) && (!IS_ACT); - + // Tensor chunk handled by each CUDA block constexpr size_t CHUNK_DIM_Y = CAST_DBIAS_ONLY ? 128 : 64; constexpr size_t CHUNK_DIM_X = CAST_DBIAS_ONLY ? 128 : 64; - constexpr size_t THREADS_PER_CHUNK = CAST_DBIAS_ONLY ? 128 : 64; + // CUDA block config + constexpr size_t THREADS_PER_CHUNK = CAST_DBIAS_ONLY ? 128 : 64; constexpr size_t THREADS_X = CHUNK_DIM_X / SCALE_DIM_X; constexpr size_t THREADS_Y = THREADS_PER_CHUNK / THREADS_X; + constexpr size_t BUFF_DIM_Y = THREADS_Y; constexpr size_t BUFF_DIM_X = CHUNK_DIM_X; @@ -579,6 +607,8 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, const dim3 grid(blocks_X, blocks_Y); const size_t block_size = THREADS_PER_CHUNK; + const bool with_gemm_swizzled_scales = output->with_gemm_swizzled_scales; + const size_t scale_stride_rowwise = use_rowwise_scaling ? output->scale_inv.shape[1] : 1; const size_t scale_stride_colwise = use_colwise_scaling ? output->columnwise_scale_inv.shape[1] : 1; @@ -619,168 +649,195 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, input.dtype(), IType, TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( output->dtype(), OType, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + with_gemm_swizzled_scales, WITH_GEMM_SWIZZLED_SCALES, + + if (specialized::hasSpec() && + !WITH_GEMM_SWIZZLED_SCALES) { + switch (scaling_type) { + case ScalingType::ROWWISE: { + using traits = specialized::CastTraits; + auto kernel = specialized::quantize_mxfp8_kernel_cast_only; + + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + traits::smem); + + dim3 block(traits::threadLayout::num, traits::warpLayout::N, + traits::warpLayout::M); + dim3 grid((cols + traits::blockDimN - 1) / traits::blockDimN, + (rows + traits::blockDimM - 1) / traits::blockDimM); + kernel<<>>( + reinterpret_cast(input.data.dptr), + reinterpret_cast(output->data.dptr), + scales_rowwise_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise); + + break; + } + case ScalingType::COLWISE: { + NVTE_WARN("Colwise scaling will fallback to original kernel."); + break; + } + case ScalingType::BIDIMENSIONAL: { + using traits = specialized::CastTraits; + auto kernel = specialized::quantize_mxfp8_kernel_cast_only; + + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + traits::smem); + // TMA for loading, so that we don't need STS for transposing + alignas(64) CUtensorMap tensor_map_input{}; + constexpr size_t input_type_bit_size = TypeInfo::size; + create_2D_tensor_map(tensor_map_input, input.data, rows, cols, + traits::blockIterDim::M, traits::blockIterDim::N, + /*stride_elems=*/cols, + /*offset_elems=*/0, input_type_bit_size, + traits::input_swizzle_pattern); + + alignas(64) CUtensorMap tensor_map_rowwise_output{}; + alignas(64) CUtensorMap tensor_map_colwise_output{}; + constexpr size_t output_type_bit_size = TypeInfo::size; + create_2D_tensor_map(tensor_map_rowwise_output, output->data, rows, cols, + traits::blockIterDim::M, traits::blockIterDim::N, + /*stride_elems=*/cols, + /*offset_elems=*/0, output_type_bit_size, + traits::output_swizzle_pattern); + create_2D_tensor_map(tensor_map_colwise_output, output->columnwise_data, rows, + cols, traits::blockIterDim::M, traits::blockIterDim::N, + cols, 0, output_type_bit_size, + traits::output_swizzle_pattern); + + dim3 block(traits::rowThreadLayout::num, traits::numWarps); + dim3 grid((cols + traits::blockDIM::N - 1) / traits::blockDIM::N, + (rows + traits::blockDIM::M - 1) / traits::blockDIM::M); + kernel<<>>( + tensor_map_input, tensor_map_rowwise_output, tensor_map_colwise_output, + scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise); + + break; + } + default: { + NVTE_ERROR("Invalid scaling type."); + } + } + return; + } - if (specialized::hasSpec()) { - switch (scaling_type) { - case ScalingType::ROWWISE: { - using traits = specialized::CastTraits; - auto kernel = specialized::quantize_mxfp8_kernel_cast_only; + alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_act_input{}; + alignas(64) CUtensorMap tensor_map_output_rowwise{}; + alignas(64) CUtensorMap tensor_map_output_colwise{}; - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - traits::smem); + constexpr size_t input_type_bit_size = TypeInfo::size; + constexpr size_t output_type_bit_size = TypeInfo::size; - dim3 block(traits::threadLayout::num, traits::warpLayout::N, traits::warpLayout::M); - dim3 grid((cols + traits::blockDimN - 1) / traits::blockDimN, - (rows + traits::blockDimM - 1) / traits::blockDimM); - kernel<<>>( - reinterpret_cast(input.data.dptr), - reinterpret_cast(output->data.dptr), - scales_rowwise_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise); + create_2D_tensor_map(tensor_map_input, input.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, + cols, 0, input_type_bit_size); - break; - } - case ScalingType::COLWISE: { - NVTE_WARN("Colwise scaling will fallback to original kernel."); - break; - } - case ScalingType::BIDIMENSIONAL: { - using traits = specialized::CastTraits; - auto kernel = specialized::quantize_mxfp8_kernel_cast_only; - - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - traits::smem); - // TMA for loading, so that we don't need STS for transposing - alignas(64) CUtensorMap tensor_map_input{}; - constexpr size_t input_type_bit_size = TypeInfo::size; - create_2D_tensor_map(tensor_map_input, input.data, rows, cols, - traits::blockIterDim::M, traits::blockIterDim::N, - /*stride_elems=*/cols, - /*offset_elems=*/0, input_type_bit_size, - traits::input_swizzle_pattern); - - alignas(64) CUtensorMap tensor_map_rowwise_output{}; - alignas(64) CUtensorMap tensor_map_colwise_output{}; - constexpr size_t output_type_bit_size = TypeInfo::size; - create_2D_tensor_map(tensor_map_rowwise_output, output->data, rows, cols, - traits::blockIterDim::M, traits::blockIterDim::N, - /*stride_elems=*/cols, - /*offset_elems=*/0, output_type_bit_size, - traits::output_swizzle_pattern); - create_2D_tensor_map(tensor_map_colwise_output, output->columnwise_data, rows, cols, - traits::blockIterDim::M, traits::blockIterDim::N, cols, 0, - output_type_bit_size, traits::output_swizzle_pattern); - - dim3 block(traits::rowThreadLayout::num, traits::numWarps); - dim3 grid((cols + traits::blockDIM::N - 1) / traits::blockDIM::N, - (rows + traits::blockDIM::M - 1) / traits::blockDIM::M); - kernel<<>>( - tensor_map_input, tensor_map_rowwise_output, tensor_map_colwise_output, - scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise); - - break; + if constexpr (IS_DACT) { + create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols, BUFF_DIM_Y, + BUFF_DIM_X, cols, 0, input_type_bit_size); } - default: { - NVTE_ERROR("Invalid scaling type."); - } - } - return; - } - - alignas(64) CUtensorMap tensor_map_input{}; - alignas(64) CUtensorMap tensor_map_act_input{}; - alignas(64) CUtensorMap tensor_map_output_rowwise{}; - alignas(64) CUtensorMap tensor_map_output_colwise{}; - - constexpr size_t input_type_bit_size = TypeInfo::size; - constexpr size_t output_type_bit_size = TypeInfo::size; - create_2D_tensor_map(tensor_map_input, input.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, - cols, 0, input_type_bit_size); - - if constexpr (IS_DACT) { - create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols, BUFF_DIM_Y, - BUFF_DIM_X, cols, 0, input_type_bit_size); - } + if (use_rowwise_scaling) { + create_2D_tensor_map(tensor_map_output_rowwise, output->data, rows, cols, + BUFF_DIM_Y, BUFF_DIM_X, cols, 0, output_type_bit_size); + } - if (use_rowwise_scaling) { - create_2D_tensor_map(tensor_map_output_rowwise, output->data, rows, cols, BUFF_DIM_Y, - BUFF_DIM_X, cols, 0, output_type_bit_size); - } + if (use_colwise_scaling) { + create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, rows, cols, + BUFF_DIM_Y, BUFF_DIM_X, cols, 0, output_type_bit_size); + } - if (use_colwise_scaling) { - create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, rows, cols, - BUFF_DIM_Y, BUFF_DIM_X, cols, 0, output_type_bit_size); - } + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + constexpr size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8; + constexpr size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8; + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT); + + constexpr size_t elt_input_mem = buff_size_aligned_in; + constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); + constexpr size_t in_mem = elt_input_mem + act_input_mem; + + const size_t out_rowwise_mem = (use_rowwise_scaling ? buff_size_aligned_out : 0); + const size_t out_colwise_mem = (use_colwise_scaling ? buff_size_aligned_out : 0); + const size_t out_mem = out_rowwise_mem + out_colwise_mem; + + const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; + + // Zero out swizzled scales if padding is needed + /// TODO (tmoon) Handle this within the cast kernel + if (with_gemm_swizzled_scales) { + constexpr size_t TILE_DIM_X = 128; // Tile dim in data buffer + constexpr size_t TILE_DIM_Y = 128; + if (cols % TILE_DIM_X != 0 || rows % TILE_DIM_Y != 0) { + if (use_rowwise_scaling) { + NVTE_CHECK_CUDA(cudaMemsetAsync(output->scale_inv.dptr, 0, + output->scale_inv.buffer_size_bytes(), stream)); + } + if (use_colwise_scaling) { + NVTE_CHECK_CUDA( + cudaMemsetAsync(output->columnwise_scale_inv.dptr, 0, + output->columnwise_scale_inv.buffer_size_bytes(), stream)); + } + } + } - constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; - constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; - constexpr size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8; - constexpr size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8; - constexpr size_t buff_size_aligned_in = - DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT); - constexpr size_t buff_size_aligned_out = - DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT); - - constexpr size_t elt_input_mem = buff_size_aligned_in; - constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); - constexpr size_t in_mem = elt_input_mem + act_input_mem; - - const size_t out_rowwise_mem = (use_rowwise_scaling ? buff_size_aligned_out : 0); - const size_t out_colwise_mem = (use_colwise_scaling ? buff_size_aligned_out : 0); - const size_t out_mem = out_rowwise_mem + out_colwise_mem; - - const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; - - switch (scaling_type) { - case ScalingType::ROWWISE: { - auto kernel = - quantize_mxfp8_kernel; - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); - - kernel<<>>( - tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, - tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, - workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise); - NVTE_CHECK_CUDA(cudaGetLastError()); - break; - } - case ScalingType::COLWISE: { - auto kernel = - quantize_mxfp8_kernel; - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); - - kernel<<>>( - tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, - tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, - workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise); - NVTE_CHECK_CUDA(cudaGetLastError()); - break; - } - case ScalingType::BIDIMENSIONAL: { - auto kernel = - quantize_mxfp8_kernel; - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); - - kernel<<>>( - tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, - tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, - workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise); - NVTE_CHECK_CUDA(cudaGetLastError()); - break; - } - } + switch (scaling_type) { + case ScalingType::ROWWISE: { + auto kernel = quantize_mxfp8_kernel; + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); + + kernel<<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, + workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise); + NVTE_CHECK_CUDA(cudaGetLastError()); + break; + } + case ScalingType::COLWISE: { + auto kernel = quantize_mxfp8_kernel; + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); + + kernel<<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, + workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise); + NVTE_CHECK_CUDA(cudaGetLastError()); + break; + } + case ScalingType::BIDIMENSIONAL: { + auto kernel = quantize_mxfp8_kernel; + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); + + kernel<<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, + workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise); + NVTE_CHECK_CUDA(cudaGetLastError()); + break; + } + } - if constexpr (IS_DBIAS) { - common::reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); - }); // NOLINT(*) - ); // NOLINT(*) + if constexpr (IS_DBIAS) { + common::reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); + }); // NOLINT(*) + ); // NOLINT(*) + ); // NOLINT(*) } } // namespace mxfp8 diff --git a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh index 5307cad37f0..ed36a90a3a7 100644 --- a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh @@ -80,6 +80,7 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) CheckInputTensor(input, "input"); CheckOutputTensor(*output, "output"); NVTE_CHECK(input.data.dtype == DType::kFloat4E2M1, "Input must have FP4 type."); + NVTE_CHECK(!input.with_gemm_swizzled_scales, "Input must have scales in compact format."); NVTE_CHECK(is_high_precision_dtype(output->data.dtype), "Output must be in higher precision."); NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); diff --git a/transformer_engine/common/cast/nvfp4/quantize_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_nvfp4.cuh index 83ad8fd40bd..24641f3d249 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_nvfp4.cuh @@ -142,17 +142,17 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) constexpr size_t buff_size_aligned_out_mxfp8 = DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); - constexpr size_t buff_size_nvfp4_scales = - CHUNK_DIM_Y * (CHUNK_DIM_X / SCALE_DIM_X) * sizeof(fp8e4m3); - constexpr size_t buff_size_mxfp8_scales = - (CHUNK_DIM_Y / SCALE_DIM_Y) * CHUNK_DIM_X * sizeof(fp8e8m0); + // constexpr size_t buff_size_nvfp4_scales = + // CHUNK_DIM_Y * (CHUNK_DIM_X / SCALE_DIM_X) * sizeof(fp8e4m3); + // constexpr size_t buff_size_mxfp8_scales = + // (CHUNK_DIM_Y / SCALE_DIM_Y) * CHUNK_DIM_X * sizeof(fp8e8m0); constexpr size_t in_mem = buff_size_aligned_in; constexpr size_t out_mem_rowwise_data = (ROWWISE_SCALING ? buff_size_aligned_out_nvfp4 : 0); constexpr size_t out_mem_colwise_data = (COLWISE_SCALING ? buff_size_aligned_out_mxfp8 : 0); - constexpr size_t out_mem_rowwise_scales = (ROWWISE_SCALING ? buff_size_nvfp4_scales : 0); - constexpr size_t out_mem_colwise_scales = (COLWISE_SCALING ? buff_size_mxfp8_scales : 0); + // constexpr size_t out_mem_rowwise_scales = (ROWWISE_SCALING ? buff_size_nvfp4_scales : 0); + // constexpr size_t out_mem_colwise_scales = (COLWISE_SCALING ? buff_size_mxfp8_scales : 0); extern __shared__ char dynamic_shmem[]; uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); @@ -167,8 +167,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) OType *out_colwise_data_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise_data); fp8e4m3 *out_rowwise_scales_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data); - e8m0_t *out_colwise_scales_sh = reinterpret_cast( - dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data + out_mem_rowwise_scales); + (void)out_rowwise_scales_sh; // Suppress unused variable warning + // e8m0_t *out_colwise_scales_sh = reinterpret_cast( + // dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data + out_mem_rowwise_scales); IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer constexpr int shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; @@ -557,6 +558,7 @@ inline void quantize(const Tensor &input, const Tensor *noop, Tensor *output, cu NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type."); NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); + NVTE_CHECK(!output->with_gemm_swizzled_scales, "Output must have scales in compact format."); bool use_colwise_scaling = output->has_columnwise_data(); if (use_colwise_scaling) { diff --git a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh index 7322bf2655d..f7ca307778c 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh @@ -1179,6 +1179,7 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, NVTE_CHECK(output->has_data(), "NVFP4 output tensor must be allocated."); NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type."); NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); + NVTE_CHECK(!output->with_gemm_swizzled_scales, "Output must have scales in compact format."); if (return_transpose) { NVTE_CHECK(output->has_columnwise_data(), "NVFP4 transposed output tensor must be allocated."); NVTE_CHECK(is_fp4_dtype(output->columnwise_data.dtype), diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index 56369db27fa..4762a3f186f 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -172,7 +172,17 @@ CommOverlapCore::~CommOverlapCore() { TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, size_t chunk_offset, const std::vector &chunk_shape) { + // Check tensor format const auto scaling_mode = source.scaling_mode(); + NVTE_CHECK(scaling_mode == NVTE_DELAYED_TENSOR_SCALING || scaling_mode == NVTE_MXFP8_1D_SCALING, + "Unsupported tensor format (", to_string(scaling_mode), ")."); + if (scaling_mode == NVTE_MXFP8_1D_SCALING) { + bool has_swizzled_scales = false; + nvte_get_tensor_param_v2(source.data(), NVTETensorParam::kNVTEWithGEMMSwizzledScales, + &has_swizzled_scales, sizeof(has_swizzled_scales), nullptr); + NVTE_CHECK(has_swizzled_scales, + "Expected MXFP8 tensor to have scales in GEMM swizzled format."); + } // Tensor dimensions std::vector shape = shape_to_vector(source.shape()); diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 0e264eaae34..91ee5a9556b 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -133,6 +133,20 @@ struct Tensor { NVTEScalingMode scaling_mode; NVTETensor nvte_tensor; + /*! Whether scaling factors are in format expected by GEMM */ + bool with_gemm_swizzled_scales = false; + + /*! Map from NVTETensorParam to parameter sizes */ + static constexpr size_t attr_sizes[] = { + sizeof(NVTEBasicTensor), // kNVTERowwiseData + sizeof(NVTEBasicTensor), // kNVTEColumnwiseData + sizeof(NVTEBasicTensor), // kNVTEScale + sizeof(NVTEBasicTensor), // kNVTEAmax + sizeof(NVTEBasicTensor), // kNVTERowwiseScaleInv + sizeof(NVTEBasicTensor), // kNVTEColumnwiseScaleInv + sizeof(NVTEBasicTensor), // kNVTEColumnwiseAmax + sizeof(bool) // kNVTEWithGEMMSwizzledScales + }; Tensor() : scaling_mode{NVTE_DELAYED_TENSOR_SCALING}, nvte_tensor{0} {} @@ -146,6 +160,7 @@ struct Tensor { scale_inv.clear(); columnwise_scale_inv.clear(); scaling_mode = NVTE_DELAYED_TENSOR_SCALING; + with_gemm_swizzled_scales = false; } explicit operator NVTETensor() const noexcept { return nvte_tensor; } @@ -389,8 +404,6 @@ struct QuantizationConfig { bool force_pow_2_scales = false; float amax_epsilon = 0.0f; NVTETensor noop_tensor = nullptr; - Float8BlockScaleTensorFormat float8_block_scale_tensor_format = - Float8BlockScaleTensorFormat::GEMM_READY; NVTETensor rng_state = nullptr; bool nvfp4_2d_quantization = false; bool stochastic_rounding = false; @@ -400,7 +413,7 @@ struct QuantizationConfig { sizeof(bool), // force_pow_2_scales sizeof(float), // amax_epsilon sizeof(NVTETensor), // noop_tensor - sizeof(Float8BlockScaleTensorFormat), // float8_block_scale_tensor_format + sizeof(Float8BlockScaleTensorFormat), // (deprecated) sizeof(NVTETensor), // rng_seed and offset sizeof(bool), // nvfp4_2d_quantization sizeof(bool), // stochastic_rounding diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 118bf193353..67a45c48084 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -503,6 +503,14 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, #if CUBLAS_VERSION >= 120800 NVTE_CHECK(cublas_version() >= 120800, "MXFP8 requires cuBLAS 12.8+, but run-time cuBLAS version is ", cublas_version()); + + // Check that scales are in expected format + NVTE_CHECK(inputA->with_gemm_swizzled_scales, + "MXFP8 scales are not in format expected by GEMM"); + NVTE_CHECK(inputB->with_gemm_swizzled_scales, + "MXFP8 scales are not in format expected by GEMM"); + + // Configure cuBLAS scales fp8e8m0 *A_scale_inverse = reinterpret_cast(param.A_scale_inv); fp8e8m0 *B_scale_inverse = reinterpret_cast(param.B_scale_inv); NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, @@ -513,6 +521,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, &B_scale_inverse, sizeof(B_scale_inverse))); scaling_mode_a = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0; scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0; + // Workaround for heuristic cache bug in cublasLt. This separates the MXFP8 cache key from non-block scaling. // CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE is unused for block scaling so it's safe to set. if (cublas_version() <= 120803) { @@ -529,17 +538,22 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, #if CUBLAS_VERSION >= 120800 NVTE_CHECK(cublas_version() >= 120800, "FP4 requires cuBLAS 12.8+, but run-time cuBLAS version is ", cublas_version()); - // make sure alpha beta computation dtype remains fp32 by CUBLASLT_MATMUL_DESC_SCALE_TYPE - cublasDataType_t scale_type = CUDA_R_32F; + + // Check that scales are in expected format + NVTE_CHECK(inputA->with_gemm_swizzled_scales, + "NVFP4 block scales are not in format expected by GEMM"); + NVTE_CHECK(inputB->with_gemm_swizzled_scales, + "NVFP4 block scales are not in format expected by GEMM"); + + // alpha and beta are device pointers to FP32 + const cublasDataType_t scale_type = CUDA_R_32F; NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( operationDesc, CUBLASLT_MATMUL_DESC_SCALE_TYPE, &scale_type, sizeof(scale_type))); - - // Set pointer mode: alpha and beta are both device pointers - // https://docs.nvidia.com/cuda/cublas/#cublasltpointermode-t - cublasLtPointerMode_t pointer_mode = CUBLASLT_POINTER_MODE_DEVICE; + const cublasLtPointerMode_t pointer_mode = CUBLASLT_POINTER_MODE_DEVICE; NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( operationDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointer_mode, sizeof(pointer_mode))); + // Configure cuBLAS scales fp8e4m3 *A_scale_inverse = reinterpret_cast(param.A_scale_inv); fp8e4m3 *B_scale_inverse = reinterpret_cast(param.B_scale_inv); NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, @@ -561,6 +575,13 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, NVTE_CHECK(cublas_version() >= 120900, "FP8 block scaling requires cuBLAS 12.9+, but run-time cuBLAS version is ", cublas_version()); + + // Check that matrix formats are valid + NVTE_CHECK((!(inputA->scaling_mode == NVTE_BLOCK_SCALING_2D && + inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)), + "Only 1D by 1D, 1D by 2D, and 2D by 1D block scaling supported, but got 2D by 2D"); + + // Configure cuBLAS scales float *A_scale_inverse = reinterpret_cast(param.A_scale_inv); float *B_scale_inverse = reinterpret_cast(param.B_scale_inv); NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, @@ -569,9 +590,6 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &B_scale_inverse, sizeof(B_scale_inverse))); - NVTE_CHECK((!(inputA->scaling_mode == NVTE_BLOCK_SCALING_2D && - inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)), - "Only 1D by 1D, 1D by 2D, and 2D by 1D block scaling supported, but got 2D by 2D"); scaling_mode_a = inputA->scaling_mode == NVTE_BLOCK_SCALING_1D ? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F : CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F; diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu index 11325041aec..aec5264c6d0 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu @@ -739,6 +739,7 @@ void hadamard_transform_cast_fusion_columnwise(const Tensor &input_, Tensor &out NVTE_CHECK(input_.dtype() == transformer_engine::DType::kBFloat16, "Input tensor must be BF16 tensor, but dtype is ", to_string(input_.dtype()), "."); NVTE_CHECK(input_.dim() >= 2, "Input must be a 2D tensor."); + NVTE_CHECK(!output_.with_gemm_swizzled_scales, "Output must have scales in compact format."); const SimpleTensor &input = input_.data; SimpleTensor &global_amax = output_.amax; SimpleTensor &output_t = output_.data; diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index 950014cc9be..677790992ad 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -57,24 +57,24 @@ NVTEMatmulConfig nvte_create_matmul_config(); /*! \brief Query an option in matrix multiplication configuration. * - * \param[in] config Matrix multiplication configuration. - * \param[in] attr Option type. - * \param[out] buf Memory address to write option value. Ignored if - * NULL. - * \param[in] size_in_bytes Size of buf. - * \param[out] size_written Number of bytes that have been written to - * buf. If buf is NULL, then the number of - * bytes that would have been written. + * \param[in] config Matrix multiplication configuration. + * \param[in] attr Option type. + * \param[out] buf Memory address to write option value. + * Ignored if NULL. + * \param[in] size_in_bytes Size of buf. + * \param[out] size_written Number of bytes that have been written to + * buf. If buf is NULL, then the number of + * bytes that would have been written. */ void nvte_get_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigAttribute attr, void *buf, size_t size_in_bytes, size_t *size_written); /*! \brief Set an option in matrix multiplication configuration. * - * \param[in] config Matrix multiplication configuration. - * \param[in] attr Option type. - * \param[out] buf Memory address to read option value. - * \param[in] size_in_bytes Size of buf. + * \param[in/out] config Matrix multiplication configuration. + * \param[in] attr Option type. + * \param[in] buf Memory address to read option value. + * \param[in] size_in_bytes Size of buf. */ void nvte_set_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigAttribute attr, const void *buf, size_t size_in_bytes); diff --git a/transformer_engine/common/include/transformer_engine/swizzle.h b/transformer_engine/common/include/transformer_engine/swizzle.h index 624e71d1e3a..b429dc78388 100644 --- a/transformer_engine/common/include/transformer_engine/swizzle.h +++ b/transformer_engine/common/include/transformer_engine/swizzle.h @@ -46,7 +46,7 @@ void nvte_multi_tensor_swizzle_scaling_factors(const NVTETensor* inputs, NVTETen /*! \brief Swizzling FP8 block scaling scaling factors into mxfp8 interleaved layout for GEMM * - * \param[in] input Input FP8 block scaling tensor with GEMM_READY scale_inv. + * \param[in] input Input FP8 block-scaled tensor. * \param[in,out] output Output mxfp8 tensor which hosts swizzled scale_inv. * \param[in] stream CUDA stream used for the operation. * @@ -56,7 +56,6 @@ void nvte_multi_tensor_swizzle_scaling_factors(const NVTETensor* inputs, NVTETen * Requirements: * - input is an FP8 block scaling tensor * - input has rowwise usage - * - input.scale_inv is in GEMM_READY format * - output is an MXFP8 tensor * - output has rowwise usage * - output.scale_inv has appropriate shape diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index 19cb646be29..ddd30d8b537 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -60,13 +60,14 @@ struct NVTEBasicTensor { * \brief Indicates the kind of the tensor parameter to set/get. */ enum NVTETensorParam { - kNVTERowwiseData = 0, /*!< Data usable in rowwise manner */ - kNVTEColumnwiseData = 1, /*!< Data usable in columnwise manner */ - kNVTEScale = 2, /*!< Scale tensor */ - kNVTEAmax = 3, /*!< Amax tensor */ - kNVTERowwiseScaleInv = 4, /*!< Scale inverse tensor for decoding Rowwise Data */ - kNVTEColumnwiseScaleInv = 5, /*!< Scale inverse tensor for decoding Columnwise Data */ - kNVTEColumnwiseAmax = 6, /*!< Columnwise Amax tensor */ + kNVTERowwiseData = 0, /*!< Data usable in rowwise manner */ + kNVTEColumnwiseData = 1, /*!< Data usable in columnwise manner */ + kNVTEScale = 2, /*!< Scale tensor */ + kNVTEAmax = 3, /*!< Amax tensor */ + kNVTERowwiseScaleInv = 4, /*!< Scale inverse tensor for decoding Rowwise Data */ + kNVTEColumnwiseScaleInv = 5, /*!< Scale inverse tensor for decoding Columnwise Data */ + kNVTEColumnwiseAmax = 6, /*!< Columnwise Amax tensor */ + kNVTEWithGEMMSwizzledScales = 7, /*!< Whether scaling factors are in format expected by GEMM */ kNVTENumTensorParams }; @@ -264,6 +265,9 @@ NVTEShape nvte_tensor_scale_inv_shape(const NVTETensor tensor); void nvte_zero_tensor(const NVTETensor tensor, cudaStream_t stream); /*! \brief Set a parameter of the tensor. + * + * This only supports tensor parameters of type NVTEBasicTensor. Use + * nvte_set_tensor_param_v2 for other parameter types. * * \param[in/out] tensor Tensor. * \param[in] param_name The parameter to be set. @@ -273,12 +277,39 @@ void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name, const NVTEBasicTensor *param); /*! \brief Get a value of the parameter of the tensor. + * + * This only supports tensor parameters of type NVTEBasicTensor. Use + * nvte_get_tensor_param_v2 for other parameter types. * * \param[in] tensor Tensor. * \param[in] param_name The parameter to be set. */ NVTEBasicTensor nvte_get_tensor_param(const NVTETensor tensor, NVTETensorParam param_name); +/*! \brief Set a tensor parameter. + * + * \param[in/out] tensor Tensor. + * \param[in] param Tensor parameter type. + * \param[in] buf Memory address to read parameter value. + * \param[in] size_in_bytes Size of buf. + */ +void nvte_set_tensor_param_v2(NVTETensor tensor, NVTETensorParam param, const void *buf, + size_t size_in_bytes); + +/*! \brief Query a tensor parameter. + * + * \param[in] tensor Tensor. + * \param[in] param Tensor parameter type. + * \param[out] buf Memory address to write parameter value. + * Ignored if NULL. + * \param[in] size_in_bytes Size of buf. + * \param[out] size_written Number of bytes that have been written to + * buf. If buf is NULL, then the number of + * bytes that would have been written. + */ +void nvte_get_tensor_param_v2(const NVTETensor tensor, NVTETensorParam param, void *buf, + size_t size_in_bytes, size_t *size_written); + /*! \brief Get the granularity of scaling of this tensor. * * \param[in] tensor Tensor. @@ -324,12 +355,7 @@ enum NVTEQuantizationConfigAttribute { conditional early even when captured in a static CUDA graph. */ kNVTEQuantizationConfigNoopTensor = 2, - /*! Data format for an FP8 block-scaled tensor - * - * This is not the right design since the tensor format is a - * property of the tensor, not the quantization. This enum will - * likely be refactored away in the future. - */ + /*! \warning Deprecated */ kNVTEQuantizationConfigFloat8BlockScaleTensorFormat = 3, /*! RNG state (NVTETensor with 2 elements - seed and offset */ kNVTEQuantizationConfigRNGState = 4, @@ -353,14 +379,14 @@ NVTEQuantizationConfig nvte_create_quantization_config(); /*! \brief Query an option in quantization config. * - * \param[in] config Quantization config. - * \param[in] attr Option type. - * \param[out] buf Memory address to write option value. Ignored if - * NULL. - * \param[in] size_in_bytes Size of buf. - * \param[out] size_written Number of bytes that have been written to - * buf. If buf is NULL, then the number of - * bytes that would have been written. + * \param[in] config Quantization config. + * \param[in] attr Option type. + * \param[out] buf Memory address to write option value. + * Ignored if NULL. + * \param[in] size_in_bytes Size of buf. + * \param[out] size_written Number of bytes that have been written to + * buf. If buf is NULL, then the number of + * bytes that would have been written. */ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config, NVTEQuantizationConfigAttribute attr, void *buf, @@ -368,10 +394,10 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config, /*! \brief Set an option in quantization config. * - * \param[in] config Quantization config. - * \param[in] attr Option type. - * \param[out] buf Memory address to read option value. - * \param[in] size_in_bytes Size of buf. + * \param[in/out] config Quantization config. + * \param[in] attr Option type. + * \param[in] buf Memory address to read option value. + * \param[in] size_in_bytes Size of buf. */ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config, NVTEQuantizationConfigAttribute attr, const void *buf, @@ -586,20 +612,20 @@ class TensorWrapper { const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) { tensor_ = nvte_create_tensor(scaling_mode); NVTEBasicTensor data = {dptr, static_cast(dtype), shape}; - nvte_set_tensor_param(&tensor_, kNVTERowwiseData, &data); + nvte_set_tensor_param_v2(tensor_, kNVTERowwiseData, &data, sizeof(data)); NVTEBasicTensor amax = {amax_dptr, kNVTEFloat32, amax_dptr != nullptr ? defaultShape : emptyShape}; - nvte_set_tensor_param(&tensor_, kNVTEAmax, &amax); + nvte_set_tensor_param_v2(tensor_, kNVTEAmax, &amax, sizeof(amax)); NVTEBasicTensor scale = {scale_dptr, kNVTEFloat32, scale_dptr != nullptr ? defaultShape : emptyShape}; - nvte_set_tensor_param(&tensor_, kNVTEScale, &scale); + nvte_set_tensor_param_v2(tensor_, kNVTEScale, &scale, sizeof(scale)); if (scale_inv_dptr == nullptr && scale_inv_shape.ndim == defaultShape.ndim && scale_inv_shape.ndim == 1 && scale_inv_shape.data[0] == defaultShape.data[0]) { // Scale-inv pointer has not been provided and shape matches default scale_inv_shape = emptyShape; } NVTEBasicTensor scale_inv = {scale_inv_dptr, kNVTEFloat32, scale_inv_shape}; - nvte_set_tensor_param(&tensor_, kNVTERowwiseScaleInv, &scale_inv); + nvte_set_tensor_param_v2(tensor_, kNVTERowwiseScaleInv, &scale_inv, sizeof(scale_inv)); } /*! \brief Constructs new TensorWrapper. @@ -669,7 +695,7 @@ class TensorWrapper { const ShapeType &shape) noexcept { NVTEShape nvte_shape = this->convertShape(shape); NVTEBasicTensor data = {dptr, static_cast(type), nvte_shape}; - nvte_set_tensor_param(&tensor_, param, &data); + nvte_set_tensor_param_v2(tensor_, param, &data, sizeof(data)); return *this; } @@ -708,10 +734,17 @@ class TensorWrapper { return set_parameter(kNVTEColumnwiseAmax, dptr, type, shape); } + void set_with_gemm_swizzled_scales(bool with_gemm_swizzled_scales) { + nvte_set_tensor_param_v2(tensor_, kNVTEWithGEMMSwizzledScales, &with_gemm_swizzled_scales, + sizeof(with_gemm_swizzled_scales)); + } + // Parameter getters NVTEBasicTensor get_parameter(const NVTETensorParam param) const noexcept { - return nvte_get_tensor_param(tensor_, param); + NVTEBasicTensor ret; + nvte_get_tensor_param_v2(tensor_, param, &ret, sizeof(ret), nullptr); + return ret; } NVTEBasicTensor get_rowwise_data() const noexcept { return get_parameter(kNVTERowwiseData); } @@ -736,6 +769,13 @@ class TensorWrapper { return get_parameter(kNVTEColumnwiseAmax); } + bool get_with_gemm_swizzled_scales() const { + bool with_gemm_swizzled_scales = false; + nvte_get_tensor_param_v2(tensor_, kNVTEWithGEMMSwizzledScales, &with_gemm_swizzled_scales, + sizeof(with_gemm_swizzled_scales), nullptr); + return with_gemm_swizzled_scales; + } + /*! \brief Get an underlying NVTETensor. * * \return NVTETensor held by this TensorWrapper. @@ -915,15 +955,8 @@ class TensorWrapper { NVTETensor tensor_ = nullptr; }; -/*! \enum Float8BlockScaleTensorFormat - * \brief Data format for an FP8 block-scaled tensor - */ -enum class Float8BlockScaleTensorFormat { - /*! FP8 data is transposed if needed and scales are swizzled */ - GEMM_READY = 0, - /*! FP8 data is untransposed and scales are not swizzled or padded */ - COMPACT = 1 -}; +/*! \warning Deprecated */ +enum class Float8BlockScaleTensorFormat { GEMM_READY = 0, COMPACT = 1, INVALID }; /*! \struct QuantizationConfigWrapper * \brief C++ wrapper for NVTEQuantizationConfigWrapper. @@ -978,12 +1011,8 @@ class QuantizationConfigWrapper { sizeof(NVTETensor)); } - /*! \brief Set FP8 block-scaled tensor format */ - void set_float8_block_scale_tensor_format(Float8BlockScaleTensorFormat format) { - nvte_set_quantization_config_attribute(config_, - kNVTEQuantizationConfigFloat8BlockScaleTensorFormat, - &format, sizeof(Float8BlockScaleTensorFormat)); - } + /*! \warning Deprecated */ + void set_float8_block_scale_tensor_format(Float8BlockScaleTensorFormat format) {} /*! \brief Set stochastic rounding state */ void set_rng_state(NVTETensor rng_state) { diff --git a/transformer_engine/common/normalization/layernorm/ln_api.cpp b/transformer_engine/common/normalization/layernorm/ln_api.cpp index b83ae25f250..b6a81d5e3b3 100644 --- a/transformer_engine/common/normalization/layernorm/ln_api.cpp +++ b/transformer_engine/common/normalization/layernorm/ln_api.cpp @@ -27,10 +27,15 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size const float epsilon, Tensor* z, Tensor* mu, Tensor* rsigma, Tensor* workspace, const int multiprocessorCount, const bool zero_centered_gamma, cudaStream_t stream) { + // Check for unsupported configurations if (is_fp8_dtype(z->data.dtype) && !is_delayed_tensor_scaling(z->scaling_mode) && !is_mxfp8_scaling(z->scaling_mode)) { NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + "."); } + if (is_mxfp8_scaling(z->scaling_mode)) { + NVTE_CHECK(!z->with_gemm_swizzled_scales, + "MXFP8 output must have scales in compact format, not swizzled for GEMM."); + } NVTE_CHECK(x.data.shape.size() == 2, "x must be 2D tensor."); NVTE_CHECK(gamma.data.shape == beta.data.shape, "Gamma and Beta must have the same shape."); diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp index ea6c972bf53..d294c616e6f 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp @@ -23,10 +23,15 @@ using namespace normalization; void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tensor *z, Tensor *rsigma, Tensor *workspace, const int multiprocessorCount, const bool zero_centered_gamma, cudaStream_t stream) { + // Check for unsupported configurations if (is_fp8_dtype(z->data.dtype) && !is_delayed_tensor_scaling(z->scaling_mode) && !is_mxfp8_scaling(z->scaling_mode)) { NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + "."); } + if (is_mxfp8_scaling(z->scaling_mode)) { + NVTE_CHECK(!z->with_gemm_swizzled_scales, + "MXFP8 output must have scales in compact format, not swizzled for GEMM."); + } NVTE_CHECK(x.data.shape.size() == 2, "x must be 2D tensor."); diff --git a/transformer_engine/common/recipe/mxfp8_scaling.cu b/transformer_engine/common/recipe/mxfp8_scaling.cu index 8a7ecc6b01c..f9aa4c8d761 100644 --- a/transformer_engine/common/recipe/mxfp8_scaling.cu +++ b/transformer_engine/common/recipe/mxfp8_scaling.cu @@ -110,7 +110,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) OType *output_rowwise_minus_offset = output_rowwise - start_offset; OType *output_colwise_minus_offset = output_colwise - start_offset; int warp_idx = threadIdx.x / 32; - int lane_idx = threadIdx.x % 32; + // int lane_idx = threadIdx.x % 32; int c = blockIdx.x * kColsPerTile + threadIdx.x; int r = blockIdx.y * kRowsPerTile; diff --git a/transformer_engine/common/swizzle/swizzle.cu b/transformer_engine/common/swizzle/swizzle.cu index 2cb43e8f27c..02ee7052782 100644 --- a/transformer_engine/common/swizzle/swizzle.cu +++ b/transformer_engine/common/swizzle/swizzle.cu @@ -340,6 +340,10 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s // Check tensors CheckInputTensor(*input, "scaling_factor_input"); CheckInputTensor(*output, "scaling_factor_output"); + NVTE_CHECK(!input->with_gemm_swizzled_scales, + "Expected input tensor with scales in compact format."); + NVTE_CHECK(output->with_gemm_swizzled_scales, + "Expected output tensor with scales in GEMM swizzled format."); switch (scaling_mode) { case NVTE_MXFP8_1D_SCALING: NVTE_CHECK(is_fp8_dtype(input->dtype()), "Input tensor has invalid dtype (expected FP8, got ", @@ -656,6 +660,11 @@ void multi_tensor_swizzle_scaling_factors(const std::vector& input, NVTE_CHECK( (is_fp8 && is_mxfp8_scaling(scaling_mode)) || (is_fp4 && is_nvfp4_scaling(scaling_mode)), "Not implemented scaling mode " + to_string(scaling_mode) + "."); + NVTE_CHECK(!input[i]->with_gemm_swizzled_scales, + "Expected input tensors with scales in compact format."); + NVTE_CHECK(output[i]->with_gemm_swizzled_scales, + "Expected output tensors with scales in GEMM swizzled format."); + // We don't allow empty tensors. They should be filtered out before calling this function. NVTE_CHECK(input[i]->numel() != 0, "Tensor input[", i, "] is empty."); CheckInputTensor(*input[i], "scaling_factor_input[" + std::to_string(i) + "]"); diff --git a/transformer_engine/common/swizzle/swizzle_block_scaling.cu b/transformer_engine/common/swizzle/swizzle_block_scaling.cu index 4be85474af7..d4cabeedc53 100644 --- a/transformer_engine/common/swizzle/swizzle_block_scaling.cu +++ b/transformer_engine/common/swizzle/swizzle_block_scaling.cu @@ -98,7 +98,8 @@ void __global__ __launch_bounds__(WARPS_X_PER_TB* WARPS_Y_PER_TB* WARP_SIZE) // calculate this warp's input base pointer constexpr uint32_t in_x_stride = WARP_SIZE * sizeof(uint4); - const void* const warp_src = in + in_tile_y * in_y_stride + in_tile_x * in_x_stride; + const void* const warp_src = + (reinterpret_cast(in) + in_tile_y * in_y_stride + in_tile_x * in_x_stride); // load scaling factors for this lane's initial four 1x128 tiles uint4 sf; @@ -128,7 +129,8 @@ void __global__ __launch_bounds__(WARPS_X_PER_TB* WARPS_Y_PER_TB* WARP_SIZE) // store them cooperatively for 512 1x32 tiles in a 128x128 tile constexpr uint32_t out_x_stride = 512; - void* const warp_dst = out + out_tile_y * out_y_stride + out_tile_x * out_x_stride; + void* const warp_dst = + (reinterpret_cast(out) + out_tile_y * out_y_stride + out_tile_x * out_x_stride); reinterpret_cast(warp_dst)[lane] = sf; } @@ -192,7 +194,8 @@ void __global__ __launch_bounds__(WARPS_X_PER_TB* WARPS_Y_PER_TB* WARP_SIZE) // calculate this warp's input base pointer constexpr uint32_t in_x_stride = sizeof(float); - const void* const warp_src = in + in_tile_y * in_y_stride + in_tile_x * in_x_stride; + const void* const warp_src = + (reinterpret_cast(in) + in_tile_y * in_y_stride + in_tile_x * in_x_stride); // load scaling factor for this warp's 128x128 tile uint32_t sf = *reinterpret_cast(warp_src); @@ -206,7 +209,8 @@ void __global__ __launch_bounds__(WARPS_X_PER_TB* WARPS_Y_PER_TB* WARP_SIZE) // store it cooperatively for 512 1x32 tiles in a 128x128 tile constexpr uint32_t out_x_stride = 512; - void* const warp_dst = out + out_tile_y * out_y_stride + out_tile_x * out_x_stride; + void* const warp_dst = + (reinterpret_cast(out) + out_tile_y * out_y_stride + out_tile_x * out_x_stride); reinterpret_cast(warp_dst)[lane] = sf4; } @@ -259,6 +263,9 @@ void swizzle_block_scaling_to_mxfp8_scaling_factors(const Tensor* input, Tensor* NVTE_CHECK(output->scale_inv.dtype == DType::kFloat8E8M0, "Output must have E8M0 scaling factors"); + NVTE_CHECK(output->with_gemm_swizzled_scales, + "Expected output tensor with scales in GEMM swizzled format."); + NVTE_CHECK(input->data.dptr != nullptr, "Input must have rowwise data"); NVTE_CHECK(output->data.dptr == input->data.dptr, "Output must share data with input"); NVTE_CHECK(input->scale_inv.dptr != nullptr, "Input must have rowwise scaling factors"); diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 4a140b4376d..24a2ce730a2 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -6,12 +6,16 @@ #include +#include #include #include #include #include #include +#include +#include #include +#include #include "common.h" #include "common/util/cuda_runtime.h" @@ -778,7 +782,8 @@ void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name, t->columnwise_amax = *param; break; default: - NVTE_ERROR("Unknown tensor parameter!"); + NVTE_ERROR("Unsupported tensor parameter (", static_cast(param_name), + "). Consider using nvte_set_tensor_param_v2 instead."); } } @@ -803,7 +808,148 @@ NVTEBasicTensor nvte_get_tensor_param(const NVTETensor tensor, NVTETensorParam p case kNVTEColumnwiseAmax: return t.columnwise_amax; default: - NVTE_ERROR("Unknown tensor parameter!"); + NVTE_ERROR("Unsupported tensor parameter (", static_cast(param_name), + "). Consider using nvte_set_tensor_param_v2 instead."); + } +} + +void nvte_set_tensor_param_v2(NVTETensor tensor, NVTETensorParam param, const void *buf, + size_t size_in_bytes) { + // Check attribute and buffer + NVTE_CHECK(param < kNVTENumTensorParams, "Invalid NVTETensorParam (got ", static_cast(param), + ")"); + NVTE_CHECK(tensor != nullptr, "Tensor pointer can't be NULL."); + auto &t = *transformer_engine::convertNVTETensorCheck(tensor); + const auto &attr_size = transformer_engine::Tensor::attr_sizes[param]; + NVTE_CHECK(size_in_bytes >= attr_size, + "Buffer is too small for tensor parameter " + "(parameter ", + static_cast(param), " needs ", attr_size, " bytes, but buffer has ", + size_in_bytes, " bytes)"); + NVTE_CHECK(buf != nullptr, "Invalid buffer (got NULL)"); + + // Read from buffer + switch (param) { + case kNVTERowwiseData: { + const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + t.data = *basic_tensor; + break; + } + case kNVTEColumnwiseData: { + const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + t.columnwise_data = *basic_tensor; + break; + } + case kNVTEScale: { + const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + t.scale = *basic_tensor; + break; + } + case kNVTEAmax: { + const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + t.amax = *basic_tensor; + break; + } + case kNVTERowwiseScaleInv: { + const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + t.scale_inv = *basic_tensor; + break; + } + case kNVTEColumnwiseScaleInv: { + const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + t.columnwise_scale_inv = *basic_tensor; + break; + } + case kNVTEColumnwiseAmax: { + const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + t.columnwise_amax = *basic_tensor; + break; + } + case kNVTEWithGEMMSwizzledScales: + std::memcpy(&t.with_gemm_swizzled_scales, buf, attr_size); + break; + default: + NVTE_ERROR("Unsupported tensor parameter (", static_cast(param), ")"); + } +} + +void nvte_get_tensor_param_v2(const NVTETensor tensor, NVTETensorParam param, void *buf, + size_t size_in_bytes, size_t *size_written) { + using namespace transformer_engine; + + // Check param + NVTE_CHECK(param < kNVTENumTensorParams, "Invalid NVTETensorParam (got ", static_cast(param), + ")"); + + // Write attribute size if provided + const auto &attr_size = Tensor::attr_sizes[param]; + if (size_written != nullptr) { + *size_written = attr_size; + } + + // Return immediately if buffer is not provided + if (buf == nullptr) { + return; + } + + // Check buffer size + NVTE_CHECK(size_in_bytes >= attr_size, + "Buffer is too small for tensor parameter " + "(parameter ", + static_cast(param), " needs ", attr_size, " bytes, but buffer has ", + size_in_bytes, " bytes)"); + + // Get C++ tensor + const Tensor *t = convertNVTETensor(tensor); + std::optional dummy; + if (t == nullptr) { + // Make dummy tensor if provided tensor is invalid + dummy.emplace(); + t = &(*dummy); + } + + // Write to buffer + switch (param) { + case kNVTERowwiseData: { + NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + *basic_tensor = static_cast(t->data); + break; + } + case kNVTEColumnwiseData: { + NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + *basic_tensor = static_cast(t->columnwise_data); + break; + } + case kNVTEScale: { + NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + *basic_tensor = static_cast(t->scale); + break; + } + case kNVTEAmax: { + NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + *basic_tensor = static_cast(t->amax); + break; + } + case kNVTERowwiseScaleInv: { + NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + *basic_tensor = static_cast(t->scale_inv); + break; + } + case kNVTEColumnwiseScaleInv: { + NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + *basic_tensor = static_cast(t->columnwise_scale_inv); + break; + } + case kNVTEColumnwiseAmax: { + NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + *basic_tensor = static_cast(t->columnwise_amax); + break; + } + case kNVTEWithGEMMSwizzledScales: + std::memcpy(buf, &t->with_gemm_swizzled_scales, attr_size); + break; + default: + NVTE_ERROR("Unsupported tensor parameter (", static_cast(param), ")"); } } @@ -854,10 +1000,12 @@ NVTEQuantizationConfig nvte_create_quantization_config() { void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config, NVTEQuantizationConfigAttribute attr, void *buf, size_t size_in_bytes, size_t *size_written) { + using namespace transformer_engine; + // Write attribute size NVTE_CHECK(attr < kNVTEQuantizationConfigNumAttributes, "Invalid NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); - const auto &attr_size = transformer_engine::QuantizationConfig::attr_sizes[attr]; + const auto &attr_size = QuantizationConfig::attr_sizes[attr]; if (size_written != nullptr) { *size_written = attr_size; } @@ -876,7 +1024,7 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config, // Write to buffer NVTE_CHECK(config != nullptr, "Invalid NVTEQuantizationConfig (got NULL)"); - const auto &config_ = *reinterpret_cast(config); + const auto &config_ = *reinterpret_cast(config); switch (attr) { case kNVTEQuantizationConfigForcePow2Scales: std::memcpy(buf, &config_.force_pow_2_scales, attr_size); @@ -887,9 +1035,12 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config, case kNVTEQuantizationConfigNoopTensor: std::memcpy(buf, &config_.noop_tensor, attr_size); break; - case kNVTEQuantizationConfigFloat8BlockScaleTensorFormat: - std::memcpy(buf, &config_.float8_block_scale_tensor_format, attr_size); + case kNVTEQuantizationConfigFloat8BlockScaleTensorFormat: { + // Deprecated + const auto invalid = Float8BlockScaleTensorFormat::INVALID; + std::memcpy(buf, &invalid, attr_size); break; + } case kNVTEQuantizationConfigRNGState: std::memcpy(buf, &config_.rng_state, attr_size); break; @@ -910,10 +1061,12 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config, void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config, NVTEQuantizationConfigAttribute attr, const void *buf, size_t size_in_bytes) { + using namespace transformer_engine; + // Check attribute and buffer NVTE_CHECK(attr < kNVTEQuantizationConfigNumAttributes, "Invalid NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); - const auto &attr_size = transformer_engine::QuantizationConfig::attr_sizes[attr]; + const auto &attr_size = QuantizationConfig::attr_sizes[attr]; NVTE_CHECK(size_in_bytes >= attr_size, "Buffer is too small for quantization config attribute " "(attribute ", @@ -923,7 +1076,7 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config, // Read from buffer NVTE_CHECK(config != nullptr, "Invalid NVTEQuantizationConfig (got NULL)"); - auto &config_ = *reinterpret_cast(config); + auto &config_ = *reinterpret_cast(config); switch (attr) { case kNVTEQuantizationConfigForcePow2Scales: std::memcpy(&config_.force_pow_2_scales, buf, attr_size); @@ -935,7 +1088,7 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config, std::memcpy(&config_.noop_tensor, buf, attr_size); break; case kNVTEQuantizationConfigFloat8BlockScaleTensorFormat: - std::memcpy(&config_.float8_block_scale_tensor_format, buf, attr_size); + // Deprecated break; case kNVTEQuantizationConfigRNGState: std::memcpy(&config_.rng_state, buf, attr_size); diff --git a/transformer_engine/common/transpose/cast_transpose.h b/transformer_engine/common/transpose/cast_transpose.h index 89266f4bbc5..2ecb4c92beb 100644 --- a/transformer_engine/common/transpose/cast_transpose.h +++ b/transformer_engine/common/transpose/cast_transpose.h @@ -30,6 +30,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor &input, SimpleTensor const bool return_transpose, const bool pow_2_scale, const SimpleTensor &noop_tensor, cudaStream_t stream); +/// TODO Compact format is removed. Replace this enum with bool. // enum class for rowwise usage enum class FP8BlockwiseRowwiseOption { // No rowwise data, skip rowwise quantization @@ -40,6 +41,7 @@ enum class FP8BlockwiseRowwiseOption { ROWWISE_COMPACT }; +/// TODO Compact format is removed. Replace this enum with bool. // enum class for columnwise usage // For Hopper sm90 with only TN fp8 gemm, there is need to do columnwise transpose when doing 1D block scaling enum class FP8BlockwiseColumnwiseOption { diff --git a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu index 661cf339ae8..2e3e9695dbe 100644 --- a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu +++ b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu @@ -492,7 +492,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor } NVTE_CHECK(input.shape == output.shape, "Input and output must have the same shape."); - const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u; + const size_t row_length = input.shape.size() > 0 ? input.shape.back() : 1; size_t num_rows = 1; for (size_t i = 0; (i < input.shape.size() - 1) && (input.shape.size() > 0); ++i) { num_rows *= input.shape.at(i); @@ -511,12 +511,14 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor const float* noop_ptr = reinterpret_cast(noop_tensor.dptr); if (return_transpose) { - NVTE_CHECK(output_t.shape.size() == input.shape.size(), - "output_t must have same number of dimensions as input."); + NVTE_CHECK(output_t.shape.size() == input.shape.size(), "input (shape=", input.shape, + ") and output_t (shape=", output_t.shape, ") have incompatible dims."); if (output_t.shape.size() > 0) { - NVTE_CHECK(output_t.shape[0] == row_length, "Wrong dimension 0 of output_t."); + NVTE_CHECK(output_t.shape.front() == input.shape.back(), "input (shape=", input.shape, + ") and output_t (shape=", output_t.shape, ") have incompatible dims."); for (size_t i = 1; i < output_t.shape.size(); ++i) { - NVTE_CHECK(output_t.shape.at(i) == input.shape.at(i - 1), "Wrong dimension in output_t"); + NVTE_CHECK(output_t.shape[i] == input.shape[i - 1], "input (shape=", input.shape, + ") and output_t (shape=", output_t.shape, ") have incompatible dims."); } } NVTE_CHECK(output.dtype == output_t.dtype, "output and output_t need to have the same type."); diff --git a/transformer_engine/common/transpose/transpose.cu b/transformer_engine/common/transpose/transpose.cu index 9f0acd8071e..ce5789ed441 100644 --- a/transformer_engine/common/transpose/transpose.cu +++ b/transformer_engine/common/transpose/transpose.cu @@ -14,8 +14,10 @@ #include "../util/rtc.h" #include "../util/string.h" #include "../utils.cuh" +#include "./transpose.h" namespace transformer_engine { +namespace detail { namespace { @@ -203,7 +205,8 @@ void transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cudaStr NVTE_CHECK(input.data.dptr != nullptr, "Input is not allocated."); NVTE_CHECK(output.data.dptr != nullptr, "Output is not allocated."); - NVTE_CHECK(input.data.dtype == output.data.dtype, "Input and output type must match."); + NVTE_CHECK(input.data.dtype == output.data.dtype, "Input (dtype=", to_string(input.data.dtype), + ") and output (dtype=", to_string(output.data.dtype), ") do not match."); if (noop.data.dptr != nullptr) { NVTE_CHECK(noop.numel() == 1, "Expected 1 element, ", "but found ", noop.numel(), "."); @@ -283,19 +286,20 @@ void transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cudaStr }); // NOLINT(*) } +} // namespace detail } // namespace transformer_engine void nvte_transpose(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_transpose); using namespace transformer_engine; auto noop = Tensor(); - transpose(*convertNVTETensorCheck(input), noop, convertNVTETensor(output), stream); + detail::transpose(*convertNVTETensorCheck(input), noop, convertNVTETensor(output), stream); } void nvte_transpose_with_noop(const NVTETensor input, const NVTETensor noop, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_transpose_with_noop); using namespace transformer_engine; - transpose(*convertNVTETensorCheck(input), *convertNVTETensorCheck(noop), - convertNVTETensor(output), stream); + detail::transpose(*convertNVTETensorCheck(input), *convertNVTETensorCheck(noop), + convertNVTETensor(output), stream); } diff --git a/transformer_engine/common/transpose/transpose.h b/transformer_engine/common/transpose/transpose.h new file mode 100644 index 00000000000..055f8268f63 --- /dev/null +++ b/transformer_engine/common/transpose/transpose.h @@ -0,0 +1,20 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_COMMON_TRANSPOSE_TRANSPOSE_H_ +#define TRANSFORMER_ENGINE_COMMON_TRANSPOSE_TRANSPOSE_H_ + +#include "../common.h" + +namespace transformer_engine { +namespace detail { + +void transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cudaStream_t stream); + +} // namespace detail +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_COMMON_TRANSPOSE_TRANSPOSE_H_ diff --git a/transformer_engine/common/util/ptx.cuh b/transformer_engine/common/util/ptx.cuh index 7f296c9e383..ddaca7e3495 100644 --- a/transformer_engine/common/util/ptx.cuh +++ b/transformer_engine/common/util/ptx.cuh @@ -840,6 +840,7 @@ __device__ __forceinline__ int32_t elect_one_sync(uint32_t mask = 0xFFFFFFFFu) { return pred; #else NVTE_DEVICE_ERROR("elect_one_sync is only supported on SM 10.0+."); + return 0; #endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } @@ -891,6 +892,7 @@ __device__ __forceinline__ bf16 get_amax(bf16 a, bf16 b) { return r; #else NVTE_DEVICE_ERROR("get_amax is only supported on SM 10.0+."); + return 0.f; #endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } @@ -903,6 +905,7 @@ __device__ __forceinline__ fp16 get_amax(fp16 a, fp16 b) { return r; #else NVTE_DEVICE_ERROR("get_amax is only supported on SM 10.0+."); + return 0.f; #endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index bce124e705f..a7a5c75f8fa 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -83,7 +83,8 @@ pybind11::enum_( \ m, "Float8BlockScaleTensorFormat", pybind11::module_local()) \ .value("GEMM_READY", transformer_engine::Float8BlockScaleTensorFormat::GEMM_READY) \ - .value("COMPACT", transformer_engine::Float8BlockScaleTensorFormat::COMPACT); \ + .value("COMPACT", transformer_engine::Float8BlockScaleTensorFormat::COMPACT) \ + .value("INVALID", transformer_engine::Float8BlockScaleTensorFormat::INVALID); \ pybind11::enum_(m, "CommOverlapType", \ pybind11::module_local()) \ .value("RS", transformer_engine::CommOverlapType::RS) \ diff --git a/transformer_engine/debug/pytorch/debug_quantization.py b/transformer_engine/debug/pytorch/debug_quantization.py index caa5f5a7ee2..da039751990 100644 --- a/transformer_engine/debug/pytorch/debug_quantization.py +++ b/transformer_engine/debug/pytorch/debug_quantization.py @@ -62,12 +62,17 @@ def __init__( self.tp_group = tp_group # used in inspect_tensor calls self.iteration = TEDebugState.get_iteration() - # .internal = True is slightly faster, but results - # in errors when caching the weights. - # Setting .internal = False is safer. + # Configure parent quantizer if parent_quantizer is not None: + # .internal = True is slightly faster, but results + # in errors when caching the weights. + # Setting .internal = False is safer. parent_quantizer.internal = False + # .optimize_for_gemm = True is not supported because debug + # quantizers perform non-GEMM operations. + parent_quantizer.optimize_for_gemm = False + self.rowwise_gemm_name, self.columnwise_gemm_name = _tensor_to_gemm_names_map[tensor_name] # next iteration when this quantizer will call any API diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 6566ff1689e..001fd27d53a 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -65,23 +65,33 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( NVTE_CHECK(typeToSize(scale_dtype) == 1, "Inverse scale factors need to have an 8-bit data type."); } - if (!is_nvfp4) { + if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) { + // Assume MXFP8 scales are already swizzled if (rowwise) { input.set_rowwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); } else { input.set_columnwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); } - } else { // Swizzle for NVFP4 + input.set_with_gemm_swizzled_scales(true); + } else if (is_nvfp4) { // Swizzle for NVFP4 NVTE_CHECK(rowwise, "NVFP4 GEMM expects rowwise for both LHS and RHS"); input.set_rowwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); // Create tensor to hold swizzled scale factor TensorWrapper output(get_nvte_scaling_mode(scaling_mode)); output.set_rowwise_data(buffer.untyped_data(), input_dtype, input_shape); output.set_rowwise_scale_inv(swizzle_scale_ptr, scale_dtype, scale_shape); + output.set_with_gemm_swizzled_scales(true); // Launch swizzle kernel nvte_swizzle_scaling_factors(input.data(), output.data(), stream); // Set swizzled scales into the input tensor input.set_rowwise_scale_inv(swizzle_scale_ptr, scale_dtype, scale_shape); + input.set_with_gemm_swizzled_scales(true); + } else { // Tensor scaling + if (rowwise) { + input.set_rowwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); + } else { + input.set_columnwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); + } } } @@ -669,6 +679,7 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_swizzle_i.set_rowwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i); lhs_i.set_rowwise_scale_inv(swizzled_lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i); } + lhs_i.set_with_gemm_swizzled_scales(true); if (rhs_use_colwise) { rhs_swizzle_i.set_columnwise_data(rhs_vptr, rhs_dtype, rhs_shape_i); rhs_swizzle_i.set_columnwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i); @@ -678,6 +689,7 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type rhs_swizzle_i.set_rowwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i); rhs_i.set_rowwise_scale_inv(swizzled_rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i); } + rhs_i.set_with_gemm_swizzled_scales(true); if (!is_empty_gemm) { lhs_swizzle_wrapper_list.push_back(std::move(lhs_swizzle_i)); diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index d4ff0b96d9a..a90c8f18487 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -164,17 +164,9 @@ def general_gemm( bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype] if isinstance(A, Float8BlockwiseQTensorStorage) or isinstance(B, Float8BlockwiseQTensorStorage): - # There is not use_split_accumulator == False - # implementation for Float8BlockwiseQTensorStorage GEMM + # FP8 block-scaling requires split accumulator use_split_accumulator = True - # Check that data format is supported - if ( - A._data_format != tex.Float8BlockScaleTensorFormat.GEMM_READY - or B._data_format != tex.Float8BlockScaleTensorFormat.GEMM_READY - ): - raise RuntimeError("GEMM with Float8BlockwiseQTensor requires GEMM_READY format") - args = ( A, transa, # transa diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index e054424dd4d..511d52cc1da 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -301,11 +301,13 @@ std::vector convertShape(const NVTEShape& shape) { return std::vector(shape.data, shape.data + shape.ndim); } -size_t roundup(const size_t value, const size_t multiple) { +size_t roundup(size_t value, size_t multiple) { assert(multiple > 0); return ((value + multiple - 1) / multiple) * multiple; } +size_t ceildiv(size_t numer, size_t denom) { return (numer + denom - 1) / denom; } + void philox_unpack(at::PhiloxCudaState arg, int64_t* rng_state_ptr) { NVTE_SCOPED_GIL_RELEASE({ nvte_extract_seed_and_offset(rng_state_ptr, arg.captured_, arg.seed_.ptr, arg.seed_.val, diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 978bee52dc1..d3a6b72163f 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -120,6 +120,7 @@ class Quantizer { bool rowwise_usage = true; bool columnwise_usage = true; bool internal = false; + bool optimize_for_gemm = false; py::handle quantizer; protected: @@ -231,8 +232,6 @@ class Float8BlockQuantizer : public Quantizer { bool force_pow_2_scales = false; // Amax within quantization tile has a floor of epsilon. float amax_epsilon = 0.0; - // Whether quantized tensor will be used in an all-gather - bool all_gather_usage = false; private: int block_scaling_dim = 2; @@ -358,11 +357,12 @@ inline size_t typeToNumBits(transformer_engine::DType t) { case transformer_engine::DType::kByte: case transformer_engine::DType::kFloat8E4M3: case transformer_engine::DType::kFloat8E5M2: + case transformer_engine::DType::kFloat8E8M0: return 8; case transformer_engine::DType::kFloat4E2M1: return 4; default: - NVTE_ERROR("Invalid type"); + NVTE_ERROR("Invalid type (", static_cast(t), ")."); } } @@ -386,8 +386,10 @@ inline at::ScalarType GetATenDType(transformer_engine::DType t) { return at::kFloat8_e4m3fn; case transformer_engine::DType::kFloat8E5M2: return at::kFloat8_e5m2; + case transformer_engine::DType::kFloat8E8M0: + return at::kByte; default: - NVTE_ERROR("Invalid type"); + NVTE_ERROR("Invalid type (", static_cast(t), ")."); } } @@ -414,8 +416,7 @@ inline transformer_engine::DType GetTransformerEngineDType(at::ScalarType t) { case torch::kInt64: return transformer_engine::DType::kInt64; default: - std::cout << "Type: " << static_cast(t) << std::endl; - NVTE_ERROR("Invalid type"); + NVTE_ERROR("Invalid type (", static_cast(t), ")."); } } @@ -477,7 +478,9 @@ void* getDataPtr(at::Tensor tensor, int offset = 0); std::vector convertShape(const NVTEShape& shape); -size_t roundup(const size_t value, const size_t multiple); +size_t roundup(size_t value, size_t multiple); + +size_t ceildiv(size_t numer, size_t denom); NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape); diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 80479dccf48..2bb75f4a15a 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -7,7 +7,12 @@ #ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_ #define TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_ +#include #include +#include +#include +#include +#include #include "common.h" @@ -78,11 +83,6 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend( size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, int64_t window_size_right, bool return_max_logit, bool cuda_graph); -std::pair quantizer_helper(py::handle quantizer, - const std::vector &shape, DType dtype, - bool create_hp_tensor_for_cs, - std::optional data); - std::vector fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, @@ -474,6 +474,13 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output, void fused_multi_row_unpadding(at::Tensor input, at::Tensor output, std::vector input_row_list, std::vector unpadded_input_row_list); + +/*************************************************************************************************** + * Scale swizzling for GEMM + **************************************************************************************************/ + +void inplace_swizzle_scale_for_gemm(py::handle &tensor); + /*************************************************************************************************** * NVSHMEM APIs **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index aa9d800c7bb..200d1b76bf3 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -327,9 +327,9 @@ std::tuple, std::vector> bulk_allocate_fp (columnwise_usage ? py::cast(columnwise_scale_list[i]) : py::none()); // Construct Python tensor - tensor_py_list.emplace_back(Float8BlockwiseQTensorClass( - rowwise_data, rowwise_scale, columnwise_data, columnwise_scale, fp8_dtype, - quantizer_py_list[i], is_2D_scaled, Float8BlockScaleTensorFormat::GEMM_READY)); + tensor_py_list.emplace_back( + Float8BlockwiseQTensorClass(rowwise_data, rowwise_scale, columnwise_data, columnwise_scale, + fp8_dtype, quantizer_py_list[i], is_2D_scaled)); // Construct C++ tensor tensor_cpp_list.emplace_back(makeTransformerEngineTensor( @@ -365,6 +365,8 @@ std::tuple, std::vector> bulk_allocate_mx const auto columnwise_usage = quantizer_cpp_list[0]->columnwise_usage; const auto scaling_mode = quantizer_cpp_list[0]->get_scaling_mode(); const auto fp8_dtype = quantizer_cpp_list[0]->dtype; + const bool with_gemm_swizzled_scales = quantizer_cpp_list[0]->optimize_for_gemm; + constexpr size_t fp8_elem_size = 1; constexpr size_t scale_elem_size = 1; @@ -475,8 +477,8 @@ std::tuple, std::vector> bulk_allocate_mx // Construct Python tensor tensor_py_list.emplace_back(MXFP8TensorClass(rowwise_data, rowwise_scale, columnwise_data, - columnwise_scale, fp8_dtype, - quantizer_py_list[i])); + columnwise_scale, fp8_dtype, quantizer_py_list[i], + with_gemm_swizzled_scales)); // Construct C++ tensor tensor_cpp_list.emplace_back(makeTransformerEngineTensor( @@ -488,6 +490,7 @@ std::tuple, std::vector> bulk_allocate_mx columnwise_usage ? columnwise_scale_list[i].data_ptr() : nullptr, rowwise_usage ? rowwise_scale_shapes[i] : std::vector{0}, columnwise_usage ? columnwise_scale_shapes[i] : std::vector{0}, scaling_mode)); + tensor_cpp_list.back().set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); } return retval; @@ -517,6 +520,7 @@ std::tuple, std::vector, bool> bulk_alloc const auto columnwise_usage = quantizer_cpp_list[0]->columnwise_usage; const auto scaling_mode = quantizer_cpp_list[0]->get_scaling_mode(); const auto fp4_dtype = quantizer_cpp_list[0]->dtype; + const bool with_gemm_swizzled_scales = false; /// TODO (tmoon) Enable based on optimize_for_gemm; constexpr size_t scale_elem_size = 1; // Helper function to construct tensor view @@ -675,9 +679,9 @@ std::tuple, std::vector, bool> bulk_alloc py::object amax_columnwise = columnwise_usage ? py::cast(amax_columnwise_list[i]) : py::none(); // Construct Python tensor - tensor_py_list.emplace_back(NVFP4TensorClass(rowwise_data, rowwise_scale, columnwise_data, - columnwise_scale, amax_rowwise, amax_columnwise, - fp4_dtype, quantizer_py_list[i])); + tensor_py_list.emplace_back(NVFP4TensorClass( + rowwise_data, rowwise_scale, columnwise_data, columnwise_scale, amax_rowwise, + amax_columnwise, fp4_dtype, quantizer_py_list[i], with_gemm_swizzled_scales)); // Construct C++ tensor // Use a TensorWrapper variable to hold the output of makeTransformerEngineTensor, @@ -693,6 +697,7 @@ std::tuple, std::vector, bool> bulk_alloc columnwise_usage ? columnwise_scale_list[i].data_ptr() : nullptr, rowwise_usage ? rowwise_scale_shapes[i] : std::vector{0}, columnwise_usage ? columnwise_scale_shapes[i] : std::vector{0}, scaling_mode); + tensor_wrapper.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); // Set the amax rowwise and amax columnwise if available if (rowwise_usage) { @@ -703,6 +708,7 @@ std::tuple, std::vector, bool> bulk_alloc tensor_wrapper.set_columnwise_amax(amax_columnwise_list[i].data_ptr(), DType::kFloat32, std::vector{1}); } + tensor_cpp_list.emplace_back(std::move(tensor_wrapper)); } } diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 335052296f2..439be148a99 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -240,9 +240,12 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans auto main_stream = at::cuda::getCurrentCUDAStream(); if (A_tensor.numel() != 0 && B_tensor.numel() != 0) { // Optionally swizzle the scaling factors - swizzled_scale_inverses_list.emplace_back(std::move(swizzle_scaling_factors(A_tensor, transa))); - swizzled_scale_inverses_list.emplace_back( - std::move(swizzle_scaling_factors(B_tensor, !transb))); + auto [A_row_scales, A_col_scales] = swizzle_scales_for_gemm(A_tensor, transa, !transa); + auto [B_row_scales, B_col_scales] = swizzle_scales_for_gemm(B_tensor, !transb, transb); + swizzled_scale_inverses_list.emplace_back(std::move(A_row_scales)); + swizzled_scale_inverses_list.emplace_back(std::move(A_col_scales)); + swizzled_scale_inverses_list.emplace_back(std::move(B_row_scales)); + swizzled_scale_inverses_list.emplace_back(std::move(B_col_scales)); // Emulate the FP8 block scaling recipe with MXFP8 on Blackwell and newer // as it is not natively supported by cublasLt @@ -501,9 +504,9 @@ std::optional> te_general_grouped_gemm( // Optionally swizzle the scaling factors swizzled_scale_inverses_list.emplace_back( - multi_tensor_swizzle_scaling_factors(te_A_wrappers, transa)); + multi_tensor_swizzle_scales_for_gemm(te_A_wrappers, transa, !transa)); swizzled_scale_inverses_list.emplace_back( - multi_tensor_swizzle_scaling_factors(te_B_wrappers, !transb)); + multi_tensor_swizzle_scales_for_gemm(te_B_wrappers, !transb, transb)); // Emulate the FP8 block scaling recipe with MXFP8 on Blackwell and newer // as it is not natively supported by cublasLt diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index 3c5c17fc6f2..d7bcc3cdcf9 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -89,14 +89,8 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe TensorWrapper mu_nvte = makeTransformerEngineTensor(mu_py); TensorWrapper rsigma_nvte = makeTransformerEngineTensor(rsigma_py); - // Output tensor + // Quantizer auto quantizer_cpp = convert_quantizer(quantizer); - TensorWrapper out_nvte; - if (out.is_none()) { - std::tie(out_nvte, out) = quantizer_cpp->create_tensor(shape, out_dtype); - } else { - out_nvte = makeTransformerEngineTensor(out, quantizer); - } // Choose implementation enum class Impl { @@ -135,6 +129,19 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe } } + // Output tensor + TensorWrapper out_nvte; + if (out.is_none()) { + if (impl == Impl::FULLY_FUSED) { + // FP8 has no special logic to optimize for GEMM, MXFP8 cuDNN + // kernel does not support GEMM swizzled scales + quantizer_cpp->optimize_for_gemm = false; + } + std::tie(out_nvte, out) = quantizer_cpp->create_tensor(shape, out_dtype); + } else { + out_nvte = makeTransformerEngineTensor(out, quantizer); + } + // Construct unquantized output tensor if needed TensorWrapper unquantized_out_nvte; py::object unquantized_out; @@ -318,14 +325,8 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w at::Tensor rsigma_py = at::empty({static_cast(outer_size)}, at::CUDA(at::kFloat)); TensorWrapper rsigma_nvte = makeTransformerEngineTensor(rsigma_py); - // Output tensor + // Quantizer auto quantizer_cpp = convert_quantizer(quantizer); - TensorWrapper out_nvte; - if (out.is_none()) { - std::tie(out_nvte, out) = quantizer_cpp->create_tensor(shape, out_dtype); - } else { - out_nvte = makeTransformerEngineTensor(out, quantizer); - } // Choose implementation enum class Impl { @@ -364,6 +365,19 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w } } + // Output tensor + TensorWrapper out_nvte; + if (out.is_none()) { + if (impl == Impl::FULLY_FUSED) { + // FP8 has no special logic to optimize for GEMM, MXFP8 cuDNN + // kernel does not support GEMM swizzled scales + quantizer_cpp->optimize_for_gemm = false; + } + std::tie(out_nvte, out) = quantizer_cpp->create_tensor(shape, out_dtype); + } else { + out_nvte = makeTransformerEngineTensor(out, quantizer); + } + // Construct unquantized output tensor if needed TensorWrapper unquantized_out_nvte; py::object unquantized_out; diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index d0f450bc712..1e585e87592 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -290,6 +290,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Fused Multi-tensor padding", py::call_guard()); m.def("fused_multi_row_unpadding", &transformer_engine::pytorch::fused_multi_row_unpadding, "Fused Multi-tensor unpadding", py::call_guard()); + m.def("swizzle_scales_for_gemm_", &transformer_engine::pytorch::inplace_swizzle_scale_for_gemm, + "Convert tensor block scales into GEMM swizzled format"); // attention kernels m.def("fa_prepare_fwd", &transformer_engine::pytorch::fa_prepare_fwd, diff --git a/transformer_engine/pytorch/csrc/extensions/swizzle.cpp b/transformer_engine/pytorch/csrc/extensions/swizzle.cpp new file mode 100644 index 00000000000..b35b0c186bd --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/swizzle.cpp @@ -0,0 +1,397 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include + +#include "common.h" +#include "common/common.h" +#include "extensions.h" +#include "pybind.h" +#include "util.h" + +namespace transformer_engine { +namespace pytorch { + +namespace { + +void reset_tensor_data(transformer_engine::TensorWrapper &tensor, bool rowwise, bool columnwise) { + NVTEShape shape; + shape.ndim = 1; + shape.data[0] = 0; + const transformer_engine::DType dtype = transformer_engine::DType::kFloat32; + if (rowwise) { + tensor.set_rowwise_data(nullptr, dtype, shape); + tensor.set_rowwise_scale_inv(nullptr, dtype, shape); + } + if (columnwise) { + tensor.set_columnwise_data(nullptr, dtype, shape); + tensor.set_columnwise_scale_inv(nullptr, dtype, shape); + } +} + +} // namespace + +std::tuple, std::optional> swizzle_scales_for_gemm( + transformer_engine::TensorWrapper &tensor, bool rowwise_usage, bool columnwise_usage) { + // Return early if scale swizzling is not required + const auto scaling_mode = tensor.scaling_mode(); + switch (scaling_mode) { + case NVTE_MXFP8_1D_SCALING: + case NVTE_NVFP4_1D_SCALING: + // Tensor format requires scale swizzling + break; + case NVTE_INVALID_SCALING: + NVTE_ERROR("Invalid scaling mode for swizzling scaling factors."); + default: + // Tensor format does not require scale swizzling for GEMM + return {std::nullopt, std::nullopt}; + } + + // Return early if scales are already swizzled + if (tensor.get_with_gemm_swizzled_scales()) { + return {std::nullopt, std::nullopt}; + } + + // CUDA stream + auto stream = at::cuda::getCurrentCUDAStream(); + + // TE tensors with only scales + TensorWrapper input_nvte(scaling_mode); + TensorWrapper output_nvte(scaling_mode); + output_nvte.set_with_gemm_swizzled_scales(true); + + // Swizzle row-wise scales if needed + std::optional rowwise_scales_pyt; + if (rowwise_usage) { + // Buffer for unswizzled scales + const auto input_scales_nvte = tensor.get_rowwise_scale_inv(); + void *input_scales_dptr = input_scales_nvte.data_ptr; + const NVTEShape input_scales_shape = input_scales_nvte.shape; + const auto scales_dtype = static_cast(input_scales_nvte.dtype); + + // Allocate buffer for swizzled scales + const NVTEShape output_scales_shape = input_scales_shape; + rowwise_scales_pyt = allocateSpace(input_scales_shape, scales_dtype, false); + void *output_scales_dptr = getDataPtr(*rowwise_scales_pyt); + + // Initialize TE tensors with scales + const auto data_nvte = tensor.get_rowwise_data(); + const auto data_dtype = static_cast(data_nvte.dtype); + reset_tensor_data(input_nvte, false, true); + input_nvte.set_rowwise_data(nullptr, data_dtype, data_nvte.shape); + input_nvte.set_rowwise_scale_inv(input_scales_dptr, scales_dtype, input_scales_shape); + reset_tensor_data(output_nvte, false, true); + output_nvte.set_rowwise_data(nullptr, data_dtype, data_nvte.shape); + output_nvte.set_rowwise_scale_inv(output_scales_dptr, scales_dtype, output_scales_shape); + + // Launch kernel + NVTE_SCOPED_GIL_RELEASE( + { nvte_swizzle_scaling_factors(input_nvte.data(), output_nvte.data(), stream); }); + + // Update tensor with swizzled scales + tensor.set_rowwise_scale_inv(output_scales_dptr, scales_dtype, output_scales_shape); + } + + // Swizzle column-wise scales if needed + std::optional columnwise_scales_pyt; + if (columnwise_usage) { + // Buffer for unswizzled scales + const auto input_scales_nvte = tensor.get_columnwise_scale_inv(); + void *input_scales_dptr = input_scales_nvte.data_ptr; + const NVTEShape input_scales_shape = input_scales_nvte.shape; + const auto scales_dtype = static_cast(input_scales_nvte.dtype); + + // Allocate buffer for swizzled scales + const NVTEShape output_scales_shape = input_scales_shape; + columnwise_scales_pyt = allocateSpace(input_scales_shape, scales_dtype, false); + void *output_scales_dptr = getDataPtr(*columnwise_scales_pyt); + + // Initialize TE tensors with scales + const auto data_nvte = tensor.get_columnwise_data(); + const auto data_dtype = static_cast(data_nvte.dtype); + reset_tensor_data(input_nvte, true, false); + input_nvte.set_columnwise_data(nullptr, data_dtype, data_nvte.shape); + input_nvte.set_columnwise_scale_inv(input_scales_dptr, scales_dtype, input_scales_shape); + reset_tensor_data(output_nvte, true, false); + output_nvte.set_columnwise_data(nullptr, data_dtype, data_nvte.shape); + output_nvte.set_columnwise_scale_inv(output_scales_dptr, scales_dtype, output_scales_shape); + + // Launch kernel + NVTE_SCOPED_GIL_RELEASE( + { nvte_swizzle_scaling_factors(input_nvte.data(), output_nvte.data(), stream); }); + + // Update tensor with swizzled scales + tensor.set_columnwise_scale_inv(output_scales_dptr, scales_dtype, output_scales_shape); + } + + // Update tensor + reset_tensor_data(tensor, !rowwise_usage, !columnwise_usage); + tensor.set_with_gemm_swizzled_scales(true); + + return {std::move(rowwise_scales_pyt), std::move(columnwise_scales_pyt)}; +} + +std::optional multi_tensor_swizzle_scales_for_gemm( + std::vector &tensors, bool rowwise_usage, + bool columnwise_usage) { + // Checks and trivial cases + NVTE_CHECK(rowwise_usage != columnwise_usage, + "Expect exactly one of rowwise_usage=", rowwise_usage, + " and columnwise_usage=", columnwise_usage, "."); + if (tensors.empty()) { + return std::nullopt; + } + const auto scaling_mode = tensors.front().scaling_mode(); + for (const auto &tensor : tensors) { + NVTE_CHECK(tensor.scaling_mode() == scaling_mode, "Tensors have different scaling modes"); + } + + // Return early if scale swizzling is not required + switch (scaling_mode) { + case NVTE_MXFP8_1D_SCALING: + case NVTE_NVFP4_1D_SCALING: + // Tensor format requires scale swizzling + break; + case NVTE_INVALID_SCALING: + NVTE_ERROR("Invalid scaling mode for swizzling scaling factors."); + default: + // Tensor format does not require scale swizzling for GEMM + return std::nullopt; + } + + // Filter out tensors that already have swizzled scales + std::vector tensors_needing_swizzle; + for (auto &tensor : tensors) { + if (!tensor.get_with_gemm_swizzled_scales()) { + tensors_needing_swizzle.push_back(&tensor); + } + } + if (tensors_needing_swizzle.empty()) { + return std::nullopt; + } + + // Determine buffer size needed for swizzled scales + std::vector output_scales_offsets; + size_t output_scales_bytes = 0; + for (auto &tensor : tensors_needing_swizzle) { + const auto scales_nvte = + (rowwise_usage ? tensor->get_rowwise_scale_inv() : tensor->get_columnwise_scale_inv()); + const auto &shape = scales_nvte.shape; + const auto dtype = static_cast(scales_nvte.dtype); + const auto dtype_bits = transformer_engine::pytorch::typeToNumBits(dtype); + const auto size = product(shape, 0, shape.ndim); + output_scales_bytes = roundup(output_scales_bytes, 16); // align to 16B + output_scales_offsets.push_back(output_scales_bytes); + output_scales_bytes += ceildiv(size * dtype_bits, 8); + } + + // Allocate buffer for swizzled scales + auto output_scales_pyt = allocateSpace(std::vector{output_scales_bytes}, + transformer_engine::DType::kByte, false); + uint8_t *output_scales_dptr = reinterpret_cast(getDataPtr(output_scales_pyt)); + + // Construct TE tensors with only scales + std::vector inputs_nvte, outputs_nvte; + for (size_t i = 0; i < tensors_needing_swizzle.size(); ++i) { + auto &tensor = *tensors_needing_swizzle[i]; + inputs_nvte.emplace_back(scaling_mode); + outputs_nvte.emplace_back(scaling_mode); + auto &input_nvte = inputs_nvte.back(); + auto &output_nvte = outputs_nvte.back(); + output_nvte.set_with_gemm_swizzled_scales(true); + if (rowwise_usage) { + const auto data_nvte = tensor.get_rowwise_data(); + const auto scales_nvte = tensor.get_rowwise_scale_inv(); + const auto data_dtype = static_cast(data_nvte.dtype); + const auto scales_dtype = static_cast(scales_nvte.dtype); + input_nvte.set_rowwise_data(nullptr, data_dtype, data_nvte.shape); + input_nvte.set_rowwise_scale_inv(scales_nvte.data_ptr, scales_dtype, scales_nvte.shape); + output_nvte.set_rowwise_data(nullptr, data_dtype, data_nvte.shape); + output_nvte.set_rowwise_scale_inv(output_scales_dptr + output_scales_offsets[i], scales_dtype, + scales_nvte.shape); + } else { + const auto data_nvte = tensor.get_columnwise_data(); + const auto scales_nvte = tensor.get_columnwise_scale_inv(); + const auto data_dtype = static_cast(data_nvte.dtype); + const auto scales_dtype = static_cast(scales_nvte.dtype); + input_nvte.set_columnwise_data(nullptr, data_dtype, data_nvte.shape); + input_nvte.set_columnwise_scale_inv(scales_nvte.data_ptr, scales_dtype, scales_nvte.shape); + output_nvte.set_columnwise_data(nullptr, data_dtype, data_nvte.shape); + output_nvte.set_columnwise_scale_inv(output_scales_dptr + output_scales_offsets[i], + scales_dtype, scales_nvte.shape); + } + } + + // Pack raw NVTETensors into vectors + std::vector inputs_nvte_raw, outputs_nvte_raw; + for (auto &tensor : inputs_nvte) { + inputs_nvte_raw.emplace_back(tensor.data()); + } + for (auto &tensor : outputs_nvte) { + outputs_nvte_raw.emplace_back(tensor.data()); + } + + // Launch kernel + NVTE_SCOPED_GIL_RELEASE({ + nvte_multi_tensor_swizzle_scaling_factors(inputs_nvte_raw.data(), outputs_nvte_raw.data(), + inputs_nvte_raw.size(), + at::cuda::getCurrentCUDAStream()); + }); + + // Update tensors with swizzled scales + for (size_t i = 0; i < tensors_needing_swizzle.size(); ++i) { + auto &tensor = *tensors_needing_swizzle[i]; + reset_tensor_data(tensor, !rowwise_usage, !columnwise_usage); + tensor.set_with_gemm_swizzled_scales(true); + if (rowwise_usage) { + auto scales_nvte = outputs_nvte[i].get_rowwise_scale_inv(); + const auto scales_dtype = static_cast(scales_nvte.dtype); + tensor.set_rowwise_scale_inv(output_scales_dptr + output_scales_offsets[i], scales_dtype, + scales_nvte.shape); + } else { + auto scales_nvte = outputs_nvte[i].get_columnwise_scale_inv(); + const auto scales_dtype = static_cast(scales_nvte.dtype); + tensor.set_columnwise_scale_inv(output_scales_dptr + output_scales_offsets[i], scales_dtype, + scales_nvte.shape); + } + } + + return std::move(output_scales_pyt); +} + +at::Tensor convert_block_scaling_to_mxfp8_tensor(transformer_engine::TensorWrapper &input, + bool rowwise) { + // Check input tensor + const NVTEScalingMode scaling_mode = input.scaling_mode(); + NVTE_CHECK(scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D, + "Input tensor must be a block scaling tensor"); + + // Get tensor data + NVTEBasicTensor data; + size_t data_flat_first_dim = 1; + size_t data_flat_last_dim = 1; + if (rowwise) { + data = input.get_rowwise_data(); + for (size_t i = 0; i < data.shape.ndim - 1; ++i) { + data_flat_first_dim *= data.shape.data[i]; + } + data_flat_last_dim = data.shape.data[data.shape.ndim - 1]; + } else { + data = input.get_columnwise_data(); + data_flat_first_dim = data.shape.data[0]; + for (size_t i = 1; i < data.shape.ndim; ++i) { + data_flat_last_dim *= data.shape.data[i]; + } + } + NVTEShape data_shape{}; + data_shape.data[0] = data_flat_first_dim; + data_shape.data[1] = data_flat_last_dim; + data_shape.ndim = 2; + + // Recreate input tensor with rowwise usage + transformer_engine::TensorWrapper input_cu(scaling_mode); + input_cu.set_rowwise_data(data.data_ptr, input.dtype(), data_shape); + const NVTEBasicTensor scale_inv = + rowwise ? input.get_rowwise_scale_inv() : input.get_columnwise_scale_inv(); + input_cu.set_rowwise_scale_inv( + scale_inv.data_ptr, static_cast(scale_inv.dtype), scale_inv.shape); + + // Create output tensor + transformer_engine::TensorWrapper output_cu(NVTE_MXFP8_1D_SCALING); + output_cu.set_rowwise_data(data.data_ptr, input.dtype(), data_shape); + // Output swizzled mxfp8 scaling factor dimensions + const size_t swizzled_scale_inv_first_dim = ceildiv(data_flat_first_dim, 128) * 128; + const size_t swizzled_scale_inv_last_dim = ceildiv(data_flat_last_dim, 128) * 4; + // Allocate memory for swizzled mxfp8 scaling factors + at::Tensor swizzled_scale_inv = + allocateSpace(std::vector{swizzled_scale_inv_first_dim, swizzled_scale_inv_last_dim}, + transformer_engine::DType::kByte, false); + // Set rowwise scaling factors on output + void *const swizzled_scale_inv_dptr = getDataPtr(swizzled_scale_inv, 0); + NVTEShape swizzled_scale_inv_shape{}; + swizzled_scale_inv_shape.data[0] = swizzled_scale_inv_first_dim; + swizzled_scale_inv_shape.data[1] = swizzled_scale_inv_last_dim; + swizzled_scale_inv_shape.ndim = 2; + output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, + swizzled_scale_inv_shape); + output_cu.set_with_gemm_swizzled_scales(true); + + // Convert scaling factors from FP8 block scaling GEMM_READY format to mxfp8 swizzled format + NVTE_SCOPED_GIL_RELEASE({ + nvte_swizzle_block_scaling_to_mxfp8_scaling_factors(input_cu.data(), output_cu.data(), + at::cuda::getCurrentCUDAStream()); + }); + + // Set the input tensor to be the converted mxfp8 tensor and return the swizzled scaling factor + // for it to be kept alive during the GEMM + input = std::move(output_cu); + return swizzled_scale_inv; +} + +void inplace_swizzle_scale_for_gemm(py::handle &tensor) { + // Convert Python tensor to C++ tensor + auto tensor_nvte = makeTransformerEngineTensor(tensor, py::none()); + + // Return early if scale swizzling is not required + const auto scaling_mode = tensor_nvte.scaling_mode(); + switch (scaling_mode) { + case NVTE_MXFP8_1D_SCALING: + case NVTE_NVFP4_1D_SCALING: + // Tensor format requires scale swizzling + break; + case NVTE_INVALID_SCALING: + NVTE_ERROR("Invalid scaling mode for swizzling scaling factors."); + default: + // Tensor format does not require scale swizzling for GEMM + return; + } + + // Return early if scales are already swizzled + if (tensor_nvte.get_with_gemm_swizzled_scales()) { + return; + } + + // Check what scaling factors the tensor contains + auto is_empty = [](const NVTEBasicTensor &t) -> bool { + return t.shape.ndim == 1 && t.shape.data[0] == 0; + }; + const bool has_rowwise_scales = !is_empty(tensor_nvte.get_rowwise_scale_inv()); + const bool has_columnwise_scales = !is_empty(tensor_nvte.get_columnwise_scale_inv()); + + // Swizzle scaling factors + auto [rowwise_scales, columnwise_scales] = + swizzle_scales_for_gemm(tensor_nvte, has_rowwise_scales, has_columnwise_scales); + + // Update Python tensor with swizzled scales + switch (scaling_mode) { + case NVTE_MXFP8_1D_SCALING: + if (has_rowwise_scales) { + tensor.attr("_rowwise_scale_inv") = rowwise_scales; + } + if (has_columnwise_scales) { + tensor.attr("_columnwise_scale_inv") = columnwise_scales; + } + tensor.attr("_with_gemm_swizzled_scales") = true; + break; + case NVTE_NVFP4_1D_SCALING: + if (has_rowwise_scales) { + tensor.attr("_rowwise_scale_inv") = rowwise_scales; + } + if (has_columnwise_scales) { + tensor.attr("_columnwise_scale_inv") = columnwise_scales; + } + tensor.attr("_with_gemm_swizzled_scales") = true; + break; + default: + NVTE_ERROR("Invalid scaling mode for swizzling scaling factors."); + } +} + +} // namespace pytorch +} // namespace transformer_engine diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index fd748d1b21e..c97124a5dbf 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -52,10 +52,12 @@ Quantizer::Quantizer(const py::handle& quantizer) { this->rowwise_usage = true; this->columnwise_usage = true; this->internal = false; + this->optimize_for_gemm = false; } else { this->rowwise_usage = quantizer.attr("rowwise_usage").cast(); this->columnwise_usage = quantizer.attr("columnwise_usage").cast(); this->internal = quantizer.attr("internal").cast(); + this->optimize_for_gemm = quantizer.attr("optimize_for_gemm").cast(); this->quantizer = quantizer; } } @@ -555,7 +557,6 @@ Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quanti this->amax_epsilon = quantizer.attr("amax_epsilon").cast(); NVTE_CHECK(this->block_scaling_dim == 1 || this->block_scaling_dim == 2, "Unsupported block scaling dim."); - this->all_gather_usage = quantizer.attr("all_gather_usage").cast(); } void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const {} @@ -575,10 +576,6 @@ std::pair Float8BlockQuantizer::create_tensor( opts = opts.dtype(torch::kUInt8).device(torch::kCUDA); scale_opts = scale_opts.dtype(torch::kFloat32).device(torch::kCUDA); - Float8BlockScaleTensorFormat data_format = - (all_gather_usage ? Float8BlockScaleTensorFormat::COMPACT - : Float8BlockScaleTensorFormat::GEMM_READY); - if (rowwise_usage) { data_rowwise = at::empty(torch_shape, opts); auto scale_shape = get_scale_shape(shape, false); @@ -597,21 +594,13 @@ std::pair Float8BlockQuantizer::create_tensor( NVTE_CHECK(torch_shape.size() == shape.size(), "Shape expected to match torch shape. Shape ", columnwise_shape, " torch shape: ", torch_columnwise_shape); if (torch_shape.size() > 0) { - if (!all_gather_usage) { - torch_columnwise_shape.reserve(torch_shape.size()); - columnwise_shape.reserve(shape.size()); - torch_columnwise_shape.push_back(torch_shape[torch_shape.size() - 1]); - columnwise_shape.push_back(shape[shape.size() - 1]); - for (size_t i = 0; i < torch_shape.size() - 1; ++i) { - torch_columnwise_shape.push_back(torch_shape[i]); - columnwise_shape.push_back(shape[i]); - } - } else { - // assert we are doing 1D scaling - NVTE_CHECK(block_scaling_dim == 1, - "Compact columnwise format is not supported for 128x128 2D block scaling."); - torch_columnwise_shape = torch_shape; - columnwise_shape = shape; + torch_columnwise_shape.reserve(torch_shape.size()); + columnwise_shape.reserve(shape.size()); + torch_columnwise_shape.push_back(torch_shape[torch_shape.size() - 1]); + columnwise_shape.push_back(shape[shape.size() - 1]); + for (size_t i = 0; i < torch_shape.size() - 1; ++i) { + torch_columnwise_shape.push_back(torch_shape[i]); + columnwise_shape.push_back(shape[i]); } } auto scale_shape = get_scale_shape(shape, true); @@ -635,7 +624,7 @@ std::pair Float8BlockQuantizer::create_tensor( "rowwise_data"_a = data_rowwise, "columnwise_data"_a = data_colwise, "rowwise_scale_inv"_a = scale_inv_rowwise, "columnwise_scale_inv"_a = scale_inv_colwise, "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer, - "is_2D_scaled"_a = (block_scaling_dim == 2), "data_format"_a = data_format); + "is_2D_scaled"_a = (block_scaling_dim == 2)); } else { py::handle Float8BlockwiseQTensorClass( reinterpret_cast(Float8BlockwiseQTensorPythonClass)); @@ -643,8 +632,7 @@ std::pair Float8BlockQuantizer::create_tensor( "shape"_a = torch_shape, "dtype"_a = GetATenDType(dtype), "rowwise_data"_a = data_rowwise, "columnwise_data"_a = data_colwise, "rowwise_scale_inv"_a = scale_inv_rowwise, "columnwise_scale_inv"_a = scale_inv_colwise, "fp8_dtype"_a = this->dtype, - "quantizer"_a = this->quantizer, "is_2D_scaled"_a = (block_scaling_dim == 2), - "data_format"_a = data_format); + "quantizer"_a = this->quantizer, "is_2D_scaled"_a = (block_scaling_dim == 2)); } return {std::move(tensor), std::move(ret)}; @@ -654,6 +642,7 @@ std::pair Float8BlockQuantizer::convert_and_update_te py::object tensor) const { const DType dtype = tensor.attr("_fp8_dtype").cast(); bool is_2D_scaled = tensor.attr("_is_2D_scaled").cast(); + const bool with_gemm_swizzled_scales = true; // Extract buffers from Python tensor auto get_tensor = [&tensor](const char* name) -> std::optional { @@ -675,13 +664,10 @@ std::pair Float8BlockQuantizer::convert_and_update_te opts = opts.dtype(torch::kUInt8).device(torch::kCUDA); scale_opts = scale_opts.dtype(torch::kFloat32).device(torch::kCUDA); - auto get_columnwise_shape = [&columnwise_data](bool all_gather_usage) -> std::vector { + auto get_columnwise_shape = [&columnwise_data]() -> std::vector { if (!columnwise_data) { return std::vector(); } - if (all_gather_usage) { - return getTensorShape(*columnwise_data); - } std::vector shape = getTensorShape(*columnwise_data); std::vector shape_transposed(shape.size()); for (size_t i = 0; i + 1 < shape.size(); ++i) { @@ -696,12 +682,12 @@ std::pair Float8BlockQuantizer::convert_and_update_te if (rowwise_data) { shape = getTensorShape(*rowwise_data); if (columnwise_data) { - auto expected_shape = get_columnwise_shape(all_gather_usage); + auto expected_shape = get_columnwise_shape(); NVTE_CHECK(shape == expected_shape, "BlockwiseFP8 row-wise data (shape=", shape, ") and column-wise data (shape=", expected_shape, ") do not match"); } } else { - shape = get_columnwise_shape(all_gather_usage); + shape = get_columnwise_shape(); } std::vector torch_shape; for (auto s : shape) { @@ -738,21 +724,13 @@ std::pair Float8BlockQuantizer::convert_and_update_te std::vector columnwise_shape; std::vector torch_columnwise_shape; if (torch_shape.size() > 0) { - if (!all_gather_usage) { - torch_columnwise_shape.reserve(torch_shape.size()); - columnwise_shape.reserve(shape.size()); - torch_columnwise_shape.push_back(torch_shape[torch_shape.size() - 1]); - columnwise_shape.push_back(shape[shape.size() - 1]); - for (size_t i = 0; i < torch_shape.size() - 1; ++i) { - torch_columnwise_shape.push_back(torch_shape[i]); - columnwise_shape.push_back(shape[i]); - } - } else { - // assert we are doing 1D scaling - NVTE_CHECK(block_scaling_dim == 1, - "Compact columnwise format is not supported for 128x128 2D block scaling."); - torch_columnwise_shape = torch_shape; - columnwise_shape = shape; + torch_columnwise_shape.reserve(torch_shape.size()); + columnwise_shape.reserve(shape.size()); + torch_columnwise_shape.push_back(torch_shape[torch_shape.size() - 1]); + columnwise_shape.push_back(shape[shape.size() - 1]); + for (size_t i = 0; i < torch_shape.size() - 1; ++i) { + torch_columnwise_shape.push_back(torch_shape[i]); + columnwise_shape.push_back(shape[i]); } } if (!columnwise_data) { @@ -798,6 +776,7 @@ std::pair Float8BlockQuantizer::convert_and_update_te const auto scale_inv_colwise_shape = getTensorShape(scale_inv_colwise); ret.set_columnwise_scale_inv(scale_inv_colwise_dptr, DType::kFloat32, scale_inv_colwise_shape); } + ret.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); set_quantization_params(&ret); return {std::move(ret), std::move(tensor)}; } @@ -813,9 +792,6 @@ void Float8BlockQuantizer::quantize(const TensorWrapper& input, TensorWrapper& o } quant_config.set_force_pow_2_scales(force_pow_2_scales); quant_config.set_amax_epsilon(amax_epsilon); - if (all_gather_usage) { - quant_config.set_float8_block_scale_tensor_format(Float8BlockScaleTensorFormat::COMPACT); - } NVTE_SCOPED_GIL_RELEASE({ nvte_quantize_v2(input.data(), out.data(), quant_config, at::cuda::getCurrentCUDAStream()); }); @@ -832,10 +808,6 @@ std::vector Float8BlockQuantizer::get_scale_shape(const std::vector scale_shape; bool rowwise_usage = !columnwise; @@ -845,26 +817,17 @@ std::vector Float8BlockQuantizer::get_scale_shape(const std::vector Float8BlockQuantizer::get_scale_shape(const std::vector MXFP8Quantizer::create_tensor(const std::ve DType dtype) const { using namespace pybind11::literals; + // Scaling factor format + const bool with_gemm_swizzled_scales = this->optimize_for_gemm; + // Tensor dimensions const std::vector shape_int64(shape.begin(), shape.end()); size_t flat_first_dim = 1; @@ -951,19 +909,17 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve py::object out_py; if (internal) { py::handle MXFP8TensorClass(reinterpret_cast(MXFP8TensorStoragePythonClass)); - out_py = MXFP8TensorClass("rowwise_data"_a = rowwise_data_py, - "columnwise_data"_a = columnwise_data_py, - "rowwise_scale_inv"_a = rowwise_scale_inv_py, - "columnwise_scale_inv"_a = columnwise_scale_inv_py, - "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer); + out_py = MXFP8TensorClass(rowwise_data_py, rowwise_scale_inv_py, columnwise_data_py, + columnwise_scale_inv_py, this->dtype, this->quantizer, + with_gemm_swizzled_scales); } else { py::handle MXFP8TensorClass(reinterpret_cast(MXFP8TensorPythonClass)); - out_py = MXFP8TensorClass("shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype), - "rowwise_data"_a = rowwise_data_py, - "columnwise_data"_a = columnwise_data_py, - "rowwise_scale_inv"_a = rowwise_scale_inv_py, - "columnwise_scale_inv"_a = columnwise_scale_inv_py, - "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer); + out_py = MXFP8TensorClass( + "shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype), + "rowwise_data"_a = rowwise_data_py, "columnwise_data"_a = columnwise_data_py, + "rowwise_scale_inv"_a = rowwise_scale_inv_py, + "columnwise_scale_inv"_a = columnwise_scale_inv_py, "fp8_dtype"_a = this->dtype, + "quantizer"_a = this->quantizer, "with_gemm_swizzled_scales"_a = with_gemm_swizzled_scales); } // Construct C++ MXFP8 tensor @@ -978,6 +934,7 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve out_cpp.set_columnwise_scale_inv(columnwise_scale_inv_tensor.data_ptr(), DType::kFloat8E8M0, columnwise_scale_inv_shape); } + out_cpp.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); this->set_quantization_params(&out_cpp); return {std::move(out_cpp), std::move(out_py)}; @@ -987,6 +944,9 @@ std::pair MXFP8Quantizer::convert_and_update_tensor( py::object tensor) const { NVTE_CHECK(detail::IsMXFP8Tensor(tensor.ptr()), "MXFP8Quantizer must output to MXFP8Tensor."); + // Scaling factor format + const bool with_gemm_swizzled_scales = this->optimize_for_gemm; + // Extract buffers from Python tensor auto get_tensor = [&tensor](const char* name) -> std::optional { auto attr_py = tensor.attr(name); @@ -1070,6 +1030,7 @@ std::pair MXFP8Quantizer::convert_and_update_tensor( // Coerce other attrs tensor.attr("_fp8_dtype") = dtype; + tensor.attr("_with_gemm_swizzled_scales") = with_gemm_swizzled_scales; // Construct C++ MXFP8 tensor TensorWrapper out_cpp(NVTE_MXFP8_1D_SCALING); @@ -1083,6 +1044,7 @@ std::pair MXFP8Quantizer::convert_and_update_tensor( out_cpp.set_columnwise_scale_inv(columnwise_scale_inv->data_ptr(), DType::kFloat8E8M0, getTensorShape(*columnwise_scale_inv)); } + out_cpp.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); this->set_quantization_params(&out_cpp); return {std::move(out_cpp), std::move(tensor)}; @@ -1173,6 +1135,9 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve DType dtype) const { using namespace pybind11::literals; + // Scaling factor format + const bool with_gemm_swizzled_scales = false; /// TODO (tmoon) self->optimize_for_gemm + // Tensor dimensions const std::vector shape_int64(shape.begin(), shape.end()); size_t flat_first_dim = 1; @@ -1235,12 +1200,9 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve py::object out_py; if (internal) { py::handle NVFP4TensorClass(reinterpret_cast(NVFP4TensorStoragePythonClass)); - out_py = NVFP4TensorClass( - "rowwise_data"_a = rowwise_data_py, "columnwise_data"_a = columnwise_data_py, - "rowwise_scale_inv"_a = rowwise_scale_inv_py, - "columnwise_scale_inv"_a = columnwise_scale_inv_py, "amax_rowwise"_a = amax_rowwise_py, - "amax_columnwise"_a = amax_columnwise_py, "fp4_dtype"_a = this->dtype, - "quantizer"_a = this->quantizer); + out_py = NVFP4TensorClass(rowwise_data_py, rowwise_scale_inv_py, columnwise_data_py, + columnwise_scale_inv_py, amax_rowwise_py, amax_columnwise_py, + this->dtype, this->quantizer, with_gemm_swizzled_scales); } else { py::handle NVFP4TensorClass(reinterpret_cast(NVFP4TensorPythonClass)); out_py = NVFP4TensorClass( @@ -1249,7 +1211,7 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve "rowwise_scale_inv"_a = rowwise_scale_inv_py, "columnwise_scale_inv"_a = columnwise_scale_inv_py, "amax_rowwise"_a = amax_rowwise_py, "amax_columnwise"_a = amax_columnwise_py, "fp4_dtype"_a = this->dtype, - "quantizer"_a = this->quantizer); + "quantizer"_a = this->quantizer, "with_gemm_swizzled_scales"_a = with_gemm_swizzled_scales); } // Construct C++ tensor @@ -1272,6 +1234,7 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve out_cpp.set_columnwise_amax(amax_columnwise.data_ptr(), DType::kFloat32, std::vector{1}); } + out_cpp.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); this->set_quantization_params(&out_cpp); return {std::move(out_cpp), std::move(out_py)}; @@ -1301,6 +1264,9 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( py::object tensor) const { NVTE_CHECK(detail::IsNVFP4Tensor(tensor.ptr()), "NVFP4Quantizer must output to IsNVFP4Tensor."); + // Scaling factor format + const bool with_gemm_swizzled_scales = false; // TODO (tmoon) Enable with optimize_for_gemm + // Extract buffers from Python tensor auto get_tensor = [&tensor](const char* name) -> std::optional { auto attr_py = tensor.attr(name); @@ -1438,6 +1404,7 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( out_cpp.set_columnwise_amax(amax_columnwise->data_ptr(), DType::kFloat32, std::vector{1}); } + out_cpp.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); this->set_quantization_params(&out_cpp); return {std::move(out_cpp), std::move(tensor)}; diff --git a/transformer_engine/pytorch/csrc/type_converters.cpp b/transformer_engine/pytorch/csrc/type_converters.cpp index 368e9dcdfa3..83a7153c1f5 100644 --- a/transformer_engine/pytorch/csrc/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/type_converters.cpp @@ -55,8 +55,9 @@ TensorWrapper NVTETensorFromFloat8Tensor(py::handle tensor, Quantizer *quantizer TensorWrapper NVTETensorFromMXFP8Tensor(py::handle tensor, Quantizer *quantizer) { auto ret = TensorWrapper(NVTE_MXFP8_1D_SCALING); - bool rowwise_usage = !(tensor.attr("_rowwise_data").is_none()); - bool columnwise_usage = !(tensor.attr("_columnwise_data").is_none()); + const bool rowwise_usage = !(tensor.attr("_rowwise_data").is_none()); + const bool columnwise_usage = !(tensor.attr("_columnwise_data").is_none()); + const bool with_gemm_swizzled_scales = tensor.attr("_with_gemm_swizzled_scales").cast(); NVTE_CHECK(rowwise_usage || columnwise_usage, "No data found for MXFP8 Tensor."); @@ -78,6 +79,9 @@ TensorWrapper NVTETensorFromMXFP8Tensor(py::handle tensor, Quantizer *quantizer) getTensorShape(scale_inv)); } + // Scale layout + ret.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); + // Quantizer state quantizer->set_quantization_params(&ret); @@ -93,6 +97,7 @@ TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, Quantizer auto ret = TensorWrapper(is_2D_scaled ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D); + // Row-wise data if (rowwise_usage) { const at::Tensor &data_rowwise = tensor.attr("_rowwise_data").cast(); const at::Tensor &scale_inv_rowwise = tensor.attr("_rowwise_scale_inv").cast(); @@ -102,6 +107,8 @@ TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, Quantizer const auto scale_inv_rowwise_shape = getTensorShape(scale_inv_rowwise); ret.set_rowwise_scale_inv(scale_inv_rowwise_dptr, DType::kFloat32, scale_inv_rowwise_shape); } + + // Column-wise data if (columnwise_usage) { const at::Tensor &data_colwise = tensor.attr("_columnwise_data").cast(); const at::Tensor &scale_inv_colwise = tensor.attr("_columnwise_scale_inv").cast(); @@ -112,7 +119,10 @@ TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, Quantizer const auto scale_inv_colwise_shape = getTensorShape(scale_inv_colwise); ret.set_columnwise_scale_inv(scale_inv_colwise_dptr, DType::kFloat32, scale_inv_colwise_shape); } + + // Quantizer state quantizer->set_quantization_params(&ret); + return ret; } @@ -123,6 +133,7 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) bool rowwise_usage = !(tensor.attr("_rowwise_data").is_none()); bool columnwise_usage = !(tensor.attr("_columnwise_data").is_none()); + const bool with_gemm_swizzled_scales = tensor.attr("_with_gemm_swizzled_scales").cast(); NVTE_CHECK(rowwise_usage || columnwise_usage, "No data found for NVFP4 Tensor."); @@ -150,6 +161,9 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) getTensorShape(amax_columnwise)); } + // Scale layout + ret.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); + // Quantizer state quantizer->set_quantization_params(&ret); diff --git a/transformer_engine/pytorch/csrc/util.cpp b/transformer_engine/pytorch/csrc/util.cpp deleted file mode 100644 index ce547d302e6..00000000000 --- a/transformer_engine/pytorch/csrc/util.cpp +++ /dev/null @@ -1,263 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include "util.h" - -#include "common.h" -#include "common/common.h" - -std::optional swizzle_scaling_factors(transformer_engine::TensorWrapper& input, - bool rowwise) { - using namespace transformer_engine::pytorch; - - if (input.scaling_mode() == NVTE_INVALID_SCALING) { - NVTE_ERROR("Invalid scaling mode for swizzle."); - } else if (input.scaling_mode() != NVTE_MXFP8_1D_SCALING && - input.scaling_mode() != NVTE_NVFP4_1D_SCALING) { - return std::nullopt; - } - - NVTE_CHECK(input.element_size_bits() == 4 || input.element_size_bits() == 8, - "4-bit or 8-bit input required for swizzling scaling factors."); - - const auto nvfp4 = input.scaling_mode() == NVTE_NVFP4_1D_SCALING; - - NVTEBasicTensor scale_inv; - NVTEShape nvte_input_shape; - if (rowwise) { - nvte_input_shape = input.shape(); - scale_inv = input.get_rowwise_scale_inv(); - } else { - nvte_input_shape = input.get_columnwise_data().shape; - scale_inv = input.get_columnwise_scale_inv(); - } - - auto input_shape = nvte_shape_to_vector(nvte_input_shape); - auto scale_inv_shape = nvte_shape_to_vector(scale_inv.shape); - - NVTE_CHECK(input_shape.size() >= 2, "Wrong ndims for swizzle input shape."); - - // Allocate memory for swizzled output. - auto options = at::TensorOptions().dtype(torch::kByte).device(torch::kCUDA); - std::vector scale_inv_shape_int; - for (size_t i = 0; i < scale_inv_shape.size(); ++i) { - scale_inv_shape_int.push_back(static_cast(scale_inv_shape[i])); - } - auto swizzled_scale_inv = at::empty(scale_inv_shape_int, options); - void* scale_inv_dptr = scale_inv.data_ptr; - void* swizzled_scale_inv_dptr = getDataPtr(swizzled_scale_inv, 0); - - transformer_engine::TensorWrapper input_cu(input.scaling_mode()); - transformer_engine::TensorWrapper output_cu(input.scaling_mode()); - - const auto input_dtype = - (nvfp4) ? transformer_engine::DType::kFloat4E2M1 : transformer_engine::DType::kFloat8E4M3; - const auto scale_inv_dtype = - (nvfp4) ? transformer_engine::DType::kFloat8E4M3 : transformer_engine::DType::kFloat8E8M0; - - if (rowwise) { - input_cu.set_rowwise_data(input.dptr(), input_dtype, input_shape); - input_cu.set_rowwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shape); - output_cu.set_rowwise_data(input.dptr(), input_dtype, input_shape); - output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape); - } else { - input_cu.set_columnwise_data(input.columnwise_dptr(), input_dtype, input_shape); - input_cu.set_columnwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shape); - output_cu.set_columnwise_data(input.columnwise_dptr(), input_dtype, input_shape); - output_cu.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape); - } - - // Launch kernel - nvte_swizzle_scaling_factors(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); - - if (rowwise) { - input.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape); - } else { - input.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape); - } - - return swizzled_scale_inv; -} - -std::optional multi_tensor_swizzle_scaling_factors( - std::vector& tensors, bool rowwise) { - using namespace transformer_engine::pytorch; - - if (tensors.empty()) { - return std::nullopt; - } - - bool all_same_scaling_mode = std::all_of( - tensors.cbegin(), tensors.cend(), [&tensors](const transformer_engine::TensorWrapper& val) { - return val.scaling_mode() == tensors.front().scaling_mode(); - }); - NVTE_CHECK(all_same_scaling_mode, "Scaling mode of the input tensors must be the same."); - - if (tensors.front().scaling_mode() == NVTE_INVALID_SCALING) { - NVTE_ERROR("Invalid scaling mode for swizzle."); - } else if (tensors.front().scaling_mode() != NVTE_MXFP8_1D_SCALING && - tensors.front().scaling_mode() != NVTE_NVFP4_1D_SCALING) { - return std::nullopt; - } - - const auto scaling_mode = tensors.front().scaling_mode(); - const auto nvfp4 = scaling_mode == NVTE_NVFP4_1D_SCALING; - - std::vector wrappers; - std::vector input_tensors, output_tensors; - - // Collect scale_inv shapes and calculate buffer size and offsets for scale_invs - std::vector> scale_inv_shapes; - std::vector scale_inv_dptrs; - size_t buffer_size = 0; - std::vector scale_inv_offsets; - constexpr size_t scale_elem_size = 1; - for (auto& tensor : tensors) { - NVTEBasicTensor scale_inv; - if (rowwise) { - scale_inv = tensor.get_rowwise_scale_inv(); - } else { - scale_inv = tensor.get_columnwise_scale_inv(); - } - auto scale_inv_shape = nvte_shape_to_vector(scale_inv.shape); - buffer_size = roundup(buffer_size, 16); // align to 16B - scale_inv_offsets.push_back(buffer_size); - buffer_size += product(scale_inv_shape) * scale_elem_size; - scale_inv_shapes.emplace_back(scale_inv_shape); - scale_inv_dptrs.push_back(scale_inv.data_ptr); - } - - // Allocate full buffer - auto buffer = at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)); - - const auto input_dtype = - (nvfp4) ? transformer_engine::DType::kFloat4E2M1 : transformer_engine::DType::kFloat8E4M3; - const auto scale_inv_dtype = - (nvfp4) ? transformer_engine::DType::kFloat8E4M3 : transformer_engine::DType::kFloat8E8M0; - - for (size_t i = 0; i < tensors.size(); ++i) { - auto& tensor = tensors[i]; - void* scale_inv_dptr = scale_inv_dptrs[i]; - void* swizzled_scale_inv_dptr = getDataPtr(buffer, scale_inv_offsets[i]); - - // Empty tensors don't require scale swizzling - if (tensor.numel() == 0) { - continue; - } - - // Tensor shape - NVTEShape nvte_input_shape; - if (rowwise) { - nvte_input_shape = tensor.shape(); - } else { - nvte_input_shape = tensor.get_columnwise_data().shape; - } - - auto input_shape = nvte_shape_to_vector(nvte_input_shape); - // Reconstruct input only to avoid swizzling both directions if not needed. - // Use any 8 bit type, it's irrelevant. - transformer_engine::TensorWrapper input_cu(scaling_mode); - transformer_engine::TensorWrapper output_cu(scaling_mode); - if (rowwise) { - input_cu.set_rowwise_data(tensor.dptr(), input_dtype, input_shape); - input_cu.set_rowwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shapes[i]); - output_cu.set_rowwise_data(tensor.dptr(), input_dtype, input_shape); - output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, - scale_inv_shapes[i]); - // Set the swizzled scaling factor to the original tensor. - tensor.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shapes[i]); - } else { - input_cu.set_columnwise_data(tensor.columnwise_dptr(), input_dtype, input_shape); - input_cu.set_columnwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shapes[i]); - output_cu.set_columnwise_data(tensor.columnwise_dptr(), input_dtype, input_shape); - output_cu.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, - scale_inv_shapes[i]); - // Set the swizzled scaling factor to the original tensor. - tensor.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, - scale_inv_shapes[i]); - } - - input_tensors.emplace_back(input_cu.data()); - output_tensors.emplace_back(output_cu.data()); - wrappers.emplace_back(std::move(input_cu)); - wrappers.emplace_back(std::move(output_cu)); - } - - // Launch kernel - nvte_multi_tensor_swizzle_scaling_factors(input_tensors.data(), output_tensors.data(), - input_tensors.size(), at::cuda::getCurrentCUDAStream()); - - return buffer; -} - -at::Tensor convert_block_scaling_to_mxfp8_tensor(transformer_engine::TensorWrapper& input, - bool rowwise) { - using namespace transformer_engine::pytorch; - using transformer_engine::DIVUP; - - // Check input tensor - const NVTEScalingMode scaling_mode = input.scaling_mode(); - NVTE_CHECK(scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D, - "Input tensor must be a block scaling tensor"); - - // Get tensor data - NVTEBasicTensor data; - size_t data_flat_first_dim = 1; - size_t data_flat_last_dim = 1; - if (rowwise) { - data = input.get_rowwise_data(); - for (size_t i = 0; i < data.shape.ndim - 1; ++i) { - data_flat_first_dim *= data.shape.data[i]; - } - data_flat_last_dim = data.shape.data[data.shape.ndim - 1]; - } else { - data = input.get_columnwise_data(); - data_flat_first_dim = data.shape.data[0]; - for (size_t i = 1; i < data.shape.ndim; ++i) { - data_flat_last_dim *= data.shape.data[i]; - } - } - NVTEShape data_shape{}; - data_shape.data[0] = data_flat_first_dim; - data_shape.data[1] = data_flat_last_dim; - data_shape.ndim = 2; - - // Recreate input tensor with rowwise usage - transformer_engine::TensorWrapper input_cu(scaling_mode); - input_cu.set_rowwise_data(data.data_ptr, input.dtype(), data_shape); - const NVTEBasicTensor scale_inv = - rowwise ? input.get_rowwise_scale_inv() : input.get_columnwise_scale_inv(); - input_cu.set_rowwise_scale_inv( - scale_inv.data_ptr, static_cast(scale_inv.dtype), scale_inv.shape); - - // Create output tensor - transformer_engine::TensorWrapper output_cu(NVTE_MXFP8_1D_SCALING); - output_cu.set_rowwise_data(data.data_ptr, input.dtype(), data_shape); - // Output swizzled mxfp8 scaling factor dimensions - const size_t swizzled_scale_inv_first_dim = DIVUP(data_flat_first_dim, 128) * 128; - const size_t swizzled_scale_inv_last_dim = DIVUP(data_flat_last_dim, 128) * 4; - // Allocate memory for swizzled mxfp8 scaling factors - const auto options = at::TensorOptions().dtype(torch::kByte).device(torch::kCUDA); - at::Tensor swizzled_scale_inv = at::empty( - std::vector{swizzled_scale_inv_first_dim, swizzled_scale_inv_last_dim}, options); - // Set rowwise scaling factors on output - void* const swizzled_scale_inv_dptr = getDataPtr(swizzled_scale_inv, 0); - NVTEShape swizzled_scale_inv_shape{}; - swizzled_scale_inv_shape.data[0] = swizzled_scale_inv_first_dim; - swizzled_scale_inv_shape.data[1] = swizzled_scale_inv_last_dim; - swizzled_scale_inv_shape.ndim = 2; - output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, - swizzled_scale_inv_shape); - - // Convert scaling factors from FP8 block scaling GEMM_READY format to mxfp8 swizzled format - nvte_swizzle_block_scaling_to_mxfp8_scaling_factors(input_cu.data(), output_cu.data(), - at::cuda::getCurrentCUDAStream()); - - // Set the input tensor to be the converted mxfp8 tensor and return the swizzled scaling factor - // for it to be kept alive during the GEMM - input = std::move(output_cu); - return swizzled_scale_inv; -} diff --git a/transformer_engine/pytorch/csrc/util.h b/transformer_engine/pytorch/csrc/util.h index 57eee86d2aa..0776067e62f 100644 --- a/transformer_engine/pytorch/csrc/util.h +++ b/transformer_engine/pytorch/csrc/util.h @@ -10,33 +10,44 @@ #include #include +#include +#include #include "transformer_engine/transformer_engine.h" -/*! \brief Swizzle the scaling factor of the input tensor. +namespace transformer_engine { +namespace pytorch { + +/*! \brief Convert tensor block scales into GEMM swizzled format. * - * The returned swizzled scaling factor tensor should be kept alive during the GEMM. + * The returned swizzled scales should be kept alive during the GEMM. */ -std::optional swizzle_scaling_factors(transformer_engine::TensorWrapper &input, - bool rowwise); +std::tuple, std::optional> swizzle_scales_for_gemm( + TensorWrapper& tensor, bool rowwise_usage, bool columnwise_usage); -/*! \brief Swizzle the scaling factor of the input tensors. +/*! \brief Convert multiple tensor block scales into GEMM swizzled format. * - * The returned swizzled scaling factor tensors should be kept alive during the GEMMs. + * The returned swizzled scales should be kept alive during the GEMMs. */ -std::optional multi_tensor_swizzle_scaling_factors( - std::vector &inputs, bool rowwise); +std::optional multi_tensor_swizzle_scales_for_gemm(std::vector& tensors, + bool rowwise_usage, + bool columnwise_usage); /*! \brief Convert a block scaling tensor to an mxfp8 tensor in-place. * - * If rowwise==false, the columnwise data will be reinterpreted as rowwise data to avoid - * transposing it in memory. Due to differences in how block scaling and mxfp8 store data, - * this requires the calling code to treat the output tensor as having been tranposed in this case. + * If rowwise==false, the columnwise data will be reinterpreted as + * rowwise data to avoid transposing it in memory. Due to differences + * in how block scaling and mxfp8 store data, this requires the + * calling code to treat the output tensor as having been transposed + * in this case. * - * Returns the swizzled scaling factor of the converted mxfp8 tensor. - * The returned swizzled scaling factor tensor should be kept alive during the GEMM. + * Returns the swizzled scaling factor of the converted mxfp8 tensor. + * The returned swizzled scaling factor tensor should be kept alive + * during the GEMM. */ -at::Tensor convert_block_scaling_to_mxfp8_tensor(transformer_engine::TensorWrapper &input, - bool rowwise); +at::Tensor convert_block_scaling_to_mxfp8_tensor(TensorWrapper& input, bool rowwise); + +} // namespace pytorch +} // namespace transformer_engine #endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_ diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index deb9b3ff918..180989bb597 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -48,7 +48,7 @@ from .tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage from .tensor.storage.nvfp4_tensor_storage import NVFP4TensorStorage from .tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage -from ..debug.pytorch.debug_quantization import DebugQuantizedTensor, DebugQuantizer +from ..debug.pytorch.debug_quantization import DebugQuantizedTensor __all__ = ["checkpoint", "CudaRNGStatesTracker"] @@ -909,6 +909,34 @@ def reduce_scatter_along_first_dim( return output, handle +@dataclass +class _AsyncHandle: + """Handle for asynchronous collectives.""" + + async_handle: torch.distributed.Work + post_process_function: Optional[Callable] = None + post_process_function_args: Optional[Tuple[Any, ...]] = None + post_process_function_kwargs: Optional[Dict[str, Any]] = None + _synchronized: bool = False + + def wait(self) -> None: + """Synchronize the asynchronous communicaton. + + Perform post-processing if needed. + + """ + if self._synchronized: + return + self.async_handle.wait() + if self.post_process_function is not None: + args = self.post_process_function_args + args = () if args is None else args + kwargs = self.post_process_function_kwargs + kwargs = {} if kwargs is None else kwargs + self.post_process_function(*args, **kwargs) + self._synchronized = True + + def _all_gather_fp8( inp: torch.Tensor, process_group: dist_group_type, @@ -999,73 +1027,7 @@ def _all_gather_fp8( return out, handle -def _get_quantizer_format(quantizer: Quantizer) -> Optional[bool]: - """Get quantizer format.""" - if isinstance(quantizer, DebugQuantizer): - quantizer = quantizer.parent_quantizer - if isinstance(quantizer, Float8BlockQuantizer): - return quantizer.all_gather_usage - return None - - -def _set_quantizer_format(quantizer: Quantizer, compact: bool = False) -> None: - """Make quantizer compact""" - _quantizer = quantizer - if isinstance(quantizer, DebugQuantizer): - _quantizer = quantizer.parent_quantizer - if isinstance(_quantizer, Float8BlockQuantizer): - _quantizer.all_gather_usage = compact - - -def _post_process_fp8_blockwise_gather( - out: Float8BlockwiseQTensorStorage, - quantizer: Float8BlockQuantizer, - handle: Optional[torch.distributed.Work] = None, -) -> Float8BlockwiseQTensorStorage: - """Post-process FP8 blockwise gather.""" - if handle is not None: - handle.wait() - handle = None - - if out._is_gemm_ready_format(): - return out - - needs_columnwise_data_transpose = quantizer is not None and quantizer.columnwise_usage - need_rowwise_scale_transpose = quantizer is not None and quantizer.rowwise_usage - - # CuBLAS requires transpose of the scale inv tensor, suppose orig input is 256x1024 - # columnwise compact format means doing 128x1 quantization of it - # so quantized tensor is 256x1024, scale inv is 2x1024 - # If we were doing GEMM_READY format, then it's equivalent to do 1x128 quantization - # on a transposed 1024x256 tensor, so scale inv is 1024x2, cublas requries 2x1024 - # Thereforce, it turns out we don't need to transpose the scale inv, only columnwise data - if needs_columnwise_data_transpose: - out._transpose_columnwise_data() - if need_rowwise_scale_transpose: - out._rowwise_scale_inv = out._rowwise_scale_inv.transpose(-2, -1).contiguous() - out._data_format = tex.Float8BlockScaleTensorFormat.GEMM_READY - return out - - -@dataclass -class _FP8BlockwiseAllGatherAsyncHandle: - """Handle for asynchronous FP8 blockwise all-gather.""" - - tensor: Float8BlockwiseQTensorStorage - quantizer: Float8BlockQuantizer - async_handle: torch.distributed.Work - _synchronized: bool = False - - def wait(self) -> None: - """Wait for the async operation to complete and post-process the tensor.""" - if self._synchronized: - return - self.async_handle.wait() - _post_process_fp8_blockwise_gather(self.tensor, self.quantizer) - self._synchronized = True - - -def _all_gather_fp8_blockwise( +def _start_all_gather_fp8_blockwise( inp: torch.Tensor, process_group: dist_group_type, *, @@ -1104,44 +1066,24 @@ def _all_gather_fp8_blockwise( ) world_size = get_distributed_world_size(process_group) - # Check that quantizer is valid - if quantizer is not None and not isinstance(quantizer, Float8BlockQuantizer): - raise ValueError(f"Got non-FP8 blockwise quantizer ({quantizer.__class__.__name__})") - if not (quantizer.block_scaling_dim == 1 and quantizer.block_len == 128): - raise NotImplementedError("Only 1D blockwise quantization is supported for allgather") - # Output tensor dims if out_shape is None: out_shape = list(inp.size()) out_shape[0] *= world_size - # Doing BF16 gather for now as baseline because it's simpler - if ( - not isinstance(inp, Float8BlockwiseQTensorStorage) - and quantizer is not None - and not quantizer.is_quantizable(inp) - ): - out = torch.empty( - out_shape, - dtype=dtype, - device=device, - memory_format=torch.contiguous_format, - ) + # Check that quantizer is valid + if quantizer is not None and not isinstance(quantizer, Float8BlockQuantizer): + raise ValueError(f"Got non-FP8 blockwise quantizer ({quantizer.__class__.__name__})") + + # Fall back to high-precision all-gather if FP8 is not supported + if quantizer is None or not quantizer.is_quantizable(inp) or quantizer.block_scaling_dim != 1: + out = torch.empty(out_shape, dtype=dtype, device=device) torch.distributed.all_gather_into_tensor(out, inp, group=process_group, async_op=False) - orig_all_gather_usage = quantizer.all_gather_usage - quantizer.all_gather_usage = False - out = quantizer(out) - quantizer.all_gather_usage = orig_all_gather_usage + if quantizer is not None: + out = quantizer(out) return out, None - # Implementation of fp8 gather needs to account for: - # * Getting columnwise data as a transpose of how it is stored for GEMMS. - # * Gathering non GEMM swizzled scales. - - # Cast input tensor to Float8BlockwiseQTensor with required data - # Set to compact usage in case the quantizer is not correctly configured - orig_all_gather_usage = quantizer.all_gather_usage - quantizer.all_gather_usage = True + # Quantize input tensor if needed if not isinstance(inp, Float8BlockwiseQTensorStorage): inp = quantizer(inp) elif (quantizer.rowwise_usage and inp._rowwise_data is None) or ( @@ -1156,14 +1098,9 @@ def _all_gather_fp8_blockwise( # Construct Float8BlockwiseQTensor output tensor out = quantizer.make_empty(out_shape, dtype=dtype, device=device) - quantizer.all_gather_usage = orig_all_gather_usage - - # Begin to do network communication, need to make sure compact format - if inp._data_format != tex.Float8BlockScaleTensorFormat.COMPACT: - raise RuntimeError( - "All-gather with FP8 block-wise quantized tensor requires compact data format, " - f"but found data_format={inp._data_format}" - ) + # Temporary buffers for all-gathering transposed buffers + interleaved_rowwise_scale_inv = None + interleaved_columnwise_data = None # Coalesce NCCL collectives with torch.distributed._coalescing_manager( @@ -1172,11 +1109,17 @@ def _all_gather_fp8_blockwise( async_ops=async_op, ) as coalescing_manager: - # Gather Float8BlockwiseQTensor data for row-wise usage + # Gather row-wise data if quantizer.rowwise_usage: - # Launch all-gathers + scale_inv_shape = list(inp._rowwise_scale_inv.size()) + scale_inv_shape[0] *= world_size + interleaved_rowwise_scale_inv = torch.empty( + scale_inv_shape, + dtype=inp._rowwise_scale_inv.dtype, + device=device, + ) torch.distributed.all_gather_into_tensor( - out._rowwise_scale_inv, + interleaved_rowwise_scale_inv, inp._rowwise_scale_inv, group=process_group, ) @@ -1186,36 +1129,73 @@ def _all_gather_fp8_blockwise( group=process_group, ) - # Gather Float8BlockwiseQTensor data for column-wise usage + # Column-wise data if quantizer.columnwise_usage: - # Launch all-gathers + data_shape = list(inp._columnwise_data.size()) + data_shape[0] *= world_size + interleaved_columnwise_data = torch.empty( + data_shape, + dtype=inp._columnwise_data.dtype, + device=device, + ) torch.distributed.all_gather_into_tensor( out._columnwise_scale_inv, inp._columnwise_scale_inv, group=process_group, ) torch.distributed.all_gather_into_tensor( - out._columnwise_data, + interleaved_columnwise_data, inp._columnwise_data, group=process_group, ) - handle = coalescing_manager if async_op else None - - # Unlike MXFP8, this fp8 blockwise tensor primarily works with Hopper - # This means that we need to transpose the gathered columnwise data - # Example usage is grad_output tensor, ie. dY in linear backward - # We want to gather two FP8 tensors (rowwise and columnwise) along dim0 - # and then transpose the columnwise data to match the rowwise data - # Make sure FP8 transpose is populated if needed - + # Finalize communication if needed + async_handle = None if async_op: - handle = _FP8BlockwiseAllGatherAsyncHandle(out, quantizer, handle) + async_handle = _AsyncHandle( + coalescing_manager, + post_process_function=_finish_all_gather_fp8_blockwise, + post_process_function_args=( + out, + world_size, + interleaved_rowwise_scale_inv, + interleaved_columnwise_data, + ), + ) else: - # if it's a sync op, we need to do the transpose here as post processing step - _post_process_fp8_blockwise_gather(out, quantizer, handle) + _finish_all_gather_fp8_blockwise( + out, + world_size, + interleaved_rowwise_scale_inv, + interleaved_columnwise_data, + ) - return out, handle + return out, async_handle + + +def _finish_all_gather_fp8_blockwise( + out: Float8BlockwiseQTensorStorage, + world_size: int, + interleaved_rowwise_scale_inv: Optional[torch.Tensor], + interleaved_columnwise_data: Optional[torch.Tensor], +) -> Float8BlockwiseQTensorStorage: + """Post-process FP8 blockwise gather.""" + + # Fix interleaving in row-wise scales + if interleaved_rowwise_scale_inv is not None: + dim0 = out._rowwise_scale_inv.size(0) + view_in = interleaved_rowwise_scale_inv.view(world_size, dim0, -1) + view_out = out._rowwise_scale_inv.view(dim0, world_size, -1) + tex.swap_first_dims(view_in, out=view_out) + + # Fix interleaving in column-wise data + if interleaved_columnwise_data is not None: + dim0 = out._columnwise_data.size(0) + view_in = interleaved_columnwise_data.view(world_size, dim0, -1) + view_out = out._columnwise_data.view(dim0, world_size, -1) + tex.swap_first_dims(view_in, out=view_out) + + return out def _swap_first_dims(tensor: torch.Tensor, world_size: int): @@ -1229,7 +1209,7 @@ def _swap_first_dims(tensor: torch.Tensor, world_size: int): """ shape = tensor.shape - assert tensor.ndim >= 2, "Wrong number of dimensions for fixing interleave." + assert len(shape) >= 2, "Wrong number of dimensions for fixing interleave." first_dim = shape[0] flattened_trailing = math.prod(shape[1:]) assert first_dim % world_size == 0, "Wrong dimensions for fixing interleave." @@ -1660,7 +1640,7 @@ def gather_along_first_dim( if isinstance(inp, Float8BlockwiseQTensorStorage) or isinstance( quantizer, Float8BlockQuantizer ): - return _all_gather_fp8_blockwise( + return _start_all_gather_fp8_blockwise( inp, process_group, async_op=async_op, @@ -1698,10 +1678,6 @@ def gather_along_first_dim( ) if isinstance(inp, QuantizedTensorStorage): inp = inp.dequantize() - # Falling back to high-precision all-gather for Float8BlockQuantizer - # means that it should directly output GEMM_READY format - compact = _get_quantizer_format(quantizer) - _set_quantizer_format(quantizer, compact=False) out = torch.empty( out_shape, dtype=inp.dtype, @@ -1710,7 +1686,6 @@ def gather_along_first_dim( ) torch.distributed.all_gather_into_tensor(out, inp, group=process_group) out = quantizer(out) - _set_quantizer_format(quantizer, compact=compact) return out, None # Dequantize quantized tensor if not supported diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index ab7cd9ab47a..313fc5569f1 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -560,6 +560,8 @@ def fill_userbuffers_buffer_for_all_gather( "Userbuffers requires MXFP8 tensor dims that are divisible by 128, " f"but got MXFP8 tensor with shape={tuple(local_shape)}" ) + if local_tensor._with_gemm_swizzled_scales: + raise ValueError("Userbuffers assumes MXFP8 tensors have unswizzled scales") local_scale_inv = ( local_tensor._rowwise_scale_inv if with_rowwise_data @@ -592,6 +594,7 @@ def fill_userbuffers_buffer_for_all_gather( columnwise_scale_inv=columnwise_scale_inv, fp8_dtype=local_tensor._fp8_dtype, quantizer=quantizer, + with_gemm_swizzled_scales=False, ) return global_tensor, local_tensor diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index c4d35a9c2cd..b49d6cfc832 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -715,13 +715,9 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: """Init scales and amaxes for fwd | bwd.""" super().set_meta_tensor(fwd, recipe) - # customize quantizers based on each recipe & layer configs + # Recipe-specific quantizer configuration recipe = FP8GlobalStateManager.get_fp8_recipe() if recipe.float8_current_scaling(): - assert not self.tp_size > 1, ( - "GroupedLinear doesn't support TP > 1 with Float8 current scaling. " - "Because the TP communication is handled outside of this module." - ) self._customize_quantizers_float8_current_scaling(fwd, recipe) def reset_parameters(self, defer_init=False): @@ -874,9 +870,12 @@ def backward_dw(self): def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None: """Customize quantizers based on current scaling recipe + linear.""" - assert ( - recipe.float8_current_scaling() - ), "current scaling recipe quantizer customization here" + + assert not self.tp_size > 1, ( + "GroupedLinear doesn't support TP > 1 with Float8 current scaling. " + "Because the TP communication is handled outside of this module." + ) + if fwd: for i in range(self.num_gemms): # set configs about amax epsilon and power_2_scale @@ -949,9 +948,9 @@ def _get_quantizers(self): ] for i in range(self.num_gemms) ] - # TODO: use internal after #1638 is merged. # pylint: disable=fixme for i in range(self.num_gemms): - input_quantizers[i].internal = False + input_quantizers[i].internal = True + input_quantizers[i].optimize_for_gemm = True if torch.is_grad_enabled(): grad_output_quantizers = [ self.quantizers["scaling_bwd"][ @@ -961,6 +960,7 @@ def _get_quantizers(self): ] for i in range(self.num_gemms): grad_output_quantizers[i].internal = True + grad_output_quantizers[i].optimize_for_gemm = True return ( input_quantizers, weight_quantizers, diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 667c199c49f..86a4632c16e 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -64,7 +64,6 @@ restore_from_saved, ) from ...debug.pytorch.debug_state import TEDebugState -from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..cpu_offload import ( is_cpu_offload_enabled, @@ -253,8 +252,6 @@ def forward( if fp8 or debug: ln_out = input_quantizer(ln_out) input_quantizer.set_usage(rowwise=True, columnwise=False) - if isinstance(input_quantizer, Float8BlockQuantizer): - input_quantizer.all_gather_usage = False ln_out_total = input_quantizer(ln_out_total) else: quantizer = None @@ -1409,15 +1406,12 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: """Init scales and amaxes for fwd | bwd.""" super().set_meta_tensor(fwd, recipe) - # customize quantizers based on each recipe & layer configs + # Recipe-specific quantizer configuration recipe = FP8GlobalStateManager.get_fp8_recipe() if recipe.float8_current_scaling(): self._customize_quantizers_float8_current_scaling(fwd, recipe) - elif recipe.float8_block_scaling(): - self._customize_quantizers_float8_blockwise_scaling(fwd, recipe) elif recipe.nvfp4(): self._customize_quantizers_nvfp4(fwd, recipe) - # elif other recipes (mxfp8, etc) def reset_layer_norm_parameters(self) -> None: """Init LN params""" @@ -1619,12 +1613,16 @@ def _get_quantizers(self, fp8_output, fp8_grad, is_grad_enabled): output_quantizer = None input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] input_quantizer.internal = True + if not (self.parallel_mode == "column" and self.sequence_parallel): + input_quantizer.optimize_for_gemm = True (weight_quantizer,) = self._get_weight_quantizers() if fp8_output: output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT] if is_grad_enabled: grad_output_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1] grad_output_quantizer.internal = True + if not (self.parallel_mode == "row" and self.sequence_parallel): + grad_output_quantizer.optimize_for_gemm = True if fp8_grad: grad_input_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1] @@ -1808,14 +1806,3 @@ def _get_weight_quantizers(self) -> List[Quantizer]: weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] weight_quantizer.internal = True return [weight_quantizer] - - def _customize_quantizers_float8_blockwise_scaling(self, fwd: bool, recipe: Recipe) -> None: - """Customize quantizers based on blockwise scaling recipe + layernorm_linear.""" - assert ( - recipe.float8_block_scaling() - ), "blockwise scaling recipe quantizer customization here" - if fwd: - if self.sequence_parallel and self.parallel_mode == "column": - self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_INPUT - ].all_gather_usage = True diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 56e050fe886..95a22830bbd 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -431,8 +431,6 @@ def _forward( if fp8 or debug: ln_out = fc1_input_quantizer(ln_out) fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) - if isinstance(fc1_input_quantizer, Float8BlockQuantizer): - fc1_input_quantizer.all_gather_usage = False ln_out_total = fc1_input_quantizer(ln_out_total) else: quantizer = None @@ -1964,15 +1962,12 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: """Init scales and amaxes for fwd | bwd.""" super().set_meta_tensor(fwd, recipe) - # customize quantizers based on each recipe & layer configs + # Recipe-specific quantizer configuration recipe = FP8GlobalStateManager.get_fp8_recipe() if recipe.float8_current_scaling(): self._customize_quantizers_float8_current_scaling(fwd, recipe) - elif recipe.float8_block_scaling(): - self._customize_quantizers_float8_blockwise_scaling(fwd, recipe) elif recipe.nvfp4(): self._customize_quantizers_nvfp4(fwd, recipe) - # elif for other recipes (mxfp8, etc.) def reset_layer_norm_parameters(self) -> None: """Init LN params""" @@ -2193,6 +2188,8 @@ def _get_quantizers(self, fp8_output, is_grad_enabled): if self.fp8 or self.fp8_calibration: fc1_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] fc1_input_quantizer.internal = True + if not self.sequence_parallel: + fc1_input_quantizer.optimize_for_gemm = True fc2_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT] fc2_input_quantizer.set_usage( rowwise=True, @@ -2201,7 +2198,8 @@ def _get_quantizers(self, fp8_output, is_grad_enabled): (MXFP8Quantizer, Float8BlockQuantizer, NVFP4Quantizer), ), ) - fc1_input_quantizer.internal = True + fc2_input_quantizer.internal = True + fc2_input_quantizer.optimize_for_gemm = True if fp8_output: fc2_output_quantizer = self.quantizers["scaling_fwd"][ tex.FP8FwdTensors.GEMM2_OUTPUT @@ -2211,10 +2209,13 @@ def _get_quantizers(self, fp8_output, is_grad_enabled): tex.FP8BwdTensors.GRAD_OUTPUT2 ] fc2_grad_output_quantizer.internal = True + if not self.sequence_parallel: + fc2_grad_output_quantizer.optimize_for_gemm = True fc1_grad_output_quantizer = self.quantizers["scaling_bwd"][ tex.FP8BwdTensors.GRAD_OUTPUT1 ] fc1_grad_output_quantizer.internal = True + fc1_grad_output_quantizer.optimize_for_gemm = True return ( fc1_input_quantizer, @@ -2458,22 +2459,6 @@ def _get_weight_quantizers(self) -> List[Quantizer]: fc2_weight_quantizer.internal = True return [fc1_weight_quantizer, fc2_weight_quantizer] - def _customize_quantizers_float8_blockwise_scaling(self, fwd: bool, recipe: Recipe) -> None: - """Customize quantizers based on blockwise scaling recipe + layernorm_mlp.""" - assert ( - recipe.float8_block_scaling() - ), "blockwise scaling recipe quantizer customization here" - if fwd: - if self.sequence_parallel and self.set_parallel_mode: - self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_INPUT - ].all_gather_usage = True - else: - if self.sequence_parallel and self.set_parallel_mode: - self.quantizers["scaling_bwd"][ - tex.FP8BwdTensors.GRAD_OUTPUT2 - ].all_gather_usage = True - def backward_dw(self): """ Execute the delayed weight gradient computation. diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index b65f7005eb3..3c7a58a08c6 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -1312,15 +1312,12 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: """Init scales and amaxes for fwd | bwd.""" super().set_meta_tensor(fwd, recipe) - # customize quantizers based on each recipe & layer configs + # Recipe-specific quantizer configuration recipe = FP8GlobalStateManager.get_fp8_recipe() if recipe.float8_current_scaling(): self._customize_quantizers_float8_current_scaling(fwd, recipe) - elif recipe.float8_block_scaling(): - self._customize_quantizers_float8_blockwise_scaling(fwd, recipe) elif recipe.nvfp4(): self._customize_quantizers_nvfp4(fwd, recipe) - # elif for other recipes (mxfp8, etc.) def reset_parameters(self, defer_init=False): super().reset_parameters(defer_init=defer_init) @@ -1488,12 +1485,16 @@ def _get_quantizers(self, fp8_output, fp8_grad, is_grad_enabled): output_quantizer = None input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] input_quantizer.internal = True + if not (self.parallel_mode == "column" and self.sequence_parallel): + input_quantizer.optimize_for_gemm = True (weight_quantizer,) = self._get_weight_quantizers() if fp8_output: output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT] if is_grad_enabled: grad_output_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1] grad_output_quantizer.internal = True + if not (self.parallel_mode == "row" and self.sequence_parallel): + grad_output_quantizer.optimize_for_gemm = True if fp8_grad: grad_input_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1] return ( @@ -1668,22 +1669,3 @@ def _get_weight_quantizers(self) -> List[Quantizer]: weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] weight_quantizer.internal = True return [weight_quantizer] - - def _customize_quantizers_float8_blockwise_scaling(self, fwd: bool, recipe: Recipe) -> None: - """Customize quantizers based on blockwise scaling recipe + linear.""" - assert ( - recipe.float8_block_scaling() - ), "blockwise scaling recipe quantizer customization here" - - if fwd: - if self.sequence_parallel and self.parallel_mode == "column": - # set compact for inp tensor X - self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_INPUT - ].all_gather_usage = True - else: - if self.sequence_parallel and self.parallel_mode == "row": - # set compact for grad_output tensor dY - self.quantizers["scaling_bwd"][ - tex.FP8BwdTensors.GRAD_OUTPUT1 - ].all_gather_usage = True diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 9f09e6634be..3e28b6e7ccf 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -342,15 +342,21 @@ def pre_fuser_forward(self, *, requires_grad: bool) -> None: def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: super().reset_recipe_state(recipe=recipe) - # Input/grad output quantizers use internal tensors + # Configure input/grad output tensor + # Note: These tensors are only used internally. If there is no + # tensor-parallel communication, they are only used for GEMM. input_quantizer = self.get_quantizer("forward", 0) grad_output_quantizer = self.get_quantizer("backward", 0) if input_quantizer is not None: input_quantizer.internal = True + if not (self.tensor_parallel_mode == "column" and self.sequence_parallel): + input_quantizer.optimize_for_gemm = True if grad_output_quantizer is not None: grad_output_quantizer.internal = True + if not (self.tensor_parallel_mode == "row" and self.sequence_parallel): + grad_output_quantizer.optimize_for_gemm = True - # Handle weight quantizer + # Configure weight quantizer # Note: This function may be called in base class constructor, # before any basic linear attrs have been set. weight_quantizer = self.get_quantizer("forward", 1) diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py index 5149aa1ffb3..012e9ca8eb7 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py @@ -292,6 +292,7 @@ def _functional_backward( rowwise=True, columnwise=with_columnwise, ) + grad_output_quantizer.optimize_for_gemm = False dy_local = grad_output_quantizer(dy_local) else: dy_local = maybe_dequantize(dy_local, dtype) diff --git a/transformer_engine/pytorch/permutation.py b/transformer_engine/pytorch/permutation.py index 9fce9cefcf7..556411e6adb 100644 --- a/transformer_engine/pytorch/permutation.py +++ b/transformer_engine/pytorch/permutation.py @@ -290,6 +290,7 @@ def forward( columnwise_scale_inv=None, quantizer=None, requires_grad=output.requires_grad, + with_gemm_swizzled_scales=False, ) ctx.save_for_backward(row_id_map) @@ -493,6 +494,7 @@ def backward(ctx, unpermuted_act_grad): columnwise_scale_inv=None, quantizer=None, requires_grad=act_grad.requires_grad, + with_gemm_swizzled_scales=False, ) if not ctx.needs_input_grad[2]: diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index c9a4467a82f..b301f7c194d 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -204,10 +204,21 @@ class Quantizer(abc.ABC): """ internal: bool + """Whether to solely optimize for matrix multiplication + + The resulting quantized tensors are not guaranteed to support any + operation other than matrix multiplication. Use with care since + this is likely to break communication, checkpointing, and many + other features. + + """ + optimize_for_gemm: bool + def __init__(self, *, rowwise: bool, columnwise: bool) -> None: self.rowwise_usage = rowwise self.columnwise_usage = columnwise self.internal = False + self.optimize_for_gemm = False def __repr__(self): return ( @@ -319,7 +330,11 @@ def supports_only_rowwise_all_gather(self) -> bool: return False def is_quantizable(self, inp: torch.Tensor) -> bool: # pylint: disable=unused-argument - """Returns whether or not given tensor can be quantized""" + """Whether tensor supports quantized all-gather + + Consider a less misleading function name. + + """ return True def get_usages(self) -> Dict[str, bool]: diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 01e03e53551..7d976b1a091 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -4,14 +4,14 @@ """Tensor class with FP8 data quantized with NxN tiles""" from __future__ import annotations -from typing import Optional, Tuple, Iterable, Union - +from collections.abc import Iterable import math +from typing import Any, Optional, Tuple, Union + import torch + import transformer_engine_torch as tex from transformer_engine_torch import DType as TE_DType -from transformer_engine_torch import Float8BlockScaleTensorFormat - from transformer_engine.common.recipe import Float8BlockScaling, Recipe from .storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage from ..quantized_tensor import QuantizedTensor, Quantizer @@ -35,8 +35,6 @@ class Float8BlockQuantizer(Quantizer): amax_epsilon: float force_pow_2_scales: bool block_scaling_dim: int - # Whether to produce tensors that will be used in all-gather - all_gather_usage: bool def __init__( self, @@ -47,7 +45,6 @@ def __init__( amax_epsilon: float = 0.0, force_pow_2_scales: bool = True, block_scaling_dim: int = 2, - all_gather_usage: bool = False, ) -> None: super().__init__(rowwise=rowwise, columnwise=columnwise) self.dtype = fp8_dtype @@ -55,7 +52,6 @@ def __init__( self.force_pow_2_scales = force_pow_2_scales self.amax_epsilon = amax_epsilon self.block_scaling_dim = block_scaling_dim - self.all_gather_usage = all_gather_usage def copy(self) -> Float8BlockQuantizer: """Create shallow copy""" @@ -65,11 +61,11 @@ def copy(self) -> Float8BlockQuantizer: rowwise=self.rowwise_usage, columnwise=self.columnwise_usage, block_scaling_dim=self.block_scaling_dim, - all_gather_usage=self.all_gather_usage, amax_epsilon=self.amax_epsilon, force_pow_2_scales=self.force_pow_2_scales, ) quantizer.internal = self.internal + quantizer.optimize_for_gemm = self.optimize_for_gemm return quantizer @@ -123,103 +119,86 @@ def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: return tex.quantize(tensor, self) def get_scale_shape(self, shape: Iterable[int], columnwise: bool) -> Tuple[int, int]: - """Calculate the shape of the scaling tensor for blockwise quantization. + """Scaling tensor shape. - This method determines the shape of the scaling tensor needed for blockwise quantization, - taking into account the input tensor shape and whether columnwise scaling is used. - The scales are padded to multiples of 4 on the inner dimension for compatibility with GEMM. + This method determines the shape of the scaling tensor based + on the quantizer configuration. The scales are padded to + multiples of 4 for compatibility with GEMM. Parameters ---------- shape : Iterable[int] - Shape of the input tensor to be quantized + Logical tensor shape. columnwise : bool - Whether to use columnwise scaling (True) or rowwise scaling (False) + Whether the data is scaled column-wise (True) or row-wise (False). Returns ------- Tuple[int, int] - Shape of the scaling tensor as (outer_dim, inner_dim) - For 2D tensors: - - If columnwise: (roundup(K/blocksize), round_to_multiple(roundup(M/blocksize), 4)) - - If rowwise: (roundup(M/blocksize), round_to_multiple(roundup(K/blocksize), 4)) - For 1D tensors: - - If columnwise: (roundup(M/blocksize), round_to_multiple(K, 4)) - - If rowwise: (roundup(K/blocksize), round_to_multiple(M, 4)) + Scaling tensor shape. + """ - M, K = 1, 1 - for i in range(len(shape) - 1): - M *= shape[i] - if len(shape) > 0: - K = shape[-1] - # 2D 128x128 quantization block scaling - # CuBLAS requries 128x128 scaling factor to be padded - # currently rowwise and columnwise format option doesn't apply to 2D scaling + + # Flatten tensor to 2D + dim0 = math.prod(shape[:-1]) + dim1 = shape[-1] if shape else 1 + + # Check block dims + if self.block_scaling_dim not in (1, 2): + raise RuntimeError( + "Only 1D or 2D blocks are supported, " + f"but got block_scaling_dim={self.block_scaling_dim}" + ) + + # 128x128 block scaling if self.block_scaling_dim == 2: + scale_dim0 = (dim0 + self.block_len - 1) // self.block_len + scale_dim1 = (dim1 + self.block_len - 1) // self.block_len if columnwise: - outer = math.ceil(K / self.block_len) - inner = round_up_to_nearest_multiple(math.ceil(M / self.block_len), 4) - return (outer, inner) - # rowwise - outer = math.ceil(M / self.block_len) - inner = round_up_to_nearest_multiple(math.ceil(K / self.block_len), 4) - return (outer, inner) - # 1D 1x128 quantization block scaling - # CuBLAS requries 1x128 scaling factor to be padded and transposed - assert self.block_scaling_dim == 1, "Only 1D or 2D blocks supported" + return (scale_dim1, round_up_to_nearest_multiple(scale_dim0, 4)) + return (scale_dim0, round_up_to_nearest_multiple(scale_dim1, 4)) + + # 1x128 block scaling if columnwise: - columnwise_compact = self.all_gather_usage - outer = math.ceil(M / self.block_len) - inner = round_up_to_nearest_multiple(K, 4) if not columnwise_compact else K - # GEMM READY case: scaling factor is [outer, inner], already transposed here for CuBLAS - # for COMPACT case, since we apply 1x128 scaling here without transposing columnwise data, scaling factor is also [outer, inner] - # so no need to swap inner outer here - return (outer, inner) - # rowwise - rowwise_compact = self.all_gather_usage - outer = math.ceil(K / self.block_len) - inner = round_up_to_nearest_multiple(M, 4) if not rowwise_compact else M - # GEMM READY case: scaling factor is [outer, inner], already transposed here for CuBLAS need - # for COMPACT case, since we apply 128x1 scaling, scaling block applies to inner dim, so we need to swap outer and inner here - return (outer, inner) if not rowwise_compact else (inner, outer) + return ( + (dim0 + self.block_len - 1) // self.block_len, + round_up_to_nearest_multiple(dim1, 4), + ) + return ( + (dim1 + self.block_len - 1) // self.block_len, + round_up_to_nearest_multiple(dim0, 4), + ) def get_columnwise_shape(self, shape: Iterable[int]) -> Tuple[int, ...]: - """Calculate the shape of a tensor after columnwise permutation. + """Column-wise data shape - This method rearranges the dimensions of a tensor to be columnwise, - moving the last dimension to the front and keeping the order of other dimensions. + GEMMs expect that the column-wise data is transposed relative + to the logical tensor shape. Parameters ---------- shape : Iterable[int] - Original shape of the tensor + Logical tensor shape. Returns ------- Tuple[int, ...] - New shape with dimensions rearranged for columnwise layout. - For a shape (d1, d2, ..., dn), returns (dn, d1, d2, ..., dn-1). - Returns empty tuple for empty input shape. + Column-wise data shape. """ - if len(shape) == 0: - return tuple() - # currently columnwise format option only applies to 1D quantizer - # for 2D scaling, columnwise format should always be GEMM_READY_DATA_AND_SCALES - # since currently 2D scaling only applies to module weights - if self.block_scaling_dim == 1 and self.all_gather_usage: - return shape - colwise_shape = [shape[-1]] - for i in range(len(shape) - 1): - colwise_shape.append(shape[i]) + colwise_shape = [] + if shape: + colwise_shape.append(shape[-1]) + colwise_shape.extend(shape[:-1]) return tuple(colwise_shape) def is_quantizable(self, inp: torch.Tensor) -> bool: """Returns whether or not given inp can be quantized""" - if inp.ndim < 2: + shape = inp.size() + if len(shape) < 2: return False - if inp.shape[-1] % self.block_len != 0: + if shape[-1] % self.block_len != 0: return False - if math.prod(inp.shape[:-1]) % self.block_len != 0: + if math.prod(shape[:-1]) % self.block_len != 0: return False return True @@ -233,44 +212,36 @@ def make_empty( pin_memory: bool = False, ) -> Float8BlockwiseQTensor: """Construct quantized tensor with uninitialized data""" - if device is None: - device = torch.device("cuda") - data_format = ( - tex.Float8BlockScaleTensorFormat.COMPACT - if self.all_gather_usage - else tex.Float8BlockScaleTensorFormat.GEMM_READY - ) + tensor_kwargs = { + "device": torch.device("cuda") if device is None else device, + "pin_memory": pin_memory, + } - # Allocate FP8 data - data = None - scale_inv = None + # Allocate buffers for row-scaled data + rowwise_data = None + rowwise_scale_inv = None if self.rowwise_usage: - data = torch.empty(shape, dtype=torch.uint8, device=device, pin_memory=pin_memory) - scale_shape = self.get_scale_shape(shape, columnwise=False) - scale_inv = torch.empty( - scale_shape, + rowwise_data = torch.empty(shape, dtype=torch.uint8, **tensor_kwargs) + rowwise_scale_inv = torch.empty( + self.get_scale_shape(shape, columnwise=False), dtype=torch.float32, - device=device, - pin_memory=pin_memory, + **tensor_kwargs, ) - # Allocate FP8 data transpose if needed + # Allocate buffers for column-scaled data columnwise_data = None columnwise_scale_inv = None if self.columnwise_usage: columnwise_data = torch.empty( self.get_columnwise_shape(shape), dtype=torch.uint8, - device=device, - pin_memory=pin_memory, + **tensor_kwargs, ) - columnwise_scale_shape = self.get_scale_shape(shape, columnwise=True) columnwise_scale_inv = torch.empty( - columnwise_scale_shape, + self.get_scale_shape(shape, columnwise=True), dtype=torch.float32, - device=device, - pin_memory=pin_memory, + **tensor_kwargs, ) # Construct FP8 tensor @@ -278,13 +249,12 @@ def make_empty( shape=shape, dtype=dtype, fp8_dtype=self.dtype, - rowwise_data=data, - rowwise_scale_inv=scale_inv, + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv, columnwise_data=columnwise_data, columnwise_scale_inv=columnwise_scale_inv, quantizer=self, is_2D_scaled=self.block_scaling_dim == 2, - data_format=data_format, requires_grad=requires_grad, ) @@ -334,7 +304,6 @@ def __new__( fp8_dtype: TE_DType, quantizer: Quantizer, is_2D_scaled: bool, - data_format: tex.Float8BlockScaleTensorFormat = Float8BlockScaleTensorFormat.GEMM_READY, **kwargs, ): instance = super().__new__( @@ -346,7 +315,6 @@ def __new__( fp8_dtype, quantizer, is_2D_scaled, - data_format, *args, **kwargs, ) @@ -357,8 +325,7 @@ def __repr__(self, *, tensor_contents=None): return ( f"Float8BlockwiseQTensor(fp8_dtype={self._fp8_dtype}," f" is_2D_scaled={self._is_2D_scaled}," - f" data={self.dequantize(dtype=self.dtype)})," - f" data_format={self._data_format}" + f" data={self.dequantize(dtype=self.dtype)})" ) def quantize_( @@ -509,7 +476,7 @@ def _make_in_reduce_ex( dtype: torch.dtype, quantizer: Quantizer, is_2D_scaled: bool, - data_format: tex.Float8BlockScaleTensorFormat, + data_format: Any = None, # pylint: disable=unused-argument ) -> Float8BlockwiseQTensor: """Build Float8BlockwiseQTensor, for use in __reduce__ @@ -527,7 +494,6 @@ def _make_in_reduce_ex( dtype=dtype, quantizer=quantizer, is_2D_scaled=is_2D_scaled, - data_format=data_format, ) def __reduce_ex__(self, protocol: int) -> tuple: @@ -544,7 +510,7 @@ def __reduce_ex__(self, protocol: int) -> tuple: self.dtype, self._quantizer, self._is_2D_scaled, - self._data_format, + None, # data_format ), ) @@ -570,7 +536,6 @@ def _set_from_tensor(dst: Float8BlockwiseQTensor, src: Float8BlockwiseQTensor): dst._fp8_dtype = src._fp8_dtype dst._rowwise_scale_inv = src._rowwise_scale_inv dst._columnwise_scale_inv = src._columnwise_scale_inv - dst._data_format = src._data_format # Check that tensor dimensions match if ( @@ -618,13 +583,6 @@ def forward( ) -> Float8BlockwiseQTensor: # pylint: disable=missing-function-docstring - # Check for invalid configurations - if not tensor._is_gemm_ready_format(): - raise NotImplementedError( - "View is only supported with GEMM_READY data format, " - f"but found data_format={tensor._data_format}" - ) - # Return input tensor if shape is not provided ctx.shape = tensor.shape if shape is None: @@ -693,14 +651,6 @@ def backward( # pylint: disable=missing-function-docstring if isinstance(grad, Float8BlockwiseQTensor): - - # Check for invalid configurations - if not grad._is_gemm_ready_format(): - raise NotImplementedError( - "View is only supported with GEMM_READY data format, " - f"but found data_format={grad._data_format}" - ) - new_data = ( grad._rowwise_data.view(*ctx.shape) if grad._rowwise_data is not None else None ) @@ -740,13 +690,6 @@ def forward( ) -> Float8BlockwiseQTensor: # pylint: disable=missing-function-docstring - # Check for invalid configurations - if not tensor._is_gemm_ready_format(): - raise NotImplementedError( - "Reshape is only supported with GEMM_READY data format, " - f"but found data_format={tensor._data_format}" - ) - # Return input tensor if shape is not provided ctx.shape = tensor.shape if shape is None: @@ -814,14 +757,6 @@ def backward( # pylint: disable=missing-function-docstring if isinstance(grad, Float8BlockwiseQTensor): - - # Check for invalid configurations - if not grad._is_gemm_ready_format(): - raise NotImplementedError( - "Reshape is only supported with GEMM_READY data format, " - f"but found data_format={grad._data_format}" - ) - new_rowwise_data = None new_columnwise_data = None if grad._rowwise_data is not None: diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 67df40c047e..e39e7f5a1e2 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -293,6 +293,7 @@ def copy(self) -> Float8CurrentScalingQuantizer: amax=self.amax, ) quantizer.internal = self.internal + quantizer.optimize_for_gemm = self.optimize_for_gemm return quantizer diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 6dcf9ae79a6..66c29c68301 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -54,6 +54,7 @@ def copy(self) -> MXFP8Quantizer: columnwise=self.columnwise_usage, ) quantizer.internal = self.internal + quantizer.optimize_for_gemm = self.optimize_for_gemm return quantizer @@ -156,6 +157,7 @@ def make_empty( columnwise_scale_inv=columnwise_scale_inv, quantizer=self, requires_grad=requires_grad, + with_gemm_swizzled_scales=self.optimize_for_gemm, ) def calibrate(self, tensor: torch.Tensor) -> None: @@ -179,6 +181,7 @@ def create_tensor_from_data( columnwise_scale_inv=None, fp8_dtype=fp8_dtype, quantizer=self, + with_gemm_swizzled_scales=False, ) def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor: @@ -188,6 +191,10 @@ def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor: return self.create_tensor_from_data(data, scale_inv, fake_dtype=torch.float32) def onnx_dequantize(self, tensor: Union[MXFP8TensorStorage, MXFP8Tensor]) -> torch.Tensor: + if tensor._with_gemm_swizzled_scales: + raise NotImplementedError( + "ONNX MXFP8 dequantization is only supported with scales in compact format." + ) return torch.ops.tex.mxfp8_dequantize(tensor._rowwise_data, tensor._rowwise_scale_inv) def _get_compatible_recipe(self) -> Union[type[Recipe], None]: @@ -229,9 +236,10 @@ def __new__( columnwise_scale_inv: Optional[torch.Tensor], fp8_dtype: TE_DType, quantizer: Optional[Quantizer], + with_gemm_swizzled_scales: bool, **kwargs, ): - instance = super().__new__( + return super().__new__( cls, rowwise_data, rowwise_scale_inv, @@ -239,10 +247,10 @@ def __new__( columnwise_scale_inv, fp8_dtype, quantizer, + with_gemm_swizzled_scales, *args, **kwargs, ) - return instance def __repr__(self, *, tensor_contents=None): return f"MXFP8Tensor(fp8_dtype={self._fp8_dtype}, data={self.dequantize(dtype=self.dtype)})" @@ -334,39 +342,44 @@ def contiguous( @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): - # View op if func == aten.view.default: tensor = args[0] - data = tensor._rowwise_data - out_data = data.__torch_dispatch__( - func, - types, - [data] + list(args[1:]), - kwargs, - ) - out_shape = out_data.size() + shape = args[1] + if len(shape) < 2 or shape[-1] != tensor.size(-1): + raise ValueError( + f"Attempted to make view with size={tuple(shape)} " + f"from MXFP8 tensor with shape={tuple(tensor.size())}." + ) + rowwise_data_view = None + columnwise_data_view = None + if tensor._rowwise_data is not None: + rowwise_data_view = tensor._rowwise_data.view(shape) + if tensor._columnwise_data is not None: + columnwise_data_view = tensor._columnwise_data.view(shape) return MXFP8Tensor( - shape=out_shape, + shape=shape, dtype=tensor.dtype, - rowwise_data=out_data, + rowwise_data=rowwise_data_view, rowwise_scale_inv=tensor._rowwise_scale_inv, - columnwise_data=tensor._columnwise_data, + columnwise_data=columnwise_data_view, columnwise_scale_inv=tensor._columnwise_scale_inv, quantizer=tensor._quantizer, requires_grad=False, fp8_dtype=tensor._fp8_dtype, + with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales, ) if func == torch.ops.aten.copy_.default: dst, src = args[0], args[1] if isinstance(src, MXFP8Tensor) and isinstance(dst, MXFP8Tensor): - # Booleans to check if src has all the usages that dst needs to respect dst quantizer usages. - # If not, default to base class behavior. - rowwise_matches = src._rowwise_data is not None or dst._rowwise_data is None - columnwise_matches = ( - src._columnwise_data is not None or dst._columnwise_data is None - ) - if rowwise_matches and columnwise_matches: + if src._rowwise_data is None and dst._rowwise_data is not None: + pass + elif src._columnwise_data is None and dst._columnwise_data is not None: + pass + elif src._with_gemm_swizzled_scales != dst._with_gemm_swizzled_scales: + pass + else: + # src and dst match, so we can directly copy data if dst._rowwise_data is not None: dst._rowwise_data.copy_(src._rowwise_data.detach(), *args[2:], **kwargs) dst._rowwise_scale_inv.copy_( @@ -381,26 +394,25 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): ) return dst - # FSDP2 related functions. if func == aten.split.Tensor: - # This is called if entire model is initialized on CUDA device and - # then splitted. Finally the shard needed by the process is used - # and other splitted shards are discarded. + # With FSDP2, this is called if entire model is + # initialized on CUDA device and then splitted. Finally + # the shard needed by the process is used and other + # splitted shards are discarded. + tensor = args[0] + split_size = args[1] if "dim" in kwargs: dim_to_split = kwargs["dim"] else: dim_to_split = args[2] if len(args) > 2 else 0 - tensor = args[0] - split_size = args[1] - dim0_size = tensor.size(0) - dimlast_size = math.prod(tensor.shape[1:]) + + # Fall back to high-precision if split is non-trivial if ( - dim0_size % split_size != 0 - or dim_to_split != 0 + dim_to_split != 0 + or tensor.size(0) % split_size != 0 or split_size % MXFP8_BLOCK_SCALING_SIZE != 0 - or dimlast_size % MXFP8_BLOCK_SCALING_SIZE != 0 + or tensor._with_gemm_swizzled_scales ): - # Handle splitting by dequantizing and splitting the hp tensor return super().__torch_dispatch__(func, types, args, kwargs) out_data = [] @@ -460,28 +472,26 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): quantizer=tensor._quantizer, requires_grad=False, fp8_dtype=tensor._fp8_dtype, + with_gemm_swizzled_scales=False, ) for splitted_tensor_data in zip(*out_data) ] + if func == torch.ops.aten.as_strided.default: # Applied on unsharded param in FSDP2. In our case, this should be a no-op # This is needed for the case where some MXFP8 shards need padding i.e dimension 0 # of the unsharded param is not a multiple of the world size. If that is the case, # we down the dequantization route and weights are allgathered in high precision. # If weight doesnt need padding, this is just a no-op. + tensor = args[0] shape = args[1] strides = args[2] - tensor = args[0] if ( - len(shape) != 2 - or len(strides) != 2 - or strides[1] != 1 - or shape[0] != tensor.shape[0] - or shape[1] != tensor.shape[1] + len(shape) == len(strides) == 2 + and tuple(strides) == (shape[-1], 1) + and tuple(shape) == tuple(tensor.size()) ): - return super().__torch_dispatch__(func, types, args, kwargs) - - return MXFP8Tensor.make_like(tensor) + return MXFP8Tensor.make_like(tensor) if func == aten.slice.Tensor: # FSDP2 needed function. @@ -489,19 +499,12 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): # of the unsharded param is not a multiple of the world size. If that is the case, # we down the dequantization route and weights are allgathered in high precision instead. # If sharded weight doesnt have padding, this is just a no-op. + tensor = args[0] dim = args[1] start = args[2] length = args[3] - tensor = args[0] - if ( - dim != 0 - or length != tensor.shape[0] - or start != 0 - or length % MXFP8_BLOCK_SCALING_SIZE != 0 - or start % MXFP8_BLOCK_SCALING_SIZE != 0 - ): - return super().__torch_dispatch__(func, types, args, kwargs) - return MXFP8Tensor.make_like(tensor) + if start == 0 and length == tensor.size(dim): + return MXFP8Tensor.make_like(tensor) if func == aten.new_zeros.default: rowwise_data = None @@ -558,7 +561,9 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): quantizer=tensor._quantizer, requires_grad=False, fp8_dtype=tensor._fp8_dtype, + with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales, ) + # Default case return super().__torch_dispatch__(func, types, args, kwargs) @@ -584,19 +589,24 @@ def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, m # pylint: disable=unused-argument from transformer_engine.pytorch.distributed import _get_module_fsdp_state + # Get FSDP state fsdp_state = _get_module_fsdp_state(module) reshard_after_forward = fsdp_state._fsdp_param_group._reshard_after_forward + # Remove padding from scale inverses before allgather # Rowwise scale_inv should be divisible by [128,4], columnwise by [4, 128] rowwise_scale_inv = self._rowwise_scale_inv columnwise_scale_inv = self._columnwise_scale_inv shape = self.shape + if self._with_gemm_swizzled_scales: + raise NotImplementedError( + "FSDP2 is only supported for MXFP8Tensors with compact scales" + ) if rowwise_scale_inv is not None: # Remove padding from rowwise scale_inv flattened_in_shape0 = math.prod(shape[:-1]) if rowwise_scale_inv.size(0) != flattened_in_shape0: rowwise_scale_inv = rowwise_scale_inv[:flattened_in_shape0] - if columnwise_scale_inv is not None: # Remove padding from columnwise scale_inv flattened_in_shape0 = math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE @@ -681,7 +691,7 @@ def fsdp_post_all_gather( out._columnwise_data = columnwise_data out._columnwise_scale_inv = columnwise_scale_inv else: - # We ll be here when post all gather is called the first time. + # We'll be here when post all gather is called the first time. # MXFP8Tensor constructor makes a copy of the quantizer to # save as its own quantizer. For the consequent iterations, # the same quantizer is used. Copy is needed in the first iteration, @@ -696,6 +706,7 @@ def fsdp_post_all_gather( dtype=param_dtype, shape=rowwise_data.shape if rowwise_data is not None else columnwise_data.shape, quantizer=self._quantizer, + with_gemm_swizzled_scales=False, ) out._quantizer.set_usage(rowwise=rowwise_usage, columnwise=columnwise_usage) return out, all_gather_outputs @@ -711,6 +722,7 @@ def _make_in_reduce_ex( dtype: torch.dtype, shape: torch.shape, quantizer: Optional[Quantizer] = None, + with_gemm_swizzled_scales: bool = False, ) -> MXFP8Tensor: """Build MXFP8Tensor, for use in __reduce__ @@ -727,6 +739,7 @@ def _make_in_reduce_ex( dtype=dtype, shape=shape, quantizer=quantizer, + with_gemm_swizzled_scales=with_gemm_swizzled_scales, ) def __reduce_ex__(self, protocol: int) -> tuple: @@ -742,6 +755,7 @@ def __reduce_ex__(self, protocol: int) -> tuple: self.dtype, self.shape, self._quantizer, + self._with_gemm_swizzled_scales, ), ) @@ -763,7 +777,7 @@ def _set_data(self, tensor: torch.Tensor) -> None: if not devices_match(new_device, tensor.device): tensor = tensor.to(device=new_device) - # Just copy FP8 data if other tensor is MXFP8Tensor + # Just copy data if other tensor is MXFP8Tensor if isinstance(tensor, MXFP8Tensor): if ( # pylint: disable=too-many-boolean-expressions self.size() != tensor.size() @@ -791,6 +805,7 @@ def _set_data(self, tensor: torch.Tensor) -> None: self._fp8_dtype = tensor._fp8_dtype self._rowwise_scale_inv = tensor._rowwise_scale_inv self._columnwise_scale_inv = tensor._columnwise_scale_inv + self._with_gemm_swizzled_scales = tensor._with_gemm_swizzled_scales return # Quantize to FP8 @@ -862,6 +877,7 @@ def forward( columnwise_scale_inv=tensor._columnwise_scale_inv, fp8_dtype=tensor._fp8_dtype, quantizer=tensor._quantizer, + with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales, ) @staticmethod @@ -888,6 +904,7 @@ def backward( columnwise_scale_inv=grad._columnwise_scale_inv, fp8_dtype=grad._fp8_dtype, quantizer=grad._quantizer, + with_gemm_swizzled_scales=grad._with_gemm_swizzled_scales, ) return dgrad, None return grad.view(ctx.shape), None @@ -948,6 +965,7 @@ def forward( columnwise_scale_inv=tensor._columnwise_scale_inv, fp8_dtype=tensor._fp8_dtype, quantizer=tensor._quantizer, + with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales, ) @staticmethod @@ -973,6 +991,7 @@ def backward( columnwise_scale_inv=grad._columnwise_scale_inv, fp8_dtype=grad._fp8_dtype, quantizer=grad._quantizer, + with_gemm_swizzled_scales=grad._with_gemm_swizzled_scales, ) return dgrad, None return grad.view(ctx.shape), None diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 0c244628d65..c4a415d54e9 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -193,6 +193,7 @@ def copy(self) -> NVFP4Quantizer: stochastic_rounding=self.stochastic_rounding, ) quantizer.internal = self.internal + quantizer.optimize_for_gemm = self.optimize_for_gemm quantizer.rht_matrix = self.rht_matrix quantizer.rht_matrix_random_sign_mask_t = self.rht_matrix_random_sign_mask_t @@ -359,6 +360,7 @@ def make_empty( fp4_dtype=self.dtype, quantizer=self, requires_grad=requires_grad, + with_gemm_swizzled_scales=False, ) def calibrate(self, tensor: torch.Tensor) -> None: @@ -418,6 +420,7 @@ def __new__( amax_columnwise: Optional[torch.Tensor], fp4_dtype: TE_DType, quantizer: Quantizer, + with_gemm_swizzled_scales: bool, **kwargs, ): instance = super().__new__( @@ -430,6 +433,7 @@ def __new__( amax_columnwise, fp4_dtype, quantizer, + with_gemm_swizzled_scales, *args, **kwargs, ) @@ -592,6 +596,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): amax_columnwise=amax_columnwise, quantizer=tensor._quantizer, requires_grad=tensor.requires_grad, + with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales, ) # Default case @@ -610,6 +615,7 @@ def _make_in_reduce_ex( fp4_dtype: TE_DType, dtype: torch.dtype, quantizer: Quantizer, + with_gemm_swizzled_scales: bool = False, ) -> NVFP4Tensor: """Build NVFP4Tensor, for use in __reduce__ @@ -629,6 +635,7 @@ def _make_in_reduce_ex( amax_columnwise=amax_columnwise, quantizer=quantizer, requires_grad=False, + with_gemm_swizzled_scales=with_gemm_swizzled_scales, ) def __reduce_ex__(self, protocol: int) -> tuple: @@ -646,6 +653,7 @@ def __reduce_ex__(self, protocol: int) -> tuple: self._fp4_dtype, self.dtype, self._quantizer, + self._with_gemm_swizzled_scales, ), ) @@ -696,6 +704,7 @@ def _set_data(self, tensor: torch.Tensor) -> None: self._columnwise_scale_inv = tensor._columnwise_scale_inv self._amax_rowwise = tensor._amax_rowwise self._amax_columnwise = tensor._amax_columnwise + self._with_gemm_swizzled_scales = tensor._with_gemm_swizzled_scales return # Quantize to FP8 @@ -782,6 +791,7 @@ def forward( quantizer=tensor._quantizer, fp4_dtype=tensor._fp4_dtype, requires_grad=tensor.requires_grad, + with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales, ) @staticmethod @@ -823,6 +833,7 @@ def backward( quantizer=grad._quantizer, fp4_dtype=grad._fp4_dtype, requires_grad=grad.requires_grad, + with_gemm_swizzled_scales=grad._with_gemm_swizzled_scales, ) return dgrad, None return grad.view(ctx.shape), None @@ -902,6 +913,7 @@ def forward( quantizer=tensor._quantizer, fp4_dtype=tensor._fp4_dtype, requires_grad=tensor.requires_grad, + with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales, ) @staticmethod @@ -943,6 +955,7 @@ def backward( quantizer=grad._quantizer, fp4_dtype=grad._fp4_dtype, requires_grad=grad.requires_grad, + with_gemm_swizzled_scales=grad._with_gemm_swizzled_scales, ) return dgrad, None return grad.view(ctx.shape), None diff --git a/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py index 38d117b2a06..0f2f565bff6 100644 --- a/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py @@ -11,7 +11,6 @@ import transformer_engine_torch as tex from transformer_engine_torch import DType as TE_DType -from transformer_engine_torch import Float8BlockScaleTensorFormat from ...quantized_tensor import QuantizedTensorStorage, Quantizer @@ -36,7 +35,6 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage): _rowwise_scale_inv: Optional[torch.Tensor] _columnwise_scale_inv: Optional[torch.Tensor] _is_2D_scaled: bool - _data_format: Float8BlockScaleTensorFormat def __new__( cls, @@ -47,7 +45,6 @@ def __new__( fp8_dtype: TE_DType, quantizer: Quantizer, is_2D_scaled: bool, - data_format: Float8BlockScaleTensorFormat, *args, **kwargs, ): @@ -62,7 +59,6 @@ def __new__( instance._rowwise_scale_inv = rowwise_scale_inv instance._columnwise_scale_inv = columnwise_scale_inv instance._is_2D_scaled = is_2D_scaled - instance._data_format = data_format return instance @@ -87,13 +83,8 @@ def get_metadata(self) -> Dict[str, Any]: "fp8_dtype": self._fp8_dtype, "quantizer": self._quantizer, "is_2D_scaled": self._is_2D_scaled, - "data_format": self._data_format, } - def _is_gemm_ready_format(self) -> bool: - """Whether data is in GEMM_READY format""" - return self._data_format == Float8BlockScaleTensorFormat.GEMM_READY - def prepare_for_saving( self, ) -> Tuple[list[Optional[torch.Tensor]], Float8BlockwiseQTensorStorage]: @@ -153,36 +144,18 @@ def _dequantize_vectorwise(self, *, dtype: torch.dtype = torch.float32) -> torch for i in range(len(q.shape) - 1): q_M *= q.shape[i] inner_q_dimension_tiled = True - if self._is_gemm_ready_format(): - scales_tiled_dim, scales_untiled_dim = scale_inv.shape - inner_scale_dimension_tiled = False - scales_are_compact = False - else: - scales_untiled_dim, scales_tiled_dim = scale_inv.shape - inner_scale_dimension_tiled = True - scales_are_compact = True + scales_tiled_dim, scales_untiled_dim = scale_inv.shape else: assert self._columnwise_data is not None, "No data to dequantize" q = self._columnwise_data scale_inv = self._columnwise_scale_inv scales_tiled_dim, scales_untiled_dim = scale_inv.shape - inner_scale_dimension_tiled = False - if self._is_gemm_ready_format(): - inner_q_dimension_tiled = True - transpose_output = True - if len(q.shape) >= 1: - q_M = q.shape[0] - for i in range(1, len(q.shape)): - q_K *= q.shape[i] - scales_are_compact = False - else: - inner_q_dimension_tiled = False - transpose_output = False - if len(q.shape) >= 1: - q_K = q.shape[-1] - for i in range(len(q.shape) - 1): - q_M *= q.shape[i] - scales_are_compact = True + inner_q_dimension_tiled = True + transpose_output = True + if len(q.shape) >= 1: + q_M = q.shape[0] + for i in range(1, len(q.shape)): + q_K *= q.shape[i] orig_shape = q.shape q = q.reshape(q_M, q_K) @@ -202,15 +175,10 @@ def _dequantize_vectorwise(self, *, dtype: torch.dtype = torch.float32) -> torch ).contiguous() padded_M, padded_K = q.shape q_tiled = q.reshape(scales_tiled_dim, block_len, q_K) - if not scales_are_compact and scales_untiled_dim > q_M: + if scales_untiled_dim > q_M: # untiled scale dimension is 4 element aligned. scale_inv = scale_inv[:, :q_M].contiguous() - if scales_are_compact and inner_scale_dimension_tiled: - dq_scale = scale_inv.contiguous().reshape(q_M, scales_tiled_dim, 1) - elif scales_are_compact and not inner_scale_dimension_tiled: - dq_scale = scale_inv.contiguous().reshape(scales_tiled_dim, 1, q_K) - else: - dq_scale = scale_inv.transpose(-2, -1).contiguous().reshape(q_M, scales_tiled_dim, 1) + dq_scale = scale_inv.transpose(-2, -1).contiguous().reshape(q_M, scales_tiled_dim, 1) torch_q_dtype = TE_DType_To_Torch[self._fp8_dtype] result = q_tiled.view(torch_q_dtype).to(torch.float32) * dq_scale if padded_M != q_M or padded_K != q_K: @@ -233,12 +201,6 @@ def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor: if not self._is_2D_scaled: return self._dequantize_vectorwise(dtype=dtype) - if not self._is_gemm_ready_format(): - raise NotImplementedError( - "Dequantize is only supported with GEMM_READY data format, " - f"but found _data_format={self._data_format}" - ) - def format_scale_as_logical_shape(q_K, scales, block_len): # The GEMM for 2D blocks required padding in the scales. derived_scale_k_shape = math.ceil(q_K / block_len) @@ -304,8 +266,6 @@ def size(self, *args, **kwargs): if self._rowwise_data is not None: return self._rowwise_data.size(*args, **kwargs) dims = list(self._columnwise_data.size(*args, **kwargs)) - if not self._is_gemm_ready_format(): # compact format - return torch.Size(dims) reordered = [] for i in range(1, len(dims)): reordered.append(dims[i]) @@ -366,7 +326,7 @@ def __repr__(self): return ( "Float8BlockwiseQTensorStorage(" f"fp8_dtype={self._fp8_dtype}, " - f"{descriptor}_scaled_data={data}" + f"{descriptor}_scaled_data={data})" ) def update_usage( diff --git a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py index e7840d2c437..3fb4f34afe1 100644 --- a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py @@ -57,13 +57,23 @@ class MXFP8TensorStorage(QuantizedTensorStorage): """ + # Row-scaled FP8 data _rowwise_data: Optional[torch.Tensor] + # Column-scaled FP8 data _columnwise_data: Optional[torch.Tensor] - _quantizer: Optional[Quantizer] - _fp8_dtype: TE_DType + # Scaling factors for row-scaled FP8 data _rowwise_scale_inv: torch.Tensor + # Scaling factors for column-scaled FP8 data _columnwise_scale_inv: torch.Tensor + # Builder class for casting to MXFP8 + _quantizer: Optional[Quantizer] + # FP8 data type + _fp8_dtype: TE_DType + # Whether scaling factors are in the swizzled format expected by + # GEMM + _with_gemm_swizzled_scales: bool + def __new__( cls, rowwise_data: Optional[torch.Tensor], @@ -72,6 +82,7 @@ def __new__( columnwise_scale_inv: Optional[torch.Tensor], fp8_dtype: TE_DType, quantizer: Optional[Quantizer], + with_gemm_swizzled_scales: bool, *args, **kwargs, ): @@ -81,10 +92,11 @@ def __new__( instance = super().__new__(cls, *args, **kwargs) instance._rowwise_data = rowwise_data instance._columnwise_data = columnwise_data - instance._quantizer = quantizer.copy() if quantizer is not None else None - instance._fp8_dtype = fp8_dtype instance._rowwise_scale_inv = rowwise_scale_inv instance._columnwise_scale_inv = columnwise_scale_inv + instance._quantizer = quantizer.copy() if quantizer is not None else None + instance._fp8_dtype = fp8_dtype + instance._with_gemm_swizzled_scales = with_gemm_swizzled_scales return instance @@ -108,6 +120,7 @@ def get_metadata(self) -> Dict[str, Any]: "columnwise_scale_inv": self._columnwise_scale_inv, "fp8_dtype": self._fp8_dtype, "quantizer": self._quantizer, + "with_gemm_swizzled_scales": self._with_gemm_swizzled_scales, } def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], MXFP8TensorStorage]: @@ -197,6 +210,7 @@ def view(self, shape: torch.Size): columnwise_scale_inv=self._columnwise_scale_inv, fp8_dtype=self._fp8_dtype, quantizer=self._quantizer, + with_gemm_swizzled_scales=self._with_gemm_swizzled_scales, ) def __repr__(self): diff --git a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py index 04ab092ee28..f5bba9197b7 100644 --- a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py @@ -71,15 +71,29 @@ class NVFP4TensorStorage(QuantizedTensorStorage): """ + # Row-scaled FP4 data _rowwise_data: Optional[torch.Tensor] + # Column-scaled FP4 data _columnwise_data: Optional[torch.Tensor] - _quantizer: Optional[Quantizer] + # Block scaling factors for row-scaled FP4 data _rowwise_scale_inv: torch.Tensor + # Block scaling factors for column-scaled FP4 data _columnwise_scale_inv: torch.Tensor - _fp4_dtype: TE_DType + # Input absolute maximum value (used to compute tensor scale for + # row-scaled FP4 data) _amax_rowwise: torch.Tensor + # Input absolute maximum value (used to compute tensor scale for + # column-scaled FP4 data) _amax_columnwise: torch.Tensor + # Builder class for casting to MXFP8 + _quantizer: Optional[Quantizer] + # FP4 data type + _fp4_dtype: TE_DType + # Whether scaling factors are in the swizzled format expected by + # GEMM + _with_gemm_swizzled_scales: bool + def __new__( cls, rowwise_data: Optional[torch.Tensor], @@ -90,6 +104,7 @@ def __new__( amax_columnwise: torch.Tensor, fp4_dtype: TE_DType, quantizer: Optional[Quantizer], + with_gemm_swizzled_scales: bool, *args, **kwargs, ): @@ -104,6 +119,7 @@ def __new__( instance._columnwise_scale_inv = columnwise_scale_inv instance._amax_rowwise = amax_rowwise instance._amax_columnwise = amax_columnwise + instance._with_gemm_swizzled_scales = with_gemm_swizzled_scales return instance @@ -131,6 +147,7 @@ def get_metadata(self) -> Dict[str, Any]: "amax_columnwise": self._amax_columnwise, "fp4_dtype": self._fp4_dtype, "quantizer": self._quantizer, + "with_gemm_swizzled_scales": self._with_gemm_swizzled_scales, } def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], NVFP4TensorStorage]: @@ -248,6 +265,7 @@ def view(self, shape: torch.Size): amax_columnwise=self._amax_columnwise, quantizer=self._quantizer, fp4_dtype=self._fp4_dtype, + with_gemm_swizzled_scales=self._with_gemm_swizzled_scales, ) def __repr__(self):