Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
feb65b7
[PyTorch] Fuse permute+pad and unpermute+unpad ops for FP8 optimization
xiaoxi-wangfj Jul 3, 2025
a7de66c
[PyTorch/Common] Fuse permute+pad and unpermute+unpad support with_me…
xiaoxi-wangfj Dec 11, 2025
f550684
[PyTorch]format code
xiaoxi-wangfj Dec 11, 2025
6069277
[Common]perf expert_idx loaded once
xiaoxi-wangfj Dec 11, 2025
053abee
Merge branch 'main' into fused_perm_pad
xiaoxi-wangfj Dec 12, 2025
1ea08f7
fix: pad_offsets can be None
xiaoxi-wangfj Dec 17, 2025
ac12a91
Merge branch 'main' into fused_perm_pad
xiaoxi-wangfj Dec 17, 2025
230939c
add padding + merging probs bwd support. Not tested
tdophung Dec 11, 2025
f301462
Fix garbage initialized act grad
tdophung Dec 11, 2025
7ed584c
all test passing for jax permutation + pad
tdophung Dec 17, 2025
7998ce8
change tokens_per_experts APIs to num_out_tokens with conservative a…
tdophung Dec 17, 2025
dd5c72a
change test permutation to reduce test time
tdophung Dec 19, 2025
ce187b6
triggering PR refresh
tdophung Dec 19, 2025
7dc9ccb
format code
tdophung Dec 20, 2025
1fbe99c
Remove some tests cases from pytorch side. Add a separate toekn_dispa…
tdophung Dec 20, 2025
592f675
format code
tdophung Dec 20, 2025
1d43279
remove chance for inefficiency in moving between CPU and GPU, remove …
tdophung Dec 20, 2025
4169a4e
fix lint in jax
tdophung Dec 22, 2025
c619adf
account for both jax newer and older than version 0.8.2. Adjusted gpu…
tdophung Dec 22, 2025
405b341
format code
tdophung Dec 22, 2025
7cad5c5
fix typo
tdophung Dec 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
897 changes: 565 additions & 332 deletions tests/jax/test_permutation.py

Large diffs are not rendered by default.

658 changes: 649 additions & 9 deletions tests/pytorch/test_permutation.py

Large diffs are not rendered by default.

38 changes: 29 additions & 9 deletions transformer_engine/common/triton/permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ def _permute_kernel(
probs_ptr,
scale_ptr,
permuted_scale_ptr,
pad_offsets_ptr,
# sizes
scale_hidden_dim,
# strides
Expand All @@ -224,8 +225,11 @@ def _permute_kernel(
hidden_size: tl.constexpr,
PERMUTE_PROBS: tl.constexpr,
PERMUTE_SCALE: tl.constexpr,
FUSION_PAD: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
expert_idx = 0

pid_t = tl.program_id(0)
pid_h = tl.program_id(1)
cur_off = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
Expand All @@ -246,18 +250,22 @@ def _permute_kernel(
dst_row = tl.load(
row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert
).to(tl.int64)
if FUSION_PAD or PERMUTE_PROBS:
expert_idx = tl.load(
row_id_map_ptr
+ pid_t * stride_row_id_map_token
+ (num_experts + idx) * stride_row_id_map_expert
)
if FUSION_PAD:
pad_off = tl.load(pad_offsets_ptr + expert_idx)
dst_row = dst_row + pad_off
output_off = dst_row * stride_output_token + cur_off * stride_output_hidden
if PERMUTE_SCALE:
permuted_scale_off = (
dst_row * stride_permuted_scale_token + cur_off * stride_permuted_scale_hidden
)
tl.store(permuted_scale_ptr + permuted_scale_off, scale, mask=mask_scale)
if PERMUTE_PROBS:
expert_idx = tl.load(
row_id_map_ptr
+ pid_t * stride_row_id_map_token
+ (num_experts + idx) * stride_row_id_map_expert
)
prob_off = pid_t * stride_probs_token + expert_idx * stride_probs_expert
prob = tl.load(probs_ptr + prob_off)
if pid_h == 0:
Expand Down Expand Up @@ -297,6 +305,7 @@ def _unpermute_kernel(
row_id_map_ptr,
merging_probs_ptr,
permuted_probs_ptr,
pad_offsets_ptr,
# strides
stride_row_id_map_token,
stride_row_id_map_expert,
Expand All @@ -318,10 +327,12 @@ def _unpermute_kernel(
PROBS_LOAD_WIDTH: tl.constexpr,
WITH_MERGING_PROBS: tl.constexpr,
PERMUTE_PROBS: tl.constexpr,
FUSION_UNPAD: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
data_type = input_ptr.dtype.element_ty
compute_type = tl.float32
expert_idx = 0

pid_t = tl.program_id(0)
pid_h = tl.program_id(1)
Expand All @@ -348,15 +359,19 @@ def _unpermute_kernel(
src_row = tl.load(
row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert
).to(tl.int64)
input_off = src_row * stride_input_token + current_offset * stride_input_hidden
inp = tl.load(input_ptr + input_off, mask=mask)
inp = inp.to(compute_type)
if WITH_MERGING_PROBS:
if FUSION_UNPAD or WITH_MERGING_PROBS:
expert_idx = tl.load(
row_id_map_ptr
+ pid_t * stride_row_id_map_token
+ (num_experts + idx) * stride_row_id_map_expert
)
if FUSION_UNPAD:
pad_off = tl.load(pad_offsets_ptr + expert_idx)
src_row = src_row + pad_off
input_off = src_row * stride_input_token + current_offset * stride_input_hidden
inp = tl.load(input_ptr + input_off, mask=mask)
inp = inp.to(compute_type)
if WITH_MERGING_PROBS:
merging_prob_off = (
pid_t * stride_merging_probs_token + expert_idx * stride_merging_probs_expert
)
Expand Down Expand Up @@ -407,6 +422,7 @@ def _unpermute_bwd_with_merging_probs_kernel(
fwd_input_ptr,
merging_probs_ptr,
row_id_map_ptr,
pad_offsets_ptr,
# strides
stride_row_id_map_token,
stride_row_id_map_expert,
Expand All @@ -427,6 +443,7 @@ def _unpermute_bwd_with_merging_probs_kernel(
num_experts: tl.constexpr,
hidden_size: tl.constexpr,
PROBS_LOAD_WIDTH: tl.constexpr,
FUSION_UNPAD: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
data_type = fwd_output_grad_ptr.dtype.element_ty
Expand All @@ -450,6 +467,9 @@ def _unpermute_bwd_with_merging_probs_kernel(
+ pid * stride_row_id_map_token
+ (num_experts + idx) * stride_row_id_map_expert
)
if FUSION_UNPAD:
pad_off = tl.load(pad_offsets_ptr + expert_idx)
dst_row = dst_row + pad_off
prob_grad_accum = tl.zeros((BLOCK_SIZE,), dtype=compute_type)
current_start = 0
while current_start < hidden_size:
Expand Down
Loading
Loading