Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,16 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v,
int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t num_pages_k, int64_t num_pages_v,
int64_t page_size_k, int64_t page_size_v, int64_t max_pages_per_seq_k,
int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, bool is_training,
bool return_max_logit, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, void *devPtrQ, void *devPtrK,
void *devPtrV, void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrS1, void *devPtrS2,
void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ,
void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV,
void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType,
void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) {
int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, int64_t bias_sq, int64_t bias_skv,
bool is_training, bool return_max_logit, float scaling_factor, float dropout_probability,
NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right,
void *devPtrQ, void *devPtrK, void *devPtrV, void *devPtrBias, void *devPtrSoftmaxOffset,
void *devPtrS1, void *devPtrS2, void *devPtrO, void *devPtrDropoutSeed,
void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV,
void *devPtrPageTableK, void *devPtrPageTableV, void *devPtrSeqOffsetsQ,
void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace,
size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;

bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);
Expand Down Expand Up @@ -120,6 +121,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
max_pages_per_seq_v,
bias_b,
bias_h,
bias_sq,
bias_skv,
scaling_factor,
is_training,
dropout_probability,
Expand Down Expand Up @@ -261,10 +264,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
sdpa_options.set_alibi_mask(is_alibi);

if (is_bias) {
bias = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("bias")
.set_dim({bias_b, bias_h, s_q, s_kv})
.set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1}));
bias = mha_graph->tensor(
fe::graph::Tensor_attributes()
.set_name("bias")
.set_dim({bias_b, bias_h, bias_sq, bias_skv})
.set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1}));
sdpa_options.set_bias(bias);
}

Expand Down Expand Up @@ -540,15 +544,16 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
void fused_attn_arbitrary_seqlen_bwd_impl(
int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v,
int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t bias_b, int64_t bias_h,
float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic, void *devPtrQ,
void *devPtrKTranspose, void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats,
void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrdQ, void *devPtrdK, void *devPtrdV,
void *devPtrdO, void *devPtrdBias, void *devPtrdSoftmaxOffset, void *devPtrDropoutSeed,
void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV,
void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType,
void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) {
int64_t bias_sq, int64_t bias_skv, float scaling_factor, float dropout_probability,
NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right,
bool deterministic, void *devPtrQ, void *devPtrKTranspose, void *devPtrVTranspose,
void *devPtrO, void *devPtrSoftmaxStats, void *devPtrBias, void *devPtrSoftmaxOffset,
void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO, void *devPtrdBias,
void *devPtrdSoftmaxOffset, void *devPtrDropoutSeed, void *devPtrDropoutOffset,
void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ,
void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace,
size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;

bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);
Expand Down Expand Up @@ -612,6 +617,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
0,
bias_b,
bias_h,
bias_sq,
bias_skv,
scaling_factor,
true,
dropout_probability,
Expand Down Expand Up @@ -792,14 +799,16 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
sdpa_backward_options.set_alibi_mask(is_alibi);

if (is_bias) {
bias = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("bias")
.set_dim({bias_b, bias_h, s_q, s_kv})
.set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1}));
dBias = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("dBias")
.set_dim({bias_b, bias_h, s_q, s_kv})
.set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1}));
bias = mha_graph->tensor(
fe::graph::Tensor_attributes()
.set_name("bias")
.set_dim({bias_b, bias_h, bias_sq, bias_skv})
.set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1}));
dBias = mha_graph->tensor(
fe::graph::Tensor_attributes()
.set_name("dBias")
.set_dim({bias_b, bias_h, bias_sq, bias_skv})
.set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1}));
sdpa_backward_options.set_bias(bias);
// shapes [1, 1, s, s], [b, 1, s, s], [b, h, s, s]
// are not supported for dbias calculation but they are
Expand Down Expand Up @@ -1064,10 +1073,14 @@ void fused_attn_arbitrary_seqlen_fwd(
void *devPtrBias = nullptr;
size_t bias_b = 0;
size_t bias_h = 0;
size_t bias_sq = 0;
size_t bias_skv = 0;
if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) {
devPtrBias = input_Bias->data.dptr;
bias_b = input_Bias->data.shape[0];
bias_h = input_Bias->data.shape[1];
bias_sq = input_Bias->data.shape[2];
bias_skv = input_Bias->data.shape[3];
}
void *devPtrSoftmaxOffset = nullptr;
if (softmax_type != NVTE_VANILLA_SOFTMAX) {
Expand Down Expand Up @@ -1133,7 +1146,7 @@ void fused_attn_arbitrary_seqlen_fwd(
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_bias->data.dptr = nullptr;
output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv};
output_bias->data.shape = {bias_b, bias_h, bias_sq, bias_skv};
output_bias->data.dtype = QKV_type;
}

