From 33e2e478c88b290b834e14fa872a3b67ab208b09 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Wed, 3 Dec 2025 16:50:32 -0800 Subject: [PATCH 01/13] SWA (left, right) with FusedAttention changes cherry-picked from https://github.com/NVIDIA/TransformerEngine/pull/1369 Signed-off-by: Sudhakar Singh --- tests/pytorch/attention/test_attention.py | 7 +- .../common/fused_attn/fused_attn.cpp | 73 +++++++++-------- .../fused_attn_f16_arbitrary_seqlen.cu | 57 ++++++++----- .../fused_attn_f16_arbitrary_seqlen.h | 18 ++--- .../common/fused_attn/fused_attn_fp8.cu | 2 + transformer_engine/common/fused_attn/utils.h | 10 ++- .../include/transformer_engine/fused_attn.h | 22 ++++-- .../dot_product_attention/backends.py | 11 +++ .../dot_product_attention.py | 27 +++++++ .../attention/dot_product_attention/utils.py | 79 ++++++++++++------- .../pytorch/attention/multi_head_attention.py | 26 ++++++ .../pytorch/cpp_extensions/fused_attn.py | 10 +++ transformer_engine/pytorch/csrc/extensions.h | 10 +-- .../pytorch/csrc/extensions/attention.cpp | 26 +++--- transformer_engine/pytorch/transformer.py | 53 ++++++++++++- 15 files changed, 312 insertions(+), 119 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 4aedcff1b83..6dfeb705745 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -145,6 +145,7 @@ def test_dot_product_attention( if config.window_size == (-1, -1) and swa: config.window_size = [2, 2] + config.window_size = check_set_window_size(config.attn_mask_type, config.window_size) qkv_format = qkv_layout.replace("3", "").replace("2", "").split("_")[0] if qkv_format == "thd" and "padding" not in config.attn_mask_type: @@ -162,6 +163,7 @@ def test_dot_product_attention( is_training=is_training, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends + if not fused_attn_supported: is_training = False available_backends, _, fused_attn_backends = get_available_attention_backends( @@ -682,9 +684,10 @@ def test_dpa_bias_shapes(dtype, model_configs, model): @pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("model_configs", [model_configs_swa]) @pytest.mark.parametrize("model", model_configs_swa.keys()) -def test_dpa_sliding_window(dtype, model_configs, model): +@pytest.mark.parametrize("qkv_layout", ["thd_thd_thd", "sbhd_sbhd_sbhd"]) +def test_dpa_sliding_window(dtype, model_configs, model, qkv_layout): """Test DotProductAttention module with sliding window attention""" - test_dot_product_attention(dtype, model_configs, model, False, True, None, True, False) + test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, True, False) model_configs_alibi_slopes = { diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 611beb7b84d..6e2e5a59b05 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -406,9 +406,11 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (window_size_right == -1 || window_size_right == 0)) || // 9.2: SWA (left, 0) + top-left diagonal + {bshd, sbhd} (cudnn_runtime_version >= 90200 && - ((window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) || - ((window_size_left >= 0 || window_size_left == -1) && window_size_right == 0 && - (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + ((window_size_left == -1 && window_size_right == -1 && + attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK) || + ((window_size_left == -1 || window_size_left >= 0) && window_size_right == 0 && + (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && max_seqlen_q == max_seqlen_kv)) && max_seqlen_q <= max_seqlen_kv && dropout == 0.0 && @@ -418,12 +420,14 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( // 9.6: SWA (left, 0) + top-left/bottom-right diagonal + {bshd, sbhd, thd} (cudnn_runtime_version >= 90600 && ((window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) || - ((window_size_left >= 0 || window_size_left == -1) && window_size_right == 0 && + ((window_size_left >= 0 || window_size_left == -1) && + (window_size_right >= 0 || window_size_right == -1) && ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && // TODO(cyang): fix bug for BRCM + cross-attention on sm100 (sm_arch_ < 100 || (sm_arch_ >= 100 && ((max_seqlen_q == max_seqlen_kv && cudnn_runtime_version <= 90700) || cudnn_runtime_version > 90700)))) || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && (sm_arch_ < 100 || (sm_arch_ >= 100 && ((max_seqlen_q == max_seqlen_kv && @@ -504,7 +508,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( } // NVTE fused attention FWD with packed QKV -// DEPRECATED: This API is deprecated. +// DEPRECATED: This API is deprecated. (Should there be a version by which this is going to be removed? @cyang) // Please use nvte_fused_attn_fwd with separate Q, K, V tensors instead. void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, @@ -515,6 +519,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked); using namespace transformer_engine; @@ -589,7 +594,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, fused_attn_arbitrary_seqlen_fwd( b, h, h, max_seqlen, max_seqlen, d, d, t, t, 0, 0, 0, 0, 0, 0, is_training, return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, - window_size_left, window_size_right, &Q_view, &K_view, &V_view, input_Bias, + window_size_left, window_size_right, bottom_right_diagonal, &Q_view, &K_view, &V_view, input_Bias, input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens, input_cu_seqlens_padded, input_cu_seqlens_padded, nullptr, nullptr, input_rng_state, wkspace, stream, handle); @@ -629,7 +634,8 @@ void nvte_fused_attn_bwd_qkvpacked( NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, size_t max_seqlen, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, bool cuda_graph, + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, + bool deterministic, bool cuda_graph, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked); using namespace transformer_engine; @@ -725,10 +731,11 @@ void nvte_fused_attn_bwd_qkvpacked( fused_attn_arbitrary_seqlen_bwd( b, h, h, max_seqlen, max_seqlen, d, d, t, t, attn_scale, dropout, qkv_layout, bias_type, - attn_mask_type, softmax_type, window_size_left, window_size_right, deterministic, &Q_view, - &K_view, &V_view, input_O, input_dO, input_Bias, input_SoftmaxOffset, output_S, &dQ_view, - &dK_view, &dV_view, output_dBias, output_dSoftmaxOffset, input_cu_seqlens, input_cu_seqlens, - input_cu_seqlens_padded, input_cu_seqlens_padded, input_rng_state, wkspace, stream, handle); + attn_mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, + deterministic, &Q_view, &K_view, &V_view, input_O, input_dO, input_Bias, input_SoftmaxOffset, + output_S, &dQ_view, &dK_view, &dV_view, output_dBias, output_dSoftmaxOffset, input_cu_seqlens, + input_cu_seqlens, input_cu_seqlens_padded, input_cu_seqlens_padded, input_rng_state, + wkspace, stream, handle); #else const char *err_msg = "cuDNN 8.9.0 is required for BF16/FP16 fused attention " @@ -779,7 +786,8 @@ void nvte_fused_attn_fwd_kvpacked( size_t max_seqlen_kv, bool is_training, bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) { + int64_t window_size_right, bool bottom_right_diagonal, NVTETensor workspace, + cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -891,10 +899,10 @@ void nvte_fused_attn_fwd_kvpacked( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, t_q, t_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, - window_size_left, window_size_right, input_Q, &K_view, &V_view, input_Bias, - input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, - input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k, - input_page_table_v, input_rng_state, wkspace, stream, handle); + window_size_left, window_size_right, bottom_right_diagonal, input_Q, &K_view, &V_view, + input_Bias, input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, + input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, + input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR( "cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); @@ -933,8 +941,8 @@ void nvte_fused_attn_bwd_kvpacked( const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, bool cuda_graph, NVTETensor workspace, - cudaStream_t stream) { + int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, bool cuda_graph, + NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -1040,11 +1048,11 @@ void nvte_fused_attn_bwd_kvpacked( fused_attn_arbitrary_seqlen_bwd( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, t_q, t_kv, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, deterministic, - input_Q, &K_view, &V_view, input_O, input_dO, input_Bias, input_SoftmaxOffset, output_S, - output_dQ, &dK_view, &dV_view, output_dBias, output_dSoftmaxOffset, input_cu_seqlens_q, - input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, - wkspace, stream, handle); + bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, + bottom_right_diagonal, deterministic, input_Q, &K_view, &V_view, input_O, input_dO, + input_Bias, input_SoftmaxOffset, output_S, output_dQ, &dK_view, &dV_view, output_dBias, + output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, + input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); #else const char *err_msg = "cuDNN 8.9.3 is required for BF16/FP16 fused attention " @@ -1094,8 +1102,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, - cudaStream_t stream) { + int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -1183,10 +1191,10 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, - window_size_left, window_size_right, input_Q, input_K, input_V, input_Bias, - input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, - input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k, - input_page_table_v, input_rng_state, wkspace, stream, handle); + window_size_left, window_size_right, bottom_right_diagonal, input_Q, input_K, input_V, + input_Bias, input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, + input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, + input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR( "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); @@ -1215,7 +1223,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, + int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, bool deterministic, bool cuda_graph, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd); using namespace transformer_engine; @@ -1289,8 +1298,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso fused_attn_arbitrary_seqlen_bwd( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, - deterministic, input_Q, input_K, input_V, input_O, input_dO, input_Bias, - input_SoftmaxOffset, output_S, output_dQ, output_dK, output_dV, output_dBias, + bottom_right_diagonal, deterministic, input_Q, input_K, input_V, input_O, input_dO, + input_Bias, input_SoftmaxOffset, output_S, output_dQ, output_dK, output_dV, output_dBias, output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); #else diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 14468b543a4..8278f051f84 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -55,10 +55,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl( int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, bool is_training, bool return_max_logit, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, void *devPtrQ, void *devPtrK, - void *devPtrV, void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrS1, void *devPtrS2, - void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, - void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV, + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, void *devPtrQ, + void *devPtrK, void *devPtrV, void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrS1, + void *devPtrS2, void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset, + void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; @@ -129,6 +129,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( softmax_type, window_size_left, window_size_right, + bottom_right_diagonal, true, tensorType, cudnn_frontend::DataType_t::NOT_SET, @@ -254,9 +255,17 @@ void fused_attn_arbitrary_seqlen_fwd_impl( .set_causal_mask_bottom_right(is_bottom_right) .set_attn_scale(attn_scale); + fe::DiagonalAlignment_t const &diagonal_alignment = + bottom_right_diagonal ? fe::DiagonalAlignment_t::BOTTOM_RIGHT + : fe::DiagonalAlignment_t::TOP_LEFT; + sdpa_options.set_diagonal_alignment(diagonal_alignment); if (cudnn_runtime_version >= 90200 && window_size_left != -1) { sdpa_options.set_diagonal_band_left_bound(window_size_left + 1); } + if (cudnn_runtime_version >= 90600 && window_size_right != -1) { + // (remove comment when reviewed) Should it be `window_size_right + 1` instead? + sdpa_options.set_diagonal_band_right_bound(window_size_right); + } sdpa_options.set_alibi_mask(is_alibi); @@ -542,13 +551,14 @@ void fused_attn_arbitrary_seqlen_bwd_impl( int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t bias_b, int64_t bias_h, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, void *devPtrQ, - void *devPtrKTranspose, void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats, - void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrdQ, void *devPtrdK, void *devPtrdV, - void *devPtrdO, void *devPtrdBias, void *devPtrdSoftmaxOffset, void *devPtrDropoutSeed, - void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, - void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, - void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, + bool deterministic, void *devPtrQ, void *devPtrKTranspose, void *devPtrVTranspose, + void *devPtrO, void *devPtrSoftmaxStats, void *devPtrBias, void *devPtrSoftmaxOffset, + void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO, void *devPtrdBias, + void *devPtrdSoftmaxOffset, void *devPtrDropoutSeed, void *devPtrDropoutOffset, + void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, + void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace, + size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); @@ -621,6 +631,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( softmax_type, window_size_left, window_size_right, + bottom_right_diagonal, deterministic, tensorType, cudnn_frontend::DataType_t::NOT_SET, @@ -781,9 +792,18 @@ void fused_attn_arbitrary_seqlen_bwd_impl( sdpa_backward_options.set_max_total_seq_len_kv(s_kv); } + fe::DiagonalAlignment_t const &diagonal_alignment = + bottom_right_diagonal ? fe::DiagonalAlignment_t::BOTTOM_RIGHT + : fe::DiagonalAlignment_t::TOP_LEFT; + sdpa_backward_options.set_diagonal_alignment(diagonal_alignment); + if (cudnn_runtime_version >= 90200 && window_size_left != -1) { sdpa_backward_options.set_diagonal_band_left_bound(window_size_left + 1); } + if (cudnn_runtime_version >= 90600 && window_size_right != -1) { + // (remove comment when reviewed) Should it be `window_size_right + 1` instead? + sdpa_backward_options.set_diagonal_band_right_bound(window_size_right); + } if (cudnn_runtime_version >= 90000) { sdpa_backward_options.set_deterministic_algorithm(deterministic); @@ -1044,8 +1064,8 @@ void fused_attn_arbitrary_seqlen_fwd( size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, bool return_max_logit, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, - const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, + const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, @@ -1180,7 +1200,7 @@ void fused_attn_arbitrary_seqlen_fwd( max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training, return_max_logit, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, - window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, + window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, @@ -1206,8 +1226,9 @@ void fused_attn_arbitrary_seqlen_bwd( size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q, - const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, + bool deterministic, const Tensor *input_Q, const Tensor *input_K, + const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_S, Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, @@ -1273,8 +1294,8 @@ void fused_attn_arbitrary_seqlen_bwd( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, - deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, - devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, + bottom_right_diagonal, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, + devPtrBias, devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrdSoftmaxOffset, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h index 872b798bb40..5d1599512f1 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h @@ -25,8 +25,8 @@ void fused_attn_arbitrary_seqlen_fwd( size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, bool return_max_logit, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, - const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, + const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, @@ -37,13 +37,13 @@ void fused_attn_arbitrary_seqlen_bwd( size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q, - const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, - const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_S, - Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, - Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, + bool deterministic, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, + const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, + const Tensor *input_SoftmaxOffset, Tensor *output_S, Tensor *output_dQ, Tensor *output_dK, + Tensor *output_dV, Tensor *output_dBias, Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, + const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, + const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); #endif // CUDNN_VERSION >= 8900 } // namespace transformer_engine diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 5d806290a9e..64e2f5fa9cd 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1707,6 +1707,7 @@ void fused_attn_fp8_fwd_impl_v1( 0, 0, true, + true, qkv_tensor_type, o_tensor_type, cudnn_frontend::DataType_t::NOT_SET, @@ -2035,6 +2036,7 @@ void fused_attn_fp8_bwd_impl_v1( NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX, 0, 0, + true, false, qkv_tensor_type, o_tensor_type, diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index 72047a73f27..26874d3a795 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -110,6 +110,7 @@ struct FADescriptor_v1 { NVTE_Softmax_Type softmax_type; std::int64_t window_size_left; std::int64_t window_size_right; + bool bottom_right_diagonal; bool deterministic; cudnn_frontend::DataType_t qkv_tensor_type; cudnn_frontend::DataType_t o_tensor_type; @@ -121,14 +122,15 @@ struct FADescriptor_v1 { return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, attnScale, isTraining, dropoutProbability, layout, mask_type, softmax_type, - window_size_left, window_size_right, deterministic, bias_type, qkv_tensor_type, - o_tensor_type, do_tensor_type, dqkv_tensor_type, generate_max_sum_exp) < + window_size_left, window_size_right, bottom_right_diagonal, deterministic, + bias_type, qkv_tensor_type, o_tensor_type, do_tensor_type, dqkv_tensor_type, + generate_max_sum_exp) < std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k, rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k, rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.softmax_type, - rhs.window_size_left, rhs.window_size_right, rhs.deterministic, rhs.bias_type, - rhs.qkv_tensor_type, rhs.o_tensor_type, rhs.do_tensor_type, + rhs.window_size_left, rhs.window_size_right, rhs.bottom_right_diagonal, rhs.deterministic, + rhs.bias_type, rhs.qkv_tensor_type, rhs.o_tensor_type, rhs.do_tensor_type, rhs.dqkv_tensor_type, rhs.generate_max_sum_exp); } }; diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 6622019280d..69db8e1d3b1 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -269,6 +269,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( * \param[in] softmax_type Attention softmax type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). + * \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix. * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ @@ -284,6 +285,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with packed QKV input. @@ -332,6 +334,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, * \param[in] softmax_type Attention softmax type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). + * \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix. * \param[in] deterministic Whether to execute with deterministic behaviours. * \param[in] cuda_graph Whether cuda graph capture is enabled or not. * \param[in] workspace Workspace tensor. @@ -346,7 +349,8 @@ void nvte_fused_attn_bwd_qkvpacked( NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, size_t max_seqlen, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, bool cuda_graph, + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, + bool deterministic, bool cuda_graph, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute dot product attention with packed KV input. @@ -409,6 +413,7 @@ void nvte_fused_attn_bwd_qkvpacked( * \param[in] softmax_type Attention softmax type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). + * \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix. * \param[in] deterministic Whether to execute with deterministic behaviours. * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. @@ -425,7 +430,7 @@ void nvte_fused_attn_fwd_kvpacked( size_t max_seqlen_kv, bool is_training, bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); + int64_t window_size_right, bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with packed KV input. * @@ -479,6 +484,7 @@ void nvte_fused_attn_fwd_kvpacked( * \param[in] softmax_type Attention softmax type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). + * \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix. * \param[in] deterministic Whether to execute with deterministic behaviours. * \param[in] cuda_graph Whether cuda graph capture is enabled or not. * \param[in] workspace Workspace tensor. @@ -495,8 +501,8 @@ void nvte_fused_attn_bwd_kvpacked( const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, bool cuda_graph, NVTETensor workspace, - cudaStream_t stream); + int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, + bool cuda_graph, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute dot product attention with separate Q, K and V. * @@ -560,6 +566,7 @@ void nvte_fused_attn_bwd_kvpacked( * \param[in] softmax_type Attention softmax type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). + * \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix. * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ @@ -572,7 +579,8 @@ void nvte_fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, + NVTETensor workspace, cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with separate Q, K and V. * @@ -629,6 +637,7 @@ void nvte_fused_attn_fwd( * \param[in] softmax_type Attention softmax type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). + * \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix. * \param[in] deterministic Whether to execute with deterministic behaviours. * \param[in] cuda_graph Whether cuda graph capture is enabled or not. * \param[in] workspace Workspace tensor. @@ -644,7 +653,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, + int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, bool deterministic, bool cuda_graph, NVTETensor workspace, cudaStream_t stream); /*! \brief Update the RNG state with the seed and calculated offset. diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index c1ff46c75aa..1e97e231748 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -261,6 +261,7 @@ def forward( attn_mask_type: str = "causal", attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, window_size: Optional[Tuple[int, int]] = None, + bottom_right_diagonal: Optional[bool] = None, core_attention_bias_type: str = "no_bias", core_attention_bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, @@ -449,6 +450,7 @@ def forward( actual_seqlens_q=actual_seqlens_q if "padding" in attn_mask_type else None, actual_seqlens_kv=actual_seqlens_kv if "padding" in attn_mask_type else None, alibi_slopes=alibi_slopes, + # (This should be replaced with `bottom_right_diagonal` which is passed from the arguments) bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"], ) matmul_result = torch.baddbmm( @@ -1110,6 +1112,7 @@ def forward( attn_mask_type, softmax_type, window_size, + bottom_right_diagonal, rng_gen, fused_attention_backend, use_FAv2_bwd, @@ -1213,6 +1216,7 @@ def forward( attn_mask_type, softmax_type, window_size, + bottom_right_diagonal, rng_gen, softmax_offset, cuda_graph=is_graph_capturing(), @@ -1290,6 +1294,7 @@ def forward( attn_mask_type, softmax_type, window_size, + bottom_right_diagonal, rng_gen, softmax_offset, return_max_logit, @@ -1377,6 +1382,7 @@ def forward( ctx.attn_mask_type = attn_mask_type ctx.softmax_type = softmax_type ctx.window_size = window_size + ctx.bottom_right_diagonal = bottom_right_diagonal ctx.fused_attention_backend = ( fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"] ) @@ -1527,6 +1533,7 @@ def backward(ctx, d_out, *_args): ctx.attn_mask_type, ctx.softmax_type, ctx.window_size, + ctx.bottom_right_diagonal, ctx.deterministic, is_graph_capturing(), ) @@ -1592,6 +1599,7 @@ def backward(ctx, d_out, *_args): ctx.attn_mask_type, ctx.softmax_type, ctx.window_size, + ctx.bottom_right_diagonal, ctx.deterministic, is_graph_capturing(), ) @@ -1631,6 +1639,7 @@ def backward(ctx, d_out, *_args): None, None, None, + None, d_softmax_offset, None, None, @@ -1728,6 +1737,7 @@ def forward( attn_mask_type: str = "causal", attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, window_size: Optional[Tuple[int, int]] = None, + bottom_right_diagonal: Optional[bool] = None, fused_attention_backend: tex.NVTE_Fused_Attn_Backend = tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend, core_attention_bias_type: str = "no_bias", core_attention_bias: Optional[torch.Tensor] = None, @@ -1935,6 +1945,7 @@ def forward( attn_mask_type, self.softmax_type, window_size, + bottom_right_diagonal, None, # rng_gen fused_attention_backend, use_FAv2_bwd, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index f506035c1ef..00372fde8f3 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -228,6 +228,11 @@ class DotProductAttention(TransformerEngineBaseModule): map to ``window_size = (-1, 0)`` and Transformer Engine distinguishes them based on ``attn_mask_type``. Similar to :attr:`attn_mask_type`, ``window_size`` can be overridden by :attr:`window_size` in ``forward`` as well. + bottom_right_diagonal: Optional[bool], default = `None` + Align sliding window and ALiBi diagonal to the top left (`False`) + or bottom right (`True`) corner of the softmax matrix in the encoder. + If `None`, it will be set to `False` for `attn_mask_type` = + {'causal', 'padding_causal'} and `True` for other mask types. attention_type : str, default = "self" type of attention, either ``"self"`` and ``"cross"``. layer_number : int, default = None @@ -324,6 +329,7 @@ def __init__( qkv_format: str = "sbhd", attn_mask_type: str = "causal", window_size: Optional[Tuple[int, int]] = None, + bottom_right_diagonal: Optional[bool] = None, sequence_parallel: bool = False, tp_size: int = 1, get_rng_state_tracker: Optional[Callable] = None, @@ -350,6 +356,7 @@ def __init__( attn_mask_type = "padding_causal" self.attn_mask_type = attn_mask_type self.window_size = dpa_utils.check_set_window_size(attn_mask_type, window_size) + self.bottom_right_diagonal = bottom_right_diagonal if tp_group is None: self.tp_size = tp_size if tp_size == 1: @@ -811,6 +818,7 @@ def forward( max_seqlen_kv: int = None, attn_mask_type: Optional[str] = None, window_size: Optional[Tuple[int, int]] = None, + bottom_right_diagonal: Optional[bool] = None, checkpoint_core_attention: bool = False, core_attention_bias_type: str = "no_bias", core_attention_bias: Optional[torch.Tensor] = None, @@ -963,6 +971,11 @@ def forward( causal masks are aligned to the bottom right corner. window_size: Optional[Tuple[int, int]], default = None Sliding window size for local attention. + bottom_right_diagonal: Optional[bool], default = None + Align sliding window and ALiBi diagonal to the top left (`False`) + or bottom right (`True`) corner of the softmax matrix in the encoder. + If `None`, it will be set to `False` for `attn_mask_type` = + {'causal', 'padding_causal'} and `True` for other mask types. checkpoint_core_attention : bool, default = False If true, forward activations for attention are recomputed during the backward pass in order to save memory that would @@ -1081,6 +1094,15 @@ def forward( if window_size is None: window_size = self.window_size window_size = dpa_utils.check_set_window_size(attn_mask_type, window_size) + if bottom_right_diagonal is None: + bottom_right_diagonal = self.bottom_right_diagonal + if attn_mask_type in {"causal", "padding_causal"}: + bottom_right_diagonal = False + if bottom_right_diagonal is None or attn_mask_type in { + "causal_bottom_right", + "padding_causal_bottom_right", + }: + bottom_right_diagonal = True # checks for qkv_format if qkv_format is None: @@ -1322,6 +1344,7 @@ def forward( head_dim_v=head_dim_v, attn_mask_type=attn_mask_type, window_size=window_size, + bottom_right_diagonal=bottom_right_diagonal, alibi_slopes_shape=alibi_slopes.shape if alibi_slopes is not None else None, core_attention_bias_type=core_attention_bias_type, core_attention_bias_shape=core_attention_bias_shape, @@ -1474,6 +1497,7 @@ def forward( attn_mask_type=attn_mask_type, attention_mask=attention_mask, window_size=window_size, + bottom_right_diagonal=bottom_right_diagonal, fused_attention_backend=fused_attention_backend, core_attention_bias_type=fu_core_attention_bias_type, core_attention_bias=fu_core_attention_bias, @@ -1504,6 +1528,7 @@ def forward( attn_mask_type=attn_mask_type, attention_mask=attention_mask, window_size=window_size, + bottom_right_diagonal=bottom_right_diagonal, fused_attention_backend=fused_attention_backend, core_attention_bias_type=fu_core_attention_bias_type, core_attention_bias=fu_core_attention_bias, @@ -1538,6 +1563,7 @@ def forward( attn_mask_type=attn_mask_type, attention_mask=attention_mask, window_size=window_size, + bottom_right_diagonal=bottom_right_diagonal, core_attention_bias_type=core_attention_bias_type, core_attention_bias=core_attention_bias, alibi_slopes=alibi_slopes, @@ -1561,6 +1587,7 @@ def forward( attn_mask_type=attn_mask_type, attention_mask=attention_mask, window_size=window_size, + bottom_right_diagonal=bottom_right_diagonal, core_attention_bias_type=core_attention_bias_type, core_attention_bias=core_attention_bias, alibi_slopes=alibi_slopes, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 8c6b6afc903..3c8dccba9ac 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -200,6 +200,9 @@ class AttentionParams: `causal_bottom_right`, `padding_causal_bottom_right`, `arbitrary`} window_size : Tuple[int, int], default = None Sliding window attention size. + bottom_right_diagonal: bool, default = `True` + Whether to align sliding window and ALiBi diagonal to the bottom right corner + of the softmax matrix. alibi_slopes_shape : Optional[Union[torch.Size, List]], default = None Tensor shape of :attr:`alibi_slopes` in `DotProductAttention`. core_attention_bias_type : str, default = no_bias @@ -249,6 +252,7 @@ class AttentionParams: head_dim_v: int = 64 attn_mask_type: str = "no_mask" window_size: Union[Tuple[int, int], None] = None + bottom_right_diagonal: bool = True alibi_slopes_shape: Union[torch.Size, List, None] = None core_attention_bias_type: str = "no_bias" core_attention_bias_shape: str = "1hss" @@ -325,6 +329,7 @@ def get_attention_backend( head_dim_v = attention_params.head_dim_v attn_mask_type = attention_params.attn_mask_type window_size = attention_params.window_size + bottom_right_diagonal = attention_params.bottom_right_diagonal alibi_slopes_shape = attention_params.alibi_slopes_shape core_attention_bias_type = attention_params.core_attention_bias_type core_attention_bias_shape = attention_params.core_attention_bias_shape @@ -873,39 +878,47 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt # backend | window_size | diagonal alignment # --------------------------------------------------------------------------------- # FlashAttention | (-1, -1) or (>=0, >=0) | bottom right - # FusedAttention | (-1, 0) or (>=0, 0) | top left - # UnfusedDotProductAttention | (-1, -1) or (>=0, >=0) | both; + # FusedAttention | (-1, 0) or (>=0, >=0) | top left, bottom right + # UnfusedDotProductAttention | (-1, -1) or (>=0, >=0) | top left, bottom right # | | converts window_size to an 'arbitrary' mask if window_size is None: window_size = check_set_window_size(attn_mask_type, window_size) - else: - if use_fused_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]): - if fp8 and (fp8_meta["recipe"].fp8_dpa or fp8_meta["recipe"].fp8_mha): - logger.debug( - "Disabling FusedAttention as it does not support sliding window attention" - " for FP8" - ) - use_fused_attention = False - elif window_size[1] != 0 or attention_dropout != 0.0: - logger.debug( - "Disabling FusedAttention as it only supports sliding window attention " - "with (left, 0) and no dropout" - ) - use_fused_attention = False - elif max_seqlen_q > max_seqlen_kv: - logger.debug( - "Disabling FusedAttention as it does not support sliding window attention " - "with s_q > s_kv for cross-attention" - ) - use_fused_attention = False - if use_flash_attention_2 and (window_size[0] != -1 or window_size[1] not in [-1, 0]): - if not FlashAttentionUtils.is_installed: - FlashAttentionUtils.version_required = PkgVersion("2.3") - elif not FlashAttentionUtils.v2_3_plus: - logger.debug( - "Disabling FlashAttention as sliding window attention requires flash-attn 2.3+" - ) - use_flash_attention_2 = False + # (cyang: Why is window_size is being modified but then its value ignored + # in the following else block?) + # else: + if use_fused_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]): + if fp8 and (fp8_meta["recipe"].fp8_dpa or fp8_meta["recipe"].fp8_mha): + logger.debug( + "Disabling FusedAttention as it does not support sliding window attention" + " for FP8" + ) + use_fused_attention = False + elif attention_dropout != 0.0: + logger.debug( + "Disabling FusedAttention as it only supports sliding window attention " + "without dropout" + ) + use_fused_attention = False + elif max_seqlen_q > max_seqlen_kv: + logger.debug( + "Disabling FusedAttention as it does not support sliding window attention " + "with s_q > s_kv for cross-attention" + ) + use_fused_attention = False + if use_flash_attention_2 and (window_size[0] != -1 or window_size[1] not in [-1, 0]): + if not FlashAttentionUtils.is_installed: + FlashAttentionUtils.version_required = PkgVersion("2.3") + elif not FlashAttentionUtils.v2_3_plus: + logger.debug( + "Disabling FlashAttention as sliding window attention requires flash-attn 2.3+" + ) + use_flash_attention_2 = False + elif not bottom_right_diagonal and max_seqlen_q != max_seqlen_kv: + logger.debug( + "Disabling FlashAttention as it only supports sliding window with bottom right" + " diagonal alignment for cross-attention" + ) + use_flash_attention = False # Filter: Attention bias # backend | bias types | ALiBi diagonal alignment @@ -927,6 +940,12 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt elif not FlashAttentionUtils.v2_4_plus: logger.debug("Disabling FlashAttention as ALiBi requires flash-attn 2.4+") use_flash_attention_2 = False + elif not bottom_right_diagonal and max_seqlen_q != max_seqlen_kv: + logger.debug( + "Disabling FlashAttention as it only supports ALiBi with bottom right diagonal" + " alignment for cross-attention" + ) + use_flash_attention = False if ( core_attention_bias_type not in ["no_bias", "alibi"] diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index beb13b7f1e3..07878d63e93 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -32,6 +32,7 @@ from transformer_engine.pytorch.attention.dot_product_attention import DotProductAttention from transformer_engine.pytorch.attention.inference import InferenceParams from transformer_engine.pytorch.attention.rope import apply_rotary_pos_emb +from transformer_engine.pytorch.attention.dot_product_attention import utils as dpa_utils from transformer_engine.pytorch.cpu_offload import start_offload, is_cpu_offload_enabled @@ -93,6 +94,11 @@ class MultiheadAttention(torch.nn.Module): map to ``window_size = (-1, 0)`` and Transformer Engine distinguishes them based on ``attn_mask_type``. Similar to :attr:`attn_mask_type`, ``window_size`` can be overridden by :attr:`window_size` in :meth:`forward` as well. + bottom_right_diagonal: Optional[bool], default = `None` + Align sliding window and ALiBi diagonal to the top left (`False`) + or bottom right (`True`) corner of the softmax matrix in the encoder. + If `None`, it will be set to `False` for `attn_mask_type` = + {`causal`, `padding_causal`} and `True` for other mask types. num_gqa_groups : int, default = None number of GQA groups in the transformer layer. Grouped Query Attention is described in @@ -248,6 +254,7 @@ def __init__( layer_number: Optional[int] = None, attn_mask_type: str = "causal", window_size: Optional[Tuple[int, int]] = None, + bottom_right_diagonal: Optional[bool] = None, tp_group: Optional[dist_group_type] = None, tp_size: int = 1, num_gqa_groups: Optional[int] = None, @@ -286,6 +293,7 @@ def __init__( self.qkv_format = qkv_format self.attn_mask_type = attn_mask_type self.window_size = window_size + self.bottom_right_diagonal = bottom_right_diagonal self.layer_number = 1 if layer_number is None else layer_number self.input_layernorm = input_layernorm self.attention_type = attention_type @@ -621,6 +629,7 @@ def forward( encoder_output: Optional[torch.Tensor] = None, attn_mask_type: Optional[str] = None, window_size: Optional[Tuple[int, int]] = None, + bottom_right_diagonal: Optional[bool] = None, is_first_microbatch: Optional[bool] = None, checkpoint_core_attention: bool = False, inference_params: Optional[InferenceParams] = None, @@ -667,6 +676,11 @@ def forward( aligned to the bottom right corner. window_size: Optional[Tuple[int, int]], default = None sliding window size for local attention. + bottom_right_diagonal: Optional[bool], default = `None` + Align sliding window and ALiBi diagonal to the top left (`False`) + or bottom right (`True`) corner of the softmax matrix in the encoder. + If `None`, it will be set to `False` for `attn_mask_type` = + {`causal`, `padding_causal`} and `True` for other mask types. encoder_output : Optional[torch.Tensor], default = None Output of the encoder block to be fed into the decoder block if using ``layer_type="decoder"``. @@ -731,6 +745,17 @@ def forward( if window_size is None: window_size = self.window_size + window_size = dpa_utils.check_set_window_size(attn_mask_type, window_size) + if bottom_right_diagonal is None: + bottom_right_diagonal = self.bottom_right_diagonal + if attn_mask_type in {"causal", "padding_causal"}: + bottom_right_diagonal = False + if bottom_right_diagonal is None or attn_mask_type in { + "causal_bottom_right", + "padding_causal_bottom_right", + }: + bottom_right_diagonal = True + if "padding" in attn_mask_type and attention_mask is not None: for mask in attention_mask: assert mask.dtype == torch.bool, "Attention mask must be in boolean type!" @@ -1004,6 +1029,7 @@ def forward( attention_mask=attention_mask, attn_mask_type=attn_mask_type, window_size=window_size, + bottom_right_diagonal=bottom_right_diagonal, checkpoint_core_attention=checkpoint_core_attention, core_attention_bias_type=core_attention_bias_type, core_attention_bias=core_attention_bias, diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 88c223eb462..9a1ec9b3ca1 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -137,6 +137,7 @@ def fused_attn_fwd( attn_mask_type: str = "padding", softmax_type: str = "vanilla", window_size: Tuple[int, int] = (-1, -1), + bottom_right_diagonal: bool = True, rng_gen: torch.Generator = None, softmax_offset: torch.Tensor = None, return_max_logit: bool = False, @@ -212,6 +213,9 @@ def fused_attn_fwd( in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding window and causal mask specifically. + bottom_right_diagonal: bool, default = True + whether to align sliding window and ALiBi diagonal to the top left (False) or + bottom right (True) corner of the softmax matrix. rng_gen : torch.Generator, default = None random number generator; if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen @@ -306,6 +310,7 @@ def fused_attn_fwd( AttnMaskType[attn_mask_type], SoftmaxType[softmax_type], window_size, + bottom_right_diagonal, cu_seqlens_q, cu_seqlens_kv, q, @@ -370,6 +375,7 @@ def fused_attn_bwd( attn_mask_type: str = "padding", softmax_type: str = "vanilla", window_size: Tuple[int, int] = (-1, -1), + bottom_right_diagonal: bool = True, deterministic: bool = False, cuda_graph: bool = False, ) -> Tuple[Union[torch.Tensor, None], ...]: @@ -442,6 +448,9 @@ def fused_attn_bwd( in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding window and causal mask specifically. + bottom_right_diagonal: bool, default = True + whether to align sliding window and ALiBi diagonal to the top left (False) or + bottom right (True) corner of the softmax matrix. deterministic : bool, default = False whether to execute the backward pass with deterministic behaviours. cuda_graph : bool, default = False @@ -500,6 +509,7 @@ def fused_attn_bwd( AttnMaskType[attn_mask_type], SoftmaxType[softmax_type], window_size, + bottom_right_diagonal, deterministic, cu_seqlens_q, cu_seqlens_kv, diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 80479dccf48..caeb225446d 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -87,7 +87,7 @@ std::vector fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - const std::vector window_size, const at::Tensor cu_seqlens_q, + const std::vector window_size, bool bottom_right_diagonal, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, const at::ScalarType fake_dtype, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, @@ -99,10 +99,10 @@ std::vector fused_attn_fwd( std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - NVTE_Softmax_Type softmax_type, const std::vector window_size, bool deterministic, - const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, - const py::handle K, const py::handle V, const py::handle O, const py::handle dO, - const at::ScalarType fake_dtype, const DType dqkv_type, + NVTE_Softmax_Type softmax_type, const std::vector window_size, bool bottom_right_diagonal, + bool deterministic, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, + const py::handle Q, const py::handle K, const py::handle V, const py::handle O, + const py::handle dO, const at::ScalarType fake_dtype, const DType dqkv_type, const std::vector Aux_CTX_Tensors, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 2480d9aba9b..51888534876 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -100,7 +100,7 @@ std::vector fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - const std::vector window_size, const at::Tensor cu_seqlens_q, + const std::vector window_size, bool bottom_right_diagonal, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, const at::ScalarType fake_dtype, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, @@ -235,7 +235,7 @@ std::vector fused_attn_fwd( te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, - softmax_type, window_size[0], window_size[1], workspace.data(), + softmax_type, window_size[0], window_size[1], bottom_right_diagonal, workspace.data(), at::cuda::getCurrentCUDAStream()); }); @@ -295,7 +295,7 @@ std::vector fused_attn_fwd( te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, - softmax_type, window_size[0], window_size[1], workspace.data(), + softmax_type, window_size[0], window_size[1], bottom_right_diagonal, workspace.data(), at::cuda::getCurrentCUDAStream()); }); @@ -310,11 +310,11 @@ std::vector fused_attn_fwd( std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - NVTE_Softmax_Type softmax_type, const std::vector window_size, bool deterministic, - const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, - const py::handle K, const py::handle V, const py::handle O, const py::handle dO, - const at::ScalarType fake_dtype, const DType dqkv_type, - const std::vector Aux_CTX_Tensors, + NVTE_Softmax_Type softmax_type, const std::vector window_size, + bool bottom_right_diagonal, bool deterministic, const at::Tensor cu_seqlens_q, + const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, + const py::handle V, const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, + const DType dqkv_type, const std::vector Aux_CTX_Tensors, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, py::handle dp_quantizer, py::handle dqkv_quantizer, bool cuda_graph) { @@ -538,8 +538,9 @@ std::vector fused_attn_bwd( te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, - softmax_type, window_size[0], window_size[1], deterministic, cuda_graph, - workspace.data(), at::cuda::getCurrentCUDAStream()); + softmax_type, window_size[0], window_size[1], bottom_right_diagonal, + deterministic, cuda_graph, workspace.data(), + at::cuda::getCurrentCUDAStream()); }); // allocate memory for workspace @@ -555,8 +556,9 @@ std::vector fused_attn_bwd( te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, - softmax_type, window_size[0], window_size[1], deterministic, cuda_graph, - workspace.data(), at::cuda::getCurrentCUDAStream()); + softmax_type, window_size[0], window_size[1], bottom_right_diagonal, + deterministic, cuda_graph, workspace.data(), + at::cuda::getCurrentCUDAStream()); }); // destroy tensor wrappers diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index b3ad8ccc550..d33770c7f92 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -35,7 +35,7 @@ from transformer_engine.pytorch.distributed import get_distributed_world_size from transformer_engine.pytorch.export import is_in_onnx_export_mode from transformer_engine.pytorch.module.base import TransformerEngineBaseModule - +import transformer_engine.pytorch.attention.dot_product_attention.utils as dpa_utils warnings.filterwarnings("module", category=DeprecationWarning, module="transformer") @@ -149,11 +149,21 @@ class TransformerLayer(torch.nn.Module): distinguishes them based on :attr:`self_attn_mask_type` or :attr:`enc_dec_attn_mask_type`. Similar to :attr:`self_attn_mask_type`, :attr:`window_size` can be overridden by :attr:`window_size` in :meth:`forward` as well. + bottom_right_diagonal: Optional[bool], default = `None` + Align sliding window and ALiBi diagonal to the top left (`False`) + or bottom right (`True`) corner of the softmax matrix in the encoder. + If `None`, it will be set to `False` for `self_attn_mask_type` = + {`causal`, `padding_causal`} and `True` for other mask types. enc_dec_attn_mask_type : {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'}, default = "no_mask" type of attention mask passed into softmax operation for decoder. enc_dec_window_size : Optional[Tuple[int, int]], default = None sliding window size for local attention in decoder. + enc_dec_bottom_right_diagonal: Optional[bool], default = `None` + Align sliding window and ALiBi diagonal to the top left (`False`) + or bottom right (`True`) corner of the softmax matrix in the decoder. + If `None`, it will be set to `False` for `enc_dec_attn_mask_type` = + {`causal`, `padding_causal`} and `True` for other mask types. zero_centered_gamma : bool, default = False if set to ``True``, gamma parameter in LayerNorm is initialized to 0 and the LayerNorm formula changes to @@ -302,7 +312,9 @@ def __init__( kv_channels: Optional[int] = None, self_attn_mask_type: str = "causal", window_size: Optional[Tuple[int, int]] = None, + bottom_right_diagonal: Optional[bool] = None, enc_dec_attn_mask_type: str = "no_mask", + enc_dec_bottom_right_diagonal: Optional[bool] = None, enc_dec_window_size: Optional[Tuple[int, int]] = None, tp_group: Optional[dist_group_type] = None, tp_size: int = 1, @@ -344,8 +356,10 @@ def __init__( self.self_attn_mask_type = self_attn_mask_type self.window_size = window_size + self.bottom_right_diagonal = bottom_right_diagonal self.enc_dec_attn_mask_type = enc_dec_attn_mask_type self.enc_dec_window_size = enc_dec_window_size + self.enc_dec_bottom_right_diagonal = enc_dec_bottom_right_diagonal params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype ub_bulk_wgrad = ub_tp_comm_overlap and ub_bulk_wgrad ub_bulk_dgrad = ub_tp_comm_overlap and ub_bulk_dgrad @@ -606,10 +620,12 @@ def forward( attention_mask: Optional[torch.Tensor] = None, self_attn_mask_type: Optional[str] = None, window_size: Optional[Tuple[int, int]] = None, + bottom_right_diagonal: Optional[bool] = None, encoder_output: Optional[torch.Tensor] = None, enc_dec_attn_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, enc_dec_attn_mask_type: Optional[str] = None, enc_dec_window_size: Optional[Tuple[int, int]] = None, + enc_dec_bottom_right_diagonal: Optional[bool] = None, is_first_microbatch: Optional[bool] = None, checkpoint_core_attention: bool = False, inference_params: Optional[InferenceParams] = None, @@ -654,6 +670,11 @@ def forward( causal masks are aligned to the bottom right corner. window_size: Optional[Tuple[int, int]], default = None Sliding window size for local attention in encoder. + bottom_right_diagonal: Optional[bool] = `None` + Align sliding window and ALiBi diagonal to the top left (`False`) + or bottom right (`True`) corner of the softmax matrix in the encoder. + If `None`, it will be set to `False` for `self_attn_mask_type` = + {`causal`, `padding_causal`} and `True` for other mask types. encoder_output : Optional[torch.Tensor], default = None Output of the encoder block to be fed into the decoder block if using :attr:`layer_type` = ``"decoder"``. @@ -670,6 +691,11 @@ def forward( Type of attention mask passed into softmax operation for decoder. enc_dec_window_size: Optional[Tuple[int, int]], default = None Sliding window size for local attention in decoder. + enc_dec_bottom_right_diagonal: Optional[bool] = `None` + Align sliding window and ALiBi diagonal to the top left (`False`) + or bottom right (`True`) corner of the softmax matrix in the decoder. + If `None`, it will be set to `False` for `enc_dec_attn_mask_type` = + {`causal`, `padding_causal`} and `True` for other mask types. is_first_microbatch : {True, False, None}, default = None During training using either gradient accumulation or pipeline parallelism a minibatch of data is further split @@ -736,10 +762,33 @@ def forward( self_attn_mask_type = self.self_attn_mask_type if window_size is None: window_size = self.window_size + window_size = dpa_utils.check_set_window_size(self_attn_mask_type, window_size) + if enc_dec_attn_mask_type is None: enc_dec_attn_mask_type = self.enc_dec_attn_mask_type if enc_dec_window_size is None: enc_dec_window_size = self.enc_dec_window_size + enc_dec_window_size = dpa_utils.check_set_window_size(enc_dec_attn_mask_type, enc_dec_window_size) + + if bottom_right_diagonal is None: + bottom_right_diagonal = self.bottom_right_diagonal + if self_attn_mask_type in {"causal", "padding_causal"}: + bottom_right_diagonal = False + if bottom_right_diagonal is None or self_attn_mask_type in { + "causal_bottom_right", + "padding_causal_bottom_right", + }: + bottom_right_diagonal = True + + if enc_dec_bottom_right_diagonal is None: + enc_dec_bottom_right_diagonal = self.enc_dec_bottom_right_diagonal + if enc_dec_attn_mask_type in {"causal", "padding_causal"}: + enc_dec_bottom_right_diagonal = False + if enc_dec_bottom_right_diagonal is None or enc_dec_attn_mask_type in { + "causal_bottom_right", + "padding_causal_bottom_right", + }: + enc_dec_bottom_right_diagonal = True assert ( self_attn_mask_type in AttnMaskTypes @@ -781,6 +830,7 @@ def forward( attention_mask=attention_mask, attn_mask_type=self_attn_mask_type, window_size=window_size, + bottom_right_diagonal=bottom_right_diagonal, inference_params=inference_params, is_first_microbatch=is_first_microbatch, checkpoint_core_attention=checkpoint_core_attention, @@ -816,6 +866,7 @@ def forward( attention_mask=enc_dec_attn_mask, attn_mask_type=enc_dec_attn_mask_type, window_size=enc_dec_window_size, + bottom_right_diagonal=enc_dec_bottom_right_diagonal, encoder_output=encoder_output, inference_params=inference_params, is_first_microbatch=is_first_microbatch, From eab24bed89ee8a5d287ed3a37ae6a6a99f98039a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 4 Dec 2025 00:55:38 +0000 Subject: [PATCH 02/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/fused_attn/fused_attn.cpp | 42 +++++++------- .../fused_attn_f16_arbitrary_seqlen.cu | 28 ++++----- .../fused_attn_f16_arbitrary_seqlen.h | 7 ++- transformer_engine/common/fused_attn/utils.h | 6 +- .../include/transformer_engine/fused_attn.h | 57 +++++++++---------- .../attention/dot_product_attention/utils.py | 3 +- transformer_engine/pytorch/csrc/extensions.h | 15 ++--- .../pytorch/csrc/extensions/attention.cpp | 47 ++++++++------- transformer_engine/pytorch/transformer.py | 4 +- 9 files changed, 103 insertions(+), 106 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 6e2e5a59b05..06fa8120604 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -510,17 +510,14 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( // NVTE fused attention FWD with packed QKV // DEPRECATED: This API is deprecated. (Should there be a version by which this is going to be removed? @cyang) // Please use nvte_fused_attn_fwd with separate Q, K, V tensors instead. -void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, - const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, - NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, - const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, - size_t max_seqlen, bool is_training, bool return_max_logit, - bool cuda_graph, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, - NVTETensor workspace, cudaStream_t stream) { +void nvte_fused_attn_fwd_qkvpacked( + const NVTETensor QKV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, + NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, + const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, size_t max_seqlen, + bool is_training, bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked); using namespace transformer_engine; @@ -594,10 +591,10 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, fused_attn_arbitrary_seqlen_fwd( b, h, h, max_seqlen, max_seqlen, d, d, t, t, 0, 0, 0, 0, 0, 0, is_training, return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, - window_size_left, window_size_right, bottom_right_diagonal, &Q_view, &K_view, &V_view, input_Bias, - input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens, - input_cu_seqlens_padded, input_cu_seqlens_padded, nullptr, nullptr, input_rng_state, - wkspace, stream, handle); + window_size_left, window_size_right, bottom_right_diagonal, &Q_view, &K_view, &V_view, + input_Bias, input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens, + input_cu_seqlens, input_cu_seqlens_padded, input_cu_seqlens_padded, nullptr, nullptr, + input_rng_state, wkspace, stream, handle); #else NVTE_ERROR( "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); @@ -635,8 +632,7 @@ void nvte_fused_attn_bwd_qkvpacked( size_t max_seqlen, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, - bool deterministic, bool cuda_graph, - NVTETensor workspace, cudaStream_t stream) { + bool deterministic, bool cuda_graph, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked); using namespace transformer_engine; @@ -732,10 +728,10 @@ void nvte_fused_attn_bwd_qkvpacked( fused_attn_arbitrary_seqlen_bwd( b, h, h, max_seqlen, max_seqlen, d, d, t, t, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, - deterministic, &Q_view, &K_view, &V_view, input_O, input_dO, input_Bias, input_SoftmaxOffset, - output_S, &dQ_view, &dK_view, &dV_view, output_dBias, output_dSoftmaxOffset, input_cu_seqlens, - input_cu_seqlens, input_cu_seqlens_padded, input_cu_seqlens_padded, input_rng_state, - wkspace, stream, handle); + deterministic, &Q_view, &K_view, &V_view, input_O, input_dO, input_Bias, + input_SoftmaxOffset, output_S, &dQ_view, &dK_view, &dV_view, output_dBias, + output_dSoftmaxOffset, input_cu_seqlens, input_cu_seqlens, input_cu_seqlens_padded, + input_cu_seqlens_padded, input_rng_state, wkspace, stream, handle); #else const char *err_msg = "cuDNN 8.9.0 is required for BF16/FP16 fused attention " @@ -1224,8 +1220,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, bool deterministic, - bool cuda_graph, NVTETensor workspace, cudaStream_t stream) { + bool bottom_right_diagonal, bool deterministic, bool cuda_graph, + NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 8278f051f84..65911ce830d 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -793,8 +793,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl( } fe::DiagonalAlignment_t const &diagonal_alignment = - bottom_right_diagonal ? fe::DiagonalAlignment_t::BOTTOM_RIGHT - : fe::DiagonalAlignment_t::TOP_LEFT; + bottom_right_diagonal ? fe::DiagonalAlignment_t::BOTTOM_RIGHT + : fe::DiagonalAlignment_t::TOP_LEFT; sdpa_backward_options.set_diagonal_alignment(diagonal_alignment); if (cudnn_runtime_version >= 90200 && window_size_left != -1) { @@ -1200,11 +1200,11 @@ void fused_attn_arbitrary_seqlen_fwd( max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training, return_max_logit, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, - window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, devPtrV, devPtrBias, - devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, - devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, - devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, - stream, handle); + window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, devPtrV, + devPtrBias, devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed, + devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, + devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, + &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -1227,13 +1227,13 @@ void fused_attn_arbitrary_seqlen_bwd( size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, - bool deterministic, const Tensor *input_Q, const Tensor *input_K, - const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, - const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_S, - Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, - Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { + bool deterministic, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, + const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, + const Tensor *input_SoftmaxOffset, Tensor *output_S, Tensor *output_dQ, Tensor *output_dK, + Tensor *output_dV, Tensor *output_dBias, Tensor *output_dSoftmaxOffset, + const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, + const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, + cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const auto QKV_type = input_Q->data.dtype; void *devPtrQ = input_Q->data.dptr; diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h index 5d1599512f1..b8313c63c27 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h @@ -41,9 +41,10 @@ void fused_attn_arbitrary_seqlen_bwd( bool deterministic, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_S, Tensor *output_dQ, Tensor *output_dK, - Tensor *output_dV, Tensor *output_dBias, Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, - const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, - const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); + Tensor *output_dV, Tensor *output_dBias, Tensor *output_dSoftmaxOffset, + const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, + const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, + cudaStream_t stream, cudnnHandle_t handle); #endif // CUDNN_VERSION >= 8900 } // namespace transformer_engine diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index 26874d3a795..5dc487f76bb 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -129,9 +129,9 @@ struct FADescriptor_v1 { rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k, rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.softmax_type, - rhs.window_size_left, rhs.window_size_right, rhs.bottom_right_diagonal, rhs.deterministic, - rhs.bias_type, rhs.qkv_tensor_type, rhs.o_tensor_type, rhs.do_tensor_type, - rhs.dqkv_tensor_type, rhs.generate_max_sum_exp); + rhs.window_size_left, rhs.window_size_right, rhs.bottom_right_diagonal, + rhs.deterministic, rhs.bias_type, rhs.qkv_tensor_type, rhs.o_tensor_type, + rhs.do_tensor_type, rhs.dqkv_tensor_type, rhs.generate_max_sum_exp); } }; diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 69db8e1d3b1..98af0ddcd26 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -276,17 +276,14 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( [[deprecated( "nvte_fused_attn_fwd_qkvpacked() is deprecated. Please use nvte_fused_attn_fwd() with separate " "Q, K, V tensors instead.")]] -void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, - const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, - NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, - const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, - size_t max_seqlen, bool is_training, bool return_max_logit, - bool cuda_graph, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, - NVTETensor workspace, cudaStream_t stream); +void nvte_fused_attn_fwd_qkvpacked( + const NVTETensor QKV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, + NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, + const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, size_t max_seqlen, + bool is_training, bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with packed QKV input. * @@ -350,8 +347,7 @@ void nvte_fused_attn_bwd_qkvpacked( size_t max_seqlen, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, - bool deterministic, bool cuda_graph, - NVTETensor workspace, cudaStream_t stream); + bool deterministic, bool cuda_graph, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute dot product attention with packed KV input. * @@ -430,7 +426,8 @@ void nvte_fused_attn_fwd_kvpacked( size_t max_seqlen_kv, bool is_training, bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream); + int64_t window_size_right, bool bottom_right_diagonal, NVTETensor workspace, + cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with packed KV input. * @@ -501,8 +498,8 @@ void nvte_fused_attn_bwd_kvpacked( const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, - bool cuda_graph, NVTETensor workspace, cudaStream_t stream); + int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, bool cuda_graph, + NVTETensor workspace, cudaStream_t stream); /*! \brief Compute dot product attention with separate Q, K and V. * @@ -570,17 +567,19 @@ void nvte_fused_attn_bwd_kvpacked( * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ -void nvte_fused_attn_fwd( - const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor Bias, - const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, - const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, - const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_logit, - bool cuda_graph, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, - NVTETensor workspace, cudaStream_t stream); +void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, + const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, + NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, + const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, + const NVTETensor page_table_v, const NVTETensor rng_state, + size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, + bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with separate Q, K and V. * @@ -654,8 +653,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, bool deterministic, - bool cuda_graph, NVTETensor workspace, cudaStream_t stream); + bool bottom_right_diagonal, bool deterministic, bool cuda_graph, + NVTETensor workspace, cudaStream_t stream); /*! \brief Update the RNG state with the seed and calculated offset. * diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 3c8dccba9ac..81faa95bbda 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -889,8 +889,7 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt if use_fused_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]): if fp8 and (fp8_meta["recipe"].fp8_dpa or fp8_meta["recipe"].fp8_mha): logger.debug( - "Disabling FusedAttention as it does not support sliding window attention" - " for FP8" + "Disabling FusedAttention as it does not support sliding window attention for FP8" ) use_fused_attention = False elif attention_dropout != 0.0: diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index caeb225446d..b263695895d 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -87,9 +87,10 @@ std::vector fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - const std::vector window_size, bool bottom_right_diagonal, const at::Tensor cu_seqlens_q, - const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, - const at::ScalarType fake_dtype, const std::optional cu_seqlens_q_padded, + const std::vector window_size, bool bottom_right_diagonal, + const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, + const py::handle K, const py::handle V, const at::ScalarType fake_dtype, + const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, const std::optional page_table_k, const std::optional page_table_v, py::handle s_quantizer, py::handle o_quantizer, const std::optional Bias, @@ -99,10 +100,10 @@ std::vector fused_attn_fwd( std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - NVTE_Softmax_Type softmax_type, const std::vector window_size, bool bottom_right_diagonal, - bool deterministic, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, - const py::handle Q, const py::handle K, const py::handle V, const py::handle O, - const py::handle dO, const at::ScalarType fake_dtype, const DType dqkv_type, + NVTE_Softmax_Type softmax_type, const std::vector window_size, + bool bottom_right_diagonal, bool deterministic, const at::Tensor cu_seqlens_q, + const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, + const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const DType dqkv_type, const std::vector Aux_CTX_Tensors, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 51888534876..2ac67dca0da 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -100,9 +100,10 @@ std::vector fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - const std::vector window_size, bool bottom_right_diagonal, const at::Tensor cu_seqlens_q, - const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, - const at::ScalarType fake_dtype, const std::optional cu_seqlens_q_padded, + const std::vector window_size, bool bottom_right_diagonal, + const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, + const py::handle K, const py::handle V, const at::ScalarType fake_dtype, + const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, const std::optional page_table_k, const std::optional page_table_v, py::handle s_quantizer, py::handle o_quantizer, const std::optional Bias, @@ -312,9 +313,9 @@ std::vector fused_attn_bwd( NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, const std::vector window_size, bool bottom_right_diagonal, bool deterministic, const at::Tensor cu_seqlens_q, - const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, - const py::handle V, const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, - const DType dqkv_type, const std::vector Aux_CTX_Tensors, + const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, + const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const DType dqkv_type, + const std::vector Aux_CTX_Tensors, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, py::handle dp_quantizer, py::handle dqkv_quantizer, bool cuda_graph) { @@ -532,15 +533,14 @@ std::vector fused_attn_bwd( // populate tensors with appropriate shapes and dtypes NVTE_SCOPED_GIL_RELEASE({ - nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), - te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), - te_dK.data(), te_dV.data(), te_dBias.data(), te_dSoftmaxOffset.data(), - te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), - te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, - max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, - softmax_type, window_size[0], window_size[1], bottom_right_diagonal, - deterministic, cuda_graph, workspace.data(), - at::cuda::getCurrentCUDAStream()); + nvte_fused_attn_bwd( + te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), + &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(), + te_dSoftmaxOffset.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), + te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, + attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], + window_size[1], bottom_right_diagonal, deterministic, cuda_graph, workspace.data(), + at::cuda::getCurrentCUDAStream()); }); // allocate memory for workspace @@ -550,15 +550,14 @@ std::vector fused_attn_bwd( // execute kernel NVTE_SCOPED_GIL_RELEASE({ - nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), - te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), - te_dK.data(), te_dV.data(), te_dBias.data(), te_dSoftmaxOffset.data(), - te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), - te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, - max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, - softmax_type, window_size[0], window_size[1], bottom_right_diagonal, - deterministic, cuda_graph, workspace.data(), - at::cuda::getCurrentCUDAStream()); + nvte_fused_attn_bwd( + te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), + &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(), + te_dSoftmaxOffset.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), + te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, + attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], + window_size[1], bottom_right_diagonal, deterministic, cuda_graph, workspace.data(), + at::cuda::getCurrentCUDAStream()); }); // destroy tensor wrappers diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index d33770c7f92..a37ac5162b3 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -768,7 +768,9 @@ def forward( enc_dec_attn_mask_type = self.enc_dec_attn_mask_type if enc_dec_window_size is None: enc_dec_window_size = self.enc_dec_window_size - enc_dec_window_size = dpa_utils.check_set_window_size(enc_dec_attn_mask_type, enc_dec_window_size) + enc_dec_window_size = dpa_utils.check_set_window_size( + enc_dec_attn_mask_type, enc_dec_window_size + ) if bottom_right_diagonal is None: bottom_right_diagonal = self.bottom_right_diagonal From e761a267ae0c1cbb999c8ed66b18eeac8aa2f6a1 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Fri, 5 Dec 2025 16:53:16 -0800 Subject: [PATCH 03/13] fix test_kv_cache failures Signed-off-by: Sudhakar Singh --- .../attention/dot_product_attention/dot_product_attention.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 00372fde8f3..21e305bd25c 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -1166,6 +1166,8 @@ def forward( assert "padding" in attn_mask_type, "KV caching requires padding mask!" if attn_mask_type == "padding_causal": attn_mask_type = attn_mask_type + "_bottom_right" + # since attention mask is changed, set `bottom_right_diagonal` to True + bottom_right_diagonal = True self.attention_type = "cross" self.flash_attention.attention_type = self.attention_type From 93548fcc633b31468880c4be79dc05038794bb82 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Thu, 11 Dec 2025 14:06:52 -0800 Subject: [PATCH 04/13] remove unnecessary comments Signed-off-by: Sudhakar Singh --- .../common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu | 2 -- 1 file changed, 2 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 65911ce830d..8bc9d070da8 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -263,7 +263,6 @@ void fused_attn_arbitrary_seqlen_fwd_impl( sdpa_options.set_diagonal_band_left_bound(window_size_left + 1); } if (cudnn_runtime_version >= 90600 && window_size_right != -1) { - // (remove comment when reviewed) Should it be `window_size_right + 1` instead? sdpa_options.set_diagonal_band_right_bound(window_size_right); } @@ -801,7 +800,6 @@ void fused_attn_arbitrary_seqlen_bwd_impl( sdpa_backward_options.set_diagonal_band_left_bound(window_size_left + 1); } if (cudnn_runtime_version >= 90600 && window_size_right != -1) { - // (remove comment when reviewed) Should it be `window_size_right + 1` instead? sdpa_backward_options.set_diagonal_band_right_bound(window_size_right); } From a545ebfe5816b412214c503d54a97066e6b01bef Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Fri, 12 Dec 2025 12:37:06 -0800 Subject: [PATCH 05/13] fix some more filter issues, address feedback Signed-off-by: Sudhakar Singh --- tests/pytorch/utils.py | 14 +++++++------- .../attention/dot_product_attention/backends.py | 6 ++++-- .../dot_product_attention/dot_product_attention.py | 7 +++---- .../attention/dot_product_attention/utils.py | 4 ++-- 4 files changed, 16 insertions(+), 15 deletions(-) diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index bdf469c59a2..c4992d9e600 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -353,11 +353,11 @@ def test(): backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"} if AttentionLogging._is_logging_setup is False: AttentionLogging.setup_logging() - with logging_context(highest_level=AttentionLogging._log_level): - for i in range(3): - os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i) - _attention_backends["backend_selection_requires_update"] = True - available_backends, flash_attention_backend, fused_attention_backend = test() - if fused_attention_backend == FusedAttnBackend[backends[i]]: - fused_attn_backends.append(fused_attention_backend) + + for i in range(3): + os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i) + _attention_backends["backend_selection_requires_update"] = True + available_backends, flash_attention_backend, fused_attention_backend = test() + if fused_attention_backend == FusedAttnBackend[backends[i]]: + fused_attn_backends.append(fused_attention_backend) return available_backends, flash_attention_backend, fused_attn_backends diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 1e97e231748..8d5437ff1e2 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -347,6 +347,8 @@ def forward( attention_mask=attention_mask, window_size=window_size, attention_type=self.attention_type, + bottom_right_alignment=(attn_mask_type not in ["causal", "padding_causal"] + if bottom_right_diagonal is None else bottom_right_diagonal) ) ) @@ -450,8 +452,8 @@ def forward( actual_seqlens_q=actual_seqlens_q if "padding" in attn_mask_type else None, actual_seqlens_kv=actual_seqlens_kv if "padding" in attn_mask_type else None, alibi_slopes=alibi_slopes, - # (This should be replaced with `bottom_right_diagonal` which is passed from the arguments) - bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"], + bottom_right_alignment=(attn_mask_type not in ["causal", "padding_causal"] + if bottom_right_diagonal is None else bottom_right_diagonal) ) matmul_result = torch.baddbmm( matmul_result, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 21e305bd25c..03ca243ddf8 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -1280,7 +1280,6 @@ def forward( if self.layer_number == 1: _alibi_cache["_alibi_slopes_require_update"] = True _alibi_cache["_alibi_bias_require_update"] = True - bottom_right_alignment = (attn_mask_type not in ["causal", "padding_causal"],) if core_attention_bias_type == "alibi": assert ( core_attention_bias is None @@ -1289,7 +1288,7 @@ def forward( _alibi_cache["_num_heads"] != query_layer.shape[-2] or _alibi_cache["_max_seqlen_q"] != max_seqlen_q or _alibi_cache["_max_seqlen_kv"] != max_seqlen_kv - or _alibi_cache["_bottom_right_alignment"] != bottom_right_alignment + or _alibi_cache["_bottom_right_alignment"] != bottom_right_diagonal or _alibi_cache["_alibi_slopes"] is None ): _alibi_cache["_alibi_slopes_require_update"] = True @@ -1471,7 +1470,7 @@ def forward( fu_core_attention_bias_type = core_attention_bias_type fu_core_attention_bias = core_attention_bias if core_attention_bias_type == "alibi" and ( - alibi_slopes is not None or max_seqlen_q != max_seqlen_kv + alibi_slopes is not None ): fu_core_attention_bias_type = "post_scale_bias" _, fu_core_attention_bias = dpa_utils.get_alibi( @@ -1481,7 +1480,7 @@ def forward( max_seqlen_kv, alibi_slopes=alibi_slopes, bias_dtype=query_layer.dtype, - bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"], + bottom_right_alignment=bottom_right_diagonal, ) if checkpoint_core_attention: return self._checkpointed_attention_forward( diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 81faa95bbda..027d69f23a7 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -200,7 +200,7 @@ class AttentionParams: `causal_bottom_right`, `padding_causal_bottom_right`, `arbitrary`} window_size : Tuple[int, int], default = None Sliding window attention size. - bottom_right_diagonal: bool, default = `True` + bottom_right_diagonal: bool, default = `None` Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix. alibi_slopes_shape : Optional[Union[torch.Size, List]], default = None @@ -962,7 +962,7 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt if ( use_fused_attention and core_attention_bias_type == "alibi" - and (alibi_slopes_shape is not None or max_seqlen_q != max_seqlen_kv) + and (alibi_slopes_shape is not None) ): fu_core_attention_bias_type = "post_scale_bias" fu_core_attention_bias_requires_grad = False From 341f0f7cf5a55e60d0a33b6ede68d1829a2a3ade Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Thu, 18 Dec 2025 13:30:47 -0800 Subject: [PATCH 06/13] fix for local test case failures - `bottom_right_diagonal` should be calculated in `fused_attn_fwd` call as well Signed-off-by: Sudhakar Singh --- .../fused_attn_f16_arbitrary_seqlen.cu | 2 ++ .../attention/dot_product_attention/utils.py | 3 --- .../pytorch/cpp_extensions/fused_attn.py | 18 +++++++++++++++--- 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 4c5f95be549..a8f99f5c0c4 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -75,6 +75,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( if (is_bottom_right && s_q == s_kv && !is_padding) { is_causal = true; is_bottom_right = false; + bottom_right_diagonal = false; } bool is_softmax_offset = (softmax_type != NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX); bool is_dropout = (is_training && dropout_probability != 0.0f); @@ -572,6 +573,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( if (is_bottom_right && s_q == s_kv && !is_padding) { is_causal = true; is_bottom_right = false; + bottom_right_diagonal = false; } bool is_softmax_offset = (softmax_type != NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX); bool is_dropout = (dropout_probability != 0.0f); diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 030e63c95f9..664ed7d554f 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -877,9 +877,6 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt # | | converts window_size to an 'arbitrary' mask if window_size is None: window_size = check_set_window_size(attn_mask_type, window_size) - # (cyang: Why is window_size is being modified but then its value ignored - # in the following else block?) - # else: if use_fused_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]): if fp8 and (fp8_meta["recipe"].fp8_dpa or fp8_meta["recipe"].fp8_mha): logger.debug( diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 9a1ec9b3ca1..6d6a7274804 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -137,7 +137,7 @@ def fused_attn_fwd( attn_mask_type: str = "padding", softmax_type: str = "vanilla", window_size: Tuple[int, int] = (-1, -1), - bottom_right_diagonal: bool = True, + bottom_right_diagonal: bool = None, rng_gen: torch.Generator = None, softmax_offset: torch.Tensor = None, return_max_logit: bool = False, @@ -213,7 +213,7 @@ def fused_attn_fwd( in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding window and causal mask specifically. - bottom_right_diagonal: bool, default = True + bottom_right_diagonal: bool, default = None whether to align sliding window and ALiBi diagonal to the top left (False) or bottom right (True) corner of the softmax matrix. rng_gen : torch.Generator, default = None @@ -259,6 +259,12 @@ def fused_attn_fwd( max_logit : if return_max_logit = True, shape [h] and same data type as O; otherwise None """ + bottom_right_diagonal = ( + bottom_right_diagonal + if bottom_right_diagonal is not None + else "bottom_right" in attn_mask_type + ) + if attn_scale is None: d = q.size(-1) attn_scale = 1.0 / math.sqrt(d) @@ -375,7 +381,7 @@ def fused_attn_bwd( attn_mask_type: str = "padding", softmax_type: str = "vanilla", window_size: Tuple[int, int] = (-1, -1), - bottom_right_diagonal: bool = True, + bottom_right_diagonal: bool = None, deterministic: bool = False, cuda_graph: bool = False, ) -> Tuple[Union[torch.Tensor, None], ...]: @@ -471,6 +477,12 @@ def fused_attn_bwd( gradient tensor of softmax offset of shape [1, h_q, 1, 1]. See softmax_type in DotProductAttention for details. """ + bottom_right_diagonal = ( + bottom_right_diagonal + if bottom_right_diagonal is not None + else "bottom_right" in attn_mask_type + ) + if attn_scale is None: d = q.size(-1) attn_scale = 1.0 / math.sqrt(d) From 1fd4985da48a33297fa2a65b5c3d1ab29ecd307a Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Thu, 18 Dec 2025 15:51:27 -0800 Subject: [PATCH 07/13] make conditions more accurate Signed-off-by: Sudhakar Singh --- .../attention/dot_product_attention/backends.py | 14 ++++++++++---- .../dot_product_attention/dot_product_attention.py | 4 +--- .../pytorch/cpp_extensions/fused_attn.py | 13 ++++++++----- 3 files changed, 19 insertions(+), 12 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 8d5437ff1e2..56cd8f63692 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -347,8 +347,11 @@ def forward( attention_mask=attention_mask, window_size=window_size, attention_type=self.attention_type, - bottom_right_alignment=(attn_mask_type not in ["causal", "padding_causal"] - if bottom_right_diagonal is None else bottom_right_diagonal) + bottom_right_alignment=( + attn_mask_type not in ["causal", "padding_causal"] + if bottom_right_diagonal is None + else bottom_right_diagonal + ), ) ) @@ -452,8 +455,11 @@ def forward( actual_seqlens_q=actual_seqlens_q if "padding" in attn_mask_type else None, actual_seqlens_kv=actual_seqlens_kv if "padding" in attn_mask_type else None, alibi_slopes=alibi_slopes, - bottom_right_alignment=(attn_mask_type not in ["causal", "padding_causal"] - if bottom_right_diagonal is None else bottom_right_diagonal) + bottom_right_alignment=( + attn_mask_type not in ["causal", "padding_causal"] + if bottom_right_diagonal is None + else bottom_right_diagonal + ), ) matmul_result = torch.baddbmm( matmul_result, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 03ca243ddf8..8f51a5fc22e 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -1469,9 +1469,7 @@ def forward( if use_fused_attention: fu_core_attention_bias_type = core_attention_bias_type fu_core_attention_bias = core_attention_bias - if core_attention_bias_type == "alibi" and ( - alibi_slopes is not None - ): + if core_attention_bias_type == "alibi" and (alibi_slopes is not None): fu_core_attention_bias_type = "post_scale_bias" _, fu_core_attention_bias = dpa_utils.get_alibi( _alibi_cache, diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 6d6a7274804..f34c8f19504 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -259,11 +259,14 @@ def fused_attn_fwd( max_logit : if return_max_logit = True, shape [h] and same data type as O; otherwise None """ - bottom_right_diagonal = ( - bottom_right_diagonal - if bottom_right_diagonal is not None - else "bottom_right" in attn_mask_type - ) + if bottom_right_diagonal is None: + if attn_mask_type in { + "causal_bottom_right", + "padding_causal_bottom_right", + }: + bottom_right_diagonal = True + else: + bottom_right_diagonal = False if attn_scale is None: d = q.size(-1) From 95bbec1ecb8ac219a0772528a874307b0c304dd1 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Fri, 19 Dec 2025 10:45:03 -0800 Subject: [PATCH 08/13] add cp tests to test swa (left, right) Signed-off-by: Sudhakar Singh --- tests/pytorch/attention/test_attention_with_cp.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 2d4fe69e329..7209606c141 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -148,6 +148,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): ), # MHA "cp_1_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias"), # MHA "cp_1_4": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 0)), # MHA + "cp_1_5": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 512)), # MHA "cp_2_0": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal"), # GQA "cp_2_1": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2), # GQA "cp_2_2": ModelConfig( @@ -165,6 +166,9 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): "cp_2_4": ModelConfig( 2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 0) ), # GQA + "cp_2_5": ModelConfig( + 2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 512) + ), # GQA "cp_3_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", head_dim_v=64), # MLA "cp_3_1": ModelConfig(2, 4096, 12, 128, head_dim_v=64), # MLA "cp_3_2": ModelConfig( @@ -187,7 +191,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): qkv_formats = ["bshd", "sbhd", "thd"] cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"] if test_essential: - configs = ["cp_1_0", "cp_1_1", "cp_1_4", "cp_2_0", "cp_2_2", "cp_3_2", "cp_4_2"] + configs = ["cp_1_0", "cp_1_1", "cp_1_4", "cp_1_5", "cp_2_0", "cp_2_2", "cp_2_5", "cp_3_2", "cp_4_2"] model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs} dtypes = ["bf16", "fp8"] qkv_formats = ["sbhd", "thd"] From 13aee94aa5c53ef61321208e53776b414e1daaca Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 19 Dec 2025 18:46:24 +0000 Subject: [PATCH 09/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/attention/test_attention_with_cp.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 7209606c141..29ea437b758 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -191,7 +191,17 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): qkv_formats = ["bshd", "sbhd", "thd"] cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"] if test_essential: - configs = ["cp_1_0", "cp_1_1", "cp_1_4", "cp_1_5", "cp_2_0", "cp_2_2", "cp_2_5", "cp_3_2", "cp_4_2"] + configs = [ + "cp_1_0", + "cp_1_1", + "cp_1_4", + "cp_1_5", + "cp_2_0", + "cp_2_2", + "cp_2_5", + "cp_3_2", + "cp_4_2", + ] model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs} dtypes = ["bf16", "fp8"] qkv_formats = ["sbhd", "thd"] From 67c5d39e978b1d4ab95d92eea6dafd255482c98f Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Fri, 19 Dec 2025 14:59:20 -0800 Subject: [PATCH 10/13] remove dead code and make conditions better Signed-off-by: Sudhakar Singh --- .../attention/dot_product_attention/utils.py | 5 ++-- .../pytorch/cpp_extensions/fused_attn.py | 27 ++++++++++--------- 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 664ed7d554f..222986563c5 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -957,9 +957,8 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt ): fu_core_attention_bias_type = "post_scale_bias" fu_core_attention_bias_requires_grad = False - if alibi_slopes_shape is None: - fu_core_attention_bias_shape = "1hss" - elif len(alibi_slopes_shape) == 1 and alibi_slopes_shape[0] == num_heads: + + if len(alibi_slopes_shape) == 1 and alibi_slopes_shape[0] == num_heads: fu_core_attention_bias_shape = "1hss" elif ( len(alibi_slopes_shape) == 2 diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index f34c8f19504..d32c26a1726 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -260,13 +260,13 @@ def fused_attn_fwd( """ if bottom_right_diagonal is None: - if attn_mask_type in { - "causal_bottom_right", - "padding_causal_bottom_right", - }: - bottom_right_diagonal = True - else: - bottom_right_diagonal = False + bottom_right_diagonal = ( + True + if attn_mask_type in { + "causal_bottom_right", + "padding_causal_bottom_right" + } else False + ) if attn_scale is None: d = q.size(-1) @@ -480,11 +480,14 @@ def fused_attn_bwd( gradient tensor of softmax offset of shape [1, h_q, 1, 1]. See softmax_type in DotProductAttention for details. """ - bottom_right_diagonal = ( - bottom_right_diagonal - if bottom_right_diagonal is not None - else "bottom_right" in attn_mask_type - ) + if bottom_right_diagonal is None: + bottom_right_diagonal = ( + True + if attn_mask_type in { + "causal_bottom_right", + "padding_causal_bottom_right" + } else False + ) if attn_scale is None: d = q.size(-1) From 5e145797926d6215bf86b23e0b4dfc0bfcc17395 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 19 Dec 2025 23:01:04 +0000 Subject: [PATCH 11/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../pytorch/cpp_extensions/fused_attn.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index d32c26a1726..309e607d780 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -262,10 +262,8 @@ def fused_attn_fwd( if bottom_right_diagonal is None: bottom_right_diagonal = ( True - if attn_mask_type in { - "causal_bottom_right", - "padding_causal_bottom_right" - } else False + if attn_mask_type in {"causal_bottom_right", "padding_causal_bottom_right"} + else False ) if attn_scale is None: @@ -483,10 +481,8 @@ def fused_attn_bwd( if bottom_right_diagonal is None: bottom_right_diagonal = ( True - if attn_mask_type in { - "causal_bottom_right", - "padding_causal_bottom_right" - } else False + if attn_mask_type in {"causal_bottom_right", "padding_causal_bottom_right"} + else False ) if attn_scale is None: From da4b3d6038bb27a29d9f181f2f831495f7c69a11 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Fri, 19 Dec 2025 15:24:38 -0800 Subject: [PATCH 12/13] fix lint Signed-off-by: Sudhakar Singh --- .../pytorch/cpp_extensions/fused_attn.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 309e607d780..fa1528b348c 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -260,11 +260,10 @@ def fused_attn_fwd( """ if bottom_right_diagonal is None: - bottom_right_diagonal = ( - True - if attn_mask_type in {"causal_bottom_right", "padding_causal_bottom_right"} - else False - ) + bottom_right_diagonal = attn_mask_type in { + "causal_bottom_right", + "padding_causal_bottom_right" + } if attn_scale is None: d = q.size(-1) @@ -479,11 +478,10 @@ def fused_attn_bwd( See softmax_type in DotProductAttention for details. """ if bottom_right_diagonal is None: - bottom_right_diagonal = ( - True - if attn_mask_type in {"causal_bottom_right", "padding_causal_bottom_right"} - else False - ) + bottom_right_diagonal = attn_mask_type in { + "causal_bottom_right", + "padding_causal_bottom_right" + } if attn_scale is None: d = q.size(-1) From dbff75ab2911d1b37d511ab62872f98a1847289d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 19 Dec 2025 23:25:23 +0000 Subject: [PATCH 13/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/cpp_extensions/fused_attn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index fa1528b348c..12c62712044 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -262,7 +262,7 @@ def fused_attn_fwd( if bottom_right_diagonal is None: bottom_right_diagonal = attn_mask_type in { "causal_bottom_right", - "padding_causal_bottom_right" + "padding_causal_bottom_right", } if attn_scale is None: @@ -480,7 +480,7 @@ def fused_attn_bwd( if bottom_right_diagonal is None: bottom_right_diagonal = attn_mask_type in { "causal_bottom_right", - "padding_causal_bottom_right" + "padding_causal_bottom_right", } if attn_scale is None: