[JAX] Refactor and trim TE JAX Attn testing #2542
Draft
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
dp*cp*tp <= num gpus. Below is the L1 dist timing for TE 2.11 (B200x8)This PR runs only those L1 and L2 combinations where
dp*cp*tp==num gpus. Below is the L1 dist timing for this PR (B200x8)There is a reduction of 1020 (2157-1137) tests collected owing to the change in this PR.
For the
test_context_paralleltests, the number of test are halved in number in this PR as only the test cases fordp*cp*tp==8are collected but not those fordp*cp*tp==4anddp*cp*tp==2. This is not that big a problem in CI as we run H100x4 and GB200x4 so test cases fordp*cp*tp==4will be covered in there.Cons of this change:
dp*cp*tp==2test will not be covered in it's current form. TODO: If coverage for this is needed, CI could setcuda_visible_devices=0,1for any of these configs an run these tests as welldp*cp*tp<=8for B200, however, with this PR, we will only rundp*cp*tp==8cases. The current tests would rundp*cp*tp<=4for H100, however, with this PR, we will only rundp*cp*tp==4cases. The current tests would rundp*cp*tp<=4for GB200, however, with this PR, we will only rundp*cp*tp==4cases. Overall test cases would still be the same but we just would not have all combinations available for a given CI config (GPU arch) running on itChecklist: