Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 63 additions & 6 deletions transformer_engine/pytorch/csrc/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,17 @@ std::vector<size_t> convert_shape_back_from_fp4(const std::vector<size_t>& shape
return ret;
}

std::vector<size_t> getTensorShape(const at::Tensor& t) {
NVTEShape getTensorShape(const at::Tensor& t) { return convertTorchShape(t.sizes()); }

std::vector<size_t> getTensorShapeVector(const at::Tensor& t) {
std::vector<size_t> shape;
for (auto s : t.sizes()) {
shape.push_back(s);
}
return shape;
}

NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape) {
NVTEShape convertTorchShape(const c10::IntArrayRef& torch_shape) {
NVTEShape ret;
ret.ndim = torch_shape.size();
constexpr int max_dimensions = sizeof(ret.data) / sizeof(size_t);
Expand Down Expand Up @@ -119,10 +121,7 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor(

transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor) {
transformer_engine::DType dtype = GetTransformerEngineDType(tensor.scalar_type());
std::vector<size_t> shape;
for (auto s : tensor.sizes()) {
shape.push_back(s);
}
NVTEShape shape = getTensorShape(tensor);
return makeTransformerEngineTensor(tensor.data_ptr(), shape, dtype);
}

Expand Down Expand Up @@ -178,6 +177,41 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor(
return ret;
}

transformer_engine::TensorWrapper makeTransformerEngineTensor(
void* data_ptr, const NVTEShape& shape, const transformer_engine::DType type, void* amax_ptr,
void* scale_ptr, void* scale_inv_ptr, const NVTEShape& scale_inv_shape,
NVTEScalingMode scaling_mode) {
TensorWrapper ret(scaling_mode);
ret.set_rowwise_data(data_ptr, type, shape);
const size_t meta_shape_data[1] = {1};
NVTEShape meta_shape;
meta_shape.ndim = 1;
meta_shape.data[0] = 1;
ret.set_amax(amax_ptr, DType::kFloat32, meta_shape);
ret.set_scale(scale_ptr, DType::kFloat32, meta_shape);
auto scale_inv_dtype =
(scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0 : DType::kFloat32;
ret.set_rowwise_scale_inv(scale_inv_ptr, scale_inv_dtype, scale_inv_shape);
return ret;
}

transformer_engine::TensorWrapper makeTransformerEngineTensor(
void* data_ptr, const std::vector<size_t>& shape, const transformer_engine::DType type,
void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, const NVTEShape& scale_inv_shape,
NVTEScalingMode scaling_mode) {
TensorWrapper ret(scaling_mode);
ret.set_rowwise_data(data_ptr, type, shape);
NVTEShape meta_shape;
meta_shape.ndim = 1;
meta_shape.data[0] = 1;
ret.set_amax(amax_ptr, DType::kFloat32, meta_shape);
ret.set_scale(scale_ptr, DType::kFloat32, meta_shape);
auto scale_inv_dtype =
(scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0 : DType::kFloat32;
ret.set_rowwise_scale_inv(scale_inv_ptr, scale_inv_dtype, scale_inv_shape);
return ret;
}

transformer_engine::TensorWrapper makeTransformerEngineTensor(
void* data_ptr, void* columnwise_data_ptr, const std::vector<size_t>& shape,
const std::vector<size_t>& columnwise_shape, const transformer_engine::DType type,
Expand All @@ -199,6 +233,29 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor(
return ret;
}

transformer_engine::TensorWrapper makeTransformerEngineTensor(
void* data_ptr, void* columnwise_data_ptr, const NVTEShape& shape,
const NVTEShape& columnwise_shape, const transformer_engine::DType type, void* amax_ptr,
void* scale_ptr, void* scale_inv_ptr, void* columnwise_scale_inv_ptr,
const NVTEShape& scale_inv_shape, const NVTEShape& columnwise_scale_inv_shape,
NVTEScalingMode scaling_mode) {
TensorWrapper ret(scaling_mode);
ret.set_rowwise_data(data_ptr, type, shape);
ret.set_columnwise_data(columnwise_data_ptr, type, columnwise_shape);
NVTEShape meta_shape;
meta_shape.ndim = 1;
meta_shape.data[0] = 1;
ret.set_amax(amax_ptr, DType::kFloat32, meta_shape);
ret.set_scale(scale_ptr, DType::kFloat32, meta_shape);
auto scale_inv_dtype = (scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0
: (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat8E4M3
: DType::kFloat32;
ret.set_rowwise_scale_inv(scale_inv_ptr, scale_inv_dtype, scale_inv_shape);
ret.set_columnwise_scale_inv(columnwise_scale_inv_ptr, scale_inv_dtype,
columnwise_scale_inv_shape);
return ret;
}

transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor, at::Tensor amax,
const at::Tensor scale,
at::Tensor scale_inv,
Expand Down
29 changes: 27 additions & 2 deletions transformer_engine/pytorch/csrc/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,12 @@ class NoneQuantizer : public Quantizer {
std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape, DType dtype,
at::Tensor data) const;

std::pair<TensorWrapper, py::object> create_tensor(const NVTEShape& shape, DType dtype) const;

/*! @brief Construct a tensor with pre-initialized data */
std::pair<TensorWrapper, py::object> create_tensor(const NVTEShape& shape, DType dtype,
at::Tensor data) const;

std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object tensor) const override;

void quantize(const TensorWrapper& input, TensorWrapper& out,
Expand Down Expand Up @@ -339,7 +345,9 @@ class NVFP4Quantizer : public Quantizer {

std::unique_ptr<Quantizer> convert_quantizer(py::handle quantizer);

std::vector<size_t> getTensorShape(const at::Tensor& t);
NVTEShape getTensorShape(const at::Tensor& t);

std::vector<size_t> getTensorShapeVector(const at::Tensor& t);

transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid,
const std::string& fp8_recipe);
Expand Down Expand Up @@ -432,6 +440,16 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor(
void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, std::vector<size_t> scale_inv_shape = {1},
NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING);

transformer_engine::TensorWrapper makeTransformerEngineTensor(
void* data_ptr, const NVTEShape& shape, const transformer_engine::DType type, void* amax_ptr,
void* scale_ptr, void* scale_inv_ptr, const NVTEShape& scale_inv_shape,
NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING);

transformer_engine::TensorWrapper makeTransformerEngineTensor(
void* data_ptr, const std::vector<size_t>& shape, const transformer_engine::DType type,
void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, const NVTEShape& scale_inv_shape,
NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING);

transformer_engine::TensorWrapper makeTransformerEngineTensor(
void* data_ptr, void* columnwise_data_ptr, const std::vector<size_t>& shape,
const std::vector<size_t>& columnwise_shape, const transformer_engine::DType type,
Expand All @@ -440,6 +458,13 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor(
const std::vector<size_t>& columnwise_scale_inv_shape = {1},
NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING);

transformer_engine::TensorWrapper makeTransformerEngineTensor(
void* data_ptr, void* columnwise_data_ptr, const NVTEShape& shape,
const NVTEShape& columnwise_shape, const transformer_engine::DType type, void* amax_ptr,
void* scale_ptr, void* scale_inv_ptr, void* columnwise_scale_inv_ptr,
const NVTEShape& scale_inv_shape, const NVTEShape& columnwise_scale_inv_shape,
NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING);

transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr,
const NVTEShape& shape,
const transformer_engine::DType type);
Expand Down Expand Up @@ -479,7 +504,7 @@ std::vector<size_t> convertShape(const NVTEShape& shape);

size_t roundup(const size_t value, const size_t multiple);

NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape);
NVTEShape convertTorchShape(const c10::IntArrayRef& torch_shape);

std::vector<size_t> convert_shape_back_from_fp4(const std::vector<size_t>& shape, bool transpose);

Expand Down
10 changes: 7 additions & 3 deletions transformer_engine/pytorch/csrc/extensions/attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -478,10 +478,14 @@ std::vector<py::object> fused_attn_bwd(
auto cu_seqlens_kv_padded_sizes = cu_seqlens_kv_padded.value().sizes().vec();
std::vector<size_t> cu_seqlens_kv_padded_shape{cu_seqlens_kv_padded_sizes.begin(),
cu_seqlens_kv_padded_sizes.end()};
te_cu_seqlens_q_padded = makeTransformerEngineTensor(cu_seqlens_q_padded.value().data_ptr(),
cu_seqlens_q_padded_shape, DType::kInt32);
te_cu_seqlens_q_padded = makeTransformerEngineTensor(
cu_seqlens_q_padded.value().data_ptr(),
nvte_make_shape(cu_seqlens_q_padded_shape.data(), cu_seqlens_q_padded_shape.size()),
DType::kInt32);
te_cu_seqlens_kv_padded = makeTransformerEngineTensor(
cu_seqlens_kv_padded.value().data_ptr(), cu_seqlens_kv_padded_shape, DType::kInt32);
cu_seqlens_kv_padded.value().data_ptr(),
nvte_make_shape(cu_seqlens_kv_padded_shape.data(), cu_seqlens_kv_padded_shape.size()),
DType::kInt32);
}

// convert auxiliary tensors from forward to NVTETensors
Expand Down
6 changes: 3 additions & 3 deletions transformer_engine/pytorch/csrc/extensions/bias.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ std::vector<py::object> bgrad_quantize(const at::Tensor &grad_output, py::handle
// Grad output tensor
auto grad_output_torch = grad_output.contiguous();
const TensorWrapper &grad_output_nvte = makeTransformerEngineTensor(grad_output_torch);
const auto shape = getTensorShape(grad_output_torch);
const auto shape = getTensorShapeVector(grad_output_torch);
auto grad_output_dtype = GetTransformerEngineDType(grad_output_torch.scalar_type());

// Construct grad bias tensor
Expand Down Expand Up @@ -116,11 +116,11 @@ std::vector<py::object> dact_dbias(
// Grad output and activation input tensors
grad_output_torch = grad_output_torch.contiguous();
const TensorWrapper &grad_output_nvte = makeTransformerEngineTensor(grad_output_torch);
const auto output_shape = getTensorShape(grad_output_torch);
const auto output_shape = getTensorShapeVector(grad_output_torch);
auto grad_output_dtype = GetTransformerEngineDType(grad_output_torch.scalar_type());
act_input_torch = act_input_torch.contiguous();
const TensorWrapper &act_input_nvte = makeTransformerEngineTensor(act_input_torch);
const auto input_shape = getTensorShape(act_input_torch);
const auto input_shape = getTensorShapeVector(act_input_torch);

// Construct tensors
auto quantizer_cpp = convert_quantizer(quantizer_py);
Expand Down
Loading