Skip to content

Conversation

@xiaoxi-wangfj
Copy link
Contributor

@xiaoxi-wangfj xiaoxi-wangfj commented Jul 3, 2025

1.Fused moe_permute_with_probs + Fp8Padding and fused moe_unpermute + Fp8Unpadding, which removes the explicit padding/unpadding in the MOE experts module, improved performance and reduced peak gpu memory usage.
2.Added tests of fused permute/pad and unpermute/unpad operations.

Description

This PR optimizes FP8 MoE permute and pad operations by:

  1. Fusing moe_permute_with_probs + Fp8Padding into moe_permute_and_pad_with_probs
  2. Fusing moe_unpermute + Fp8Unpadding into moe_unpermute with pad_offsets argument
  3. Thereby removing explicit padding/unpadding steps in the MOE experts module

Results:

  • 1.1x~1.6x speedup for fused permute-and-pad operations
  • 1.7x~3x speedup for fused unpermute-and-unpad operations (measured by tests/pytorch/test_permutation.py)
  • Verified in ene-to-end FP8 model training with Megatron framework, +0.4% MFU uplift and ~1GB peak GPU memory reduction in a typical ~600B paramter setup.

Performance data

Tests covering a wide range of model training configurations were performed comparing the fused operations ("Fused:") and the original version ("Orig:"). Running time (in milliseconds) are summarized in the table below and the speedup, measured as the reciprocal of the ratio between running times, are also provided. All tests were carried out with the tests/pytorch/test_permutation.py benchmark script.

Fused-perm-pad

The usage in Megatron-LM

  1. Megatron-LM/megatron/core/transformer/moe/moe_utils.py : Added Support for Fused Operations

`

# Added fused function import
from megatron.core.extensions.transformer_engine import (
    ...,
    fused_permute_and_pad_with_probs,  # [!code ++]
)

def permute(
    ...,
    tokens_per_expert: Optional[torch.Tensor] = None,  # [!code ++]
    align_size: int = -1  # [!code ++]
):
  ...
  if fused and probs is not None:
      if not HAVE_TE or fused_permute_with_probs is None:
          raise ValueError(
              "fused_permute_with_probs is not available. Please install TE >= 2.1.0."
          )
      if tokens_per_expert is not None and align_size > 0:  # [!code ++]
          # Use fused permute+pad operation [!code ++]
          return fused_permute_and_pad_with_probs(tokens, probs, routing_map, tokens_per_expert, align_size)   # [!code ++]
      else:
          # Fallback to original implementation
          ...


def unpermute(
    ...,
    pad_offsets: Optional[torch.Tensor] = None  # [!code ++]
):
    return fused_unpermute(
        ...,
        pad_offsets=pad_offsets  # [!code ++]
    )

`

  1. Megatron-LM/megatron/core/transformer/moe/token_dispatcher.py: Scheduler Integration

`

class _DeepepManager(_DispatchManager):
    def __init__(...):
        self.pad_offsets = None  # [!code ++] Store padding offsets
    
    def get_permuted_hidden_states_by_experts(...):
        ...
        if self.config.moe_permute_padding_for_fp8:# [!code ++]
            # Use fused path [!code ++]
            (                                                          # [!code ++]
                hidden_states,                                         # [!code ++]
                permuted_probs,                                        # [!code ++]
                self.reversed_mapping_for_combine,                     # [!code ++]
                self.pad_offsets,                                      # [!code ++]
                self.tokens_per_expert                                 # [!code ++]
            ) = permute(                                               # [!code ++]
                hidden_states,                                         # [!code ++]
                self.dispatched_routing_map,                           # [!code ++]
                probs=self.dispatched_probs,                           # [!code ++]
                fused=self.permute_fusion,                             # [!code ++]
                tokens_per_expert=self.tokens_per_expert,              # [!code ++]
                align_size=get_fp8_align_size(self.config.fp8_recipe), # [!code ++]
            )                                                          # [!code ++]
        else:
            # Original path
            ...
    
    def get_restored_hidden_states_by_experts(...):
        hidden_states = unpermute(
            ...,
            pad_offsets=self.pad_offsets if self.config.moe_permute_padding_for_fp8 else None, # [!code ++]
        )
        ...

`

