From 8da325205e3d57207d071773642ac2c95a82fa1a Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Sat, 20 Dec 2025 00:54:50 +0000 Subject: [PATCH 1/2] Plumbing correct bias dims from TE to cudnn Signed-off-by: Kshitij Lakhani --- .../fused_attn_f16_arbitrary_seqlen.cu | 34 +++++++++++++------ .../common/fused_attn/fused_attn_fp8.cu | 20 +++++++---- transformer_engine/common/fused_attn/utils.h | 6 ++-- 3 files changed, 41 insertions(+), 19 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index efa4c784390..e31f142e90e 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -52,7 +52,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t num_pages_k, int64_t num_pages_v, int64_t page_size_k, int64_t page_size_v, int64_t max_pages_per_seq_k, - int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, bool is_training, + int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, int64_t bias_sq, int64_t bias_skv, bool is_training, bool return_max_logit, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, void *devPtrQ, void *devPtrK, @@ -120,6 +120,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( max_pages_per_seq_v, bias_b, bias_h, + bias_sq, + bias_skv, scaling_factor, is_training, dropout_probability, @@ -263,8 +265,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( if (is_bias) { bias = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("bias") - .set_dim({bias_b, bias_h, s_q, s_kv}) - .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1})); + .set_dim({bias_b, bias_h, bias_sq, bias_skv}) + .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1})); sdpa_options.set_bias(bias); } @@ -539,7 +541,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( void fused_attn_arbitrary_seqlen_bwd_impl( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, - int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t bias_b, int64_t bias_h, + int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t bias_b, int64_t bias_h, int64_t bias_sq, int64_t bias_skv, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool deterministic, void *devPtrQ, @@ -612,6 +614,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl( 0, bias_b, bias_h, + bias_sq, + bias_skv, scaling_factor, true, dropout_probability, @@ -794,12 +798,12 @@ void fused_attn_arbitrary_seqlen_bwd_impl( if (is_bias) { bias = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("bias") - .set_dim({bias_b, bias_h, s_q, s_kv}) - .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1})); + .set_dim({bias_b, bias_h, bias_sq, bias_skv}) + .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1})); dBias = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("dBias") - .set_dim({bias_b, bias_h, s_q, s_kv}) - .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1})); + .set_dim({bias_b, bias_h, bias_sq, bias_skv}) + .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1})); sdpa_backward_options.set_bias(bias); // shapes [1, 1, s, s], [b, 1, s, s], [b, h, s, s] // are not supported for dbias calculation but they are @@ -1064,10 +1068,14 @@ void fused_attn_arbitrary_seqlen_fwd( void *devPtrBias = nullptr; size_t bias_b = 0; size_t bias_h = 0; + size_t bias_sq = 0; + size_t bias_skv = 0; if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) { devPtrBias = input_Bias->data.dptr; bias_b = input_Bias->data.shape[0]; bias_h = input_Bias->data.shape[1]; + bias_sq = input_Bias->data.shape[2]; + bias_skv = input_Bias->data.shape[3]; } void *devPtrSoftmaxOffset = nullptr; if (softmax_type != NVTE_VANILLA_SOFTMAX) { @@ -1133,7 +1141,7 @@ void fused_attn_arbitrary_seqlen_fwd( if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_bias->data.dptr = nullptr; - output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv}; + output_bias->data.shape = {bias_b, bias_h, bias_sq, bias_skv}; output_bias->data.dtype = QKV_type; } @@ -1178,7 +1186,7 @@ void fused_attn_arbitrary_seqlen_fwd( fused_attn_arbitrary_seqlen_fwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k, - page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training, + page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, bias_sq, bias_skv, is_training, return_max_logit, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, @@ -1224,11 +1232,15 @@ void fused_attn_arbitrary_seqlen_bwd( void *devPtrdBias = nullptr; size_t bias_b = 0; size_t bias_h = 0; + size_t bias_sq = 0; + size_t bias_skv = 0; if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) { devPtrBias = input_Bias->data.dptr; devPtrdBias = output_dBias->data.dptr; bias_b = output_dBias->data.shape[0]; bias_h = output_dBias->data.shape[1]; + bias_sq = input_Bias->data.shape[2]; + bias_skv = input_Bias->data.shape[3]; } size_t max_batch_size = 0; @@ -1271,7 +1283,7 @@ void fused_attn_arbitrary_seqlen_bwd( fused_attn_arbitrary_seqlen_bwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, - max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, attn_scale, p_dropout, + max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, bias_sq, bias_skv, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 5d806290a9e..54a1bb9a656 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1671,6 +1671,8 @@ void fused_attn_fp8_fwd_impl_v1( bool is_dropout = (is_training && dropout_probability != 0.0f); auto bias_b = b; auto bias_h = h; + auto bias_sq = s_q; + auto bias_skv = s_kv; NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!"); NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!"); bool is_current_scaling = (o_tensor_type == cudnn_frontend::DataType_t::HALF || @@ -1697,6 +1699,8 @@ void fused_attn_fp8_fwd_impl_v1( 0, bias_b, bias_h, + bias_sq, + bias_skv, scaling_factor, is_training, dropout_probability, @@ -1817,8 +1821,8 @@ void fused_attn_fp8_fwd_impl_v1( // if (is_bias) { // bias = mha_graph->tensor(fe::graph::Tensor_attributes() // .set_name("bias") - // .set_dim({bias_b, bias_h, s_q, s_kv}) - // .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1})); + // .set_dim({bias_b, bias_h, bias_sq, bias_skv}) + // .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1})); // sdpa_options.set_bias(bias); // } @@ -1998,6 +2002,8 @@ void fused_attn_fp8_bwd_impl_v1( bool is_dropout = (dropout_probability != 0.0f); auto bias_b = b; auto bias_h = h; + auto bias_sq = s_q; + auto bias_skv = s_kv; NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!"); NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!"); bool is_current_scaling = (dqkv_tensor_type == cudnn_frontend::DataType_t::HALF || @@ -2026,6 +2032,8 @@ void fused_attn_fp8_bwd_impl_v1( 0, bias_b, bias_h, + bias_sq, + bias_skv, scaling_factor, true, dropout_probability, @@ -2192,12 +2200,12 @@ void fused_attn_fp8_bwd_impl_v1( // if (is_bias) { // bias = mha_graph->tensor(fe::graph::Tensor_attributes() // .set_name("bias") - // .set_dim({bias_b, bias_h, s_q, s_kv}) - // .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1})); + // .set_dim({bias_b, bias_h, bias_sq, bias_skv}) + // .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1})); // dBias = mha_graph->tensor(fe::graph::Tensor_attributes() // .set_name("dBias") - // .set_dim({bias_b, bias_h, s_q, s_kv}) - // .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1})); + // .set_dim({bias_b, bias_h, bias_sq, bias_skv}) + // .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1})); // sdpa_backward_options.set_bias(bias); // // shapes [1, 1, s, s], [b, 1, s, s], [b, h, s, s] // // are not supported for dbias calculation but they are diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index 72047a73f27..c1faa8c5bfd 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -101,6 +101,8 @@ struct FADescriptor_v1 { std::int64_t max_pages_per_seq_v; std::int64_t bias_b; std::int64_t bias_h; + std::int64_t bias_sq; + std::int64_t bias_skv; float attnScale; bool isTraining; float dropoutProbability; @@ -119,13 +121,13 @@ struct FADescriptor_v1 { bool operator<(const FADescriptor_v1 &rhs) const { return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k, - page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, + page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, bias_sq, bias_skv, attnScale, isTraining, dropoutProbability, layout, mask_type, softmax_type, window_size_left, window_size_right, deterministic, bias_type, qkv_tensor_type, o_tensor_type, do_tensor_type, dqkv_tensor_type, generate_max_sum_exp) < std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k, rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k, - rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.attnScale, rhs.isTraining, + rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.bias_sq, rhs.bias_skv, rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.softmax_type, rhs.window_size_left, rhs.window_size_right, rhs.deterministic, rhs.bias_type, rhs.qkv_tensor_type, rhs.o_tensor_type, rhs.do_tensor_type, From d3f15bfcf083e4bb70598e9b7d4c75b825291353 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 22 Dec 2025 18:22:38 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../fused_attn_f16_arbitrary_seqlen.cu | 81 ++++++++++--------- transformer_engine/common/fused_attn/utils.h | 19 ++--- 2 files changed, 53 insertions(+), 47 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index e31f142e90e..07aa5d1654a 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -52,15 +52,16 @@ void fused_attn_arbitrary_seqlen_fwd_impl( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t num_pages_k, int64_t num_pages_v, int64_t page_size_k, int64_t page_size_v, int64_t max_pages_per_seq_k, - int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, int64_t bias_sq, int64_t bias_skv, bool is_training, - bool return_max_logit, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, void *devPtrQ, void *devPtrK, - void *devPtrV, void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrS1, void *devPtrS2, - void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, - void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV, - void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, - void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { + int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, int64_t bias_sq, int64_t bias_skv, + bool is_training, bool return_max_logit, float scaling_factor, float dropout_probability, + NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + void *devPtrQ, void *devPtrK, void *devPtrV, void *devPtrBias, void *devPtrSoftmaxOffset, + void *devPtrS1, void *devPtrS2, void *devPtrO, void *devPtrDropoutSeed, + void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, + void *devPtrPageTableK, void *devPtrPageTableV, void *devPtrSeqOffsetsQ, + void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace, + size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); @@ -263,10 +264,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl( sdpa_options.set_alibi_mask(is_alibi); if (is_bias) { - bias = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("bias") - .set_dim({bias_b, bias_h, bias_sq, bias_skv}) - .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1})); + bias = mha_graph->tensor( + fe::graph::Tensor_attributes() + .set_name("bias") + .set_dim({bias_b, bias_h, bias_sq, bias_skv}) + .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1})); sdpa_options.set_bias(bias); } @@ -541,16 +543,17 @@ void fused_attn_arbitrary_seqlen_fwd_impl( void fused_attn_arbitrary_seqlen_bwd_impl( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, - int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t bias_b, int64_t bias_h, int64_t bias_sq, int64_t bias_skv, - float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, void *devPtrQ, - void *devPtrKTranspose, void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats, - void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrdQ, void *devPtrdK, void *devPtrdV, - void *devPtrdO, void *devPtrdBias, void *devPtrdSoftmaxOffset, void *devPtrDropoutSeed, - void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, - void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, - void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { + int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t bias_b, int64_t bias_h, + int64_t bias_sq, int64_t bias_skv, float scaling_factor, float dropout_probability, + NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + bool deterministic, void *devPtrQ, void *devPtrKTranspose, void *devPtrVTranspose, + void *devPtrO, void *devPtrSoftmaxStats, void *devPtrBias, void *devPtrSoftmaxOffset, + void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO, void *devPtrdBias, + void *devPtrdSoftmaxOffset, void *devPtrDropoutSeed, void *devPtrDropoutOffset, + void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, + void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace, + size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); @@ -796,14 +799,16 @@ void fused_attn_arbitrary_seqlen_bwd_impl( sdpa_backward_options.set_alibi_mask(is_alibi); if (is_bias) { - bias = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("bias") - .set_dim({bias_b, bias_h, bias_sq, bias_skv}) - .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1})); - dBias = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("dBias") - .set_dim({bias_b, bias_h, bias_sq, bias_skv}) - .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1})); + bias = mha_graph->tensor( + fe::graph::Tensor_attributes() + .set_name("bias") + .set_dim({bias_b, bias_h, bias_sq, bias_skv}) + .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1})); + dBias = mha_graph->tensor( + fe::graph::Tensor_attributes() + .set_name("dBias") + .set_dim({bias_b, bias_h, bias_sq, bias_skv}) + .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1})); sdpa_backward_options.set_bias(bias); // shapes [1, 1, s, s], [b, 1, s, s], [b, h, s, s] // are not supported for dbias calculation but they are @@ -1186,9 +1191,9 @@ void fused_attn_arbitrary_seqlen_fwd( fused_attn_arbitrary_seqlen_fwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k, - page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, bias_sq, bias_skv, is_training, - return_max_logit, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, - window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, + page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, bias_sq, bias_skv, + is_training, return_max_logit, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, + softmax_type, window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, @@ -1283,10 +1288,10 @@ void fused_attn_arbitrary_seqlen_bwd( fused_attn_arbitrary_seqlen_bwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, - max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, bias_sq, bias_skv, attn_scale, p_dropout, - qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, - deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, - devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, + max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, bias_sq, bias_skv, attn_scale, + p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, + window_size_right, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, + devPtrBias, devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrdSoftmaxOffset, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index c1faa8c5bfd..1ba8a60f03a 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -121,17 +121,18 @@ struct FADescriptor_v1 { bool operator<(const FADescriptor_v1 &rhs) const { return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k, - page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, bias_sq, bias_skv, - attnScale, isTraining, dropoutProbability, layout, mask_type, softmax_type, - window_size_left, window_size_right, deterministic, bias_type, qkv_tensor_type, - o_tensor_type, do_tensor_type, dqkv_tensor_type, generate_max_sum_exp) < + page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, bias_sq, + bias_skv, attnScale, isTraining, dropoutProbability, layout, mask_type, + softmax_type, window_size_left, window_size_right, deterministic, bias_type, + qkv_tensor_type, o_tensor_type, do_tensor_type, dqkv_tensor_type, + generate_max_sum_exp) < std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k, rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k, - rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.bias_sq, rhs.bias_skv, rhs.attnScale, rhs.isTraining, - rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.softmax_type, - rhs.window_size_left, rhs.window_size_right, rhs.deterministic, rhs.bias_type, - rhs.qkv_tensor_type, rhs.o_tensor_type, rhs.do_tensor_type, - rhs.dqkv_tensor_type, rhs.generate_max_sum_exp); + rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.bias_sq, rhs.bias_skv, + rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.layout, + rhs.mask_type, rhs.softmax_type, rhs.window_size_left, rhs.window_size_right, + rhs.deterministic, rhs.bias_type, rhs.qkv_tensor_type, rhs.o_tensor_type, + rhs.do_tensor_type, rhs.dqkv_tensor_type, rhs.generate_max_sum_exp); } };