-
Notifications
You must be signed in to change notification settings - Fork 589
Add support for SWA (left, right) with FusedAttention #2477
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
sudhakarsingh27
wants to merge
20
commits into
NVIDIA:main
Choose a base branch
from
sudhakarsingh27:swa_padding_brcm_try2
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
33e2e47
SWA (left, right) with FusedAttention changes cherry-picked from http…
sudhakarsingh27 eab24be
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] e761a26
fix test_kv_cache failures
sudhakarsingh27 48e4f4d
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into swa_p…
sudhakarsingh27 172ebbe
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into swa_p…
sudhakarsingh27 93548fc
remove unnecessary comments
sudhakarsingh27 a545ebf
fix some more filter issues, address feedback
sudhakarsingh27 a6904ea
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into swa_p…
sudhakarsingh27 341f0f7
fix for local test case failures - `bottom_right_diagonal` should be …
sudhakarsingh27 c42c555
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into swa_p…
sudhakarsingh27 1fd4985
make conditions more accurate
sudhakarsingh27 19ce4da
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into HEAD
sudhakarsingh27 95bbec1
add cp tests to test swa (left, right)
sudhakarsingh27 13aee94
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 67c5d39
remove dead code and make conditions better
sudhakarsingh27 0c81e6a
Merge branch 'swa_padding_brcm_try2' of github.com:sudhakarsingh27/Tr…
sudhakarsingh27 4f25bf2
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into swa_p…
sudhakarsingh27 5e14579
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] da4b3d6
fix lint
sudhakarsingh27 dbff75a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -406,9 +406,11 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( | |
| (window_size_right == -1 || window_size_right == 0)) || | ||
| // 9.2: SWA (left, 0) + top-left diagonal + {bshd, sbhd} | ||
| (cudnn_runtime_version >= 90200 && | ||
| ((window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) || | ||
| ((window_size_left >= 0 || window_size_left == -1) && window_size_right == 0 && | ||
| (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || | ||
| ((window_size_left == -1 && window_size_right == -1 && | ||
| attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK) || | ||
| ((window_size_left == -1 || window_size_left >= 0) && window_size_right == 0 && | ||
| (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || | ||
| attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || | ||
| (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && | ||
| max_seqlen_q == max_seqlen_kv)) && | ||
| max_seqlen_q <= max_seqlen_kv && dropout == 0.0 && | ||
|
|
@@ -418,12 +420,14 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( | |
| // 9.6: SWA (left, 0) + top-left/bottom-right diagonal + {bshd, sbhd, thd} | ||
| (cudnn_runtime_version >= 90600 && | ||
| ((window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) || | ||
| ((window_size_left >= 0 || window_size_left == -1) && window_size_right == 0 && | ||
| ((window_size_left >= 0 || window_size_left == -1) && | ||
| (window_size_right >= 0 || window_size_right == -1) && | ||
| ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && | ||
| // TODO(cyang): fix bug for BRCM + cross-attention on sm100 | ||
| (sm_arch_ < 100 || (sm_arch_ >= 100 && ((max_seqlen_q == max_seqlen_kv && | ||
| cudnn_runtime_version <= 90700) || | ||
| cudnn_runtime_version > 90700)))) || | ||
| attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || | ||
| attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || | ||
| (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && | ||
| (sm_arch_ < 100 || (sm_arch_ >= 100 && ((max_seqlen_q == max_seqlen_kv && | ||
|
|
@@ -504,18 +508,16 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( | |
| } | ||
|
|
||
| // NVTE fused attention FWD with packed QKV | ||
| // DEPRECATED: This API is deprecated. | ||
| // DEPRECATED: This API is deprecated. (Should there be a version by which this is going to be removed? @cyang) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I made some changes in #2272, but will see if I can make the 2.11 deadline. |
||
| // Please use nvte_fused_attn_fwd with separate Q, K, V tensors instead. | ||
| void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, | ||
| const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, | ||
| NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, | ||
| const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, | ||
| size_t max_seqlen, bool is_training, bool return_max_logit, | ||
| bool cuda_graph, float attn_scale, float dropout, | ||
| NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, | ||
| NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, | ||
| int64_t window_size_left, int64_t window_size_right, | ||
| NVTETensor workspace, cudaStream_t stream) { | ||
| void nvte_fused_attn_fwd_qkvpacked( | ||
| const NVTETensor QKV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, | ||
| NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, | ||
| const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, size_t max_seqlen, | ||
| bool is_training, bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, | ||
| NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, | ||
| NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, | ||
| bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream) { | ||
| NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked); | ||
| using namespace transformer_engine; | ||
|
|
||
|
|
@@ -589,10 +591,10 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, | |
| fused_attn_arbitrary_seqlen_fwd( | ||
| b, h, h, max_seqlen, max_seqlen, d, d, t, t, 0, 0, 0, 0, 0, 0, is_training, | ||
| return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, | ||
| window_size_left, window_size_right, &Q_view, &K_view, &V_view, input_Bias, | ||
| input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens, | ||
| input_cu_seqlens_padded, input_cu_seqlens_padded, nullptr, nullptr, input_rng_state, | ||
| wkspace, stream, handle); | ||
| window_size_left, window_size_right, bottom_right_diagonal, &Q_view, &K_view, &V_view, | ||
| input_Bias, input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens, | ||
| input_cu_seqlens, input_cu_seqlens_padded, input_cu_seqlens_padded, nullptr, nullptr, | ||
| input_rng_state, wkspace, stream, handle); | ||
| #else | ||
| NVTE_ERROR( | ||
| "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); | ||
|
|
@@ -629,8 +631,8 @@ void nvte_fused_attn_bwd_qkvpacked( | |
| NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, | ||
| size_t max_seqlen, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, | ||
| NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, | ||
| int64_t window_size_left, int64_t window_size_right, bool deterministic, bool cuda_graph, | ||
| NVTETensor workspace, cudaStream_t stream) { | ||
| int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, | ||
| bool deterministic, bool cuda_graph, NVTETensor workspace, cudaStream_t stream) { | ||
| NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked); | ||
| using namespace transformer_engine; | ||
|
|
||
|
|
@@ -725,10 +727,11 @@ void nvte_fused_attn_bwd_qkvpacked( | |
|
|
||
| fused_attn_arbitrary_seqlen_bwd( | ||
| b, h, h, max_seqlen, max_seqlen, d, d, t, t, attn_scale, dropout, qkv_layout, bias_type, | ||
| attn_mask_type, softmax_type, window_size_left, window_size_right, deterministic, &Q_view, | ||
| &K_view, &V_view, input_O, input_dO, input_Bias, input_SoftmaxOffset, output_S, &dQ_view, | ||
| &dK_view, &dV_view, output_dBias, output_dSoftmaxOffset, input_cu_seqlens, input_cu_seqlens, | ||
| input_cu_seqlens_padded, input_cu_seqlens_padded, input_rng_state, wkspace, stream, handle); | ||
| attn_mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, | ||
| deterministic, &Q_view, &K_view, &V_view, input_O, input_dO, input_Bias, | ||
| input_SoftmaxOffset, output_S, &dQ_view, &dK_view, &dV_view, output_dBias, | ||
| output_dSoftmaxOffset, input_cu_seqlens, input_cu_seqlens, input_cu_seqlens_padded, | ||
| input_cu_seqlens_padded, input_rng_state, wkspace, stream, handle); | ||
| #else | ||
| const char *err_msg = | ||
| "cuDNN 8.9.0 is required for BF16/FP16 fused attention " | ||
|
|
@@ -779,7 +782,8 @@ void nvte_fused_attn_fwd_kvpacked( | |
| size_t max_seqlen_kv, bool is_training, bool return_max_logit, bool cuda_graph, | ||
| float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, | ||
| NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, | ||
| int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) { | ||
| int64_t window_size_right, bool bottom_right_diagonal, NVTETensor workspace, | ||
| cudaStream_t stream) { | ||
| NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked); | ||
| using namespace transformer_engine; | ||
| const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); | ||
|
|
@@ -891,10 +895,10 @@ void nvte_fused_attn_fwd_kvpacked( | |
| b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, t_q, t_kv, num_pages_k, num_pages_v, | ||
| page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, | ||
| return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, | ||
| window_size_left, window_size_right, input_Q, &K_view, &V_view, input_Bias, | ||
| input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, | ||
| input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k, | ||
| input_page_table_v, input_rng_state, wkspace, stream, handle); | ||
| window_size_left, window_size_right, bottom_right_diagonal, input_Q, &K_view, &V_view, | ||
| input_Bias, input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, | ||
| input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, | ||
| input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle); | ||
| #else | ||
| NVTE_ERROR( | ||
| "cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); | ||
|
|
@@ -933,8 +937,8 @@ void nvte_fused_attn_bwd_kvpacked( | |
| const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, | ||
| float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, | ||
| NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, | ||
| int64_t window_size_right, bool deterministic, bool cuda_graph, NVTETensor workspace, | ||
| cudaStream_t stream) { | ||
| int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, bool cuda_graph, | ||
| NVTETensor workspace, cudaStream_t stream) { | ||
| NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked); | ||
| using namespace transformer_engine; | ||
| const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); | ||
|
|
@@ -1040,11 +1044,11 @@ void nvte_fused_attn_bwd_kvpacked( | |
|
|
||
| fused_attn_arbitrary_seqlen_bwd( | ||
| b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, t_q, t_kv, attn_scale, dropout, qkv_layout, | ||
| bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, deterministic, | ||
| input_Q, &K_view, &V_view, input_O, input_dO, input_Bias, input_SoftmaxOffset, output_S, | ||
| output_dQ, &dK_view, &dV_view, output_dBias, output_dSoftmaxOffset, input_cu_seqlens_q, | ||
| input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, | ||
| wkspace, stream, handle); | ||
| bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, | ||
| bottom_right_diagonal, deterministic, input_Q, &K_view, &V_view, input_O, input_dO, | ||
| input_Bias, input_SoftmaxOffset, output_S, output_dQ, &dK_view, &dV_view, output_dBias, | ||
| output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, | ||
| input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); | ||
| #else | ||
| const char *err_msg = | ||
| "cuDNN 8.9.3 is required for BF16/FP16 fused attention " | ||
|
|
@@ -1094,8 +1098,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso | |
| bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, | ||
| NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, | ||
| NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, | ||
| int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, | ||
| cudaStream_t stream) { | ||
| int64_t window_size_left, int64_t window_size_right, | ||
| bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream) { | ||
| NVTE_API_CALL(nvte_flash_attn_fwd); | ||
| using namespace transformer_engine; | ||
| const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); | ||
|
|
@@ -1183,10 +1187,10 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso | |
| b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, num_pages_k, num_pages_v, | ||
| page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, | ||
| return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, | ||
| window_size_left, window_size_right, input_Q, input_K, input_V, input_Bias, | ||
| input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, | ||
| input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k, | ||
| input_page_table_v, input_rng_state, wkspace, stream, handle); | ||
| window_size_left, window_size_right, bottom_right_diagonal, input_Q, input_K, input_V, | ||
| input_Bias, input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, | ||
| input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, | ||
| input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle); | ||
| #else | ||
| NVTE_ERROR( | ||
| "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); | ||
|
|
@@ -1215,8 +1219,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso | |
| size_t max_seqlen_kv, float attn_scale, float dropout, | ||
| NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, | ||
| NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, | ||
| int64_t window_size_left, int64_t window_size_right, bool deterministic, | ||
| bool cuda_graph, NVTETensor workspace, cudaStream_t stream) { | ||
| int64_t window_size_left, int64_t window_size_right, | ||
| bool bottom_right_diagonal, bool deterministic, bool cuda_graph, | ||
| NVTETensor workspace, cudaStream_t stream) { | ||
| NVTE_API_CALL(nvte_flash_attn_bwd); | ||
| using namespace transformer_engine; | ||
| const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); | ||
|
|
@@ -1289,8 +1294,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso | |
| fused_attn_arbitrary_seqlen_bwd( | ||
| b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, attn_scale, dropout, | ||
| qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, | ||
| deterministic, input_Q, input_K, input_V, input_O, input_dO, input_Bias, | ||
| input_SoftmaxOffset, output_S, output_dQ, output_dK, output_dV, output_dBias, | ||
| bottom_right_diagonal, deterministic, input_Q, input_K, input_V, input_O, input_dO, | ||
| input_Bias, input_SoftmaxOffset, output_S, output_dQ, output_dK, output_dV, output_dBias, | ||
| output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, | ||
| input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); | ||
| #else | ||
|
|
||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a couple of SWA tests to the CP tests as well? I think it's just a matter of replacing (left,0) with (left, right) and test them out. Thanks!