Type of change

Documentation change (change only to the documentation, either a fix or a new content)

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Added moe_permute_and_pad_with_probs api for fused permute and pad, modified moe_unpermute api with pad_offsets argument for fused unpermute and unpad in transformer_engine/pytorch/permutation.py
  • Added tests in tests/pytorch/test_permutation.py

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@xiaoxi-wangfj xiaoxi-wangfj marked this pull request as draft July 4, 2025 04:49
@xiaoxi-wangfj xiaoxi-wangfj marked this pull request as ready for review July 4, 2025 05:03
@yaox12
Copy link
Member

yaox12 commented Jul 11, 2025

Thanks for you contribution.
We also notice there're redundant read and write with the current "permute and then pad" routine. We plan to tackle it by padding the routing map before permutation. Refer to this commit in Megatron-LM.
For permute fusion, we try to avoid tiny cuda kernels as much as possible. I went through your PR and found that these lines may introduce some of them.
We prefer padding routing map instead of fusing permutation and padding because we think that tokens_per_expert is produced by preprocess and we don't want it to be changed later, otherwise it may cause some confusions.
So we won't merge this PR. Thanks for you contribution again.

Never mind, I agree that fusing permutation with padding is a better solution.

@xiaoxi-wangfj
Copy link
Contributor Author

xiaoxi-wangfj commented Jul 11, 2025

Thanks for you contribution. We also notice there're redundant read and write with the current "permute and then pad" routine. We plan to tackle it by padding the routing map before permutation. Refer to this commit in Megatron-LM. For permute fusion, we try to avoid tiny cuda kernels as much as possible. I went through your PR and found that these lines may introduce some of them. We prefer padding routing map instead of fusing permutation and padding because we think that tokens_per_expert is produced by preprocess and we don't want it to be changed later, otherwise it may cause some confusions. So we won't merge this PR. Thanks for you contribution again.

@yaox12
Thank you for your response.

  1. The moe_router_padding_for_fp8 and fused_permute_pad_for_fp8 configurations are compatible. Within the moe_router_padding_for_fp8 logic, if not_enough_tokens_to_pad is triggered, execution will fall back to the fused_permute_pad_for_fp8 computational path. Otherwise, pad_offsets will be set to None."

  2. We previously enabled the Megatron-LM commit configuration you mentioned, but found that pading routing may caused loss instability. During pre-training—especially with larger fp8 align size values like 128 (due to our 1*128 blockwise setting)—it frequently triggered the not_enough_tokens_to_pad warning, and this forced a fallback to explicit padding within GroupedMLP, which halved iteration performance when occurring, and sometimes resulted in loss deterioration. Then we disabled pading routing map.
    The implementation of Fused_permute_pad will skips any configuration modifications for Fp8Padding/Fp8Unpadding through its update of tokens_per_expert. Regarding refinements to this fused approach, do you have any suggestions for? I hope to continue to refine that optimization."

@shcho1118
Copy link

shcho1118 commented Jul 22, 2025

We have also implemented permute + pad, unpermute + unpad fusion that works with deepgemm grouped gemm for training, and we are using it.
The implementation is a middle ground between what @yaox12 mentioned and this pr (we pad routing map, but there are some tiny kernels).
The loss instability mentioned by @xiaoxi-wangfj could be due to garbage values in padded gradients from unpermute backward (same as permute gradients).
We had a similar problem, and the solution was to do a zero-init when allocating gradient tensors.

@nvMelissa nvMelissa added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label Oct 9, 2025
@xiaoxi-wangfj
Copy link
Contributor Author

During pretraining of a 600B-parameter MoE model using the blockwise recipe, I observed that the pad_routing_map behavior leads to performance degradation at larger expert parallelism (EP) scales. Therefore, I evaluated the performance of the fused_permute_pad optimization under these settings. The results are as follows:
I found that fused_permute_pad demonstrates better scalability and delivers consistent performance gains over the baseline across different EP configurations.
fusedpermpad

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Dec 10, 2025

