diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index e054424dd4..d4ce064fac 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -26,7 +26,9 @@ std::vector convert_shape_back_from_fp4(const std::vector& shape return ret; } -std::vector getTensorShape(const at::Tensor& t) { +NVTEShape getTensorShape(const at::Tensor& t) { return convertTorchShape(t.sizes()); } + +std::vector getTensorShapeVector(const at::Tensor& t) { std::vector shape; for (auto s : t.sizes()) { shape.push_back(s); @@ -34,7 +36,7 @@ std::vector getTensorShape(const at::Tensor& t) { return shape; } -NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape) { +NVTEShape convertTorchShape(const c10::IntArrayRef& torch_shape) { NVTEShape ret; ret.ndim = torch_shape.size(); constexpr int max_dimensions = sizeof(ret.data) / sizeof(size_t); @@ -119,10 +121,7 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor) { transformer_engine::DType dtype = GetTransformerEngineDType(tensor.scalar_type()); - std::vector shape; - for (auto s : tensor.sizes()) { - shape.push_back(s); - } + NVTEShape shape = getTensorShape(tensor); return makeTransformerEngineTensor(tensor.data_ptr(), shape, dtype); } @@ -178,6 +177,41 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( return ret; } +transformer_engine::TensorWrapper makeTransformerEngineTensor( + void* data_ptr, const NVTEShape& shape, const transformer_engine::DType type, void* amax_ptr, + void* scale_ptr, void* scale_inv_ptr, const NVTEShape& scale_inv_shape, + NVTEScalingMode scaling_mode) { + TensorWrapper ret(scaling_mode); + ret.set_rowwise_data(data_ptr, type, shape); + const size_t meta_shape_data[1] = {1}; + NVTEShape meta_shape; + meta_shape.ndim = 1; + meta_shape.data[0] = 1; + ret.set_amax(amax_ptr, DType::kFloat32, meta_shape); + ret.set_scale(scale_ptr, DType::kFloat32, meta_shape); + auto scale_inv_dtype = + (scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0 : DType::kFloat32; + ret.set_rowwise_scale_inv(scale_inv_ptr, scale_inv_dtype, scale_inv_shape); + return ret; +} + +transformer_engine::TensorWrapper makeTransformerEngineTensor( + void* data_ptr, const std::vector& shape, const transformer_engine::DType type, + void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, const NVTEShape& scale_inv_shape, + NVTEScalingMode scaling_mode) { + TensorWrapper ret(scaling_mode); + ret.set_rowwise_data(data_ptr, type, shape); + NVTEShape meta_shape; + meta_shape.ndim = 1; + meta_shape.data[0] = 1; + ret.set_amax(amax_ptr, DType::kFloat32, meta_shape); + ret.set_scale(scale_ptr, DType::kFloat32, meta_shape); + auto scale_inv_dtype = + (scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0 : DType::kFloat32; + ret.set_rowwise_scale_inv(scale_inv_ptr, scale_inv_dtype, scale_inv_shape); + return ret; +} + transformer_engine::TensorWrapper makeTransformerEngineTensor( void* data_ptr, void* columnwise_data_ptr, const std::vector& shape, const std::vector& columnwise_shape, const transformer_engine::DType type, @@ -199,6 +233,29 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( return ret; } +transformer_engine::TensorWrapper makeTransformerEngineTensor( + void* data_ptr, void* columnwise_data_ptr, const NVTEShape& shape, + const NVTEShape& columnwise_shape, const transformer_engine::DType type, void* amax_ptr, + void* scale_ptr, void* scale_inv_ptr, void* columnwise_scale_inv_ptr, + const NVTEShape& scale_inv_shape, const NVTEShape& columnwise_scale_inv_shape, + NVTEScalingMode scaling_mode) { + TensorWrapper ret(scaling_mode); + ret.set_rowwise_data(data_ptr, type, shape); + ret.set_columnwise_data(columnwise_data_ptr, type, columnwise_shape); + NVTEShape meta_shape; + meta_shape.ndim = 1; + meta_shape.data[0] = 1; + ret.set_amax(amax_ptr, DType::kFloat32, meta_shape); + ret.set_scale(scale_ptr, DType::kFloat32, meta_shape); + auto scale_inv_dtype = (scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0 + : (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat8E4M3 + : DType::kFloat32; + ret.set_rowwise_scale_inv(scale_inv_ptr, scale_inv_dtype, scale_inv_shape); + ret.set_columnwise_scale_inv(columnwise_scale_inv_ptr, scale_inv_dtype, + columnwise_scale_inv_shape); + return ret; +} + transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor, at::Tensor amax, const at::Tensor scale, at::Tensor scale_inv, diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 978bee52dc..58e2acb695 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -141,6 +141,12 @@ class NoneQuantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype, at::Tensor data) const; + std::pair create_tensor(const NVTEShape& shape, DType dtype) const; + + /*! @brief Construct a tensor with pre-initialized data */ + std::pair create_tensor(const NVTEShape& shape, DType dtype, + at::Tensor data) const; + std::pair convert_and_update_tensor(py::object tensor) const override; void quantize(const TensorWrapper& input, TensorWrapper& out, @@ -339,7 +345,9 @@ class NVFP4Quantizer : public Quantizer { std::unique_ptr convert_quantizer(py::handle quantizer); -std::vector getTensorShape(const at::Tensor& t); +NVTEShape getTensorShape(const at::Tensor& t); + +std::vector getTensorShapeVector(const at::Tensor& t); transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid, const std::string& fp8_recipe); @@ -432,6 +440,16 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, std::vector scale_inv_shape = {1}, NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING); +transformer_engine::TensorWrapper makeTransformerEngineTensor( + void* data_ptr, const NVTEShape& shape, const transformer_engine::DType type, void* amax_ptr, + void* scale_ptr, void* scale_inv_ptr, const NVTEShape& scale_inv_shape, + NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING); + +transformer_engine::TensorWrapper makeTransformerEngineTensor( + void* data_ptr, const std::vector& shape, const transformer_engine::DType type, + void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, const NVTEShape& scale_inv_shape, + NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING); + transformer_engine::TensorWrapper makeTransformerEngineTensor( void* data_ptr, void* columnwise_data_ptr, const std::vector& shape, const std::vector& columnwise_shape, const transformer_engine::DType type, @@ -440,6 +458,13 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( const std::vector& columnwise_scale_inv_shape = {1}, NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING); +transformer_engine::TensorWrapper makeTransformerEngineTensor( + void* data_ptr, void* columnwise_data_ptr, const NVTEShape& shape, + const NVTEShape& columnwise_shape, const transformer_engine::DType type, void* amax_ptr, + void* scale_ptr, void* scale_inv_ptr, void* columnwise_scale_inv_ptr, + const NVTEShape& scale_inv_shape, const NVTEShape& columnwise_scale_inv_shape, + NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING); + transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr, const NVTEShape& shape, const transformer_engine::DType type); @@ -479,7 +504,7 @@ std::vector convertShape(const NVTEShape& shape); size_t roundup(const size_t value, const size_t multiple); -NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape); +NVTEShape convertTorchShape(const c10::IntArrayRef& torch_shape); std::vector convert_shape_back_from_fp4(const std::vector& shape, bool transpose); diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 2480d9aba9..1007dcb80c 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -478,10 +478,14 @@ std::vector fused_attn_bwd( auto cu_seqlens_kv_padded_sizes = cu_seqlens_kv_padded.value().sizes().vec(); std::vector cu_seqlens_kv_padded_shape{cu_seqlens_kv_padded_sizes.begin(), cu_seqlens_kv_padded_sizes.end()}; - te_cu_seqlens_q_padded = makeTransformerEngineTensor(cu_seqlens_q_padded.value().data_ptr(), - cu_seqlens_q_padded_shape, DType::kInt32); + te_cu_seqlens_q_padded = makeTransformerEngineTensor( + cu_seqlens_q_padded.value().data_ptr(), + nvte_make_shape(cu_seqlens_q_padded_shape.data(), cu_seqlens_q_padded_shape.size()), + DType::kInt32); te_cu_seqlens_kv_padded = makeTransformerEngineTensor( - cu_seqlens_kv_padded.value().data_ptr(), cu_seqlens_kv_padded_shape, DType::kInt32); + cu_seqlens_kv_padded.value().data_ptr(), + nvte_make_shape(cu_seqlens_kv_padded_shape.data(), cu_seqlens_kv_padded_shape.size()), + DType::kInt32); } // convert auxiliary tensors from forward to NVTETensors diff --git a/transformer_engine/pytorch/csrc/extensions/bias.cpp b/transformer_engine/pytorch/csrc/extensions/bias.cpp index b0435d2723..c3e89ed085 100644 --- a/transformer_engine/pytorch/csrc/extensions/bias.cpp +++ b/transformer_engine/pytorch/csrc/extensions/bias.cpp @@ -26,7 +26,7 @@ std::vector bgrad_quantize(const at::Tensor &grad_output, py::handle // Grad output tensor auto grad_output_torch = grad_output.contiguous(); const TensorWrapper &grad_output_nvte = makeTransformerEngineTensor(grad_output_torch); - const auto shape = getTensorShape(grad_output_torch); + const auto shape = getTensorShapeVector(grad_output_torch); auto grad_output_dtype = GetTransformerEngineDType(grad_output_torch.scalar_type()); // Construct grad bias tensor @@ -116,11 +116,11 @@ std::vector dact_dbias( // Grad output and activation input tensors grad_output_torch = grad_output_torch.contiguous(); const TensorWrapper &grad_output_nvte = makeTransformerEngineTensor(grad_output_torch); - const auto output_shape = getTensorShape(grad_output_torch); + const auto output_shape = getTensorShapeVector(grad_output_torch); auto grad_output_dtype = GetTransformerEngineDType(grad_output_torch.scalar_type()); act_input_torch = act_input_torch.contiguous(); const TensorWrapper &act_input_nvte = makeTransformerEngineTensor(act_input_torch); - const auto input_shape = getTensorShape(act_input_torch); + const auto input_shape = getTensorShapeVector(act_input_torch); // Construct tensors auto quantizer_cpp = convert_quantizer(quantizer_py); diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 335052296f..0e478ecd3c 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -40,8 +40,8 @@ bool is_low_precision(const DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; } -std::vector getGemmOutputShape(const NVTEShape& A_shape, const bool transa, - const NVTEShape& B_shape, const bool transb) { +NVTEShape getGemmOutputShape(const NVTEShape& A_shape, const bool transa, const NVTEShape& B_shape, + const bool transb) { // Flatten outer dims to get 2D matrices const size_t A0 = A_shape.ndim > 0 ? product(A_shape, 0, A_shape.ndim - 1) : 1; const size_t A1 = A_shape.ndim > 0 ? A_shape.data[A_shape.ndim - 1] : 1; @@ -53,27 +53,29 @@ std::vector getGemmOutputShape(const NVTEShape& A_shape, const bool tran A0, ",", A1, "), transa=", transa, ", B=(", B0, ",", B1, "), transb=", transb, ")"); // Construct output dims - std::vector ret; + NVTEShape ret; + size_t idx = 0; if (transb) { - ret.emplace_back(B1); + ret.data[idx++] = B1; } else { // Unflatten B0 for (size_t i = 0; i < B_shape.ndim - 1; ++i) { - ret.emplace_back(B_shape.data[i]); + ret.data[idx++] = B_shape.data[i]; } } if (transa) { - ret.emplace_back(A0); + ret.data[idx++] = A0; } else { - ret.emplace_back(A1); + ret.data[idx++] = A1; } + ret.ndim = idx; return ret; } -bool checkGemmShape(const std::vector& expected, const NVTEShape& actual) { - if (expected.size() != actual.ndim) return false; - for (size_t i = 0; i < expected.size(); ++i) { - if (expected[i] != actual.data[i]) return false; +bool checkGemmShape(const NVTEShape& expected, const NVTEShape& actual) { + if (expected.ndim != actual.ndim) return false; + for (size_t i = 0; i < expected.ndim; ++i) { + if (expected.data[i] != actual.data[i]) return false; } return true; } @@ -117,7 +119,7 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans // Check tensor dimensions const auto& A_shape = A_tensor.shape(); const auto& B_shape = B_tensor.shape(); - const auto& D_shape = detail::getGemmOutputShape(A_shape, transa, B_shape, transb); + const NVTEShape D_shape = detail::getGemmOutputShape(A_shape, transa, B_shape, transb); NVTE_CHECK(A_shape.ndim >= 1, "Tensor A needs to have at least 1 dimension"); NVTE_CHECK(B_shape.ndim >= 1, "Tensor B needs to have at least 1 dimension"); @@ -138,7 +140,7 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans // Output tensor TensorWrapper D_tensor; if (D.is_none()) { - std::tie(D_tensor, D) = createOutputTensor(D_shape, output_dtype, quantizer); + std::tie(D_tensor, D) = createOutputTensor(convertShape(D_shape), output_dtype, quantizer); } else { D_tensor = makeTransformerEngineTensor(D, quantizer); NVTE_CHECK(detail::checkGemmShape(D_shape, D_tensor.shape()), @@ -168,7 +170,8 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans if (unfused_quantization_needed) { NoneQuantizer q{none}; - std::tie(unquantized_D_tensor, unquantized_out) = q.create_tensor(D_shape, output_dtype); + std::tie(unquantized_D_tensor, unquantized_out) = + q.create_tensor(convertShape(D_shape), output_dtype); } TensorWrapper& out_tensor = unfused_quantization_needed ? unquantized_D_tensor : D_tensor; @@ -197,8 +200,8 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans auto dtype = GetATenDType(gelu_type); auto opts = torch::TensorOptions().dtype(dtype).device(torch::kCUDA); std::vector torch_shape; - for (auto v : D_shape) { - torch_shape.push_back(v); + for (size_t i = 0; i < D_shape.ndim; ++i) { + torch_shape.push_back(static_cast(D_shape.data[i])); } pre_gelu_out = at::empty(torch_shape, opts); } else { @@ -207,14 +210,22 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans } } } - const auto gelu_shape = gelu ? D_shape : std::vector{0}; + NVTEShape gelu_shape; + gelu_shape.ndim = 1; + gelu_shape.data[0] = 0; + if (gelu) { + gelu_shape = D_shape; + } auto te_pre_gelu_out = makeTransformerEngineTensor(get_data_ptr(pre_gelu_out), gelu_shape, gelu_type); // Workspace - auto te_workspace = makeTransformerEngineTensor(workspace.data_ptr(), - std::vector{workspaceSize}, DType::kByte); + NVTEShape workspace_shape; + workspace_shape.ndim = 1; + workspace_shape.data[0] = workspaceSize; + auto te_workspace = + makeTransformerEngineTensor(workspace.data_ptr(), workspace_shape, DType::kByte); // Set an external SM Margin to all the GEMMs. // This comes in handy when DP is overlapped with GEMMs @@ -263,8 +274,10 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans if (extra_output.has_value()) { extra_output_tensor = makeTransformerEngineTensor(*extra_output); } else { + NVTEShape extra_output_shape; + extra_output_shape.ndim = 0; extra_output_tensor = - makeTransformerEngineTensor(nullptr, std::vector{0}, DType::kByte); + makeTransformerEngineTensor(nullptr, extra_output_shape, DType::kByte); } // Direct GEMM call to the correct overlap @@ -365,28 +378,43 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type, NVTEScalingMode nvte_scaling_modeA = NVTE_DELAYED_TENSOR_SCALING; NVTEScalingMode nvte_scaling_modeB = NVTE_DELAYED_TENSOR_SCALING; - auto te_A = makeTransformerEngineTensor( - A.data_ptr(), {static_cast(A.size(0)), static_cast(A.size(1))}, A_type, - nullptr, nullptr, A_scale_inverse.data_ptr(), getTensorShape(A_scale_inverse), - nvte_scaling_modeA); - auto te_B = makeTransformerEngineTensor( - B.data_ptr(), {static_cast(B.size(0)), static_cast(B.size(1))}, B_type, - nullptr, nullptr, B_scale_inverse.data_ptr(), getTensorShape(B_scale_inverse), - nvte_scaling_modeB); + const size_t A_shape_data[2] = {static_cast(A.size(0)), static_cast(A.size(1))}; + const NVTEShape A_shape = nvte_make_shape(A_shape_data, 2); + auto te_A = makeTransformerEngineTensor(A.data_ptr(), A_shape, A_type, nullptr, nullptr, + A_scale_inverse.data_ptr(), + getTensorShape(A_scale_inverse), nvte_scaling_modeA); + const size_t B_shape_data[2] = {static_cast(B.size(0)), static_cast(B.size(1))}; + const NVTEShape B_shape = nvte_make_shape(B_shape_data, 2); + auto te_B = makeTransformerEngineTensor(B.data_ptr(), B_shape, B_type, nullptr, nullptr, + B_scale_inverse.data_ptr(), + getTensorShape(B_scale_inverse), nvte_scaling_modeB); // TODO: D_scale_inv cannot be nullptr when D_type is FP8. - auto te_D = makeTransformerEngineTensor( - D.data_ptr(), - std::vector{static_cast(D.size(0)), static_cast(D.size(1))}, D_type, - D_amax.data_ptr(), D_scale.data_ptr(), nullptr); - auto te_bias = makeTransformerEngineTensor( - bias.data_ptr(), std::vector{static_cast(bias.size(0))}, bias_type); - auto te_counter = makeTransformerEngineTensor( - counter.data_ptr(), std::vector{static_cast(counter.size(0))}, DType::kInt32); - - const auto gelu_shape = pre_gelu_out.data_ptr() == nullptr - ? std::vector{static_cast(pre_gelu_out.size(0))} - : std::vector{static_cast(pre_gelu_out.size(0)), - static_cast(pre_gelu_out.size(1))}; + NVTEShape D_shape, D_scale_inv_shape; + D_shape.ndim = 2; + D_scale_inv_shape.ndim = 1; + D_scale_inv_shape.data[0] = 1; + D_shape.data[0] = static_cast(D.size(0)); + D_shape.data[1] = static_cast(D.size(1)); + auto te_D = makeTransformerEngineTensor(D.data_ptr(), D_shape, D_type, D_amax.data_ptr(), + D_scale.data_ptr(), nullptr, D_scale_inv_shape); + NVTEShape bias_shape; + bias_shape.ndim = 1; + bias_shape.data[0] = static_cast(bias.size(0)); + auto te_bias = makeTransformerEngineTensor(bias.data_ptr(), bias_shape, bias_type); + NVTEShape counter_shape; + counter_shape.ndim = 1; + counter_shape.data[0] = static_cast(counter.size(0)); + auto te_counter = makeTransformerEngineTensor(counter.data_ptr(), counter_shape, DType::kInt32); + + NVTEShape gelu_shape; + if (pre_gelu_out.data_ptr() == nullptr) { + gelu_shape.ndim = 1; + gelu_shape.data[0] = static_cast(pre_gelu_out.size(0)); + } else { + gelu_shape.ndim = 2; + gelu_shape.data[0] = static_cast(pre_gelu_out.size(0)); + gelu_shape.data[1] = static_cast(pre_gelu_out.size(1)); + } auto te_pre_gelu_out = makeTransformerEngineTensor( pre_gelu_out.data_ptr(), gelu_shape, GetTransformerEngineDType(pre_gelu_out.scalar_type())); auto te_workspace = makeTransformerEngineTensor(workspace.data_ptr(), @@ -430,12 +458,13 @@ std::optional> te_general_grouped_gemm( // if there is single output at::Tensor out_tensor; - auto size_t_shape = + const NVTEShape nvte_D_shape = pytorch::detail::getGemmOutputShape(te_A.shape(), transa, te_B.shape(), transb); bool D_numel_is_zero = false; std::vector D_shape; - for (size_t t : size_t_shape) { - D_shape.push_back(t); + for (size_t j = 0; j < nvte_D_shape.ndim; ++j) { + const size_t t = nvte_D_shape.data[j]; + D_shape.push_back(static_cast(t)); if (t == 0) { D_numel_is_zero = true; } diff --git a/transformer_engine/pytorch/csrc/extensions/padding.cpp b/transformer_engine/pytorch/csrc/extensions/padding.cpp index d4b64a485c..389308405b 100644 --- a/transformer_engine/pytorch/csrc/extensions/padding.cpp +++ b/transformer_engine/pytorch/csrc/extensions/padding.cpp @@ -34,7 +34,7 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output, input_row_list[tensor_id] * input.size(1) * input.element_size(); input_char_ptr += input_dptr_offset; d_input_ptr = reinterpret_cast(input_char_ptr); - + NVTEShape input_shape = {input_row_list[tensor_id], static_cast(input.size(1))}; input_shape_list.push_back({input_row_list[tensor_id], static_cast(input.size(1))}); input_type_list.push_back(GetTransformerEngineDType(input.scalar_type())); diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cpp b/transformer_engine/pytorch/csrc/extensions/transpose.cpp index 7dfdf99547..5ace996afc 100644 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cpp +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cpp @@ -19,7 +19,8 @@ at::Tensor fp8_transpose(at::Tensor input, DType otype, std::optional transpose_shape_int64; if (shape.size() > 0) { transpose_shape_int64.push_back(shape.back()); @@ -60,7 +61,8 @@ at::Tensor swap_first_dims(at::Tensor tensor, std::optional out) { // Allocate output tensor if needed if (!out) { - auto in_shape = getTensorShape(input); + const auto in_shape_nvte = getTensorShape(input); + const auto in_shape = convertShape(in_shape_nvte); NVTE_CHECK(in_shape.size() >= 2, "Invalid input tensor dimensions (shape=", in_shape, ")"); std::vector out_shape_int64(in_shape.begin(), in_shape.end()); out_shape_int64[0] = static_cast(in_shape[1]); diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index fd748d1b21..0f8aa8381a 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -77,6 +77,16 @@ std::pair NoneQuantizer::create_tensor(const std::vec return create_tensor(shape, dtype, at::empty(shape_int64, opts)); } +std::pair NoneQuantizer::create_tensor(const NVTEShape& shape, + DType dtype) const { + std::vector shape_int64; + for (size_t i = 0; i < shape.ndim; ++i) { + shape_int64.push_back(static_cast(shape.data[i])); + } + const auto opts = at::TensorOptions().dtype(GetATenDType(dtype)).device(torch::kCUDA); + return create_tensor(shape, dtype, at::empty(shape_int64, opts)); +} + std::pair NoneQuantizer::create_tensor(const std::vector& shape, DType dtype, at::Tensor data) const { @@ -86,6 +96,15 @@ std::pair NoneQuantizer::create_tensor(const std::vec return {std::move(out_cpp), py::cast(data)}; } +std::pair NoneQuantizer::create_tensor(const NVTEShape& shape, + DType dtype, + at::Tensor data) const { + TensorWrapper out_cpp; + out_cpp.set_rowwise_data(data.data_ptr(), dtype, shape); + set_quantization_params(&out_cpp); + return {std::move(out_cpp), py::cast(data)}; +} + std::pair NoneQuantizer::convert_and_update_tensor( py::object tensor) const { auto tensor_pyt = tensor.cast(); @@ -209,7 +228,7 @@ std::pair Float8Quantizer::convert_and_update_tensor( // Tensor dimensions std::vector shape; if (has_transpose) { - const auto transpose_shape = getTensorShape(*transpose_tensor); + const auto transpose_shape = getTensorShapeVector(*transpose_tensor); if (transpose_shape.size() > 0) { for (size_t i = 1; i < transpose_shape.size(); ++i) { shape.push_back(transpose_shape[i]); @@ -217,12 +236,12 @@ std::pair Float8Quantizer::convert_and_update_tensor( shape.push_back(transpose_shape.front()); } if (has_data) { - auto expected_shape = getTensorShape(*data_tensor); + const auto expected_shape = getTensorShapeVector(*data_tensor); NVTE_CHECK(shape == expected_shape, "FP8 data (shape=", expected_shape, ") and transpose (shape=", transpose_shape, ") do not match"); } } else { // Already checked has_data == true - shape = getTensorShape(*data_tensor); + shape = getTensorShapeVector(*data_tensor); } // Coerce data tensor @@ -430,7 +449,7 @@ std::pair Float8CurrentScalingQuantizer::convert_and_ // Tensor dimensions std::vector shape; if (has_transpose) { - const auto transpose_shape = getTensorShape(*transpose_tensor); + const auto transpose_shape = getTensorShapeVector(*transpose_tensor); if (transpose_shape.size() > 0) { for (size_t i = 1; i < transpose_shape.size(); ++i) { shape.push_back(transpose_shape[i]); @@ -438,12 +457,12 @@ std::pair Float8CurrentScalingQuantizer::convert_and_ shape.push_back(transpose_shape.front()); } if (has_data) { - auto expected_shape = getTensorShape(*data_tensor); + const auto expected_shape = getTensorShapeVector(*data_tensor); NVTE_CHECK(shape == expected_shape, "FP8 data (shape=", expected_shape, ") and transpose (shape=", transpose_shape, ") do not match"); } } else { // Already checked has_data == true - shape = getTensorShape(*data_tensor); + shape = getTensorShapeVector(*data_tensor); } // Coerce data tensor in Python tensor @@ -680,9 +699,9 @@ std::pair Float8BlockQuantizer::convert_and_update_te return std::vector(); } if (all_gather_usage) { - return getTensorShape(*columnwise_data); + return getTensorShapeVector(*columnwise_data); } - std::vector shape = getTensorShape(*columnwise_data); + std::vector shape = getTensorShapeVector(*columnwise_data); std::vector shape_transposed(shape.size()); for (size_t i = 0; i + 1 < shape.size(); ++i) { shape_transposed[i] = shape[i + 1]; @@ -694,7 +713,7 @@ std::pair Float8BlockQuantizer::convert_and_update_te }; std::vector shape; if (rowwise_data) { - shape = getTensorShape(*rowwise_data); + shape = getTensorShapeVector(*rowwise_data); if (columnwise_data) { auto expected_shape = get_columnwise_shape(all_gather_usage); NVTE_CHECK(shape == expected_shape, "BlockwiseFP8 row-wise data (shape=", shape, @@ -1004,14 +1023,14 @@ std::pair MXFP8Quantizer::convert_and_update_tensor( // Tensor dimensions std::vector shape; if (columnwise_data) { - shape = getTensorShape(*columnwise_data); + shape = getTensorShapeVector(*columnwise_data); if (rowwise_data) { - auto expected_shape = getTensorShape(*rowwise_data); + const auto expected_shape = getTensorShapeVector(*rowwise_data); NVTE_CHECK(shape == expected_shape, "MXFP8 row-wise data (shape=", expected_shape, ") and column-wise data (shape=", shape, ") do not match"); } } else { // Already checked columnwise_data_tensor == true - shape = getTensorShape(*rowwise_data); + shape = getTensorShapeVector(*rowwise_data); } // Coerce row-wise data @@ -1320,14 +1339,14 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( // Tensor dimensions, shape means original shape std::vector shape; if (columnwise_data) { - shape = convert_shape_back_from_fp4(getTensorShape(*columnwise_data), true); + shape = convert_shape_back_from_fp4(getTensorShapeVector(*columnwise_data), true); if (rowwise_data) { - auto expected_shape = convert_shape_back_from_fp4(getTensorShape(*rowwise_data), false); + auto expected_shape = convert_shape_back_from_fp4(getTensorShapeVector(*rowwise_data), false); NVTE_CHECK(shape == expected_shape, "NVFP4 row-wise data (shape=", expected_shape, ") and column-wise data (shape=", shape, ") do not match"); } } else { // Already checked columnwise_data_tensor == true - shape = convert_shape_back_from_fp4(getTensorShape(*rowwise_data), false); + shape = convert_shape_back_from_fp4(getTensorShapeVector(*rowwise_data), false); } size_t flat_first_dim = 1; diff --git a/transformer_engine/pytorch/csrc/type_converters.cpp b/transformer_engine/pytorch/csrc/type_converters.cpp index 368e9dcdfa..48e9f06cc4 100644 --- a/transformer_engine/pytorch/csrc/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/type_converters.cpp @@ -132,7 +132,7 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) const auto &scale_inv = tensor.attr("_rowwise_scale_inv").cast(); const auto &amax_rowwise = tensor.attr("_amax_rowwise").cast(); ret.set_rowwise_data(data.data_ptr(), dtype, - convert_shape_back_from_fp4(getTensorShape(data), false)); + convert_shape_back_from_fp4(getTensorShapeVector(data), false)); ret.set_rowwise_scale_inv(scale_inv.data_ptr(), DType::kFloat8E4M3, getTensorShape(scale_inv)); ret.set_amax(amax_rowwise.data_ptr(), DType::kFloat32, getTensorShape(amax_rowwise)); } @@ -143,7 +143,7 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) const auto &scale_inv = tensor.attr("_columnwise_scale_inv").cast(); const auto &amax_columnwise = tensor.attr("_amax_columnwise").cast(); ret.set_columnwise_data(data.data_ptr(), DType::kFloat4E2M1, - convert_shape_back_from_fp4(getTensorShape(data), false)); + convert_shape_back_from_fp4(getTensorShapeVector(data), false)); ret.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat8E4M3, getTensorShape(scale_inv)); ret.set_columnwise_amax(amax_columnwise.data_ptr(), DType::kFloat32, diff --git a/transformer_engine/pytorch/csrc/util.cpp b/transformer_engine/pytorch/csrc/util.cpp index ce547d302e..be4c34b75a 100644 --- a/transformer_engine/pytorch/csrc/util.cpp +++ b/transformer_engine/pytorch/csrc/util.cpp @@ -15,7 +15,7 @@ std::optional swizzle_scaling_factors(transformer_engine::TensorWrap if (input.scaling_mode() == NVTE_INVALID_SCALING) { NVTE_ERROR("Invalid scaling mode for swizzle."); - } else if (input.scaling_mode() != NVTE_MXFP8_1D_SCALING && + } else if (input.scaling_mode() != NVTE_MXFP8_1D_SCALING || input.scaling_mode() != NVTE_NVFP4_1D_SCALING) { return std::nullopt; } @@ -59,24 +59,24 @@ std::optional swizzle_scaling_factors(transformer_engine::TensorWrap (nvfp4) ? transformer_engine::DType::kFloat8E4M3 : transformer_engine::DType::kFloat8E8M0; if (rowwise) { - input_cu.set_rowwise_data(input.dptr(), input_dtype, input_shape); - input_cu.set_rowwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shape); - output_cu.set_rowwise_data(input.dptr(), input_dtype, input_shape); - output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape); + input_cu.set_rowwise_data(input.dptr(), input_dtype, nvte_input_shape); + input_cu.set_rowwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv.shape); + output_cu.set_rowwise_data(input.dptr(), input_dtype, nvte_input_shape); + output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv.shape); } else { - input_cu.set_columnwise_data(input.columnwise_dptr(), input_dtype, input_shape); - input_cu.set_columnwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shape); - output_cu.set_columnwise_data(input.columnwise_dptr(), input_dtype, input_shape); - output_cu.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape); + input_cu.set_columnwise_data(input.columnwise_dptr(), input_dtype, nvte_input_shape); + input_cu.set_columnwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv.shape); + output_cu.set_columnwise_data(input.columnwise_dptr(), input_dtype, nvte_input_shape); + output_cu.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv.shape); } // Launch kernel nvte_swizzle_scaling_factors(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); if (rowwise) { - input.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape); + input.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv.shape); } else { - input.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape); + input.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv.shape); } return swizzled_scale_inv; diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index b65f7005eb..965367ac31 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -96,24 +96,26 @@ def forward( ( is_first_microbatch, + cpu_offloading, + is_grad_enabled, + fp8_output, + fp8_grad, + module, + skip_fp8_weight_update, + debug, + ) = non_tensor_args + + ( fp8, fp8_calibration, wgrad_store, - input_quantizer, - weight_quantizer, - output_quantizer, - grad_input_quantizer, - grad_weight_quantizer, - grad_output_quantizer, fuse_wgrad_accumulation, - cpu_offloading, tp_group, tp_size, sequence_parallel, tensor_parallel, activation_dtype, parallel_mode, - is_grad_enabled, ub_overlap_rs_fprop, ub_overlap_ag_dgrad, ub_overlap_ag_fprop, @@ -121,14 +123,46 @@ def forward( ub_bulk_dgrad, ub_bulk_wgrad, ub_name, - fp8_output, # pylint: disable=unused-variable fsdp_group, - module, - skip_fp8_weight_update, symmetric_ar_type, save_original_input, - debug, - ) = non_tensor_args + ) = ( + module.fp8, + module.fp8_calibration, + module.wgrad_store, + module.fuse_wgrad_accumulation, + module.tp_group, + module.tp_size, + module.sequence_parallel, + module.tp_size > 1, + module.activation_dtype, + module.parallel_mode, + module.ub_overlap_rs_fprop, + module.ub_overlap_ag_dgrad, + module.ub_overlap_ag_fprop, + module.ub_overlap_rs_dgrad, + module.ub_bulk_dgrad, + module.ub_bulk_wgrad, + module.ub_name, + module.fsdp_group, + module.symmetric_ar_type, + module.save_original_input, + ) + quantizers = module._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) + + if debug: + quantizers = module._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) + if module.no_debug_features_active(quantizers): + debug = False + quantizers = module._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) + ( + input_quantizer, + weight_quantizer, + output_quantizer, + grad_input_quantizer, + grad_weight_quantizer, + grad_output_quantizer, + ) = quantizers # NVTX label for profiling nvtx_label = "transformer_engine._Linear.forward" @@ -1343,7 +1377,6 @@ def reset_parameters(self, defer_init=False): elif self.parallel_mode == "column": set_tensor_model_parallel_attributes(getattr(self, bias), True, 0, 1) - @no_torch_dynamo() def forward( self, inp: torch.Tensor, @@ -1401,28 +1434,7 @@ def forward( inp, allow_non_contiguous=isinstance(inp, QuantizedTensor), ) as inp: - weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() - - quantizers = ( - self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) - if not debug - else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled) - ) - if debug: - if self.no_debug_features_active(quantizers): - debug = False - quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) - - ( - input_quantizer, - weight_quantizer, - output_quantizer, - grad_input_quantizer, - grad_weight_quantizer, - grad_output_quantizer, - ) = quantizers - if is_grad_enabled: linear_fn = _Linear.apply autograd_ctx = [] @@ -1432,37 +1444,12 @@ def forward( non_tensor_args = ( is_first_microbatch, - self.fp8, - self.fp8_calibration, - self.wgrad_store, - input_quantizer, - weight_quantizer, - output_quantizer, - grad_input_quantizer, - grad_weight_quantizer, - grad_output_quantizer, - self.fuse_wgrad_accumulation, is_cpu_offload_enabled(), - self.tp_group, - self.tp_size, - self.sequence_parallel, - self.tp_size > 1, - self.activation_dtype, - self.parallel_mode, is_grad_enabled, - self.ub_overlap_rs_fprop, - self.ub_overlap_ag_dgrad, - self.ub_overlap_ag_fprop, - self.ub_overlap_rs_dgrad, - self.ub_bulk_dgrad, - self.ub_bulk_wgrad, - self.ub_name, fp8_output, - self.fsdp_group, + fp8_grad, self, skip_fp8_weight_update, - self.symmetric_ar_type, - self.save_original_input, debug, ) out = linear_fn( @@ -1687,3 +1674,10 @@ def _customize_quantizers_float8_blockwise_scaling(self, fwd: bool, recipe: Reci self.quantizers["scaling_bwd"][ tex.FP8BwdTensors.GRAD_OUTPUT1 ].all_gather_usage = True + + +# disable torch dynamo just once to reduce wrapped function overhead on each +# forward call of te Linear. +if torch.__version__ >= "2": + Linear.forward._torchdynamo_disable = True + Linear.forward._torchdynamo_disable_msg = None