From d32b7ab9ed3de67ae94b2d83a9f7af720829fe2d Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Fri, 12 Dec 2025 04:17:16 -0800 Subject: [PATCH 1/5] avoiding shape copy, torch dynamo and torch autograd overheads Signed-off-by: Varun Thumbe --- transformer_engine/pytorch/csrc/common.cpp | 62 +++++++- transformer_engine/pytorch/csrc/common.h | 19 ++- .../pytorch/csrc/extensions/bias.cpp | 9 +- .../pytorch/csrc/extensions/gemm.cpp | 8 +- .../pytorch/csrc/extensions/transpose.cpp | 6 +- transformer_engine/pytorch/csrc/quantizer.cpp | 34 +++-- .../pytorch/csrc/type_converters.cpp | 4 +- transformer_engine/pytorch/module/linear.py | 138 ++++++++---------- 8 files changed, 173 insertions(+), 107 deletions(-) diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index e054424dd4d..f7a8540197f 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -26,12 +26,8 @@ std::vector convert_shape_back_from_fp4(const std::vector& shape return ret; } -std::vector getTensorShape(const at::Tensor& t) { - std::vector shape; - for (auto s : t.sizes()) { - shape.push_back(s); - } - return shape; +NVTEShape getTensorShape(const at::Tensor& t) { + return convertTorchShape(t.sizes()); } NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape) { @@ -178,6 +174,38 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( return ret; } +transformer_engine::TensorWrapper makeTransformerEngineTensor( + void* data_ptr, const NVTEShape& shape, const transformer_engine::DType type, + void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, const NVTEShape& scale_inv_shape, + NVTEScalingMode scaling_mode) { + TensorWrapper ret(scaling_mode); + ret.set_rowwise_data(data_ptr, type, shape); + const size_t meta_shape_data[1] = {1}; + const NVTEShape meta_shape = nvte_make_shape(meta_shape_data, 1); + ret.set_amax(amax_ptr, DType::kFloat32, meta_shape); + ret.set_scale(scale_ptr, DType::kFloat32, meta_shape); + auto scale_inv_dtype = + (scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0 : DType::kFloat32; + ret.set_rowwise_scale_inv(scale_inv_ptr, scale_inv_dtype, scale_inv_shape); + return ret; +} + +transformer_engine::TensorWrapper makeTransformerEngineTensor( + void* data_ptr, const std::vector& shape, const transformer_engine::DType type, + void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, const NVTEShape& scale_inv_shape, + NVTEScalingMode scaling_mode) { + TensorWrapper ret(scaling_mode); + ret.set_rowwise_data(data_ptr, type, shape); + const size_t meta_shape_data[1] = {1}; + const NVTEShape meta_shape = nvte_make_shape(meta_shape_data, 1); + ret.set_amax(amax_ptr, DType::kFloat32, meta_shape); + ret.set_scale(scale_ptr, DType::kFloat32, meta_shape); + auto scale_inv_dtype = + (scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0 : DType::kFloat32; + ret.set_rowwise_scale_inv(scale_inv_ptr, scale_inv_dtype, scale_inv_shape); + return ret; +} + transformer_engine::TensorWrapper makeTransformerEngineTensor( void* data_ptr, void* columnwise_data_ptr, const std::vector& shape, const std::vector& columnwise_shape, const transformer_engine::DType type, @@ -199,6 +227,28 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( return ret; } +transformer_engine::TensorWrapper makeTransformerEngineTensor( + void* data_ptr, void* columnwise_data_ptr, const NVTEShape& shape, + const NVTEShape& columnwise_shape, const transformer_engine::DType type, + void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, void* columnwise_scale_inv_ptr, + const NVTEShape& scale_inv_shape, const NVTEShape& columnwise_scale_inv_shape, + NVTEScalingMode scaling_mode) { + TensorWrapper ret(scaling_mode); + ret.set_rowwise_data(data_ptr, type, shape); + ret.set_columnwise_data(columnwise_data_ptr, type, columnwise_shape); + const size_t meta_shape_data[1] = {1}; + const NVTEShape meta_shape = nvte_make_shape(meta_shape_data, 1); + ret.set_amax(amax_ptr, DType::kFloat32, meta_shape); + ret.set_scale(scale_ptr, DType::kFloat32, meta_shape); + auto scale_inv_dtype = (scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0 + : (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat8E4M3 + : DType::kFloat32; + ret.set_rowwise_scale_inv(scale_inv_ptr, scale_inv_dtype, scale_inv_shape); + ret.set_columnwise_scale_inv(columnwise_scale_inv_ptr, scale_inv_dtype, + columnwise_scale_inv_shape); + return ret; +} + transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor, at::Tensor amax, const at::Tensor scale, at::Tensor scale_inv, diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 978bee52dc1..883c2a24cad 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -339,7 +339,7 @@ class NVFP4Quantizer : public Quantizer { std::unique_ptr convert_quantizer(py::handle quantizer); -std::vector getTensorShape(const at::Tensor& t); +NVTEShape getTensorShape(const at::Tensor& t); transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid, const std::string& fp8_recipe); @@ -432,6 +432,16 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, std::vector scale_inv_shape = {1}, NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING); +transformer_engine::TensorWrapper makeTransformerEngineTensor( + void* data_ptr, const NVTEShape& shape, const transformer_engine::DType type, + void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, const NVTEShape& scale_inv_shape, + NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING); + +transformer_engine::TensorWrapper makeTransformerEngineTensor( + void* data_ptr, const std::vector& shape, const transformer_engine::DType type, + void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, const NVTEShape& scale_inv_shape, + NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING); + transformer_engine::TensorWrapper makeTransformerEngineTensor( void* data_ptr, void* columnwise_data_ptr, const std::vector& shape, const std::vector& columnwise_shape, const transformer_engine::DType type, @@ -440,6 +450,13 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( const std::vector& columnwise_scale_inv_shape = {1}, NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING); +transformer_engine::TensorWrapper makeTransformerEngineTensor( + void* data_ptr, void* columnwise_data_ptr, const NVTEShape& shape, + const NVTEShape& columnwise_shape, const transformer_engine::DType type, + void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, void* columnwise_scale_inv_ptr, + const NVTEShape& scale_inv_shape, const NVTEShape& columnwise_scale_inv_shape, + NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING); + transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr, const NVTEShape& shape, const transformer_engine::DType type); diff --git a/transformer_engine/pytorch/csrc/extensions/bias.cpp b/transformer_engine/pytorch/csrc/extensions/bias.cpp index b0435d27230..2eef7438068 100644 --- a/transformer_engine/pytorch/csrc/extensions/bias.cpp +++ b/transformer_engine/pytorch/csrc/extensions/bias.cpp @@ -26,7 +26,8 @@ std::vector bgrad_quantize(const at::Tensor &grad_output, py::handle // Grad output tensor auto grad_output_torch = grad_output.contiguous(); const TensorWrapper &grad_output_nvte = makeTransformerEngineTensor(grad_output_torch); - const auto shape = getTensorShape(grad_output_torch); + const auto shape_nvte = getTensorShape(grad_output_torch); + const auto shape = convertShape(shape_nvte); auto grad_output_dtype = GetTransformerEngineDType(grad_output_torch.scalar_type()); // Construct grad bias tensor @@ -116,11 +117,13 @@ std::vector dact_dbias( // Grad output and activation input tensors grad_output_torch = grad_output_torch.contiguous(); const TensorWrapper &grad_output_nvte = makeTransformerEngineTensor(grad_output_torch); - const auto output_shape = getTensorShape(grad_output_torch); + const auto output_shape_nvte = getTensorShape(grad_output_torch); + const auto output_shape = convertShape(output_shape_nvte); auto grad_output_dtype = GetTransformerEngineDType(grad_output_torch.scalar_type()); act_input_torch = act_input_torch.contiguous(); const TensorWrapper &act_input_nvte = makeTransformerEngineTensor(act_input_torch); - const auto input_shape = getTensorShape(act_input_torch); + const auto input_shape_nvte = getTensorShape(act_input_torch); + const auto input_shape = convertShape(input_shape_nvte); // Construct tensors auto quantizer_cpp = convert_quantizer(quantizer_py); diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 13e8bfb6e5f..f704864cb60 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -365,12 +365,16 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type, NVTEScalingMode nvte_scaling_modeA = NVTE_DELAYED_TENSOR_SCALING; NVTEScalingMode nvte_scaling_modeB = NVTE_DELAYED_TENSOR_SCALING; + const size_t A_shape_data[2] = {static_cast(A.size(0)), static_cast(A.size(1))}; + const NVTEShape A_shape = nvte_make_shape(A_shape_data, 2); auto te_A = makeTransformerEngineTensor( - A.data_ptr(), {static_cast(A.size(0)), static_cast(A.size(1))}, A_type, + A.data_ptr(), A_shape, A_type, nullptr, nullptr, A_scale_inverse.data_ptr(), getTensorShape(A_scale_inverse), nvte_scaling_modeA); + const size_t B_shape_data[2] = {static_cast(B.size(0)), static_cast(B.size(1))}; + const NVTEShape B_shape = nvte_make_shape(B_shape_data, 2); auto te_B = makeTransformerEngineTensor( - B.data_ptr(), {static_cast(B.size(0)), static_cast(B.size(1))}, B_type, + B.data_ptr(), B_shape, B_type, nullptr, nullptr, B_scale_inverse.data_ptr(), getTensorShape(B_scale_inverse), nvte_scaling_modeB); // TODO: D_scale_inv cannot be nullptr when D_type is FP8. diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cpp b/transformer_engine/pytorch/csrc/extensions/transpose.cpp index 7dfdf995475..5ace996afcc 100644 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cpp +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cpp @@ -19,7 +19,8 @@ at::Tensor fp8_transpose(at::Tensor input, DType otype, std::optional transpose_shape_int64; if (shape.size() > 0) { transpose_shape_int64.push_back(shape.back()); @@ -60,7 +61,8 @@ at::Tensor swap_first_dims(at::Tensor tensor, std::optional out) { // Allocate output tensor if needed if (!out) { - auto in_shape = getTensorShape(input); + const auto in_shape_nvte = getTensorShape(input); + const auto in_shape = convertShape(in_shape_nvte); NVTE_CHECK(in_shape.size() >= 2, "Invalid input tensor dimensions (shape=", in_shape, ")"); std::vector out_shape_int64(in_shape.begin(), in_shape.end()); out_shape_int64[0] = static_cast(in_shape[1]); diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index d7e8912ac74..3b94d38ac16 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -209,7 +209,8 @@ std::pair Float8Quantizer::convert_and_update_tensor( // Tensor dimensions std::vector shape; if (has_transpose) { - const auto transpose_shape = getTensorShape(*transpose_tensor); + const auto transpose_shape_nvte = getTensorShape(*transpose_tensor); + const auto transpose_shape = convertShape(transpose_shape_nvte); if (transpose_shape.size() > 0) { for (size_t i = 1; i < transpose_shape.size(); ++i) { shape.push_back(transpose_shape[i]); @@ -217,12 +218,13 @@ std::pair Float8Quantizer::convert_and_update_tensor( shape.push_back(transpose_shape.front()); } if (has_data) { - auto expected_shape = getTensorShape(*data_tensor); + const auto expected_shape_nvte = getTensorShape(*data_tensor); + const auto expected_shape = convertShape(expected_shape_nvte); NVTE_CHECK(shape == expected_shape, "FP8 data (shape=", expected_shape, ") and transpose (shape=", transpose_shape, ") do not match"); } } else { // Already checked has_data == true - shape = getTensorShape(*data_tensor); + shape = convertShape(getTensorShape(*data_tensor)); } // Coerce data tensor @@ -430,7 +432,8 @@ std::pair Float8CurrentScalingQuantizer::convert_and_ // Tensor dimensions std::vector shape; if (has_transpose) { - const auto transpose_shape = getTensorShape(*transpose_tensor); + const auto transpose_shape_nvte = getTensorShape(*transpose_tensor); + const auto transpose_shape = convertShape(transpose_shape_nvte); if (transpose_shape.size() > 0) { for (size_t i = 1; i < transpose_shape.size(); ++i) { shape.push_back(transpose_shape[i]); @@ -438,12 +441,13 @@ std::pair Float8CurrentScalingQuantizer::convert_and_ shape.push_back(transpose_shape.front()); } if (has_data) { - auto expected_shape = getTensorShape(*data_tensor); + const auto expected_shape_nvte = getTensorShape(*data_tensor); + const auto expected_shape = convertShape(expected_shape_nvte); NVTE_CHECK(shape == expected_shape, "FP8 data (shape=", expected_shape, ") and transpose (shape=", transpose_shape, ") do not match"); } } else { // Already checked has_data == true - shape = getTensorShape(*data_tensor); + shape = convertShape(getTensorShape(*data_tensor)); } // Coerce data tensor in Python tensor @@ -680,9 +684,9 @@ std::pair Float8BlockQuantizer::convert_and_update_te return std::vector(); } if (all_gather_usage) { - return getTensorShape(*columnwise_data); + return convertShape(getTensorShape(*columnwise_data)); } - std::vector shape = getTensorShape(*columnwise_data); + std::vector shape = convertShape(getTensorShape(*columnwise_data)); std::vector shape_transposed(shape.size()); for (size_t i = 0; i + 1 < shape.size(); ++i) { shape_transposed[i] = shape[i + 1]; @@ -694,7 +698,7 @@ std::pair Float8BlockQuantizer::convert_and_update_te }; std::vector shape; if (rowwise_data) { - shape = getTensorShape(*rowwise_data); + shape = convertShape(getTensorShape(*rowwise_data)); if (columnwise_data) { auto expected_shape = get_columnwise_shape(all_gather_usage); NVTE_CHECK(shape == expected_shape, "BlockwiseFP8 row-wise data (shape=", shape, @@ -1004,14 +1008,14 @@ std::pair MXFP8Quantizer::convert_and_update_tensor( // Tensor dimensions std::vector shape; if (columnwise_data) { - shape = getTensorShape(*columnwise_data); + shape = convertShape(getTensorShape(*columnwise_data)); if (rowwise_data) { - auto expected_shape = getTensorShape(*rowwise_data); + const auto expected_shape = convertShape(getTensorShape(*rowwise_data)); NVTE_CHECK(shape == expected_shape, "MXFP8 row-wise data (shape=", expected_shape, ") and column-wise data (shape=", shape, ") do not match"); } } else { // Already checked columnwise_data_tensor == true - shape = getTensorShape(*rowwise_data); + shape = convertShape(getTensorShape(*rowwise_data)); } // Coerce row-wise data @@ -1320,14 +1324,14 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( // Tensor dimensions, shape means original shape std::vector shape; if (columnwise_data) { - shape = convert_shape_back_from_fp4(getTensorShape(*columnwise_data), true); + shape = convert_shape_back_from_fp4(convertShape(getTensorShape(*columnwise_data)), true); if (rowwise_data) { - auto expected_shape = convert_shape_back_from_fp4(getTensorShape(*rowwise_data), false); + auto expected_shape = convert_shape_back_from_fp4(convertShape(getTensorShape(*rowwise_data)), false); NVTE_CHECK(shape == expected_shape, "NVFP4 row-wise data (shape=", expected_shape, ") and column-wise data (shape=", shape, ") do not match"); } } else { // Already checked columnwise_data_tensor == true - shape = convert_shape_back_from_fp4(getTensorShape(*rowwise_data), false); + shape = convert_shape_back_from_fp4(convertShape(getTensorShape(*rowwise_data)), false); } size_t flat_first_dim = 1; diff --git a/transformer_engine/pytorch/csrc/type_converters.cpp b/transformer_engine/pytorch/csrc/type_converters.cpp index 368e9dcdfa3..780a08da7f8 100644 --- a/transformer_engine/pytorch/csrc/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/type_converters.cpp @@ -132,7 +132,7 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) const auto &scale_inv = tensor.attr("_rowwise_scale_inv").cast(); const auto &amax_rowwise = tensor.attr("_amax_rowwise").cast(); ret.set_rowwise_data(data.data_ptr(), dtype, - convert_shape_back_from_fp4(getTensorShape(data), false)); + convert_shape_back_from_fp4(convertShape(getTensorShape(data)), false)); ret.set_rowwise_scale_inv(scale_inv.data_ptr(), DType::kFloat8E4M3, getTensorShape(scale_inv)); ret.set_amax(amax_rowwise.data_ptr(), DType::kFloat32, getTensorShape(amax_rowwise)); } @@ -143,7 +143,7 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) const auto &scale_inv = tensor.attr("_columnwise_scale_inv").cast(); const auto &amax_columnwise = tensor.attr("_amax_columnwise").cast(); ret.set_columnwise_data(data.data_ptr(), DType::kFloat4E2M1, - convert_shape_back_from_fp4(getTensorShape(data), false)); + convert_shape_back_from_fp4(convertShape(getTensorShape(data)), false)); ret.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat8E4M3, getTensorShape(scale_inv)); ret.set_columnwise_amax(amax_columnwise.data_ptr(), DType::kFloat32, diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index b65f7005eb3..7557f5c5396 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -96,40 +96,66 @@ def forward( ( is_first_microbatch, - fp8, - fp8_calibration, - wgrad_store, - input_quantizer, - weight_quantizer, - output_quantizer, - grad_input_quantizer, - grad_weight_quantizer, - grad_output_quantizer, - fuse_wgrad_accumulation, cpu_offloading, - tp_group, - tp_size, - sequence_parallel, - tensor_parallel, - activation_dtype, - parallel_mode, is_grad_enabled, - ub_overlap_rs_fprop, - ub_overlap_ag_dgrad, - ub_overlap_ag_fprop, - ub_overlap_rs_dgrad, - ub_bulk_dgrad, - ub_bulk_wgrad, - ub_name, - fp8_output, # pylint: disable=unused-variable - fsdp_group, + fp8_output, + fp8_grad, module, skip_fp8_weight_update, - symmetric_ar_type, - save_original_input, debug, ) = non_tensor_args + (fp8, + fp8_calibration, + wgrad_store, + fuse_wgrad_accumulation, + tp_group, + tp_size, + sequence_parallel, + tensor_parallel, + activation_dtype, + parallel_mode, + ub_overlap_rs_fprop, + ub_overlap_ag_dgrad, + ub_overlap_ag_fprop, + ub_overlap_rs_dgrad, + ub_bulk_dgrad, + ub_bulk_wgrad, + ub_name, + fsdp_group, + symmetric_ar_type, + save_original_input + ) = (module.fp8, + module.fp8_calibration, + module.wgrad_store, + module.fuse_wgrad_accumulation, + module.tp_group, + module.tp_size, + module.sequence_parallel, + module.tp_size > 1, + module.activation_dtype, + module.parallel_mode, + module.ub_overlap_rs_fprop, + module.ub_overlap_ag_dgrad, + module.ub_overlap_ag_fprop, + module.ub_overlap_rs_dgrad, + module.ub_bulk_dgrad, + module.ub_bulk_wgrad, + module.ub_name, + module.fsdp_group, + module.symmetric_ar_type, + module.save_original_input, + ) + quantizers = module._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) + + if debug: + quantizers = module._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) + if module.no_debug_features_active(quantizers): + debug = False + quantizers = module._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) + (input_quantizer, weight_quantizer, output_quantizer, grad_input_quantizer, grad_weight_quantizer, grad_output_quantizer) = quantizers + + # NVTX label for profiling nvtx_label = "transformer_engine._Linear.forward" if ub_name is not None: @@ -981,7 +1007,6 @@ def wgrad_gemm( None, ) - class Linear(TransformerEngineBaseModule): """Applies a linear transformation to the incoming data :math:`y = xA^T + b` @@ -1343,7 +1368,6 @@ def reset_parameters(self, defer_init=False): elif self.parallel_mode == "column": set_tensor_model_parallel_attributes(getattr(self, bias), True, 0, 1) - @no_torch_dynamo() def forward( self, inp: torch.Tensor, @@ -1401,28 +1425,7 @@ def forward( inp, allow_non_contiguous=isinstance(inp, QuantizedTensor), ) as inp: - weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() - - quantizers = ( - self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) - if not debug - else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled) - ) - if debug: - if self.no_debug_features_active(quantizers): - debug = False - quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) - - ( - input_quantizer, - weight_quantizer, - output_quantizer, - grad_input_quantizer, - grad_weight_quantizer, - grad_output_quantizer, - ) = quantizers - if is_grad_enabled: linear_fn = _Linear.apply autograd_ctx = [] @@ -1432,37 +1435,12 @@ def forward( non_tensor_args = ( is_first_microbatch, - self.fp8, - self.fp8_calibration, - self.wgrad_store, - input_quantizer, - weight_quantizer, - output_quantizer, - grad_input_quantizer, - grad_weight_quantizer, - grad_output_quantizer, - self.fuse_wgrad_accumulation, is_cpu_offload_enabled(), - self.tp_group, - self.tp_size, - self.sequence_parallel, - self.tp_size > 1, - self.activation_dtype, - self.parallel_mode, is_grad_enabled, - self.ub_overlap_rs_fprop, - self.ub_overlap_ag_dgrad, - self.ub_overlap_ag_fprop, - self.ub_overlap_rs_dgrad, - self.ub_bulk_dgrad, - self.ub_bulk_wgrad, - self.ub_name, fp8_output, - self.fsdp_group, + fp8_grad, self, skip_fp8_weight_update, - self.symmetric_ar_type, - self.save_original_input, debug, ) out = linear_fn( @@ -1687,3 +1665,11 @@ def _customize_quantizers_float8_blockwise_scaling(self, fwd: bool, recipe: Reci self.quantizers["scaling_bwd"][ tex.FP8BwdTensors.GRAD_OUTPUT1 ].all_gather_usage = True + + +# disable torch dynamo just once to reduce wrapped function overhead on each +# forward call of te Linear. +if torch.__version__ >= "2": + Linear.forward._torchdynamo_disable = True + Linear.forward._torchdynamo_disable_msg = None + From e7248151ecf7877e3217b6bdd1fcf3e4b59d28ae Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Fri, 12 Dec 2025 22:47:01 +0000 Subject: [PATCH 2/5] minor additional change Signed-off-by: Varun Thumbe --- transformer_engine/pytorch/csrc/common.cpp | 2 +- transformer_engine/pytorch/csrc/common.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index f7a8540197f..3467223d2ac 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -30,7 +30,7 @@ NVTEShape getTensorShape(const at::Tensor& t) { return convertTorchShape(t.sizes()); } -NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape) { +NVTEShape convertTorchShape(const c10::IntArrayRef& torch_shape) { NVTEShape ret; ret.ndim = torch_shape.size(); constexpr int max_dimensions = sizeof(ret.data) / sizeof(size_t); diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 883c2a24cad..22061de4773 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -496,7 +496,7 @@ std::vector convertShape(const NVTEShape& shape); size_t roundup(const size_t value, const size_t multiple); -NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape); +NVTEShape convertTorchShape(const c10::IntArrayRef& torch_shape); std::vector convert_shape_back_from_fp4(const std::vector& shape, bool transpose); From b725f5b31d52adc800514898cf34f8c65851e6be Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 12 Dec 2025 23:02:13 +0000 Subject: [PATCH 3/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/csrc/common.cpp | 12 +-- transformer_engine/pytorch/csrc/common.h | 8 +- .../pytorch/csrc/extensions/gemm.cpp | 14 ++- transformer_engine/pytorch/csrc/quantizer.cpp | 3 +- transformer_engine/pytorch/module/linear.py | 94 ++++++++++--------- 5 files changed, 68 insertions(+), 63 deletions(-) diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index 3467223d2ac..c7f0975216b 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -26,9 +26,7 @@ std::vector convert_shape_back_from_fp4(const std::vector& shape return ret; } -NVTEShape getTensorShape(const at::Tensor& t) { - return convertTorchShape(t.sizes()); -} +NVTEShape getTensorShape(const at::Tensor& t) { return convertTorchShape(t.sizes()); } NVTEShape convertTorchShape(const c10::IntArrayRef& torch_shape) { NVTEShape ret; @@ -175,8 +173,8 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( } transformer_engine::TensorWrapper makeTransformerEngineTensor( - void* data_ptr, const NVTEShape& shape, const transformer_engine::DType type, - void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, const NVTEShape& scale_inv_shape, + void* data_ptr, const NVTEShape& shape, const transformer_engine::DType type, void* amax_ptr, + void* scale_ptr, void* scale_inv_ptr, const NVTEShape& scale_inv_shape, NVTEScalingMode scaling_mode) { TensorWrapper ret(scaling_mode); ret.set_rowwise_data(data_ptr, type, shape); @@ -229,8 +227,8 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( transformer_engine::TensorWrapper makeTransformerEngineTensor( void* data_ptr, void* columnwise_data_ptr, const NVTEShape& shape, - const NVTEShape& columnwise_shape, const transformer_engine::DType type, - void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, void* columnwise_scale_inv_ptr, + const NVTEShape& columnwise_shape, const transformer_engine::DType type, void* amax_ptr, + void* scale_ptr, void* scale_inv_ptr, void* columnwise_scale_inv_ptr, const NVTEShape& scale_inv_shape, const NVTEShape& columnwise_scale_inv_shape, NVTEScalingMode scaling_mode) { TensorWrapper ret(scaling_mode); diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 22061de4773..e6c22880323 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -433,8 +433,8 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING); transformer_engine::TensorWrapper makeTransformerEngineTensor( - void* data_ptr, const NVTEShape& shape, const transformer_engine::DType type, - void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, const NVTEShape& scale_inv_shape, + void* data_ptr, const NVTEShape& shape, const transformer_engine::DType type, void* amax_ptr, + void* scale_ptr, void* scale_inv_ptr, const NVTEShape& scale_inv_shape, NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING); transformer_engine::TensorWrapper makeTransformerEngineTensor( @@ -452,8 +452,8 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( transformer_engine::TensorWrapper makeTransformerEngineTensor( void* data_ptr, void* columnwise_data_ptr, const NVTEShape& shape, - const NVTEShape& columnwise_shape, const transformer_engine::DType type, - void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, void* columnwise_scale_inv_ptr, + const NVTEShape& columnwise_shape, const transformer_engine::DType type, void* amax_ptr, + void* scale_ptr, void* scale_inv_ptr, void* columnwise_scale_inv_ptr, const NVTEShape& scale_inv_shape, const NVTEShape& columnwise_scale_inv_shape, NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING); diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index f704864cb60..35b523b5192 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -367,16 +367,14 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type, const size_t A_shape_data[2] = {static_cast(A.size(0)), static_cast(A.size(1))}; const NVTEShape A_shape = nvte_make_shape(A_shape_data, 2); - auto te_A = makeTransformerEngineTensor( - A.data_ptr(), A_shape, A_type, - nullptr, nullptr, A_scale_inverse.data_ptr(), getTensorShape(A_scale_inverse), - nvte_scaling_modeA); + auto te_A = makeTransformerEngineTensor(A.data_ptr(), A_shape, A_type, nullptr, nullptr, + A_scale_inverse.data_ptr(), + getTensorShape(A_scale_inverse), nvte_scaling_modeA); const size_t B_shape_data[2] = {static_cast(B.size(0)), static_cast(B.size(1))}; const NVTEShape B_shape = nvte_make_shape(B_shape_data, 2); - auto te_B = makeTransformerEngineTensor( - B.data_ptr(), B_shape, B_type, - nullptr, nullptr, B_scale_inverse.data_ptr(), getTensorShape(B_scale_inverse), - nvte_scaling_modeB); + auto te_B = makeTransformerEngineTensor(B.data_ptr(), B_shape, B_type, nullptr, nullptr, + B_scale_inverse.data_ptr(), + getTensorShape(B_scale_inverse), nvte_scaling_modeB); // TODO: D_scale_inv cannot be nullptr when D_type is FP8. auto te_D = makeTransformerEngineTensor( D.data_ptr(), diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 3b94d38ac16..aa8416121d0 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1326,7 +1326,8 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( if (columnwise_data) { shape = convert_shape_back_from_fp4(convertShape(getTensorShape(*columnwise_data)), true); if (rowwise_data) { - auto expected_shape = convert_shape_back_from_fp4(convertShape(getTensorShape(*rowwise_data)), false); + auto expected_shape = + convert_shape_back_from_fp4(convertShape(getTensorShape(*rowwise_data)), false); NVTE_CHECK(shape == expected_shape, "NVFP4 row-wise data (shape=", expected_shape, ") and column-wise data (shape=", shape, ") do not match"); } diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 7557f5c5396..965367ac31b 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -105,46 +105,48 @@ def forward( debug, ) = non_tensor_args - (fp8, - fp8_calibration, - wgrad_store, - fuse_wgrad_accumulation, - tp_group, - tp_size, - sequence_parallel, - tensor_parallel, - activation_dtype, - parallel_mode, - ub_overlap_rs_fprop, - ub_overlap_ag_dgrad, - ub_overlap_ag_fprop, - ub_overlap_rs_dgrad, - ub_bulk_dgrad, - ub_bulk_wgrad, - ub_name, - fsdp_group, - symmetric_ar_type, - save_original_input - ) = (module.fp8, - module.fp8_calibration, - module.wgrad_store, - module.fuse_wgrad_accumulation, - module.tp_group, - module.tp_size, - module.sequence_parallel, - module.tp_size > 1, - module.activation_dtype, - module.parallel_mode, - module.ub_overlap_rs_fprop, - module.ub_overlap_ag_dgrad, - module.ub_overlap_ag_fprop, - module.ub_overlap_rs_dgrad, - module.ub_bulk_dgrad, - module.ub_bulk_wgrad, - module.ub_name, - module.fsdp_group, - module.symmetric_ar_type, - module.save_original_input, + ( + fp8, + fp8_calibration, + wgrad_store, + fuse_wgrad_accumulation, + tp_group, + tp_size, + sequence_parallel, + tensor_parallel, + activation_dtype, + parallel_mode, + ub_overlap_rs_fprop, + ub_overlap_ag_dgrad, + ub_overlap_ag_fprop, + ub_overlap_rs_dgrad, + ub_bulk_dgrad, + ub_bulk_wgrad, + ub_name, + fsdp_group, + symmetric_ar_type, + save_original_input, + ) = ( + module.fp8, + module.fp8_calibration, + module.wgrad_store, + module.fuse_wgrad_accumulation, + module.tp_group, + module.tp_size, + module.sequence_parallel, + module.tp_size > 1, + module.activation_dtype, + module.parallel_mode, + module.ub_overlap_rs_fprop, + module.ub_overlap_ag_dgrad, + module.ub_overlap_ag_fprop, + module.ub_overlap_rs_dgrad, + module.ub_bulk_dgrad, + module.ub_bulk_wgrad, + module.ub_name, + module.fsdp_group, + module.symmetric_ar_type, + module.save_original_input, ) quantizers = module._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) @@ -153,8 +155,14 @@ def forward( if module.no_debug_features_active(quantizers): debug = False quantizers = module._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) - (input_quantizer, weight_quantizer, output_quantizer, grad_input_quantizer, grad_weight_quantizer, grad_output_quantizer) = quantizers - + ( + input_quantizer, + weight_quantizer, + output_quantizer, + grad_input_quantizer, + grad_weight_quantizer, + grad_output_quantizer, + ) = quantizers # NVTX label for profiling nvtx_label = "transformer_engine._Linear.forward" @@ -1007,6 +1015,7 @@ def wgrad_gemm( None, ) + class Linear(TransformerEngineBaseModule): """Applies a linear transformation to the incoming data :math:`y = xA^T + b` @@ -1672,4 +1681,3 @@ def _customize_quantizers_float8_blockwise_scaling(self, fwd: bool, recipe: Reci if torch.__version__ >= "2": Linear.forward._torchdynamo_disable = True Linear.forward._torchdynamo_disable_msg = None - From 7b031d011331324b5f872f9e4038a9ace4a9f86c Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Tue, 23 Dec 2025 12:39:17 +0000 Subject: [PATCH 4/5] changes done to remove the additional nvte_make_shape calls Signed-off-by: Varun Thumbe --- transformer_engine/pytorch/csrc/common.cpp | 31 ++++-- transformer_engine/pytorch/csrc/common.h | 9 ++ .../pytorch/csrc/extensions/attention.cpp | 4 +- .../pytorch/csrc/extensions/bias.cpp | 9 +- .../pytorch/csrc/extensions/cast.cpp | 8 +- .../pytorch/csrc/extensions/gemm.cpp | 105 ++++++++++++------ .../pytorch/csrc/extensions/padding.cpp | 2 +- transformer_engine/pytorch/csrc/quantizer.cpp | 54 +++++---- .../pytorch/csrc/type_converters.cpp | 4 +- transformer_engine/pytorch/csrc/util.cpp | 22 ++-- 10 files changed, 155 insertions(+), 93 deletions(-) diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index c7f0975216b..b6a3853f6fd 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -26,7 +26,17 @@ std::vector convert_shape_back_from_fp4(const std::vector& shape return ret; } -NVTEShape getTensorShape(const at::Tensor& t) { return convertTorchShape(t.sizes()); } +NVTEShape getTensorShape(const at::Tensor& t) { + return convertTorchShape(t.sizes()); +} + +std::vector getTensorShapeVector(const at::Tensor& t) { + std::vector shape; + for (auto s : t.sizes()) { + shape.push_back(s); + } + return shape; +} NVTEShape convertTorchShape(const c10::IntArrayRef& torch_shape) { NVTEShape ret; @@ -113,10 +123,7 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor) { transformer_engine::DType dtype = GetTransformerEngineDType(tensor.scalar_type()); - std::vector shape; - for (auto s : tensor.sizes()) { - shape.push_back(s); - } + NVTEShape shape = getTensorShape(tensor); return makeTransformerEngineTensor(tensor.data_ptr(), shape, dtype); } @@ -179,7 +186,9 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( TensorWrapper ret(scaling_mode); ret.set_rowwise_data(data_ptr, type, shape); const size_t meta_shape_data[1] = {1}; - const NVTEShape meta_shape = nvte_make_shape(meta_shape_data, 1); + NVTEShape meta_shape; + meta_shape.ndim = 1; + meta_shape.data[0] = 1; ret.set_amax(amax_ptr, DType::kFloat32, meta_shape); ret.set_scale(scale_ptr, DType::kFloat32, meta_shape); auto scale_inv_dtype = @@ -194,8 +203,9 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( NVTEScalingMode scaling_mode) { TensorWrapper ret(scaling_mode); ret.set_rowwise_data(data_ptr, type, shape); - const size_t meta_shape_data[1] = {1}; - const NVTEShape meta_shape = nvte_make_shape(meta_shape_data, 1); + NVTEShape meta_shape; + meta_shape.ndim = 1; + meta_shape.data[0] = 1; ret.set_amax(amax_ptr, DType::kFloat32, meta_shape); ret.set_scale(scale_ptr, DType::kFloat32, meta_shape); auto scale_inv_dtype = @@ -234,8 +244,9 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( TensorWrapper ret(scaling_mode); ret.set_rowwise_data(data_ptr, type, shape); ret.set_columnwise_data(columnwise_data_ptr, type, columnwise_shape); - const size_t meta_shape_data[1] = {1}; - const NVTEShape meta_shape = nvte_make_shape(meta_shape_data, 1); + NVTEShape meta_shape; + meta_shape.ndim = 1; + meta_shape.data[0] = 1; ret.set_amax(amax_ptr, DType::kFloat32, meta_shape); ret.set_scale(scale_ptr, DType::kFloat32, meta_shape); auto scale_inv_dtype = (scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0 diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index e6c22880323..a9e7d895192 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -141,6 +141,13 @@ class NoneQuantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype, at::Tensor data) const; + std::pair create_tensor(const NVTEShape& shape, + DType dtype) const; + + /*! @brief Construct a tensor with pre-initialized data */ + std::pair create_tensor(const NVTEShape& shape, DType dtype, + at::Tensor data) const; + std::pair convert_and_update_tensor(py::object tensor) const override; void quantize(const TensorWrapper& input, TensorWrapper& out, @@ -341,6 +348,8 @@ std::unique_ptr convert_quantizer(py::handle quantizer); NVTEShape getTensorShape(const at::Tensor& t); +std::vector getTensorShapeVector(const at::Tensor& t); + transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid, const std::string& fp8_recipe); diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 2480d9aba9b..804a4667d71 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -479,9 +479,9 @@ std::vector fused_attn_bwd( std::vector cu_seqlens_kv_padded_shape{cu_seqlens_kv_padded_sizes.begin(), cu_seqlens_kv_padded_sizes.end()}; te_cu_seqlens_q_padded = makeTransformerEngineTensor(cu_seqlens_q_padded.value().data_ptr(), - cu_seqlens_q_padded_shape, DType::kInt32); + nvte_make_shape(cu_seqlens_q_padded_shape.data(), cu_seqlens_q_padded_shape.size()), DType::kInt32); te_cu_seqlens_kv_padded = makeTransformerEngineTensor( - cu_seqlens_kv_padded.value().data_ptr(), cu_seqlens_kv_padded_shape, DType::kInt32); + cu_seqlens_kv_padded.value().data_ptr(), nvte_make_shape(cu_seqlens_kv_padded_shape.data(), cu_seqlens_kv_padded_shape.size()), DType::kInt32); } // convert auxiliary tensors from forward to NVTETensors diff --git a/transformer_engine/pytorch/csrc/extensions/bias.cpp b/transformer_engine/pytorch/csrc/extensions/bias.cpp index 2eef7438068..c3e89ed0856 100644 --- a/transformer_engine/pytorch/csrc/extensions/bias.cpp +++ b/transformer_engine/pytorch/csrc/extensions/bias.cpp @@ -26,8 +26,7 @@ std::vector bgrad_quantize(const at::Tensor &grad_output, py::handle // Grad output tensor auto grad_output_torch = grad_output.contiguous(); const TensorWrapper &grad_output_nvte = makeTransformerEngineTensor(grad_output_torch); - const auto shape_nvte = getTensorShape(grad_output_torch); - const auto shape = convertShape(shape_nvte); + const auto shape = getTensorShapeVector(grad_output_torch); auto grad_output_dtype = GetTransformerEngineDType(grad_output_torch.scalar_type()); // Construct grad bias tensor @@ -117,13 +116,11 @@ std::vector dact_dbias( // Grad output and activation input tensors grad_output_torch = grad_output_torch.contiguous(); const TensorWrapper &grad_output_nvte = makeTransformerEngineTensor(grad_output_torch); - const auto output_shape_nvte = getTensorShape(grad_output_torch); - const auto output_shape = convertShape(output_shape_nvte); + const auto output_shape = getTensorShapeVector(grad_output_torch); auto grad_output_dtype = GetTransformerEngineDType(grad_output_torch.scalar_type()); act_input_torch = act_input_torch.contiguous(); const TensorWrapper &act_input_nvte = makeTransformerEngineTensor(act_input_torch); - const auto input_shape_nvte = getTensorShape(act_input_torch); - const auto input_shape = convertShape(input_shape_nvte); + const auto input_shape = getTensorShapeVector(act_input_torch); // Construct tensors auto quantizer_cpp = convert_quantizer(quantizer_py); diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index b12da7542bb..3f107f443c7 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -334,12 +334,12 @@ std::tuple, std::vector> bulk_allocate_fp tensor_cpp_list.emplace_back(makeTransformerEngineTensor( rowwise_usage ? rowwise_data_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_data_list[i].data_ptr() : nullptr, - rowwise_usage ? rowwise_data_shapes[i] : std::vector{}, - columnwise_usage ? columnwise_data_shapes[i] : std::vector{}, fp8_dtype, nullptr, + rowwise_usage ? nvte_make_shape(rowwise_data_shapes[i].data(), rowwise_data_shapes[i].size()) : NVTEShape{}, + columnwise_usage ? nvte_make_shape(columnwise_data_shapes[i].data(), columnwise_data_shapes[i].size()) : NVTEShape{}, fp8_dtype, nullptr, nullptr, rowwise_usage ? rowwise_scale_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_scale_list[i].data_ptr() : nullptr, - rowwise_usage ? rowwise_scale_shapes[i] : std::vector{}, - columnwise_usage ? columnwise_scale_shapes[i] : std::vector{}, scaling_mode)); + rowwise_usage ? nvte_make_shape(rowwise_scale_shapes[i].data(), rowwise_scale_shapes[i].size()) : NVTEShape{}, + columnwise_usage ? nvte_make_shape(columnwise_scale_shapes[i].data(), columnwise_scale_shapes[i].size()) : NVTEShape{}, scaling_mode)); } return retval; diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 35b523b5192..11be2d4e2fe 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -40,8 +40,8 @@ bool is_low_precision(const DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; } -std::vector getGemmOutputShape(const NVTEShape& A_shape, const bool transa, - const NVTEShape& B_shape, const bool transb) { +NVTEShape getGemmOutputShape(const NVTEShape& A_shape, const bool transa, + const NVTEShape& B_shape, const bool transb) { // Flatten outer dims to get 2D matrices const size_t A0 = product(A_shape, 0, A_shape.ndim - 1); const size_t A1 = A_shape.data[A_shape.ndim - 1]; @@ -53,27 +53,29 @@ std::vector getGemmOutputShape(const NVTEShape& A_shape, const bool tran A0, ",", A1, "), transa=", transa, ", B=(", B0, ",", B1, "), transb=", transb, ")"); // Construct output dims - std::vector ret; + NVTEShape ret; + size_t idx = 0; if (transb) { - ret.emplace_back(B1); + ret.data[idx++] = B1; } else { // Unflatten B0 for (size_t i = 0; i < B_shape.ndim - 1; ++i) { - ret.emplace_back(B_shape.data[i]); + ret.data[idx++] = B_shape.data[i]; } } if (transa) { - ret.emplace_back(A0); + ret.data[idx++] = A0; } else { - ret.emplace_back(A1); + ret.data[idx++] = A1; } + ret.ndim = idx; return ret; } -bool checkGemmShape(const std::vector& expected, const NVTEShape& actual) { - if (expected.size() != actual.ndim) return false; - for (size_t i = 0; i < expected.size(); ++i) { - if (expected[i] != actual.data[i]) return false; +bool checkGemmShape(const NVTEShape& expected, const NVTEShape& actual) { + if (expected.ndim != actual.ndim) return false; + for (size_t i = 0; i < expected.ndim; ++i) { + if (expected.data[i] != actual.data[i]) return false; } return true; } @@ -117,7 +119,7 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans // Check tensor dimensions const auto& A_shape = A_tensor.shape(); const auto& B_shape = B_tensor.shape(); - const auto& D_shape = detail::getGemmOutputShape(A_shape, transa, B_shape, transb); + const NVTEShape D_shape = detail::getGemmOutputShape(A_shape, transa, B_shape, transb); NVTE_CHECK(A_shape.ndim >= 1, "Tensor A needs to have at least 1 dimension"); NVTE_CHECK(B_shape.ndim >= 1, "Tensor B needs to have at least 1 dimension"); @@ -138,7 +140,7 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans // Output tensor TensorWrapper D_tensor; if (D.is_none()) { - std::tie(D_tensor, D) = createOutputTensor(D_shape, output_dtype, quantizer); + std::tie(D_tensor, D) = createOutputTensor(convertShape(D_shape), output_dtype, quantizer); } else { D_tensor = makeTransformerEngineTensor(D, quantizer); NVTE_CHECK(detail::checkGemmShape(D_shape, D_tensor.shape()), @@ -168,7 +170,7 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans if (unfused_quantization_needed) { NoneQuantizer q{none}; - std::tie(unquantized_D_tensor, unquantized_out) = q.create_tensor(D_shape, output_dtype); + std::tie(unquantized_D_tensor, unquantized_out) = q.create_tensor(convertShape(D_shape), output_dtype); } TensorWrapper& out_tensor = unfused_quantization_needed ? unquantized_D_tensor : D_tensor; @@ -197,8 +199,8 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans auto dtype = GetATenDType(gelu_type); auto opts = torch::TensorOptions().dtype(dtype).device(torch::kCUDA); std::vector torch_shape; - for (auto v : D_shape) { - torch_shape.push_back(v); + for (size_t i = 0; i < D_shape.ndim; ++i) { + torch_shape.push_back(static_cast(D_shape.data[i])); } pre_gelu_out = at::empty(torch_shape, opts); } else { @@ -207,14 +209,21 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans } } } - const auto gelu_shape = gelu ? D_shape : std::vector{0}; + NVTEShape gelu_shape; + gelu_shape.ndim = 1; + gelu_shape.data[0] = 0; + if (gelu) { + gelu_shape = D_shape; + } auto te_pre_gelu_out = makeTransformerEngineTensor(get_data_ptr(pre_gelu_out), gelu_shape, gelu_type); // Workspace - auto te_workspace = makeTransformerEngineTensor(workspace.data_ptr(), - std::vector{workspaceSize}, DType::kByte); + NVTEShape workspace_shape; + workspace_shape.ndim = 1; + workspace_shape.data[0] = workspaceSize; + auto te_workspace = makeTransformerEngineTensor(workspace.data_ptr(), workspace_shape, DType::kByte); // Set an external SM Margin to all the GEMMs. // This comes in handy when DP is overlapped with GEMMs @@ -263,8 +272,10 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans if (extra_output.has_value()) { extra_output_tensor = makeTransformerEngineTensor(*extra_output); } else { + NVTEShape extra_output_shape; + extra_output_shape.ndim = 0; extra_output_tensor = - makeTransformerEngineTensor(nullptr, std::vector{0}, DType::kByte); + makeTransformerEngineTensor(nullptr, extra_output_shape, DType::kByte); } // Direct GEMM call to the correct overlap @@ -367,28 +378,47 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type, const size_t A_shape_data[2] = {static_cast(A.size(0)), static_cast(A.size(1))}; const NVTEShape A_shape = nvte_make_shape(A_shape_data, 2); - auto te_A = makeTransformerEngineTensor(A.data_ptr(), A_shape, A_type, nullptr, nullptr, - A_scale_inverse.data_ptr(), - getTensorShape(A_scale_inverse), nvte_scaling_modeA); + auto te_A = makeTransformerEngineTensor( + A.data_ptr(), A_shape, A_type, + nullptr, nullptr, A_scale_inverse.data_ptr(), getTensorShape(A_scale_inverse), + nvte_scaling_modeA); const size_t B_shape_data[2] = {static_cast(B.size(0)), static_cast(B.size(1))}; const NVTEShape B_shape = nvte_make_shape(B_shape_data, 2); - auto te_B = makeTransformerEngineTensor(B.data_ptr(), B_shape, B_type, nullptr, nullptr, - B_scale_inverse.data_ptr(), - getTensorShape(B_scale_inverse), nvte_scaling_modeB); + auto te_B = makeTransformerEngineTensor( + B.data_ptr(), B_shape, B_type, + nullptr, nullptr, B_scale_inverse.data_ptr(), getTensorShape(B_scale_inverse), + nvte_scaling_modeB); // TODO: D_scale_inv cannot be nullptr when D_type is FP8. + NVTEShape D_shape, D_scale_inv_shape; + D_shape.ndim = 2; + D_scale_inv_shape.ndim = 1; + D_scale_inv_shape.data[0] = 1; + D_shape.data[0] = static_cast(D.size(0)); + D_shape.data[1] = static_cast(D.size(1)); auto te_D = makeTransformerEngineTensor( D.data_ptr(), - std::vector{static_cast(D.size(0)), static_cast(D.size(1))}, D_type, - D_amax.data_ptr(), D_scale.data_ptr(), nullptr); + D_shape, D_type, + D_amax.data_ptr(), D_scale.data_ptr(), nullptr, D_scale_inv_shape); + NVTEShape bias_shape; + bias_shape.ndim = 1; + bias_shape.data[0] = static_cast(bias.size(0)); auto te_bias = makeTransformerEngineTensor( - bias.data_ptr(), std::vector{static_cast(bias.size(0))}, bias_type); + bias.data_ptr(), bias_shape, bias_type); + NVTEShape counter_shape; + counter_shape.ndim = 1; + counter_shape.data[0] = static_cast(counter.size(0)); auto te_counter = makeTransformerEngineTensor( - counter.data_ptr(), std::vector{static_cast(counter.size(0))}, DType::kInt32); + counter.data_ptr(), counter_shape, DType::kInt32); - const auto gelu_shape = pre_gelu_out.data_ptr() == nullptr - ? std::vector{static_cast(pre_gelu_out.size(0))} - : std::vector{static_cast(pre_gelu_out.size(0)), - static_cast(pre_gelu_out.size(1))}; + NVTEShape gelu_shape; + if (pre_gelu_out.data_ptr() == nullptr) { + gelu_shape.ndim = 1; + gelu_shape.data[0] = static_cast(pre_gelu_out.size(0)); + } else { + gelu_shape.ndim = 2; + gelu_shape.data[0] = static_cast(pre_gelu_out.size(0)); + gelu_shape.data[1] = static_cast(pre_gelu_out.size(1)); + } auto te_pre_gelu_out = makeTransformerEngineTensor( pre_gelu_out.data_ptr(), gelu_shape, GetTransformerEngineDType(pre_gelu_out.scalar_type())); auto te_workspace = makeTransformerEngineTensor(workspace.data_ptr(), @@ -432,12 +462,13 @@ std::optional> te_general_grouped_gemm( // if there is single output at::Tensor out_tensor; - auto size_t_shape = + const NVTEShape nvte_D_shape = pytorch::detail::getGemmOutputShape(te_A.shape(), transa, te_B.shape(), transb); bool D_numel_is_zero = false; std::vector D_shape; - for (size_t t : size_t_shape) { - D_shape.push_back(t); + for (size_t j = 0; j < nvte_D_shape.ndim; ++j) { + const size_t t = nvte_D_shape.data[j]; + D_shape.push_back(static_cast(t)); if (t == 0) { D_numel_is_zero = true; } diff --git a/transformer_engine/pytorch/csrc/extensions/padding.cpp b/transformer_engine/pytorch/csrc/extensions/padding.cpp index d4b64a485c1..389308405b1 100644 --- a/transformer_engine/pytorch/csrc/extensions/padding.cpp +++ b/transformer_engine/pytorch/csrc/extensions/padding.cpp @@ -34,7 +34,7 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output, input_row_list[tensor_id] * input.size(1) * input.element_size(); input_char_ptr += input_dptr_offset; d_input_ptr = reinterpret_cast(input_char_ptr); - + NVTEShape input_shape = {input_row_list[tensor_id], static_cast(input.size(1))}; input_shape_list.push_back({input_row_list[tensor_id], static_cast(input.size(1))}); input_type_list.push_back(GetTransformerEngineDType(input.scalar_type())); diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index aa8416121d0..00f43433435 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -77,6 +77,16 @@ std::pair NoneQuantizer::create_tensor(const std::vec return create_tensor(shape, dtype, at::empty(shape_int64, opts)); } +std::pair NoneQuantizer::create_tensor(const NVTEShape& shape, + DType dtype) const { + std::vector shape_int64; + for (size_t i = 0; i < shape.ndim; ++i) { + shape_int64.push_back(static_cast(shape.data[i])); + } + const auto opts = at::TensorOptions().dtype(GetATenDType(dtype)).device(torch::kCUDA); + return create_tensor(shape, dtype, at::empty(shape_int64, opts)); +} + std::pair NoneQuantizer::create_tensor(const std::vector& shape, DType dtype, at::Tensor data) const { @@ -86,6 +96,15 @@ std::pair NoneQuantizer::create_tensor(const std::vec return {std::move(out_cpp), py::cast(data)}; } +std::pair NoneQuantizer::create_tensor(const NVTEShape& shape, + DType dtype, + at::Tensor data) const { +TensorWrapper out_cpp; +out_cpp.set_rowwise_data(data.data_ptr(), dtype, shape); +set_quantization_params(&out_cpp); +return {std::move(out_cpp), py::cast(data)}; +} + std::pair NoneQuantizer::convert_and_update_tensor( py::object tensor) const { auto tensor_pyt = tensor.cast(); @@ -209,8 +228,7 @@ std::pair Float8Quantizer::convert_and_update_tensor( // Tensor dimensions std::vector shape; if (has_transpose) { - const auto transpose_shape_nvte = getTensorShape(*transpose_tensor); - const auto transpose_shape = convertShape(transpose_shape_nvte); + const auto transpose_shape = getTensorShapeVector(*transpose_tensor); if (transpose_shape.size() > 0) { for (size_t i = 1; i < transpose_shape.size(); ++i) { shape.push_back(transpose_shape[i]); @@ -218,13 +236,12 @@ std::pair Float8Quantizer::convert_and_update_tensor( shape.push_back(transpose_shape.front()); } if (has_data) { - const auto expected_shape_nvte = getTensorShape(*data_tensor); - const auto expected_shape = convertShape(expected_shape_nvte); + const auto expected_shape = getTensorShapeVector(*data_tensor); NVTE_CHECK(shape == expected_shape, "FP8 data (shape=", expected_shape, ") and transpose (shape=", transpose_shape, ") do not match"); } } else { // Already checked has_data == true - shape = convertShape(getTensorShape(*data_tensor)); + shape = getTensorShapeVector(*data_tensor); } // Coerce data tensor @@ -432,8 +449,7 @@ std::pair Float8CurrentScalingQuantizer::convert_and_ // Tensor dimensions std::vector shape; if (has_transpose) { - const auto transpose_shape_nvte = getTensorShape(*transpose_tensor); - const auto transpose_shape = convertShape(transpose_shape_nvte); + const auto transpose_shape = getTensorShapeVector(*transpose_tensor); if (transpose_shape.size() > 0) { for (size_t i = 1; i < transpose_shape.size(); ++i) { shape.push_back(transpose_shape[i]); @@ -441,13 +457,12 @@ std::pair Float8CurrentScalingQuantizer::convert_and_ shape.push_back(transpose_shape.front()); } if (has_data) { - const auto expected_shape_nvte = getTensorShape(*data_tensor); - const auto expected_shape = convertShape(expected_shape_nvte); + const auto expected_shape = getTensorShapeVector(*data_tensor); NVTE_CHECK(shape == expected_shape, "FP8 data (shape=", expected_shape, ") and transpose (shape=", transpose_shape, ") do not match"); } } else { // Already checked has_data == true - shape = convertShape(getTensorShape(*data_tensor)); + shape = getTensorShapeVector(*data_tensor); } // Coerce data tensor in Python tensor @@ -684,9 +699,9 @@ std::pair Float8BlockQuantizer::convert_and_update_te return std::vector(); } if (all_gather_usage) { - return convertShape(getTensorShape(*columnwise_data)); + return getTensorShapeVector(*columnwise_data); } - std::vector shape = convertShape(getTensorShape(*columnwise_data)); + std::vector shape = getTensorShapeVector(*columnwise_data); std::vector shape_transposed(shape.size()); for (size_t i = 0; i + 1 < shape.size(); ++i) { shape_transposed[i] = shape[i + 1]; @@ -698,7 +713,7 @@ std::pair Float8BlockQuantizer::convert_and_update_te }; std::vector shape; if (rowwise_data) { - shape = convertShape(getTensorShape(*rowwise_data)); + shape = getTensorShapeVector(*rowwise_data); if (columnwise_data) { auto expected_shape = get_columnwise_shape(all_gather_usage); NVTE_CHECK(shape == expected_shape, "BlockwiseFP8 row-wise data (shape=", shape, @@ -1008,14 +1023,14 @@ std::pair MXFP8Quantizer::convert_and_update_tensor( // Tensor dimensions std::vector shape; if (columnwise_data) { - shape = convertShape(getTensorShape(*columnwise_data)); + shape = getTensorShapeVector(*columnwise_data); if (rowwise_data) { - const auto expected_shape = convertShape(getTensorShape(*rowwise_data)); + const auto expected_shape = getTensorShapeVector(*rowwise_data); NVTE_CHECK(shape == expected_shape, "MXFP8 row-wise data (shape=", expected_shape, ") and column-wise data (shape=", shape, ") do not match"); } } else { // Already checked columnwise_data_tensor == true - shape = convertShape(getTensorShape(*rowwise_data)); + shape = getTensorShapeVector(*rowwise_data); } // Coerce row-wise data @@ -1324,15 +1339,14 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( // Tensor dimensions, shape means original shape std::vector shape; if (columnwise_data) { - shape = convert_shape_back_from_fp4(convertShape(getTensorShape(*columnwise_data)), true); + shape = convert_shape_back_from_fp4(getTensorShapeVector(*columnwise_data), true); if (rowwise_data) { - auto expected_shape = - convert_shape_back_from_fp4(convertShape(getTensorShape(*rowwise_data)), false); + auto expected_shape = convert_shape_back_from_fp4(getTensorShapeVector(*rowwise_data), false); NVTE_CHECK(shape == expected_shape, "NVFP4 row-wise data (shape=", expected_shape, ") and column-wise data (shape=", shape, ") do not match"); } } else { // Already checked columnwise_data_tensor == true - shape = convert_shape_back_from_fp4(convertShape(getTensorShape(*rowwise_data)), false); + shape = convert_shape_back_from_fp4(getTensorShapeVector(*rowwise_data), false); } size_t flat_first_dim = 1; diff --git a/transformer_engine/pytorch/csrc/type_converters.cpp b/transformer_engine/pytorch/csrc/type_converters.cpp index 780a08da7f8..48e9f06cc40 100644 --- a/transformer_engine/pytorch/csrc/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/type_converters.cpp @@ -132,7 +132,7 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) const auto &scale_inv = tensor.attr("_rowwise_scale_inv").cast(); const auto &amax_rowwise = tensor.attr("_amax_rowwise").cast(); ret.set_rowwise_data(data.data_ptr(), dtype, - convert_shape_back_from_fp4(convertShape(getTensorShape(data)), false)); + convert_shape_back_from_fp4(getTensorShapeVector(data), false)); ret.set_rowwise_scale_inv(scale_inv.data_ptr(), DType::kFloat8E4M3, getTensorShape(scale_inv)); ret.set_amax(amax_rowwise.data_ptr(), DType::kFloat32, getTensorShape(amax_rowwise)); } @@ -143,7 +143,7 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) const auto &scale_inv = tensor.attr("_columnwise_scale_inv").cast(); const auto &amax_columnwise = tensor.attr("_amax_columnwise").cast(); ret.set_columnwise_data(data.data_ptr(), DType::kFloat4E2M1, - convert_shape_back_from_fp4(convertShape(getTensorShape(data)), false)); + convert_shape_back_from_fp4(getTensorShapeVector(data), false)); ret.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat8E4M3, getTensorShape(scale_inv)); ret.set_columnwise_amax(amax_columnwise.data_ptr(), DType::kFloat32, diff --git a/transformer_engine/pytorch/csrc/util.cpp b/transformer_engine/pytorch/csrc/util.cpp index 134185ac823..7fc04801e49 100644 --- a/transformer_engine/pytorch/csrc/util.cpp +++ b/transformer_engine/pytorch/csrc/util.cpp @@ -15,7 +15,7 @@ std::optional swizzle_scaling_factors(transformer_engine::TensorWrap if (input.scaling_mode() == NVTE_INVALID_SCALING) { NVTE_ERROR("Invalid scaling mode for swizzle."); - } else if (input.scaling_mode() != NVTE_MXFP8_1D_SCALING && + } else if (input.scaling_mode() != NVTE_MXFP8_1D_SCALING || input.scaling_mode() != NVTE_NVFP4_1D_SCALING) { return std::nullopt; } @@ -59,24 +59,24 @@ std::optional swizzle_scaling_factors(transformer_engine::TensorWrap (nvfp4) ? transformer_engine::DType::kFloat8E4M3 : transformer_engine::DType::kFloat8E8M0; if (rowwise) { - input_cu.set_rowwise_data(input.dptr(), input_dtype, input_shape); - input_cu.set_rowwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shape); - output_cu.set_rowwise_data(input.dptr(), input_dtype, input_shape); - output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape); + input_cu.set_rowwise_data(input.dptr(), input_dtype, nvte_input_shape); + input_cu.set_rowwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv.shape); + output_cu.set_rowwise_data(input.dptr(), input_dtype, nvte_input_shape); + output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv.shape); } else { - input_cu.set_columnwise_data(input.columnwise_dptr(), input_dtype, input_shape); - input_cu.set_columnwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shape); - output_cu.set_columnwise_data(input.columnwise_dptr(), input_dtype, input_shape); - output_cu.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape); + input_cu.set_columnwise_data(input.columnwise_dptr(), input_dtype, nvte_input_shape); + input_cu.set_columnwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv.shape); + output_cu.set_columnwise_data(input.columnwise_dptr(), input_dtype, nvte_input_shape); + output_cu.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv.shape); } // Launch kernel nvte_swizzle_scaling_factors(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); if (rowwise) { - input.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape); + input.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv.shape); } else { - input.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape); + input.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv.shape); } return swizzled_scale_inv; From d6ac3f1b28b2d3d2e8abbe0ab51cba7f782712b6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 24 Dec 2025 04:47:56 +0000 Subject: [PATCH 5/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/csrc/common.cpp | 4 +-- transformer_engine/pytorch/csrc/common.h | 9 +++-- .../pytorch/csrc/extensions/attention.cpp | 10 ++++-- .../pytorch/csrc/extensions/gemm.cpp | 36 +++++++++---------- transformer_engine/pytorch/csrc/quantizer.cpp | 14 ++++---- 5 files changed, 35 insertions(+), 38 deletions(-) diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index b6a3853f6fd..d4ce064facf 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -26,9 +26,7 @@ std::vector convert_shape_back_from_fp4(const std::vector& shape return ret; } -NVTEShape getTensorShape(const at::Tensor& t) { - return convertTorchShape(t.sizes()); -} +NVTEShape getTensorShape(const at::Tensor& t) { return convertTorchShape(t.sizes()); } std::vector getTensorShapeVector(const at::Tensor& t) { std::vector shape; diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index a9e7d895192..58e2acb6959 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -141,13 +141,12 @@ class NoneQuantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype, at::Tensor data) const; - std::pair create_tensor(const NVTEShape& shape, - DType dtype) const; - + std::pair create_tensor(const NVTEShape& shape, DType dtype) const; + /*! @brief Construct a tensor with pre-initialized data */ std::pair create_tensor(const NVTEShape& shape, DType dtype, - at::Tensor data) const; - + at::Tensor data) const; + std::pair convert_and_update_tensor(py::object tensor) const override; void quantize(const TensorWrapper& input, TensorWrapper& out, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 804a4667d71..1007dcb80c6 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -478,10 +478,14 @@ std::vector fused_attn_bwd( auto cu_seqlens_kv_padded_sizes = cu_seqlens_kv_padded.value().sizes().vec(); std::vector cu_seqlens_kv_padded_shape{cu_seqlens_kv_padded_sizes.begin(), cu_seqlens_kv_padded_sizes.end()}; - te_cu_seqlens_q_padded = makeTransformerEngineTensor(cu_seqlens_q_padded.value().data_ptr(), - nvte_make_shape(cu_seqlens_q_padded_shape.data(), cu_seqlens_q_padded_shape.size()), DType::kInt32); + te_cu_seqlens_q_padded = makeTransformerEngineTensor( + cu_seqlens_q_padded.value().data_ptr(), + nvte_make_shape(cu_seqlens_q_padded_shape.data(), cu_seqlens_q_padded_shape.size()), + DType::kInt32); te_cu_seqlens_kv_padded = makeTransformerEngineTensor( - cu_seqlens_kv_padded.value().data_ptr(), nvte_make_shape(cu_seqlens_kv_padded_shape.data(), cu_seqlens_kv_padded_shape.size()), DType::kInt32); + cu_seqlens_kv_padded.value().data_ptr(), + nvte_make_shape(cu_seqlens_kv_padded_shape.data(), cu_seqlens_kv_padded_shape.size()), + DType::kInt32); } // convert auxiliary tensors from forward to NVTETensors diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index b8928053d77..0e478ecd3ce 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -40,8 +40,8 @@ bool is_low_precision(const DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; } -NVTEShape getGemmOutputShape(const NVTEShape& A_shape, const bool transa, - const NVTEShape& B_shape, const bool transb) { +NVTEShape getGemmOutputShape(const NVTEShape& A_shape, const bool transa, const NVTEShape& B_shape, + const bool transb) { // Flatten outer dims to get 2D matrices const size_t A0 = A_shape.ndim > 0 ? product(A_shape, 0, A_shape.ndim - 1) : 1; const size_t A1 = A_shape.ndim > 0 ? A_shape.data[A_shape.ndim - 1] : 1; @@ -170,7 +170,8 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans if (unfused_quantization_needed) { NoneQuantizer q{none}; - std::tie(unquantized_D_tensor, unquantized_out) = q.create_tensor(convertShape(D_shape), output_dtype); + std::tie(unquantized_D_tensor, unquantized_out) = + q.create_tensor(convertShape(D_shape), output_dtype); } TensorWrapper& out_tensor = unfused_quantization_needed ? unquantized_D_tensor : D_tensor; @@ -223,7 +224,8 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans NVTEShape workspace_shape; workspace_shape.ndim = 1; workspace_shape.data[0] = workspaceSize; - auto te_workspace = makeTransformerEngineTensor(workspace.data_ptr(), workspace_shape, DType::kByte); + auto te_workspace = + makeTransformerEngineTensor(workspace.data_ptr(), workspace_shape, DType::kByte); // Set an external SM Margin to all the GEMMs. // This comes in handy when DP is overlapped with GEMMs @@ -378,16 +380,14 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type, const size_t A_shape_data[2] = {static_cast(A.size(0)), static_cast(A.size(1))}; const NVTEShape A_shape = nvte_make_shape(A_shape_data, 2); - auto te_A = makeTransformerEngineTensor( - A.data_ptr(), A_shape, A_type, - nullptr, nullptr, A_scale_inverse.data_ptr(), getTensorShape(A_scale_inverse), - nvte_scaling_modeA); + auto te_A = makeTransformerEngineTensor(A.data_ptr(), A_shape, A_type, nullptr, nullptr, + A_scale_inverse.data_ptr(), + getTensorShape(A_scale_inverse), nvte_scaling_modeA); const size_t B_shape_data[2] = {static_cast(B.size(0)), static_cast(B.size(1))}; const NVTEShape B_shape = nvte_make_shape(B_shape_data, 2); - auto te_B = makeTransformerEngineTensor( - B.data_ptr(), B_shape, B_type, - nullptr, nullptr, B_scale_inverse.data_ptr(), getTensorShape(B_scale_inverse), - nvte_scaling_modeB); + auto te_B = makeTransformerEngineTensor(B.data_ptr(), B_shape, B_type, nullptr, nullptr, + B_scale_inverse.data_ptr(), + getTensorShape(B_scale_inverse), nvte_scaling_modeB); // TODO: D_scale_inv cannot be nullptr when D_type is FP8. NVTEShape D_shape, D_scale_inv_shape; D_shape.ndim = 2; @@ -395,20 +395,16 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type, D_scale_inv_shape.data[0] = 1; D_shape.data[0] = static_cast(D.size(0)); D_shape.data[1] = static_cast(D.size(1)); - auto te_D = makeTransformerEngineTensor( - D.data_ptr(), - D_shape, D_type, - D_amax.data_ptr(), D_scale.data_ptr(), nullptr, D_scale_inv_shape); + auto te_D = makeTransformerEngineTensor(D.data_ptr(), D_shape, D_type, D_amax.data_ptr(), + D_scale.data_ptr(), nullptr, D_scale_inv_shape); NVTEShape bias_shape; bias_shape.ndim = 1; bias_shape.data[0] = static_cast(bias.size(0)); - auto te_bias = makeTransformerEngineTensor( - bias.data_ptr(), bias_shape, bias_type); + auto te_bias = makeTransformerEngineTensor(bias.data_ptr(), bias_shape, bias_type); NVTEShape counter_shape; counter_shape.ndim = 1; counter_shape.data[0] = static_cast(counter.size(0)); - auto te_counter = makeTransformerEngineTensor( - counter.data_ptr(), counter_shape, DType::kInt32); + auto te_counter = makeTransformerEngineTensor(counter.data_ptr(), counter_shape, DType::kInt32); NVTEShape gelu_shape; if (pre_gelu_out.data_ptr() == nullptr) { diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 64a6fa84766..0f8aa8381a8 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -78,7 +78,7 @@ std::pair NoneQuantizer::create_tensor(const std::vec } std::pair NoneQuantizer::create_tensor(const NVTEShape& shape, - DType dtype) const { + DType dtype) const { std::vector shape_int64; for (size_t i = 0; i < shape.ndim; ++i) { shape_int64.push_back(static_cast(shape.data[i])); @@ -97,12 +97,12 @@ std::pair NoneQuantizer::create_tensor(const std::vec } std::pair NoneQuantizer::create_tensor(const NVTEShape& shape, - DType dtype, - at::Tensor data) const { -TensorWrapper out_cpp; -out_cpp.set_rowwise_data(data.data_ptr(), dtype, shape); -set_quantization_params(&out_cpp); -return {std::move(out_cpp), py::cast(data)}; + DType dtype, + at::Tensor data) const { + TensorWrapper out_cpp; + out_cpp.set_rowwise_data(data.data_ptr(), dtype, shape); + set_quantization_params(&out_cpp); + return {std::move(out_cpp), py::cast(data)}; } std::pair NoneQuantizer::convert_and_update_tensor(