Expand Down Expand Up @@ -1178,9 +1191,9 @@ void fused_attn_arbitrary_seqlen_fwd(
fused_attn_arbitrary_seqlen_fwd_impl(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v,
max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k,
page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training,
return_max_logit, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type,
window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias,
page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, bias_sq, bias_skv,
is_training, return_max_logit, attn_scale, p_dropout, qkv_layout, bias_type, mask_type,
softmax_type, window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias,
devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset,
devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ,
devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size,
Expand Down Expand Up @@ -1224,11 +1237,15 @@ void fused_attn_arbitrary_seqlen_bwd(
void *devPtrdBias = nullptr;
size_t bias_b = 0;
size_t bias_h = 0;
size_t bias_sq = 0;
size_t bias_skv = 0;
if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) {
devPtrBias = input_Bias->data.dptr;
devPtrdBias = output_dBias->data.dptr;
bias_b = output_dBias->data.shape[0];
bias_h = output_dBias->data.shape[1];
bias_sq = input_Bias->data.shape[2];
bias_skv = input_Bias->data.shape[3];
}

size_t max_batch_size = 0;
Expand Down Expand Up @@ -1271,10 +1288,10 @@ void fused_attn_arbitrary_seqlen_bwd(

fused_attn_arbitrary_seqlen_bwd_impl(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v,
max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, attn_scale, p_dropout,
qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right,
deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias,
devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias,
max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, bias_sq, bias_skv, attn_scale,
p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left,
window_size_right, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats,
devPtrBias, devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias,
devPtrdSoftmaxOffset, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ,
devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type),
workspace->data.dptr, &workspace_size, stream, handle);
Expand Down
20 changes: 14 additions & 6 deletions transformer_engine/common/fused_attn/fused_attn_fp8.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1671,6 +1671,8 @@ void fused_attn_fp8_fwd_impl_v1(
bool is_dropout = (is_training && dropout_probability != 0.0f);
auto bias_b = b;
auto bias_h = h;
auto bias_sq = s_q;
auto bias_skv = s_kv;
NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!");
NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!");
bool is_current_scaling = (o_tensor_type == cudnn_frontend::DataType_t::HALF ||
Expand All @@ -1697,6 +1699,8 @@ void fused_attn_fp8_fwd_impl_v1(
0,
bias_b,
bias_h,
bias_sq,
bias_skv,
scaling_factor,
is_training,
dropout_probability,
Expand Down Expand Up @@ -1817,8 +1821,8 @@ void fused_attn_fp8_fwd_impl_v1(
// if (is_bias) {
// bias = mha_graph->tensor(fe::graph::Tensor_attributes()
// .set_name("bias")
// .set_dim({bias_b, bias_h, s_q, s_kv})
// .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1}));
// .set_dim({bias_b, bias_h, bias_sq, bias_skv})
// .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1}));
// sdpa_options.set_bias(bias);
// }

Expand Down Expand Up @@ -1998,6 +2002,8 @@ void fused_attn_fp8_bwd_impl_v1(
bool is_dropout = (dropout_probability != 0.0f);
auto bias_b = b;
auto bias_h = h;
auto bias_sq = s_q;
auto bias_skv = s_kv;
NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!");
NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!");
bool is_current_scaling = (dqkv_tensor_type == cudnn_frontend::DataType_t::HALF ||
Expand Down Expand Up @@ -2026,6 +2032,8 @@ void fused_attn_fp8_bwd_impl_v1(
0,
bias_b,
bias_h,
bias_sq,
bias_skv,
scaling_factor,
true,
dropout_probability,
Expand Down Expand Up @@ -2192,12 +2200,12 @@ void fused_attn_fp8_bwd_impl_v1(
// if (is_bias) {
// bias = mha_graph->tensor(fe::graph::Tensor_attributes()
// .set_name("bias")
// .set_dim({bias_b, bias_h, s_q, s_kv})
// .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1}));
// .set_dim({bias_b, bias_h, bias_sq, bias_skv})
// .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1}));
// dBias = mha_graph->tensor(fe::graph::Tensor_attributes()
// .set_name("dBias")
// .set_dim({bias_b, bias_h, s_q, s_kv})
// .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1}));
// .set_dim({bias_b, bias_h, bias_sq, bias_skv})
// .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1}));
// sdpa_backward_options.set_bias(bias);
// // shapes [1, 1, s, s], [b, 1, s, s], [b, h, s, s]
// // are not supported for dbias calculation but they are
Expand Down
21 changes: 12 additions & 9 deletions transformer_engine/common/fused_attn/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ struct FADescriptor_v1 {
std::int64_t max_pages_per_seq_v;
std::int64_t bias_b;
std::int64_t bias_h;
std::int64_t bias_sq;
std::int64_t bias_skv;
float attnScale;
bool isTraining;
float dropoutProbability;
Expand All @@ -119,17 +121,18 @@ struct FADescriptor_v1 {

bool operator<(const FADescriptor_v1 &rhs) const {
return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k,
page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h,
attnScale, isTraining, dropoutProbability, layout, mask_type, softmax_type,
window_size_left, window_size_right, deterministic, bias_type, qkv_tensor_type,
o_tensor_type, do_tensor_type, dqkv_tensor_type, generate_max_sum_exp) <
page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, bias_sq,
bias_skv, attnScale, isTraining, dropoutProbability, layout, mask_type,
softmax_type, window_size_left, window_size_right, deterministic, bias_type,
qkv_tensor_type, o_tensor_type, do_tensor_type, dqkv_tensor_type,
generate_max_sum_exp) <
std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k,
rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k,
rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.attnScale, rhs.isTraining,
rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.softmax_type,
rhs.window_size_left, rhs.window_size_right, rhs.deterministic, rhs.bias_type,
rhs.qkv_tensor_type, rhs.o_tensor_type, rhs.do_tensor_type,
rhs.dqkv_tensor_type, rhs.generate_max_sum_exp);
rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.bias_sq, rhs.bias_skv,
rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.layout,
rhs.mask_type, rhs.softmax_type, rhs.window_size_left, rhs.window_size_right,
rhs.deterministic, rhs.bias_type, rhs.qkv_tensor_type, rhs.o_tensor_type,
rhs.do_tensor_type, rhs.dqkv_tensor_type, rhs.generate_max_sum_exp);
}
};

Expand Down
Loading