Greptile Summary

  • Introduces fused permute+pad and unpermute+unpad operations for FP8 MoE optimization that combine token permutation with memory alignment in single operations
  • Adds new moe_permute_and_pad_with_probs function to PyTorch API and extends existing moe_unpermute with pad_offsets parameter support
  • Implements comprehensive Triton kernel modifications and JAX primitives to support the fused operations with FUSION_PAD/FUSION_UNPAD flags

Important Files Changed

Filename Overview
transformer_engine/common/triton/permutation.py Core Triton kernels modified with pad_offsets support and FUSION_PAD/FUSION_UNPAD flags for fused operations
transformer_engine/pytorch/permutation.py New moe_permute_and_pad_with_probs function added and existing functions extended with pad_offsets parameter
transformer_engine/jax/triton_extensions/permutation.py New JAX primitives added for fused operations with proper MLIR lowering support
tests/pytorch/test_permutation.py Comprehensive test coverage added for fused operations with correctness validation against naive implementations

Confidence score: 4/5

  • This PR appears safe to merge with significant performance benefits and comprehensive test coverage
  • Score reflects the complexity of the implementation across multiple layers (kernels, PyTorch/JAX APIs, tests) and potential for subtle bugs in fused operations, though the extensive testing mitigates this risk
  • Pay close attention to the Triton kernel changes in transformer_engine/common/triton/permutation.py as they handle memory indexing logic that could cause memory access issues if incorrect

Sequence Diagram

sequenceDiagram
    participant User
    participant API as "moe_permute_and_pad_with_probs"
    participant AutogradFn as "_moe_permute_mask_map"
    participant TritonKernel as "permute_with_mask_map"
    participant UnpermuteAPI as "moe_unpermute"
    participant UnpermuteAutograd as "_moe_unpermute_mask_map"
    participant UnpermuteKernel as "unpermute_with_mask_map"

    User->>API: "moe_permute_and_pad_with_probs(inp, probs, routing_map, tokens_per_expert, align_size)"
    API->>API: "Calculate target_tokens_per_expert and pad_offsets"
    API->>AutogradFn: "apply(inp, routing_map, num_out_tokens, probs, pad_offsets)"
    AutogradFn->>TritonKernel: "permute_with_mask_map(..., pad_offsets, ...)"
    TritonKernel->>AutogradFn: "Return permuted and padded output, permuted_probs"
    AutogradFn->>API: "Return output, row_id_map, permuted_probs"
    API->>User: "Return output, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert"

    User->>UnpermuteAPI: "moe_unpermute(inp, row_id_map, merging_probs, pad_offsets=pad_offsets)"
    UnpermuteAPI->>UnpermuteAutograd: "apply(inp, row_id_map, merging_probs, restore_shape, pad_offsets)"
    UnpermuteAutograd->>UnpermuteKernel: "unpermute_with_mask_map(..., pad_offsets, ...)"
    UnpermuteKernel->>UnpermuteAutograd: "Return unpermuted and unpadded output"
    UnpermuteAutograd->>UnpermuteAPI: "Return unpermuted_output"
    UnpermuteAPI->>User: "Return unpermuted_output"
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (2)

  1. transformer_engine/pytorch/triton/permutation.py, line 438-470 (link)

    syntax: syntax error: duplicated and malformed tensor allocation code

    lines 438-456 allocate output, permuted_probs, and permuted_scale with if-else blocks. then lines 457-470 duplicate this logic but leave line 456's if-block incomplete (missing the else branch allocation) and line 467 is a stray incomplete statement.

  2. transformer_engine/pytorch/permutation.py, line 632-633 (link)

    logic: inconsistent device placement

    torch.cat([torch.zeros(1, dtype=torch.int32), cum_pad[:-1]]) creates tensors on CPU but cum_pad is derived from tokens_per_expert which may be on CPU. explicitly calling .cuda() afterwards can cause unnecessary data transfers and potential device mismatch errors if tokens_per_expert is already on CUDA.

