-
Notifications
You must be signed in to change notification settings - Fork 589
[PyTorch] Fuse permute+pad and unpermute+unpad ops for FP8 optimization #1921
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Thanks for you contribution. Never mind, I agree that fusing permutation with padding is a better solution. |
@yaox12
|
|
We have also implemented permute + pad, unpermute + unpad fusion that works with deepgemm grouped gemm for training, and we are using it. |
Greptile Summary
Important Files Changed
Confidence score: 4/5
Sequence DiagramsequenceDiagram
participant User
participant API as "moe_permute_and_pad_with_probs"
participant AutogradFn as "_moe_permute_mask_map"
participant TritonKernel as "permute_with_mask_map"
participant UnpermuteAPI as "moe_unpermute"
participant UnpermuteAutograd as "_moe_unpermute_mask_map"
participant UnpermuteKernel as "unpermute_with_mask_map"
User->>API: "moe_permute_and_pad_with_probs(inp, probs, routing_map, tokens_per_expert, align_size)"
API->>API: "Calculate target_tokens_per_expert and pad_offsets"
API->>AutogradFn: "apply(inp, routing_map, num_out_tokens, probs, pad_offsets)"
AutogradFn->>TritonKernel: "permute_with_mask_map(..., pad_offsets, ...)"
TritonKernel->>AutogradFn: "Return permuted and padded output, permuted_probs"
AutogradFn->>API: "Return output, row_id_map, permuted_probs"
API->>User: "Return output, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert"
User->>UnpermuteAPI: "moe_unpermute(inp, row_id_map, merging_probs, pad_offsets=pad_offsets)"
UnpermuteAPI->>UnpermuteAutograd: "apply(inp, row_id_map, merging_probs, restore_shape, pad_offsets)"
UnpermuteAutograd->>UnpermuteKernel: "unpermute_with_mask_map(..., pad_offsets, ...)"
UnpermuteKernel->>UnpermuteAutograd: "Return unpermuted and unpadded output"
UnpermuteAutograd->>UnpermuteAPI: "Return unpermuted_output"
UnpermuteAPI->>User: "Return unpermuted_output"
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (2)
-
transformer_engine/pytorch/triton/permutation.py, line 438-470 (link)syntax: syntax error: duplicated and malformed tensor allocation code
lines 438-456 allocate
output,permuted_probs, andpermuted_scalewith if-else blocks. then lines 457-470 duplicate this logic but leave line 456's if-block incomplete (missing the else branch allocation) and line 467 is a stray incomplete statement. -
transformer_engine/pytorch/permutation.py, line 632-633 (link)logic: inconsistent device placement
torch.cat([torch.zeros(1, dtype=torch.int32), cum_pad[:-1]])creates tensors on CPU butcum_padis derived fromtokens_per_expertwhich may be on CPU. explicitly calling.cuda()afterwards can cause unnecessary data transfers and potential device mismatch errors iftokens_per_expertis already on CUDA.
4 files reviewed, 2 comments
eb8498c to
6df72fe
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (4)
-
transformer_engine/pytorch/permutation.py, line 594 (link)syntax: Return type annotation incorrect - function returns 5 values but annotation shows only 2
-
transformer_engine/pytorch/permutation.py, line 615-616 (link)style: Missing documentation for return values. Function returns 5 values:
(output, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert)but docstring doesn't describe them. -
transformer_engine/pytorch/permutation.py, line 632-633 (link)logic: Potential device mismatch -
tokens_per_expertis on CPU (line 702 in test shows.cpu()), socum_padwill be on CPU too, but then.cuda()is called. If user's inputtokens_per_expertis already on CUDA or on a specific CUDA device, this could fail or move to wrong device. Usetokens_per_expert.deviceinstead of hardcoded.cuda() -
transformer_engine/pytorch/triton/permutation.py, line 350-357 (link)style: Duplicate
expert_idxload when bothFUSION_PADandPERMUTE_PROBSare true. The same value is loaded again at lines 365-369. Consider loading once and reusing to improve performance.
4 files reviewed, 4 comments
6df72fe to
e2af4a7
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
transformer_engine/pytorch/permutation.py, line 632-633 (link)logic: hardcoded
.cuda()call may fail if input tensors are on different devices
5 files reviewed, 1 comment
e2af4a7 to
8834b0f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
transformer_engine/pytorch/permutation.py, line 631 (link)syntax: dtype mismatch:
cum_pad[:-1]istorch.longbuttorch.zeros(1, dtype=torch.int32)istorch.int32
5 files reviewed, 1 comment
1.Fused `moe_permute_with_probs` + `Fp8Padding` and fused `moe_unpermute` + `Fp8Unpadding`, that can remove the explicit padding/unpadding of moe expert, improved performance and reduced peak gpu memory usage. 2.Add tests of fused permute/pad and unpermute/unpad. Signed-off-by: xiaoxi-wangfj <690912414@qq.com>
8834b0f to
feb65b7
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, no comments
|
also, since you do not allow write permission, the automatic pre-commit.ci service cannot automatically format the code for you and therefore fail the pre-commit.ci - pr check. Therefore, you could also try running the qa/format.sh script manually to format the code before pushing. Otherwise please enable write permission (as requested for jax compatibility purposes above) |
…rging_probs Signed-off-by: xiaoxi-wangfj <690912414@qq.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
transformer_engine/common/triton/permutation.py, line 251-270 (link)style:
expert_idxloaded twice when bothFUSION_PADandPERMUTE_PROBSare enabled - once at line 252-256 and again at lines 266-270. Consider loading once and reusing.
5 files reviewed, 1 comment
Signed-off-by: xiaoxi-wangfj <690912414@qq.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (2)
-
transformer_engine/pytorch/permutation.py, line 625 (link)style: the
torch.ceiloperation inherits device from input - iftokens_per_expertis on CPU (as in the test),target_tokens_per_expertwill also be on CPU, which is correct for host-side calculations -
tests/pytorch/test_permutation.py, line 1490-1510 (link)style: consider adding an empty input test case for the fused operations, similar to
test_permutation_mask_map_empty_inputat line 1514
5 files reviewed, 2 comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
transformer_engine/pytorch/permutation.py, line 625 (link)syntax: needs
.float()before division to avoid integer division issues
5 files reviewed, 1 comment
Signed-off-by: tdophung <tdophung@nvidia.com>
…tch test for sanity in case combine accidentally undo an error on dispatch in the roundtrip test. Add distinction between L0 and L2 in test cases in jax Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (2)
-
transformer_engine/pytorch/permutation.py, line 630 (link)style: potential device transfer inefficiency if
tokens_per_expertis already on GPUtokens_per_expertmay already be on GPU in some use cases. The.to(inp.device)call after computingtarget_tokens_per_expert - tokens_per_expertforces a device transfer. Consider moving the entire computation to the input device upfront:or check device once at the start of the function.
-
transformer_engine/pytorch/permutation.py, line 625 (link)logic: missing validation for
align_sizeparameteralign_sizeshould be validated to ensure it's positive before division. Invalid values (zero or negative) would cause division by zero or incorrect alignment calculations.
9 files reviewed, 2 comments
…redundant primitive using a new static bool for padding, add assert for align size Signed-off-by: tdophung <tdophung@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
transformer_engine/common/triton/permutation.py, line 253-258 (link)style: optimized to load
expert_idxonly once when eitherFUSION_PADorPERMUTE_PROBSis enabled, avoiding duplicate loads of the same value fromrow_id_mapNote: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
9 files reviewed, 1 comment
|
@tdophung , The CI failure (Lint / JAX Python) seems to be triggered by the changes under the jax/ directory. |
Signed-off-by: tdophung <tdophung@nvidia.com>
|
/te-ci |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
tests/pytorch/test_permutation.py, line 969-970 (link)style: inconsistent None comparison - use
is Noneinstead of== NoneNote: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
9 files reviewed, 1 comment
jberchtold-nvidia
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM pending the JAX version check, thanks!
… triton binding accordingly Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (6)
-
transformer_engine/pytorch/permutation.py, line 632-633 (link)logic: the
pad_offsetscheck should happen before usingtorch.equal()sincetorch.equal()returns False when comparing tensors of different shapes, but the real issue is when padding isn't needed -
transformer_engine/pytorch/permutation.py, line 635-639 (link)style:
cum_pad[:-1]can be empty when there's only one expert. This will causetorch.catto concatenate a single-element zero tensor with an empty tensor, which works but is inefficient -
transformer_engine/pytorch/permutation.py, line 630 (link)logic: using
.long()afterceil()creates int64 tensors, but the rest of the codebase uses int32 for indexing (see line 52, 149). This inconsistency could cause issues or unnecessary dtype conversionsShould target_tokens_per_expert use int32 to match the rest of the codebase's indexing dtypes, or is int64 needed for the
.sum().item()call on line 642? -
transformer_engine/common/triton/permutation.py, line 253-258 (link)style: loading
expert_idxtwice when bothFUSION_PADandPERMUTE_PROBSare true - the load at line 254-258 happens, then another identical load occurs at line 269-273 (not shown but in original code). Move theexpert_idxload outside the conditional to avoid redundant memory accesses -
transformer_engine/pytorch/triton/permutation.py, line 160-162 (link)style: using
torch.zerosfor allocation whenpad_offsetsis provided ensures padding regions are zeroed, but this is more expensive thantorch.empty. Since the kernel writes zeros explicitly for padding slots (line 277 in common/triton/permutation.py), usingtorch.emptywould be more efficient. Does the kernel reliably write to all padding positions, or are there edge cases where torch.zeros is necessary for correctness?Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
-
tests/pytorch/test_permutation.py, line 1791 (link)logic: removed
topK=1from test parameters, but single-expert routing (topK=1) is a common edge case that should be tested to ensure the fused padding logic works correctly when each token routes to exactly one expert
9 files reviewed, 6 comments
jberchtold-nvidia
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
timmoon10
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
|
/te-ci |
7cad5c5
Greptile's behavior is changing!From now on, if a review finishes with no comments, we will not post an additional "statistics" comment to confirm that our review found nothing to comment on. However, you can confirm that we reviewed your changes in the status check section. This feature can be toggled off in your Code Review Settings by deselecting "Create a status check for each PR". |
|
/te-ci |
|
/te-ci jax |
|
@tdophung @pytest.mark.parametrize("num_tokens", [4096]) In addition, we could remove the test_permutation_and_padding_with_merging_probs test for reduce test time, since the test scenarios for with_merging_probs args are already covered in test_permutation_and_padding_mask_map. |

1.Fused
moe_permute_with_probs+Fp8Paddingand fusedmoe_unpermute+Fp8Unpadding, which removes the explicit padding/unpadding in the MOE experts module, improved performance and reduced peak gpu memory usage.2.Added tests of fused permute/pad and unpermute/unpad operations.
Description
This PR optimizes FP8 MoE permute and pad operations by:
moe_permute_with_probs+Fp8Paddingintomoe_permute_and_pad_with_probsmoe_unpermute+Fp8Unpaddingintomoe_unpermutewithpad_offsetsargumentResults:
tests/pytorch/test_permutation.py)Performance data
Tests covering a wide range of model training configurations were performed comparing the fused operations ("Fused:") and the original version ("Orig:"). Running time (in milliseconds) are summarized in the table below and the speedup, measured as the reciprocal of the ratio between running times, are also provided. All tests were carried out with the tests/pytorch/test_permutation.py benchmark script.
The usage in Megatron-LM
Megatron-LM/megatron/core/transformer/moe/moe_utils.py: Added Support for Fused Operations`
`
Megatron-LM/megatron/core/transformer/moe/token_dispatcher.py: Scheduler Integration`
`
Type of change
Documentation change (change only to the documentation, either a fix or a new content)
Changes
moe_permute_and_pad_with_probsapi for fused permute and pad, modifiedmoe_unpermuteapi with pad_offsets argument for fused unpermute and unpad in transformer_engine/pytorch/permutation.pytests/pytorch/test_permutation.pyChecklist: