diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 4aedcff1b83..6dfeb705745 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -145,6 +145,7 @@ def test_dot_product_attention( if config.window_size == (-1, -1) and swa: config.window_size = [2, 2] + config.window_size = check_set_window_size(config.attn_mask_type, config.window_size) qkv_format = qkv_layout.replace("3", "").replace("2", "").split("_")[0] if qkv_format == "thd" and "padding" not in config.attn_mask_type: @@ -162,6 +163,7 @@ def test_dot_product_attention( is_training=is_training, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends + if not fused_attn_supported: is_training = False available_backends, _, fused_attn_backends = get_available_attention_backends( @@ -682,9 +684,10 @@ def test_dpa_bias_shapes(dtype, model_configs, model): @pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("model_configs", [model_configs_swa]) @pytest.mark.parametrize("model", model_configs_swa.keys()) -def test_dpa_sliding_window(dtype, model_configs, model): +@pytest.mark.parametrize("qkv_layout", ["thd_thd_thd", "sbhd_sbhd_sbhd"]) +def test_dpa_sliding_window(dtype, model_configs, model, qkv_layout): """Test DotProductAttention module with sliding window attention""" - test_dot_product_attention(dtype, model_configs, model, False, True, None, True, False) + test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, True, False) model_configs_alibi_slopes = { diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 2d4fe69e329..29ea437b758 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -148,6 +148,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): ), # MHA "cp_1_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias"), # MHA "cp_1_4": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 0)), # MHA + "cp_1_5": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 512)), # MHA "cp_2_0": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal"), # GQA "cp_2_1": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2), # GQA "cp_2_2": ModelConfig( @@ -165,6 +166,9 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): "cp_2_4": ModelConfig( 2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 0) ), # GQA + "cp_2_5": ModelConfig( + 2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 512) + ), # GQA "cp_3_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", head_dim_v=64), # MLA "cp_3_1": ModelConfig(2, 4096, 12, 128, head_dim_v=64), # MLA "cp_3_2": ModelConfig( @@ -187,7 +191,17 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): qkv_formats = ["bshd", "sbhd", "thd"] cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"] if test_essential: - configs = ["cp_1_0", "cp_1_1", "cp_1_4", "cp_2_0", "cp_2_2", "cp_3_2", "cp_4_2"] + configs = [ + "cp_1_0", + "cp_1_1", + "cp_1_4", + "cp_1_5", + "cp_2_0", + "cp_2_2", + "cp_2_5", + "cp_3_2", + "cp_4_2", + ] model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs} dtypes = ["bf16", "fp8"] qkv_formats = ["sbhd", "thd"] diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index bdf469c59a2..c4992d9e600 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -353,11 +353,11 @@ def test(): backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"} if AttentionLogging._is_logging_setup is False: AttentionLogging.setup_logging() - with logging_context(highest_level=AttentionLogging._log_level): - for i in range(3): - os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i) - _attention_backends["backend_selection_requires_update"] = True - available_backends, flash_attention_backend, fused_attention_backend = test() - if fused_attention_backend == FusedAttnBackend[backends[i]]: - fused_attn_backends.append(fused_attention_backend) + + for i in range(3): + os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i) + _attention_backends["backend_selection_requires_update"] = True + available_backends, flash_attention_backend, fused_attention_backend = test() + if fused_attention_backend == FusedAttnBackend[backends[i]]: + fused_attn_backends.append(fused_attention_backend) return available_backends, flash_attention_backend, fused_attn_backends diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 611beb7b84d..06fa8120604 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -406,9 +406,11 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (window_size_right == -1 || window_size_right == 0)) || // 9.2: SWA (left, 0) + top-left diagonal + {bshd, sbhd} (cudnn_runtime_version >= 90200 && - ((window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) || - ((window_size_left >= 0 || window_size_left == -1) && window_size_right == 0 && - (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + ((window_size_left == -1 && window_size_right == -1 && + attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK) || + ((window_size_left == -1 || window_size_left >= 0) && window_size_right == 0 && + (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && max_seqlen_q == max_seqlen_kv)) && max_seqlen_q <= max_seqlen_kv && dropout == 0.0 && @@ -418,12 +420,14 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( // 9.6: SWA (left, 0) + top-left/bottom-right diagonal + {bshd, sbhd, thd} (cudnn_runtime_version >= 90600 && ((window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) || - ((window_size_left >= 0 || window_size_left == -1) && window_size_right == 0 && + ((window_size_left >= 0 || window_size_left == -1) && + (window_size_right >= 0 || window_size_right == -1) && ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && // TODO(cyang): fix bug for BRCM + cross-attention on sm100 (sm_arch_ < 100 || (sm_arch_ >= 100 && ((max_seqlen_q == max_seqlen_kv && cudnn_runtime_version <= 90700) || cudnn_runtime_version > 90700)))) || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && (sm_arch_ < 100 || (sm_arch_ >= 100 && ((max_seqlen_q == max_seqlen_kv && @@ -504,18 +508,16 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( } // NVTE fused attention FWD with packed QKV -// DEPRECATED: This API is deprecated. +// DEPRECATED: This API is deprecated. (Should there be a version by which this is going to be removed? @cyang) // Please use nvte_fused_attn_fwd with separate Q, K, V tensors instead. -void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, - const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, - NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, - const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, - size_t max_seqlen, bool is_training, bool return_max_logit, - bool cuda_graph, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, - NVTETensor workspace, cudaStream_t stream) { +void nvte_fused_attn_fwd_qkvpacked( + const NVTETensor QKV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, + NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, + const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, size_t max_seqlen, + bool is_training, bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked); using namespace transformer_engine; @@ -589,10 +591,10 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, fused_attn_arbitrary_seqlen_fwd( b, h, h, max_seqlen, max_seqlen, d, d, t, t, 0, 0, 0, 0, 0, 0, is_training, return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, - window_size_left, window_size_right, &Q_view, &K_view, &V_view, input_Bias, - input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens, - input_cu_seqlens_padded, input_cu_seqlens_padded, nullptr, nullptr, input_rng_state, - wkspace, stream, handle); + window_size_left, window_size_right, bottom_right_diagonal, &Q_view, &K_view, &V_view, + input_Bias, input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens, + input_cu_seqlens, input_cu_seqlens_padded, input_cu_seqlens_padded, nullptr, nullptr, + input_rng_state, wkspace, stream, handle); #else NVTE_ERROR( "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); @@ -629,8 +631,8 @@ void nvte_fused_attn_bwd_qkvpacked( NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, size_t max_seqlen, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, bool cuda_graph, - NVTETensor workspace, cudaStream_t stream) { + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, + bool deterministic, bool cuda_graph, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked); using namespace transformer_engine; @@ -725,10 +727,11 @@ void nvte_fused_attn_bwd_qkvpacked( fused_attn_arbitrary_seqlen_bwd( b, h, h, max_seqlen, max_seqlen, d, d, t, t, attn_scale, dropout, qkv_layout, bias_type, - attn_mask_type, softmax_type, window_size_left, window_size_right, deterministic, &Q_view, - &K_view, &V_view, input_O, input_dO, input_Bias, input_SoftmaxOffset, output_S, &dQ_view, - &dK_view, &dV_view, output_dBias, output_dSoftmaxOffset, input_cu_seqlens, input_cu_seqlens, - input_cu_seqlens_padded, input_cu_seqlens_padded, input_rng_state, wkspace, stream, handle); + attn_mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, + deterministic, &Q_view, &K_view, &V_view, input_O, input_dO, input_Bias, + input_SoftmaxOffset, output_S, &dQ_view, &dK_view, &dV_view, output_dBias, + output_dSoftmaxOffset, input_cu_seqlens, input_cu_seqlens, input_cu_seqlens_padded, + input_cu_seqlens_padded, input_rng_state, wkspace, stream, handle); #else const char *err_msg = "cuDNN 8.9.0 is required for BF16/FP16 fused attention " @@ -779,7 +782,8 @@ void nvte_fused_attn_fwd_kvpacked( size_t max_seqlen_kv, bool is_training, bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) { + int64_t window_size_right, bool bottom_right_diagonal, NVTETensor workspace, + cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -891,10 +895,10 @@ void nvte_fused_attn_fwd_kvpacked( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, t_q, t_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, - window_size_left, window_size_right, input_Q, &K_view, &V_view, input_Bias, - input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, - input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k, - input_page_table_v, input_rng_state, wkspace, stream, handle); + window_size_left, window_size_right, bottom_right_diagonal, input_Q, &K_view, &V_view, + input_Bias, input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, + input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, + input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR( "cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); @@ -933,8 +937,8 @@ void nvte_fused_attn_bwd_kvpacked( const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, bool cuda_graph, NVTETensor workspace, - cudaStream_t stream) { + int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, bool cuda_graph, + NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -1040,11 +1044,11 @@ void nvte_fused_attn_bwd_kvpacked( fused_attn_arbitrary_seqlen_bwd( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, t_q, t_kv, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, deterministic, - input_Q, &K_view, &V_view, input_O, input_dO, input_Bias, input_SoftmaxOffset, output_S, - output_dQ, &dK_view, &dV_view, output_dBias, output_dSoftmaxOffset, input_cu_seqlens_q, - input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, - wkspace, stream, handle); + bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, + bottom_right_diagonal, deterministic, input_Q, &K_view, &V_view, input_O, input_dO, + input_Bias, input_SoftmaxOffset, output_S, output_dQ, &dK_view, &dV_view, output_dBias, + output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, + input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); #else const char *err_msg = "cuDNN 8.9.3 is required for BF16/FP16 fused attention " @@ -1094,8 +1098,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, - cudaStream_t stream) { + int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -1183,10 +1187,10 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, - window_size_left, window_size_right, input_Q, input_K, input_V, input_Bias, - input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, - input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k, - input_page_table_v, input_rng_state, wkspace, stream, handle); + window_size_left, window_size_right, bottom_right_diagonal, input_Q, input_K, input_V, + input_Bias, input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, + input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, + input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR( "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); @@ -1215,8 +1219,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, - bool cuda_graph, NVTETensor workspace, cudaStream_t stream) { + int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, bool deterministic, bool cuda_graph, + NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -1289,8 +1294,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso fused_attn_arbitrary_seqlen_bwd( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, - deterministic, input_Q, input_K, input_V, input_O, input_dO, input_Bias, - input_SoftmaxOffset, output_S, output_dQ, output_dK, output_dV, output_dBias, + bottom_right_diagonal, deterministic, input_Q, input_K, input_V, input_O, input_dO, + input_Bias, input_SoftmaxOffset, output_S, output_dQ, output_dK, output_dV, output_dBias, output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); #else diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index efa4c784390..a8f99f5c0c4 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -55,10 +55,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl( int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, bool is_training, bool return_max_logit, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, void *devPtrQ, void *devPtrK, - void *devPtrV, void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrS1, void *devPtrS2, - void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, - void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV, + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, void *devPtrQ, + void *devPtrK, void *devPtrV, void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrS1, + void *devPtrS2, void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset, + void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; @@ -75,6 +75,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( if (is_bottom_right && s_q == s_kv && !is_padding) { is_causal = true; is_bottom_right = false; + bottom_right_diagonal = false; } bool is_softmax_offset = (softmax_type != NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX); bool is_dropout = (is_training && dropout_probability != 0.0f); @@ -129,6 +130,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( softmax_type, window_size_left, window_size_right, + bottom_right_diagonal, true, tensorType, cudnn_frontend::DataType_t::NOT_SET, @@ -254,9 +256,16 @@ void fused_attn_arbitrary_seqlen_fwd_impl( .set_causal_mask_bottom_right(is_bottom_right) .set_attn_scale(attn_scale); + fe::DiagonalAlignment_t const &diagonal_alignment = + bottom_right_diagonal ? fe::DiagonalAlignment_t::BOTTOM_RIGHT + : fe::DiagonalAlignment_t::TOP_LEFT; + sdpa_options.set_diagonal_alignment(diagonal_alignment); if (cudnn_runtime_version >= 90200 && window_size_left != -1) { sdpa_options.set_diagonal_band_left_bound(window_size_left + 1); } + if (cudnn_runtime_version >= 90600 && window_size_right != -1) { + sdpa_options.set_diagonal_band_right_bound(window_size_right); + } sdpa_options.set_alibi_mask(is_alibi); @@ -542,13 +551,14 @@ void fused_attn_arbitrary_seqlen_bwd_impl( int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t bias_b, int64_t bias_h, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, void *devPtrQ, - void *devPtrKTranspose, void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats, - void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrdQ, void *devPtrdK, void *devPtrdV, - void *devPtrdO, void *devPtrdBias, void *devPtrdSoftmaxOffset, void *devPtrDropoutSeed, - void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, - void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, - void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, + bool deterministic, void *devPtrQ, void *devPtrKTranspose, void *devPtrVTranspose, + void *devPtrO, void *devPtrSoftmaxStats, void *devPtrBias, void *devPtrSoftmaxOffset, + void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO, void *devPtrdBias, + void *devPtrdSoftmaxOffset, void *devPtrDropoutSeed, void *devPtrDropoutOffset, + void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, + void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace, + size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); @@ -563,6 +573,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( if (is_bottom_right && s_q == s_kv && !is_padding) { is_causal = true; is_bottom_right = false; + bottom_right_diagonal = false; } bool is_softmax_offset = (softmax_type != NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX); bool is_dropout = (dropout_probability != 0.0f); @@ -621,6 +632,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( softmax_type, window_size_left, window_size_right, + bottom_right_diagonal, deterministic, tensorType, cudnn_frontend::DataType_t::NOT_SET, @@ -781,9 +793,17 @@ void fused_attn_arbitrary_seqlen_bwd_impl( sdpa_backward_options.set_max_total_seq_len_kv(s_kv); } + fe::DiagonalAlignment_t const &diagonal_alignment = + bottom_right_diagonal ? fe::DiagonalAlignment_t::BOTTOM_RIGHT + : fe::DiagonalAlignment_t::TOP_LEFT; + sdpa_backward_options.set_diagonal_alignment(diagonal_alignment); + if (cudnn_runtime_version >= 90200 && window_size_left != -1) { sdpa_backward_options.set_diagonal_band_left_bound(window_size_left + 1); } + if (cudnn_runtime_version >= 90600 && window_size_right != -1) { + sdpa_backward_options.set_diagonal_band_right_bound(window_size_right); + } if (cudnn_runtime_version >= 90000) { sdpa_backward_options.set_deterministic_algorithm(deterministic); @@ -1044,8 +1064,8 @@ void fused_attn_arbitrary_seqlen_fwd( size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, bool return_max_logit, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, - const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, + const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, @@ -1180,11 +1200,11 @@ void fused_attn_arbitrary_seqlen_fwd( max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training, return_max_logit, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, - window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, - devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, - devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, - devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, - stream, handle); + window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, devPtrV, + devPtrBias, devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed, + devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, + devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, + &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -1206,13 +1226,14 @@ void fused_attn_arbitrary_seqlen_bwd( size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q, - const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, - const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_S, - Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, - Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, + bool deterministic, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, + const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, + const Tensor *input_SoftmaxOffset, Tensor *output_S, Tensor *output_dQ, Tensor *output_dK, + Tensor *output_dV, Tensor *output_dBias, Tensor *output_dSoftmaxOffset, + const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, + const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, + cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const auto QKV_type = input_Q->data.dtype; void *devPtrQ = input_Q->data.dptr; @@ -1273,8 +1294,8 @@ void fused_attn_arbitrary_seqlen_bwd( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, - deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, - devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, + bottom_right_diagonal, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, + devPtrBias, devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrdSoftmaxOffset, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h index 872b798bb40..b8313c63c27 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h @@ -25,8 +25,8 @@ void fused_attn_arbitrary_seqlen_fwd( size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, bool return_max_logit, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, - const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, + const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, @@ -37,13 +37,14 @@ void fused_attn_arbitrary_seqlen_bwd( size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q, - const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, - const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_S, - Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, - Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, + bool deterministic, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, + const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, + const Tensor *input_SoftmaxOffset, Tensor *output_S, Tensor *output_dQ, Tensor *output_dK, + Tensor *output_dV, Tensor *output_dBias, Tensor *output_dSoftmaxOffset, + const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, + const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, + cudaStream_t stream, cudnnHandle_t handle); #endif // CUDNN_VERSION >= 8900 } // namespace transformer_engine diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 5d806290a9e..64e2f5fa9cd 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1707,6 +1707,7 @@ void fused_attn_fp8_fwd_impl_v1( 0, 0, true, + true, qkv_tensor_type, o_tensor_type, cudnn_frontend::DataType_t::NOT_SET, @@ -2035,6 +2036,7 @@ void fused_attn_fp8_bwd_impl_v1( NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX, 0, 0, + true, false, qkv_tensor_type, o_tensor_type, diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index 72047a73f27..5dc487f76bb 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -110,6 +110,7 @@ struct FADescriptor_v1 { NVTE_Softmax_Type softmax_type; std::int64_t window_size_left; std::int64_t window_size_right; + bool bottom_right_diagonal; bool deterministic; cudnn_frontend::DataType_t qkv_tensor_type; cudnn_frontend::DataType_t o_tensor_type; @@ -121,15 +122,16 @@ struct FADescriptor_v1 { return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, attnScale, isTraining, dropoutProbability, layout, mask_type, softmax_type, - window_size_left, window_size_right, deterministic, bias_type, qkv_tensor_type, - o_tensor_type, do_tensor_type, dqkv_tensor_type, generate_max_sum_exp) < + window_size_left, window_size_right, bottom_right_diagonal, deterministic, + bias_type, qkv_tensor_type, o_tensor_type, do_tensor_type, dqkv_tensor_type, + generate_max_sum_exp) < std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k, rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k, rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.softmax_type, - rhs.window_size_left, rhs.window_size_right, rhs.deterministic, rhs.bias_type, - rhs.qkv_tensor_type, rhs.o_tensor_type, rhs.do_tensor_type, - rhs.dqkv_tensor_type, rhs.generate_max_sum_exp); + rhs.window_size_left, rhs.window_size_right, rhs.bottom_right_diagonal, + rhs.deterministic, rhs.bias_type, rhs.qkv_tensor_type, rhs.o_tensor_type, + rhs.do_tensor_type, rhs.dqkv_tensor_type, rhs.generate_max_sum_exp); } }; diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 6622019280d..98af0ddcd26 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -269,22 +269,21 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( * \param[in] softmax_type Attention softmax type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). + * \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix. * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ [[deprecated( "nvte_fused_attn_fwd_qkvpacked() is deprecated. Please use nvte_fused_attn_fwd() with separate " "Q, K, V tensors instead.")]] -void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, - const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, - NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, - const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, - size_t max_seqlen, bool is_training, bool return_max_logit, - bool cuda_graph, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, - NVTETensor workspace, cudaStream_t stream); +void nvte_fused_attn_fwd_qkvpacked( + const NVTETensor QKV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, + NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, + const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, size_t max_seqlen, + bool is_training, bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with packed QKV input. * @@ -332,6 +331,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, * \param[in] softmax_type Attention softmax type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). + * \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix. * \param[in] deterministic Whether to execute with deterministic behaviours. * \param[in] cuda_graph Whether cuda graph capture is enabled or not. * \param[in] workspace Workspace tensor. @@ -346,8 +346,8 @@ void nvte_fused_attn_bwd_qkvpacked( NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, size_t max_seqlen, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, bool cuda_graph, - NVTETensor workspace, cudaStream_t stream); + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, + bool deterministic, bool cuda_graph, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute dot product attention with packed KV input. * @@ -409,6 +409,7 @@ void nvte_fused_attn_bwd_qkvpacked( * \param[in] softmax_type Attention softmax type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). + * \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix. * \param[in] deterministic Whether to execute with deterministic behaviours. * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. @@ -425,7 +426,8 @@ void nvte_fused_attn_fwd_kvpacked( size_t max_seqlen_kv, bool is_training, bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); + int64_t window_size_right, bool bottom_right_diagonal, NVTETensor workspace, + cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with packed KV input. * @@ -479,6 +481,7 @@ void nvte_fused_attn_fwd_kvpacked( * \param[in] softmax_type Attention softmax type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). + * \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix. * \param[in] deterministic Whether to execute with deterministic behaviours. * \param[in] cuda_graph Whether cuda graph capture is enabled or not. * \param[in] workspace Workspace tensor. @@ -495,8 +498,8 @@ void nvte_fused_attn_bwd_kvpacked( const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, bool cuda_graph, NVTETensor workspace, - cudaStream_t stream); + int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, bool cuda_graph, + NVTETensor workspace, cudaStream_t stream); /*! \brief Compute dot product attention with separate Q, K and V. * @@ -560,19 +563,23 @@ void nvte_fused_attn_bwd_kvpacked( * \param[in] softmax_type Attention softmax type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). + * \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix. * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ -void nvte_fused_attn_fwd( - const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor Bias, - const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, - const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, - const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_logit, - bool cuda_graph, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); +void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, + const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, + NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, + const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, + const NVTETensor page_table_v, const NVTETensor rng_state, + size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, + bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with separate Q, K and V. * @@ -629,6 +636,7 @@ void nvte_fused_attn_fwd( * \param[in] softmax_type Attention softmax type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). + * \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix. * \param[in] deterministic Whether to execute with deterministic behaviours. * \param[in] cuda_graph Whether cuda graph capture is enabled or not. * \param[in] workspace Workspace tensor. @@ -644,8 +652,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, - bool cuda_graph, NVTETensor workspace, cudaStream_t stream); + int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, bool deterministic, bool cuda_graph, + NVTETensor workspace, cudaStream_t stream); /*! \brief Update the RNG state with the seed and calculated offset. * diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index c1ff46c75aa..56cd8f63692 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -261,6 +261,7 @@ def forward( attn_mask_type: str = "causal", attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, window_size: Optional[Tuple[int, int]] = None, + bottom_right_diagonal: Optional[bool] = None, core_attention_bias_type: str = "no_bias", core_attention_bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, @@ -346,6 +347,11 @@ def forward( attention_mask=attention_mask, window_size=window_size, attention_type=self.attention_type, + bottom_right_alignment=( + attn_mask_type not in ["causal", "padding_causal"] + if bottom_right_diagonal is None + else bottom_right_diagonal + ), ) ) @@ -449,7 +455,11 @@ def forward( actual_seqlens_q=actual_seqlens_q if "padding" in attn_mask_type else None, actual_seqlens_kv=actual_seqlens_kv if "padding" in attn_mask_type else None, alibi_slopes=alibi_slopes, - bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"], + bottom_right_alignment=( + attn_mask_type not in ["causal", "padding_causal"] + if bottom_right_diagonal is None + else bottom_right_diagonal + ), ) matmul_result = torch.baddbmm( matmul_result, @@ -1110,6 +1120,7 @@ def forward( attn_mask_type, softmax_type, window_size, + bottom_right_diagonal, rng_gen, fused_attention_backend, use_FAv2_bwd, @@ -1213,6 +1224,7 @@ def forward( attn_mask_type, softmax_type, window_size, + bottom_right_diagonal, rng_gen, softmax_offset, cuda_graph=is_graph_capturing(), @@ -1290,6 +1302,7 @@ def forward( attn_mask_type, softmax_type, window_size, + bottom_right_diagonal, rng_gen, softmax_offset, return_max_logit, @@ -1377,6 +1390,7 @@ def forward( ctx.attn_mask_type = attn_mask_type ctx.softmax_type = softmax_type ctx.window_size = window_size + ctx.bottom_right_diagonal = bottom_right_diagonal ctx.fused_attention_backend = ( fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"] ) @@ -1527,6 +1541,7 @@ def backward(ctx, d_out, *_args): ctx.attn_mask_type, ctx.softmax_type, ctx.window_size, + ctx.bottom_right_diagonal, ctx.deterministic, is_graph_capturing(), ) @@ -1592,6 +1607,7 @@ def backward(ctx, d_out, *_args): ctx.attn_mask_type, ctx.softmax_type, ctx.window_size, + ctx.bottom_right_diagonal, ctx.deterministic, is_graph_capturing(), ) @@ -1631,6 +1647,7 @@ def backward(ctx, d_out, *_args): None, None, None, + None, d_softmax_offset, None, None, @@ -1728,6 +1745,7 @@ def forward( attn_mask_type: str = "causal", attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, window_size: Optional[Tuple[int, int]] = None, + bottom_right_diagonal: Optional[bool] = None, fused_attention_backend: tex.NVTE_Fused_Attn_Backend = tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend, core_attention_bias_type: str = "no_bias", core_attention_bias: Optional[torch.Tensor] = None, @@ -1935,6 +1953,7 @@ def forward( attn_mask_type, self.softmax_type, window_size, + bottom_right_diagonal, None, # rng_gen fused_attention_backend, use_FAv2_bwd, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index f506035c1ef..8f51a5fc22e 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -228,6 +228,11 @@ class DotProductAttention(TransformerEngineBaseModule): map to ``window_size = (-1, 0)`` and Transformer Engine distinguishes them based on ``attn_mask_type``. Similar to :attr:`attn_mask_type`, ``window_size`` can be overridden by :attr:`window_size` in ``forward`` as well. + bottom_right_diagonal: Optional[bool], default = `None` + Align sliding window and ALiBi diagonal to the top left (`False`) + or bottom right (`True`) corner of the softmax matrix in the encoder. + If `None`, it will be set to `False` for `attn_mask_type` = + {'causal', 'padding_causal'} and `True` for other mask types. attention_type : str, default = "self" type of attention, either ``"self"`` and ``"cross"``. layer_number : int, default = None @@ -324,6 +329,7 @@ def __init__( qkv_format: str = "sbhd", attn_mask_type: str = "causal", window_size: Optional[Tuple[int, int]] = None, + bottom_right_diagonal: Optional[bool] = None, sequence_parallel: bool = False, tp_size: int = 1, get_rng_state_tracker: Optional[Callable] = None, @@ -350,6 +356,7 @@ def __init__( attn_mask_type = "padding_causal" self.attn_mask_type = attn_mask_type self.window_size = dpa_utils.check_set_window_size(attn_mask_type, window_size) + self.bottom_right_diagonal = bottom_right_diagonal if tp_group is None: self.tp_size = tp_size if tp_size == 1: @@ -811,6 +818,7 @@ def forward( max_seqlen_kv: int = None, attn_mask_type: Optional[str] = None, window_size: Optional[Tuple[int, int]] = None, + bottom_right_diagonal: Optional[bool] = None, checkpoint_core_attention: bool = False, core_attention_bias_type: str = "no_bias", core_attention_bias: Optional[torch.Tensor] = None, @@ -963,6 +971,11 @@ def forward( causal masks are aligned to the bottom right corner. window_size: Optional[Tuple[int, int]], default = None Sliding window size for local attention. + bottom_right_diagonal: Optional[bool], default = None + Align sliding window and ALiBi diagonal to the top left (`False`) + or bottom right (`True`) corner of the softmax matrix in the encoder. + If `None`, it will be set to `False` for `attn_mask_type` = + {'causal', 'padding_causal'} and `True` for other mask types. checkpoint_core_attention : bool, default = False If true, forward activations for attention are recomputed during the backward pass in order to save memory that would @@ -1081,6 +1094,15 @@ def forward( if window_size is None: window_size = self.window_size window_size = dpa_utils.check_set_window_size(attn_mask_type, window_size) + if bottom_right_diagonal is None: + bottom_right_diagonal = self.bottom_right_diagonal + if attn_mask_type in {"causal", "padding_causal"}: + bottom_right_diagonal = False + if bottom_right_diagonal is None or attn_mask_type in { + "causal_bottom_right", + "padding_causal_bottom_right", + }: + bottom_right_diagonal = True # checks for qkv_format if qkv_format is None: @@ -1144,6 +1166,8 @@ def forward( assert "padding" in attn_mask_type, "KV caching requires padding mask!" if attn_mask_type == "padding_causal": attn_mask_type = attn_mask_type + "_bottom_right" + # since attention mask is changed, set `bottom_right_diagonal` to True + bottom_right_diagonal = True self.attention_type = "cross" self.flash_attention.attention_type = self.attention_type @@ -1256,7 +1280,6 @@ def forward( if self.layer_number == 1: _alibi_cache["_alibi_slopes_require_update"] = True _alibi_cache["_alibi_bias_require_update"] = True - bottom_right_alignment = (attn_mask_type not in ["causal", "padding_causal"],) if core_attention_bias_type == "alibi": assert ( core_attention_bias is None @@ -1265,7 +1288,7 @@ def forward( _alibi_cache["_num_heads"] != query_layer.shape[-2] or _alibi_cache["_max_seqlen_q"] != max_seqlen_q or _alibi_cache["_max_seqlen_kv"] != max_seqlen_kv - or _alibi_cache["_bottom_right_alignment"] != bottom_right_alignment + or _alibi_cache["_bottom_right_alignment"] != bottom_right_diagonal or _alibi_cache["_alibi_slopes"] is None ): _alibi_cache["_alibi_slopes_require_update"] = True @@ -1322,6 +1345,7 @@ def forward( head_dim_v=head_dim_v, attn_mask_type=attn_mask_type, window_size=window_size, + bottom_right_diagonal=bottom_right_diagonal, alibi_slopes_shape=alibi_slopes.shape if alibi_slopes is not None else None, core_attention_bias_type=core_attention_bias_type, core_attention_bias_shape=core_attention_bias_shape, @@ -1445,9 +1469,7 @@ def forward( if use_fused_attention: fu_core_attention_bias_type = core_attention_bias_type fu_core_attention_bias = core_attention_bias - if core_attention_bias_type == "alibi" and ( - alibi_slopes is not None or max_seqlen_q != max_seqlen_kv - ): + if core_attention_bias_type == "alibi" and (alibi_slopes is not None): fu_core_attention_bias_type = "post_scale_bias" _, fu_core_attention_bias = dpa_utils.get_alibi( _alibi_cache, @@ -1456,7 +1478,7 @@ def forward( max_seqlen_kv, alibi_slopes=alibi_slopes, bias_dtype=query_layer.dtype, - bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"], + bottom_right_alignment=bottom_right_diagonal, ) if checkpoint_core_attention: return self._checkpointed_attention_forward( @@ -1474,6 +1496,7 @@ def forward( attn_mask_type=attn_mask_type, attention_mask=attention_mask, window_size=window_size, + bottom_right_diagonal=bottom_right_diagonal, fused_attention_backend=fused_attention_backend, core_attention_bias_type=fu_core_attention_bias_type, core_attention_bias=fu_core_attention_bias, @@ -1504,6 +1527,7 @@ def forward( attn_mask_type=attn_mask_type, attention_mask=attention_mask, window_size=window_size, + bottom_right_diagonal=bottom_right_diagonal, fused_attention_backend=fused_attention_backend, core_attention_bias_type=fu_core_attention_bias_type, core_attention_bias=fu_core_attention_bias, @@ -1538,6 +1562,7 @@ def forward( attn_mask_type=attn_mask_type, attention_mask=attention_mask, window_size=window_size, + bottom_right_diagonal=bottom_right_diagonal, core_attention_bias_type=core_attention_bias_type, core_attention_bias=core_attention_bias, alibi_slopes=alibi_slopes, @@ -1561,6 +1586,7 @@ def forward( attn_mask_type=attn_mask_type, attention_mask=attention_mask, window_size=window_size, + bottom_right_diagonal=bottom_right_diagonal, core_attention_bias_type=core_attention_bias_type, core_attention_bias=core_attention_bias, alibi_slopes=alibi_slopes, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 10a06ed9655..222986563c5 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -200,6 +200,9 @@ class AttentionParams: `causal_bottom_right`, `padding_causal_bottom_right`, `arbitrary`} window_size : Tuple[int, int], default = None Sliding window attention size. + bottom_right_diagonal: bool, default = `None` + Whether to align sliding window and ALiBi diagonal to the bottom right corner + of the softmax matrix. alibi_slopes_shape : Optional[Union[torch.Size, List]], default = None Tensor shape of :attr:`alibi_slopes` in `DotProductAttention`. core_attention_bias_type : str, default = no_bias @@ -249,6 +252,7 @@ class AttentionParams: head_dim_v: int = 64 attn_mask_type: str = "no_mask" window_size: Union[Tuple[int, int], None] = None + bottom_right_diagonal: bool = True alibi_slopes_shape: Union[torch.Size, List, None] = None core_attention_bias_type: str = "no_bias" core_attention_bias_shape: str = "1hss" @@ -325,6 +329,7 @@ def get_attention_backend( head_dim_v = attention_params.head_dim_v attn_mask_type = attention_params.attn_mask_type window_size = attention_params.window_size + bottom_right_diagonal = attention_params.bottom_right_diagonal alibi_slopes_shape = attention_params.alibi_slopes_shape core_attention_bias_type = attention_params.core_attention_bias_type core_attention_bias_shape = attention_params.core_attention_bias_shape @@ -867,39 +872,43 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt # backend | window_size | diagonal alignment # --------------------------------------------------------------------------------- # FlashAttention | (-1, -1) or (>=0, >=0) | bottom right - # FusedAttention | (-1, 0) or (>=0, 0) | top left - # UnfusedDotProductAttention | (-1, -1) or (>=0, >=0) | both; + # FusedAttention | (-1, 0) or (>=0, >=0) | top left, bottom right + # UnfusedDotProductAttention | (-1, -1) or (>=0, >=0) | top left, bottom right # | | converts window_size to an 'arbitrary' mask if window_size is None: window_size = check_set_window_size(attn_mask_type, window_size) - else: - if use_fused_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]): - if fp8 and (fp8_meta["recipe"].fp8_dpa or fp8_meta["recipe"].fp8_mha): - logger.debug( - "Disabling FusedAttention as it does not support sliding window attention" - " for FP8" - ) - use_fused_attention = False - elif window_size[1] != 0 or attention_dropout != 0.0: - logger.debug( - "Disabling FusedAttention as it only supports sliding window attention " - "with (left, 0) and no dropout" - ) - use_fused_attention = False - elif max_seqlen_q > max_seqlen_kv: - logger.debug( - "Disabling FusedAttention as it does not support sliding window attention " - "with s_q > s_kv for cross-attention" - ) - use_fused_attention = False - if use_flash_attention_2 and (window_size[0] != -1 or window_size[1] not in [-1, 0]): - if not FlashAttentionUtils.is_installed: - FlashAttentionUtils.version_required = PkgVersion("2.3") - elif not FlashAttentionUtils.v2_3_plus: - logger.debug( - "Disabling FlashAttention as sliding window attention requires flash-attn 2.3+" - ) - use_flash_attention_2 = False + if use_fused_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]): + if fp8 and (fp8_meta["recipe"].fp8_dpa or fp8_meta["recipe"].fp8_mha): + logger.debug( + "Disabling FusedAttention as it does not support sliding window attention for FP8" + ) + use_fused_attention = False + elif attention_dropout != 0.0: + logger.debug( + "Disabling FusedAttention as it only supports sliding window attention " + "without dropout" + ) + use_fused_attention = False + elif max_seqlen_q > max_seqlen_kv: + logger.debug( + "Disabling FusedAttention as it does not support sliding window attention " + "with s_q > s_kv for cross-attention" + ) + use_fused_attention = False + if use_flash_attention_2 and (window_size[0] != -1 or window_size[1] not in [-1, 0]): + if not FlashAttentionUtils.is_installed: + FlashAttentionUtils.version_required = PkgVersion("2.3") + elif not FlashAttentionUtils.v2_3_plus: + logger.debug( + "Disabling FlashAttention as sliding window attention requires flash-attn 2.3+" + ) + use_flash_attention_2 = False + elif not bottom_right_diagonal and max_seqlen_q != max_seqlen_kv: + logger.debug( + "Disabling FlashAttention as it only supports sliding window with bottom right" + " diagonal alignment for cross-attention" + ) + use_flash_attention = False # Filter: Attention bias # backend | bias types | ALiBi diagonal alignment @@ -921,6 +930,12 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt elif not FlashAttentionUtils.v2_4_plus: logger.debug("Disabling FlashAttention as ALiBi requires flash-attn 2.4+") use_flash_attention_2 = False + elif not bottom_right_diagonal and max_seqlen_q != max_seqlen_kv: + logger.debug( + "Disabling FlashAttention as it only supports ALiBi with bottom right diagonal" + " alignment for cross-attention" + ) + use_flash_attention = False if ( core_attention_bias_type not in ["no_bias", "alibi"] @@ -938,13 +953,12 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt if ( use_fused_attention and core_attention_bias_type == "alibi" - and (alibi_slopes_shape is not None or max_seqlen_q != max_seqlen_kv) + and (alibi_slopes_shape is not None) ): fu_core_attention_bias_type = "post_scale_bias" fu_core_attention_bias_requires_grad = False - if alibi_slopes_shape is None: - fu_core_attention_bias_shape = "1hss" - elif len(alibi_slopes_shape) == 1 and alibi_slopes_shape[0] == num_heads: + + if len(alibi_slopes_shape) == 1 and alibi_slopes_shape[0] == num_heads: fu_core_attention_bias_shape = "1hss" elif ( len(alibi_slopes_shape) == 2 diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index beb13b7f1e3..07878d63e93 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -32,6 +32,7 @@ from transformer_engine.pytorch.attention.dot_product_attention import DotProductAttention from transformer_engine.pytorch.attention.inference import InferenceParams from transformer_engine.pytorch.attention.rope import apply_rotary_pos_emb +from transformer_engine.pytorch.attention.dot_product_attention import utils as dpa_utils from transformer_engine.pytorch.cpu_offload import start_offload, is_cpu_offload_enabled @@ -93,6 +94,11 @@ class MultiheadAttention(torch.nn.Module): map to ``window_size = (-1, 0)`` and Transformer Engine distinguishes them based on ``attn_mask_type``. Similar to :attr:`attn_mask_type`, ``window_size`` can be overridden by :attr:`window_size` in :meth:`forward` as well. + bottom_right_diagonal: Optional[bool], default = `None` + Align sliding window and ALiBi diagonal to the top left (`False`) + or bottom right (`True`) corner of the softmax matrix in the encoder. + If `None`, it will be set to `False` for `attn_mask_type` = + {`causal`, `padding_causal`} and `True` for other mask types. num_gqa_groups : int, default = None number of GQA groups in the transformer layer. Grouped Query Attention is described in @@ -248,6 +254,7 @@ def __init__( layer_number: Optional[int] = None, attn_mask_type: str = "causal", window_size: Optional[Tuple[int, int]] = None, + bottom_right_diagonal: Optional[bool] = None, tp_group: Optional[dist_group_type] = None, tp_size: int = 1, num_gqa_groups: Optional[int] = None, @@ -286,6 +293,7 @@ def __init__( self.qkv_format = qkv_format self.attn_mask_type = attn_mask_type self.window_size = window_size + self.bottom_right_diagonal = bottom_right_diagonal self.layer_number = 1 if layer_number is None else layer_number self.input_layernorm = input_layernorm self.attention_type = attention_type @@ -621,6 +629,7 @@ def forward( encoder_output: Optional[torch.Tensor] = None, attn_mask_type: Optional[str] = None, window_size: Optional[Tuple[int, int]] = None, + bottom_right_diagonal: Optional[bool] = None, is_first_microbatch: Optional[bool] = None, checkpoint_core_attention: bool = False, inference_params: Optional[InferenceParams] = None, @@ -667,6 +676,11 @@ def forward( aligned to the bottom right corner. window_size: Optional[Tuple[int, int]], default = None sliding window size for local attention. + bottom_right_diagonal: Optional[bool], default = `None` + Align sliding window and ALiBi diagonal to the top left (`False`) + or bottom right (`True`) corner of the softmax matrix in the encoder. + If `None`, it will be set to `False` for `attn_mask_type` = + {`causal`, `padding_causal`} and `True` for other mask types. encoder_output : Optional[torch.Tensor], default = None Output of the encoder block to be fed into the decoder block if using ``layer_type="decoder"``. @@ -731,6 +745,17 @@ def forward( if window_size is None: window_size = self.window_size + window_size = dpa_utils.check_set_window_size(attn_mask_type, window_size) + if bottom_right_diagonal is None: + bottom_right_diagonal = self.bottom_right_diagonal + if attn_mask_type in {"causal", "padding_causal"}: + bottom_right_diagonal = False + if bottom_right_diagonal is None or attn_mask_type in { + "causal_bottom_right", + "padding_causal_bottom_right", + }: + bottom_right_diagonal = True + if "padding" in attn_mask_type and attention_mask is not None: for mask in attention_mask: assert mask.dtype == torch.bool, "Attention mask must be in boolean type!" @@ -1004,6 +1029,7 @@ def forward( attention_mask=attention_mask, attn_mask_type=attn_mask_type, window_size=window_size, + bottom_right_diagonal=bottom_right_diagonal, checkpoint_core_attention=checkpoint_core_attention, core_attention_bias_type=core_attention_bias_type, core_attention_bias=core_attention_bias, diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 88c223eb462..12c62712044 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -137,6 +137,7 @@ def fused_attn_fwd( attn_mask_type: str = "padding", softmax_type: str = "vanilla", window_size: Tuple[int, int] = (-1, -1), + bottom_right_diagonal: bool = None, rng_gen: torch.Generator = None, softmax_offset: torch.Tensor = None, return_max_logit: bool = False, @@ -212,6 +213,9 @@ def fused_attn_fwd( in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding window and causal mask specifically. + bottom_right_diagonal: bool, default = None + whether to align sliding window and ALiBi diagonal to the top left (False) or + bottom right (True) corner of the softmax matrix. rng_gen : torch.Generator, default = None random number generator; if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen @@ -255,6 +259,12 @@ def fused_attn_fwd( max_logit : if return_max_logit = True, shape [h] and same data type as O; otherwise None """ + if bottom_right_diagonal is None: + bottom_right_diagonal = attn_mask_type in { + "causal_bottom_right", + "padding_causal_bottom_right", + } + if attn_scale is None: d = q.size(-1) attn_scale = 1.0 / math.sqrt(d) @@ -306,6 +316,7 @@ def fused_attn_fwd( AttnMaskType[attn_mask_type], SoftmaxType[softmax_type], window_size, + bottom_right_diagonal, cu_seqlens_q, cu_seqlens_kv, q, @@ -370,6 +381,7 @@ def fused_attn_bwd( attn_mask_type: str = "padding", softmax_type: str = "vanilla", window_size: Tuple[int, int] = (-1, -1), + bottom_right_diagonal: bool = None, deterministic: bool = False, cuda_graph: bool = False, ) -> Tuple[Union[torch.Tensor, None], ...]: @@ -442,6 +454,9 @@ def fused_attn_bwd( in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding window and causal mask specifically. + bottom_right_diagonal: bool, default = True + whether to align sliding window and ALiBi diagonal to the top left (False) or + bottom right (True) corner of the softmax matrix. deterministic : bool, default = False whether to execute the backward pass with deterministic behaviours. cuda_graph : bool, default = False @@ -462,6 +477,12 @@ def fused_attn_bwd( gradient tensor of softmax offset of shape [1, h_q, 1, 1]. See softmax_type in DotProductAttention for details. """ + if bottom_right_diagonal is None: + bottom_right_diagonal = attn_mask_type in { + "causal_bottom_right", + "padding_causal_bottom_right", + } + if attn_scale is None: d = q.size(-1) attn_scale = 1.0 / math.sqrt(d) @@ -500,6 +521,7 @@ def fused_attn_bwd( AttnMaskType[attn_mask_type], SoftmaxType[softmax_type], window_size, + bottom_right_diagonal, deterministic, cu_seqlens_q, cu_seqlens_kv, diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 80479dccf48..b263695895d 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -87,9 +87,10 @@ std::vector fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - const std::vector window_size, const at::Tensor cu_seqlens_q, - const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, - const at::ScalarType fake_dtype, const std::optional cu_seqlens_q_padded, + const std::vector window_size, bool bottom_right_diagonal, + const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, + const py::handle K, const py::handle V, const at::ScalarType fake_dtype, + const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, const std::optional page_table_k, const std::optional page_table_v, py::handle s_quantizer, py::handle o_quantizer, const std::optional Bias, @@ -99,10 +100,10 @@ std::vector fused_attn_fwd( std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - NVTE_Softmax_Type softmax_type, const std::vector window_size, bool deterministic, - const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, - const py::handle K, const py::handle V, const py::handle O, const py::handle dO, - const at::ScalarType fake_dtype, const DType dqkv_type, + NVTE_Softmax_Type softmax_type, const std::vector window_size, + bool bottom_right_diagonal, bool deterministic, const at::Tensor cu_seqlens_q, + const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, + const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const DType dqkv_type, const std::vector Aux_CTX_Tensors, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 2480d9aba9b..2ac67dca0da 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -100,9 +100,10 @@ std::vector fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - const std::vector window_size, const at::Tensor cu_seqlens_q, - const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, - const at::ScalarType fake_dtype, const std::optional cu_seqlens_q_padded, + const std::vector window_size, bool bottom_right_diagonal, + const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, + const py::handle K, const py::handle V, const at::ScalarType fake_dtype, + const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, const std::optional page_table_k, const std::optional page_table_v, py::handle s_quantizer, py::handle o_quantizer, const std::optional Bias, @@ -235,7 +236,7 @@ std::vector fused_attn_fwd( te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, - softmax_type, window_size[0], window_size[1], workspace.data(), + softmax_type, window_size[0], window_size[1], bottom_right_diagonal, workspace.data(), at::cuda::getCurrentCUDAStream()); }); @@ -295,7 +296,7 @@ std::vector fused_attn_fwd( te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, - softmax_type, window_size[0], window_size[1], workspace.data(), + softmax_type, window_size[0], window_size[1], bottom_right_diagonal, workspace.data(), at::cuda::getCurrentCUDAStream()); }); @@ -310,10 +311,10 @@ std::vector fused_attn_fwd( std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - NVTE_Softmax_Type softmax_type, const std::vector window_size, bool deterministic, - const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, - const py::handle K, const py::handle V, const py::handle O, const py::handle dO, - const at::ScalarType fake_dtype, const DType dqkv_type, + NVTE_Softmax_Type softmax_type, const std::vector window_size, + bool bottom_right_diagonal, bool deterministic, const at::Tensor cu_seqlens_q, + const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, + const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const DType dqkv_type, const std::vector Aux_CTX_Tensors, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, @@ -532,14 +533,14 @@ std::vector fused_attn_bwd( // populate tensors with appropriate shapes and dtypes NVTE_SCOPED_GIL_RELEASE({ - nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), - te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), - te_dK.data(), te_dV.data(), te_dBias.data(), te_dSoftmaxOffset.data(), - te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), - te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, - max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, - softmax_type, window_size[0], window_size[1], deterministic, cuda_graph, - workspace.data(), at::cuda::getCurrentCUDAStream()); + nvte_fused_attn_bwd( + te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), + &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(), + te_dSoftmaxOffset.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), + te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, + attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], + window_size[1], bottom_right_diagonal, deterministic, cuda_graph, workspace.data(), + at::cuda::getCurrentCUDAStream()); }); // allocate memory for workspace @@ -549,14 +550,14 @@ std::vector fused_attn_bwd( // execute kernel NVTE_SCOPED_GIL_RELEASE({ - nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), - te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), - te_dK.data(), te_dV.data(), te_dBias.data(), te_dSoftmaxOffset.data(), - te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), - te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, - max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, - softmax_type, window_size[0], window_size[1], deterministic, cuda_graph, - workspace.data(), at::cuda::getCurrentCUDAStream()); + nvte_fused_attn_bwd( + te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), + &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(), + te_dSoftmaxOffset.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), + te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, + attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], + window_size[1], bottom_right_diagonal, deterministic, cuda_graph, workspace.data(), + at::cuda::getCurrentCUDAStream()); }); // destroy tensor wrappers diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index b3ad8ccc550..a37ac5162b3 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -35,7 +35,7 @@ from transformer_engine.pytorch.distributed import get_distributed_world_size from transformer_engine.pytorch.export import is_in_onnx_export_mode from transformer_engine.pytorch.module.base import TransformerEngineBaseModule - +import transformer_engine.pytorch.attention.dot_product_attention.utils as dpa_utils warnings.filterwarnings("module", category=DeprecationWarning, module="transformer") @@ -149,11 +149,21 @@ class TransformerLayer(torch.nn.Module): distinguishes them based on :attr:`self_attn_mask_type` or :attr:`enc_dec_attn_mask_type`. Similar to :attr:`self_attn_mask_type`, :attr:`window_size` can be overridden by :attr:`window_size` in :meth:`forward` as well. + bottom_right_diagonal: Optional[bool], default = `None` + Align sliding window and ALiBi diagonal to the top left (`False`) + or bottom right (`True`) corner of the softmax matrix in the encoder. + If `None`, it will be set to `False` for `self_attn_mask_type` = + {`causal`, `padding_causal`} and `True` for other mask types. enc_dec_attn_mask_type : {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'}, default = "no_mask" type of attention mask passed into softmax operation for decoder. enc_dec_window_size : Optional[Tuple[int, int]], default = None sliding window size for local attention in decoder. + enc_dec_bottom_right_diagonal: Optional[bool], default = `None` + Align sliding window and ALiBi diagonal to the top left (`False`) + or bottom right (`True`) corner of the softmax matrix in the decoder. + If `None`, it will be set to `False` for `enc_dec_attn_mask_type` = + {`causal`, `padding_causal`} and `True` for other mask types. zero_centered_gamma : bool, default = False if set to ``True``, gamma parameter in LayerNorm is initialized to 0 and the LayerNorm formula changes to @@ -302,7 +312,9 @@ def __init__( kv_channels: Optional[int] = None, self_attn_mask_type: str = "causal", window_size: Optional[Tuple[int, int]] = None, + bottom_right_diagonal: Optional[bool] = None, enc_dec_attn_mask_type: str = "no_mask", + enc_dec_bottom_right_diagonal: Optional[bool] = None, enc_dec_window_size: Optional[Tuple[int, int]] = None, tp_group: Optional[dist_group_type] = None, tp_size: int = 1, @@ -344,8 +356,10 @@ def __init__( self.self_attn_mask_type = self_attn_mask_type self.window_size = window_size + self.bottom_right_diagonal = bottom_right_diagonal self.enc_dec_attn_mask_type = enc_dec_attn_mask_type self.enc_dec_window_size = enc_dec_window_size + self.enc_dec_bottom_right_diagonal = enc_dec_bottom_right_diagonal params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype ub_bulk_wgrad = ub_tp_comm_overlap and ub_bulk_wgrad ub_bulk_dgrad = ub_tp_comm_overlap and ub_bulk_dgrad @@ -606,10 +620,12 @@ def forward( attention_mask: Optional[torch.Tensor] = None, self_attn_mask_type: Optional[str] = None, window_size: Optional[Tuple[int, int]] = None, + bottom_right_diagonal: Optional[bool] = None, encoder_output: Optional[torch.Tensor] = None, enc_dec_attn_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, enc_dec_attn_mask_type: Optional[str] = None, enc_dec_window_size: Optional[Tuple[int, int]] = None, + enc_dec_bottom_right_diagonal: Optional[bool] = None, is_first_microbatch: Optional[bool] = None, checkpoint_core_attention: bool = False, inference_params: Optional[InferenceParams] = None, @@ -654,6 +670,11 @@ def forward( causal masks are aligned to the bottom right corner. window_size: Optional[Tuple[int, int]], default = None Sliding window size for local attention in encoder. + bottom_right_diagonal: Optional[bool] = `None` + Align sliding window and ALiBi diagonal to the top left (`False`) + or bottom right (`True`) corner of the softmax matrix in the encoder. + If `None`, it will be set to `False` for `self_attn_mask_type` = + {`causal`, `padding_causal`} and `True` for other mask types. encoder_output : Optional[torch.Tensor], default = None Output of the encoder block to be fed into the decoder block if using :attr:`layer_type` = ``"decoder"``. @@ -670,6 +691,11 @@ def forward( Type of attention mask passed into softmax operation for decoder. enc_dec_window_size: Optional[Tuple[int, int]], default = None Sliding window size for local attention in decoder. + enc_dec_bottom_right_diagonal: Optional[bool] = `None` + Align sliding window and ALiBi diagonal to the top left (`False`) + or bottom right (`True`) corner of the softmax matrix in the decoder. + If `None`, it will be set to `False` for `enc_dec_attn_mask_type` = + {`causal`, `padding_causal`} and `True` for other mask types. is_first_microbatch : {True, False, None}, default = None During training using either gradient accumulation or pipeline parallelism a minibatch of data is further split @@ -736,10 +762,35 @@ def forward( self_attn_mask_type = self.self_attn_mask_type if window_size is None: window_size = self.window_size + window_size = dpa_utils.check_set_window_size(self_attn_mask_type, window_size) + if enc_dec_attn_mask_type is None: enc_dec_attn_mask_type = self.enc_dec_attn_mask_type if enc_dec_window_size is None: enc_dec_window_size = self.enc_dec_window_size + enc_dec_window_size = dpa_utils.check_set_window_size( + enc_dec_attn_mask_type, enc_dec_window_size + ) + + if bottom_right_diagonal is None: + bottom_right_diagonal = self.bottom_right_diagonal + if self_attn_mask_type in {"causal", "padding_causal"}: + bottom_right_diagonal = False + if bottom_right_diagonal is None or self_attn_mask_type in { + "causal_bottom_right", + "padding_causal_bottom_right", + }: + bottom_right_diagonal = True + + if enc_dec_bottom_right_diagonal is None: + enc_dec_bottom_right_diagonal = self.enc_dec_bottom_right_diagonal + if enc_dec_attn_mask_type in {"causal", "padding_causal"}: + enc_dec_bottom_right_diagonal = False + if enc_dec_bottom_right_diagonal is None or enc_dec_attn_mask_type in { + "causal_bottom_right", + "padding_causal_bottom_right", + }: + enc_dec_bottom_right_diagonal = True assert ( self_attn_mask_type in AttnMaskTypes @@ -781,6 +832,7 @@ def forward( attention_mask=attention_mask, attn_mask_type=self_attn_mask_type, window_size=window_size, + bottom_right_diagonal=bottom_right_diagonal, inference_params=inference_params, is_first_microbatch=is_first_microbatch, checkpoint_core_attention=checkpoint_core_attention, @@ -816,6 +868,7 @@ def forward( attention_mask=enc_dec_attn_mask, attn_mask_type=enc_dec_attn_mask_type, window_size=enc_dec_window_size, + bottom_right_diagonal=enc_dec_bottom_right_diagonal, encoder_output=encoder_output, inference_params=inference_params, is_first_microbatch=is_first_microbatch,