4 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (4)

  1. transformer_engine/pytorch/permutation.py, line 594 (link)

    syntax: Return type annotation incorrect - function returns 5 values but annotation shows only 2

  2. transformer_engine/pytorch/permutation.py, line 615-616 (link)

    style: Missing documentation for return values. Function returns 5 values: (output, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert) but docstring doesn't describe them.

  3. transformer_engine/pytorch/permutation.py, line 632-633 (link)

    logic: Potential device mismatch - tokens_per_expert is on CPU (line 702 in test shows .cpu()), so cum_pad will be on CPU too, but then .cuda() is called. If user's input tokens_per_expert is already on CUDA or on a specific CUDA device, this could fail or move to wrong device. Use tokens_per_expert.device instead of hardcoded .cuda()

  4. transformer_engine/pytorch/triton/permutation.py, line 350-357 (link)

    style: Duplicate expert_idx load when both FUSION_PAD and PERMUTE_PROBS are true. The same value is loaded again at lines 365-369. Consider loading once and reusing to improve performance.

4 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. transformer_engine/pytorch/permutation.py, line 632-633 (link)

    logic: hardcoded .cuda() call may fail if input tensors are on different devices

5 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. transformer_engine/pytorch/permutation.py, line 631 (link)

    syntax: dtype mismatch: cum_pad[:-1] is torch.long but torch.zeros(1, dtype=torch.int32) is torch.int32

5 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

1.Fused `moe_permute_with_probs` + `Fp8Padding` and fused `moe_unpermute` + `Fp8Unpadding`,
  that can remove the explicit padding/unpadding of moe expert, improved performance and reduced peak gpu memory usage.
2.Add tests of fused permute/pad and unpermute/unpad.

Signed-off-by: xiaoxi-wangfj <690912414@qq.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@tdophung
Copy link
Collaborator

tdophung commented Dec 11, 2025

also, since you do not allow write permission, the automatic pre-commit.ci service cannot automatically format the code for you and therefore fail the pre-commit.ci - pr check. Therefore, you could also try running the qa/format.sh script manually to format the code before pushing. Otherwise please enable write permission (as requested for jax compatibility purposes above)

…rging_probs

Signed-off-by: xiaoxi-wangfj <690912414@qq.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. transformer_engine/common/triton/permutation.py, line 251-270 (link)

    style: expert_idx loaded twice when both FUSION_PAD and PERMUTE_PROBS are enabled - once at line 252-256 and again at lines 266-270. Consider loading once and reusing.

5 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Signed-off-by: xiaoxi-wangfj <690912414@qq.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (2)

  1. transformer_engine/pytorch/permutation.py, line 625 (link)

    style: the torch.ceil operation inherits device from input - if tokens_per_expert is on CPU (as in the test), target_tokens_per_expert will also be on CPU, which is correct for host-side calculations

  2. tests/pytorch/test_permutation.py, line 1490-1510 (link)

    style: consider adding an empty input test case for the fused operations, similar to test_permutation_mask_map_empty_input at line 1514

5 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. transformer_engine/pytorch/permutation.py, line 625 (link)

    syntax: needs .float() before division to avoid integer division issues

5 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Signed-off-by: tdophung <tdophung@nvidia.com>
…tch test for sanity in case combine accidentally undo an error on dispatch in the roundtrip test. Add distinction between L0 and L2 in test cases in jax

Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (2)

  1. transformer_engine/pytorch/permutation.py, line 630 (link)

    style: potential device transfer inefficiency if tokens_per_expert is already on GPU

    tokens_per_expert may already be on GPU in some use cases. The .to(inp.device) call after computing target_tokens_per_expert - tokens_per_expert forces a device transfer. Consider moving the entire computation to the input device upfront:

    or check device once at the start of the function.

  2. transformer_engine/pytorch/permutation.py, line 625 (link)

    logic: missing validation for align_size parameter

    align_size should be validated to ensure it's positive before division. Invalid values (zero or negative) would cause division by zero or incorrect alignment calculations.

9 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

…redundant primitive using a new static bool for padding, add assert for align size

Signed-off-by: tdophung <tdophung@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. transformer_engine/common/triton/permutation.py, line 253-258 (link)

    style: optimized to load expert_idx only once when either FUSION_PAD or PERMUTE_PROBS is enabled, avoiding duplicate loads of the same value from row_id_map

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

9 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@xiaoxi-wangfj
Copy link
Contributor Author

@tdophung , The CI failure (Lint / JAX Python) seems to be triggered by the changes under the jax/ directory.
Would you mind fixing the issue on your side?

Signed-off-by: tdophung <tdophung@nvidia.com>
@tdophung
Copy link
Collaborator

/te-ci

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. tests/pytorch/test_permutation.py, line 969-970 (link)

    style: inconsistent None comparison - use is None instead of == None

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

9 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Copy link
Collaborator

@jberchtold-nvidia jberchtold-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM pending the JAX version check, thanks!

… triton binding accordingly

Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (6)

  1. transformer_engine/pytorch/permutation.py, line 632-633 (link)

    logic: the pad_offsets check should happen before using torch.equal() since torch.equal() returns False when comparing tensors of different shapes, but the real issue is when padding isn't needed

  2. transformer_engine/pytorch/permutation.py, line 635-639 (link)

    style: cum_pad[:-1] can be empty when there's only one expert. This will cause torch.cat to concatenate a single-element zero tensor with an empty tensor, which works but is inefficient

  3. transformer_engine/pytorch/permutation.py, line 630 (link)

    logic: using .long() after ceil() creates int64 tensors, but the rest of the codebase uses int32 for indexing (see line 52, 149). This inconsistency could cause issues or unnecessary dtype conversions

    Should target_tokens_per_expert use int32 to match the rest of the codebase's indexing dtypes, or is int64 needed for the .sum().item() call on line 642?

  4. transformer_engine/common/triton/permutation.py, line 253-258 (link)

    style: loading expert_idx twice when both FUSION_PAD and PERMUTE_PROBS are true - the load at line 254-258 happens, then another identical load occurs at line 269-273 (not shown but in original code). Move the expert_idx load outside the conditional to avoid redundant memory accesses

  5. transformer_engine/pytorch/triton/permutation.py, line 160-162 (link)

    style: using torch.zeros for allocation when pad_offsets is provided ensures padding regions are zeroed, but this is more expensive than torch.empty. Since the kernel writes zeros explicitly for padding slots (line 277 in common/triton/permutation.py), using torch.empty would be more efficient. Does the kernel reliably write to all padding positions, or are there edge cases where torch.zeros is necessary for correctness?

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

  6. tests/pytorch/test_permutation.py, line 1791 (link)

    logic: removed topK=1 from test parameters, but single-expert routing (topK=1) is a common edge case that should be tested to ensure the fused padding logic works correctly when each token routes to exactly one expert

9 files reviewed, 6 comments

Edit Code Review Agent Settings | Greptile

Copy link
Collaborator

@jberchtold-nvidia jberchtold-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

timmoon10
timmoon10 previously approved these changes Dec 22, 2025
Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@tdophung
Copy link
Collaborator

/te-ci

Signed-off-by: tdophung <tdophung@nvidia.com>
@tdophung tdophung dismissed stale reviews from timmoon10 and jberchtold-nvidia via 7cad5c5 December 23, 2025 02:40
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Dec 23, 2025

Greptile's behavior is changing!

From now on, if a review finishes with no comments, we will not post an additional "statistics" comment to confirm that our review found nothing to comment on. However, you can confirm that we reviewed your changes in the status check section.

This feature can be toggled off in your Code Review Settings by deselecting "Create a status check for each PR".

@yaox12
Copy link
Member

yaox12 commented Dec 23, 2025

/te-ci

@tdophung
Copy link
Collaborator

/te-ci jax

@xiaoxi-wangfj
Copy link
Contributor Author

@tdophung
Is the CI failure caused by test_permutation_and_padding_mask_map taking too long?
Perhaps we can reduce the configuration values for "num_tokens, num_expert, hidden_size, topK", making them consistent with the configuration used in test_permutation_mask_map, for example:

@pytest.mark.parametrize("num_tokens", [4096])
@pytest.mark.parametrize("num_expert", [7, 16])
@pytest.mark.parametrize("hidden_size", [4096])
@pytest.mark.parametrize("topK", [2, 5])

In addition, we could remove the test_permutation_and_padding_with_merging_probs test for reduce test time, since the test scenarios for with_merging_probs args are already covered in test_permutation_and_padding_mask_map.
Do you think this is appropriate?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. MoE

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants