From feb65b74f9633a00f9765c56f84183456581879f Mon Sep 17 00:00:00 2001 From: xiaoxi-wangfj <690912414@qq.com> Date: Thu, 3 Jul 2025 10:42:45 +0000 Subject: [PATCH 01/19] [PyTorch] Fuse permute+pad and unpermute+unpad ops for FP8 optimization 1.Fused `moe_permute_with_probs` + `Fp8Padding` and fused `moe_unpermute` + `Fp8Unpadding`, that can remove the explicit padding/unpadding of moe expert, improved performance and reduced peak gpu memory usage. 2.Add tests of fused permute/pad and unpermute/unpad. Signed-off-by: xiaoxi-wangfj <690912414@qq.com> --- tests/pytorch/test_permutation.py | 355 ++++++++++++++++++ .../common/triton/permutation.py | 20 + transformer_engine/pytorch/__init__.py | 1 + transformer_engine/pytorch/permutation.py | 91 ++++- .../pytorch/triton/permutation.py | 36 +- 5 files changed, 480 insertions(+), 23 deletions(-) diff --git a/tests/pytorch/test_permutation.py b/tests/pytorch/test_permutation.py index e8a7bedc873..5a650a21680 100644 --- a/tests/pytorch/test_permutation.py +++ b/tests/pytorch/test_permutation.py @@ -13,6 +13,7 @@ from transformer_engine.pytorch import ( moe_permute as te_permute, moe_permute_with_probs as te_permute_with_probs, + moe_permute_and_pad_with_probs as te_permute_and_pad_with_probs, moe_unpermute as te_unpermute, moe_sort_chunks_by_index as te_sort_chunks_by_index, moe_sort_chunks_by_index_with_probs as te_sort_chunks_by_index_with_probs, @@ -24,6 +25,7 @@ MXFP8Quantizer, ) import transformer_engine_torch as tex +from transformer_engine.pytorch import Fp8Padding, Fp8Unpadding import copy seed = 1234 @@ -653,6 +655,303 @@ def _test_permutation_mask_map( print(f"unpermute\tbwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms") +def _test_permutation_and_padding_mask_map( + te_dtype, + num_tokens, + num_expert, + hidden_size, + topK, + num_out_tokens, + align_size=16, + BENCHMARK=False, +): + if topK > num_expert: + pytest.skip("topK should be smaller than the number of experts.") + + if num_out_tokens == None: + num_out_tokens = num_tokens * topK + + print( + "permutation and padding:" + f" token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK} align_size:{align_size} {te_dtype}" + ) + + # Convert TE dtypes to PyTorch dtypes + if te_dtype == tex.DType.kFloat32: + dtype = torch.float32 + elif te_dtype == tex.DType.kFloat16: + dtype = torch.float16 + elif te_dtype == tex.DType.kBFloat16: + dtype = torch.bfloat16 + else: + pytest.skip("Invalid dtype.") + + _tmp_tensor = torch.zeros((num_tokens * num_expert,)) + _tmp_tensor[: int(num_out_tokens)] = 1.0 + _tmp_idx = torch.randperm(num_tokens * num_expert) + routing_map = ( + torch.reshape(_tmp_tensor[_tmp_idx], (num_tokens, num_expert)).bool().cuda() + ) + + probs = torch.rand(num_tokens, num_expert).cuda() * routing_map + row_sums = probs.sum(dim=1, keepdim=True) + probs = probs / row_sums + probs = probs.to(dtype) + probs.requires_grad_(True) + + tokens_per_expert = routing_map.sum(dim=0).cpu() + target_tokens_per_expert = ( + torch.ceil(tokens_per_expert / align_size) * align_size + ).long() + num_permute_pad_out_tokens = target_tokens_per_expert.sum().item() + + permute_pad_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() + permute_pad_bwd_input = torch.rand( + (num_permute_pad_out_tokens, hidden_size), dtype=dtype + ).cuda() + unpermute_unpad_bwd_input = torch.rand( + (num_tokens, hidden_size), dtype=dtype + ).cuda() + permute_pad_fwd_input.requires_grad_(True) + + restore_shape = permute_pad_fwd_input.shape + ################################################################################################################################### + # + # moe_permute_with_probs and Fp8Padding, moe_unpermute and Fp8Unpadding + # + ################################################################################################################################### + # permute + padding + permuted_output, permuted_probs, row_id_map = te_permute_with_probs( + permute_pad_fwd_input, + probs, + routing_map, + num_out_tokens=num_out_tokens, + ) + tokens_per_expert_list = tokens_per_expert.tolist() + fp8_padding = Fp8Padding(num_expert, align_size) + permuted_paded_output, _ = fp8_padding(permuted_output, tokens_per_expert_list) + permuted_paded_probs, _ = fp8_padding( + permuted_probs.unsqueeze(-1), tokens_per_expert_list + ) + + permuted_paded_output.backward(permute_pad_bwd_input, retain_graph=True) + + # unpadding + unpermute + + unpermute_unpad_fwd_input = permuted_paded_output.detach() + unpermute_unpad_fwd_input.requires_grad_(True) + + fp8_unpadding = Fp8Unpadding(num_expert, align_size) + unpaded_output = fp8_unpadding(unpermute_unpad_fwd_input, tokens_per_expert_list) + unpermuted_unpaded_output = te_unpermute( + unpaded_output, row_id_map, restore_shape=restore_shape + ) + + unpermuted_unpaded_output.backward(unpermute_unpad_bwd_input, retain_graph=True) + + ################################################################################################################################### + # + # fusion moe_permute_with_probs and Fp8Padding, fusion fusion moe_unpermute and Fp8Unpadding + # + ################################################################################################################################### + # fusion permute_and_pad + fusion_permute_and_pad_fwd_input = permute_pad_fwd_input.detach() + fusion_permute_and_pad_fwd_input.requires_grad_(True) + probs = probs.detach() + probs.requires_grad_(True) + + ( + fusion_permuted_padded_output, + fusion_permuted_padded_probs, + row_id_map, + pad_offsets, + target_tokens_per_expert, + ) = te_permute_and_pad_with_probs( + fusion_permute_and_pad_fwd_input, + probs, + routing_map, + tokens_per_expert, + align_size, + ) + fusion_permuted_padded_probs = fusion_permuted_padded_probs.unsqueeze(-1) + + fusion_permute_pad_bwd_input = permute_pad_bwd_input.detach() + fusion_permuted_padded_output.backward( + fusion_permute_pad_bwd_input, retain_graph=True + ) + + # fusion unpad and unpermute + fusion_unpermute_unpad_fwd_input = fusion_permuted_padded_output.detach() + fusion_unpermute_unpad_fwd_input.requires_grad_(True) + + fusion_unpermuted_unpaded_output = te_unpermute( + fusion_unpermute_unpad_fwd_input, + row_id_map, + restore_shape=restore_shape, + pad_offsets=pad_offsets, + ) + + fusion_unpermute_bwd_input = unpermute_unpad_bwd_input.detach() + fusion_unpermuted_unpaded_output.backward( + fusion_unpermute_bwd_input, retain_graph=True + ) + + ################################################################################################################################### + # + # Results Check + # + ################################################################################################################################### + tols = dtype_tols(te_dtype) + + permuted_paded_output_ = permuted_paded_output.float() + fusion_permuted_padded_output_ = fusion_permuted_padded_output.float() + permute_pad_fwd_input_grad = permute_pad_fwd_input.grad.float() + fusion_permute_and_pad_fwd_input_grad = ( + fusion_permute_and_pad_fwd_input.grad.float() + ) + + unpermuted_unpaded_output_ = unpermuted_unpaded_output.float() + fusion_unpermuted_unpaded_output_ = fusion_unpermuted_unpaded_output.float() + unpermute_unpad_fwd_input_grad = unpermute_unpad_fwd_input.grad.float() + fusion_unpermute_unpad_fwd_input_grad = ( + fusion_unpermute_unpad_fwd_input.grad.float() + ) + + if not BENCHMARK: + torch.testing.assert_close( + permuted_paded_output_, + fusion_permuted_padded_output_, + msg=f"Mismatch in te_permute_and_pad fwd", + **tols, + ) + torch.testing.assert_close( + permute_pad_fwd_input_grad, + fusion_permute_and_pad_fwd_input_grad, + msg=f"Mismatch in te_permute_and_pad bwd", + **tols, + ) + torch.testing.assert_close( + unpermuted_unpaded_output_, + fusion_unpermuted_unpaded_output_, + msg=f"Mismatch in te_unpermute fwd", + **tols, + ) + torch.testing.assert_close( + unpermute_unpad_fwd_input_grad, + fusion_unpermute_unpad_fwd_input_grad, + msg=f"Mismatch in te_unpermute bwd", + **tols, + ) + torch.testing.assert_close( + permuted_paded_probs.float(), + fusion_permuted_padded_probs.float(), + msg=f"Mismatch in te_permute_and_pad bwd", + **tols, + ) + + ################################################################################################################################### + # + # Benchmark + # + ################################################################################################################################### + if BENCHMARK: + + def permute_and_pad(): + permuted_output, permuted_probs, row_id_map = te_permute_with_probs( + permute_pad_fwd_input, + probs, + routing_map, + num_out_tokens=num_out_tokens, + ) + fp8_padding(permuted_output, tokens_per_expert_list) + fp8_padding(permuted_probs.unsqueeze(-1), tokens_per_expert_list) + + def fusion_permute_and_pad(): + ( + fusion_permuted_padded_output, + fusion_permuted_padded_probs, + row_id_map, + pad_offsets, + target_tokens_per_expert, + ) = te_permute_and_pad_with_probs( + fusion_permute_and_pad_fwd_input, + probs, + routing_map, + tokens_per_expert, + align_size, + ) + fusion_permuted_padded_probs = fusion_permuted_padded_probs.unsqueeze(-1) + + t1 = perf_test_cuda_kernel(lambda: permute_and_pad()) + + t2 = perf_test_cuda_kernel(lambda: fusion_permute_and_pad()) + + print(f"permute_and_pad\t\tfwd: naive: {t1:.3f} ms, fusion: {t2:.3f} ms") + + t1 = perf_test_cuda_kernel( + lambda: backward_wrapper( + permuted_paded_output, + permute_pad_bwd_input, + forward_input=[permute_pad_fwd_input], + retain_graph=True, + accumulate_grad=False, + ) + ) + t2 = perf_test_cuda_kernel( + lambda: backward_wrapper( + fusion_permuted_padded_output, + fusion_permute_pad_bwd_input, + forward_input=[fusion_permute_and_pad_fwd_input], + retain_graph=True, + accumulate_grad=False, + ) + ) + print(f"permute_and_pad\t\tbwd: naive: {t1:.3f} ms, fusion: {t2:.3f} ms") + + def unpad_unpermute(): + unpaded_output = fp8_unpadding( + unpermute_unpad_fwd_input, tokens_per_expert_list + ) + unpermuted_unpaded_output = te_unpermute( + unpaded_output, row_id_map, restore_shape=restore_shape + ) + + unpermuted_unpaded_output.backward( + unpermute_unpad_bwd_input, retain_graph=True + ) + + t1 = perf_test_cuda_kernel(lambda: unpad_unpermute()) + t2 = perf_test_cuda_kernel( + lambda: te_unpermute( + fusion_unpermute_unpad_fwd_input, + row_id_map, + restore_shape=restore_shape, + pad_offsets=pad_offsets, + ) + ) + print(f"unpermute_and_unpad\tfwd: naive: {t1:.3f} ms, fusion: {t2:.3f} ms") + + t1 = perf_test_cuda_kernel( + lambda: backward_wrapper( + unpermuted_unpaded_output, + unpermute_unpad_bwd_input, + forward_input=([unpermute_unpad_fwd_input, probs]), + retain_graph=True, + accumulate_grad=False, + ) + ) + t2 = perf_test_cuda_kernel( + lambda: backward_wrapper( + fusion_unpermuted_unpaded_output, + fusion_unpermute_bwd_input, + forward_input=([fusion_unpermute_unpad_fwd_input, probs]), + retain_graph=True, + accumulate_grad=False, + ) + ) + print(f"unpermute_and_unpad\tbwd: naive: {t1:.3f} ms, fusion: {t2:.3f} ms") + + def _test_permutation_mask_map_fp8( te_dtype, num_tokens, @@ -1180,6 +1479,40 @@ def test_permutation_mask_map( ) +@pytest.mark.parametrize("te_dtype", _te_dtypes) +@pytest.mark.parametrize("num_out_tokens", [None]) +@pytest.mark.parametrize( + "num_tokens, num_expert, hidden_size, topK", + [ + (4096, 64, 1280, 7), + (4096, 64, 2048, 6), + (4096, 160, 5120, 6), + (4096, 256, 7168, 8), + (4096, 384, 8192, 8), + (4096, 512, 9216, 8), + ], +) +def test_permutation_and_padding_mask_map( + te_dtype, + num_tokens, + num_expert, + hidden_size, + topK, + num_out_tokens, +): + BENCHMARK = False + + _test_permutation_and_padding_mask_map( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + BENCHMARK=BENCHMARK, + ) + + @pytest.mark.parametrize("te_dtype", _te_dtypes) def test_permutation_mask_map_empty_input(te_dtype): with_probs = True @@ -1413,6 +1746,16 @@ def test_permutation_single_case(): BENCHMARK=Benchmark, ) + _test_permutation_and_padding_mask_map( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + BENCHMARK=Benchmark, + ) + _test_moe_chunk_sort( te_dtype=te_dtype, num_tokens=num_tokens, @@ -1479,6 +1822,18 @@ def benchmark_single_case( ) torch.cuda.nvtx.range_pop() + torch.cuda.nvtx.range_push("permutation_and_padding_mask_map") + _test_permutation_and_padding_mask_map( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + BENCHMARK=True, + ) + torch.cuda.nvtx.range_pop() + torch.cuda.nvtx.range_push("permutation_mask_map_alongside_probs") _test_permutation_mask_map_alongside_probs( te_dtype=te_dtype, diff --git a/transformer_engine/common/triton/permutation.py b/transformer_engine/common/triton/permutation.py index 87a9c245334..8376f5ea92e 100644 --- a/transformer_engine/common/triton/permutation.py +++ b/transformer_engine/common/triton/permutation.py @@ -200,6 +200,7 @@ def _permute_kernel( probs_ptr, scale_ptr, permuted_scale_ptr, + pad_offsets_ptr, # sizes scale_hidden_dim, # strides @@ -224,6 +225,7 @@ def _permute_kernel( hidden_size: tl.constexpr, PERMUTE_PROBS: tl.constexpr, PERMUTE_SCALE: tl.constexpr, + FUSION_PAD: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): pid_t = tl.program_id(0) @@ -246,6 +248,14 @@ def _permute_kernel( dst_row = tl.load( row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert ).to(tl.int64) + if FUSION_PAD: + expert_idx = tl.load( + row_id_map_ptr + + pid_t * stride_row_id_map_token + + (num_experts + idx) * stride_row_id_map_expert + ) + pad_off = tl.load(pad_offsets_ptr + expert_idx) + dst_row = dst_row + pad_off output_off = dst_row * stride_output_token + cur_off * stride_output_hidden if PERMUTE_SCALE: permuted_scale_off = ( @@ -297,6 +307,7 @@ def _unpermute_kernel( row_id_map_ptr, merging_probs_ptr, permuted_probs_ptr, + pad_offsets_ptr, # strides stride_row_id_map_token, stride_row_id_map_expert, @@ -318,6 +329,7 @@ def _unpermute_kernel( PROBS_LOAD_WIDTH: tl.constexpr, WITH_MERGING_PROBS: tl.constexpr, PERMUTE_PROBS: tl.constexpr, + FUSION_UNPAD: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): data_type = input_ptr.dtype.element_ty @@ -348,6 +360,14 @@ def _unpermute_kernel( src_row = tl.load( row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert ).to(tl.int64) + if FUSION_UNPAD: + expert_idx = tl.load( + row_id_map_ptr + + pid_t * stride_row_id_map_token + + (num_experts + idx) * stride_row_id_map_expert + ) + pad_off = tl.load(pad_offsets_ptr + expert_idx) + src_row = src_row + pad_off input_off = src_row * stride_input_token + current_offset * stride_input_hidden inp = tl.load(input_ptr + input_off, mask=mask) inp = inp.to(compute_type) diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 5341af3d742..9f4a9678eb9 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -34,6 +34,7 @@ from transformer_engine.pytorch.permutation import ( moe_permute, moe_permute_with_probs, + moe_permute_and_pad_with_probs, moe_unpermute, moe_sort_chunks_by_index, moe_sort_chunks_by_index_with_probs, diff --git a/transformer_engine/pytorch/permutation.py b/transformer_engine/pytorch/permutation.py index 9fce9cefcf7..7e8de84c58d 100644 --- a/transformer_engine/pytorch/permutation.py +++ b/transformer_engine/pytorch/permutation.py @@ -191,6 +191,7 @@ def forward( routing_map: torch.Tensor, num_out_tokens: int, probs: torch.Tensor, + pad_offsets: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: # pylint: disable=missing-function-docstring if not inp.numel(): @@ -201,6 +202,8 @@ def forward( assert routing_map.is_cuda, "TransformerEngine needs CUDA." if probs is not None: assert probs.is_cuda, "TransformerEngine needs CUDA." + if pad_offsets is not None: + assert pad_offsets.is_cuda, "TransformerEngine needs CUDA." assert inp.size(0) == routing_map.size(0), "Permute not possible" num_tokens, hidden_size = inp.size() @@ -250,6 +253,7 @@ def forward( row_id_map, probs, fp8_scale, + pad_offsets, num_tokens, num_experts, num_out_tokens, @@ -292,7 +296,7 @@ def forward( requires_grad=output.requires_grad, ) - ctx.save_for_backward(row_id_map) + ctx.save_for_backward(row_id_map, pad_offsets) ctx.num_experts = num_experts ctx.num_tokens = num_tokens ctx.hidden_size = hidden_size @@ -307,12 +311,12 @@ def backward( ) -> Tuple[torch.Tensor, ...]: # pylint: disable=missing-function-docstring if not permuted_act_grad.numel(): - return permuted_act_grad, None, None, ctx.probs + return permuted_act_grad, None, None, ctx.probs, None act_grad = None probs_grad = None if ctx.needs_input_grad[0]: - (row_id_map,) = ctx.saved_tensors + row_id_map, pad_offsets = ctx.saved_tensors assert not isinstance( permuted_act_grad, QuantizedTensor ), "The backward of moe_permute does not support FP8." @@ -321,13 +325,14 @@ def backward( row_id_map, None, permuted_probs_grad, + pad_offsets, ctx.num_tokens, ctx.num_experts, ctx.hidden_size, ) if not ctx.needs_input_grad[3]: probs_grad = None - return act_grad, None, None, probs_grad + return act_grad, None, None, probs_grad, None class _moe_unpermute_mask_map(torch.autograd.Function): @@ -340,6 +345,7 @@ def forward( row_id_map: torch.Tensor, merging_probs: Optional[torch.Tensor], restore_shape: Optional[torch.Size], + pad_offsets: Optional[torch.Tensor], ) -> torch.Tensor: # pylint: disable=missing-function-docstring if not inp.numel(): @@ -358,6 +364,8 @@ def forward( # Device check assert inp.is_cuda, "TransformerEngine needs CUDA." assert row_id_map.is_cuda, "TransformerEngine needs CUDA." + if pad_offsets is not None: + assert pad_offsets.is_cuda, "TransformerEngine needs CUDA." assert not isinstance( inp, QuantizedTensor @@ -367,6 +375,7 @@ def forward( row_id_map, merging_probs, None, + pad_offsets, num_tokens, num_experts, hidden_size, @@ -375,7 +384,7 @@ def forward( if with_probs: ctx.save_for_backward(inp, row_id_map, merging_probs) else: - ctx.save_for_backward(row_id_map) + ctx.save_for_backward(row_id_map, pad_offsets) ctx.num_experts = num_experts ctx.num_tokens = num_tokens ctx.num_permuted_tokens = inp.size(0) @@ -387,7 +396,7 @@ def forward( def backward(ctx, unpermuted_act_grad): # pylint: disable=missing-function-docstring if not unpermuted_act_grad.numel(): - return unpermuted_act_grad, None, ctx.merging_probs, None + return unpermuted_act_grad, None, ctx.merging_probs, None, None act_grad = None probs_grad = None @@ -395,7 +404,7 @@ def backward(ctx, unpermuted_act_grad): if ctx.with_probs: fwd_input, row_id_map, merging_probs = ctx.saved_tensors else: - (row_id_map,) = ctx.saved_tensors + row_id_map, pad_offsets = ctx.saved_tensors fp8 = isinstance(unpermuted_act_grad, QuantizedTensor) per_tensor_recipe = isinstance(unpermuted_act_grad, Float8Tensor) @@ -453,6 +462,7 @@ def backward(ctx, unpermuted_act_grad): row_id_map, None, fp8_scale, + pad_offsets, ctx.num_tokens, ctx.num_experts, ctx.num_permuted_tokens, @@ -497,7 +507,7 @@ def backward(ctx, unpermuted_act_grad): if not ctx.needs_input_grad[2]: probs_grad = None - return act_grad, None, probs_grad, None + return act_grad, None, probs_grad, None, None def moe_permute( @@ -537,7 +547,7 @@ def moe_permute( if map_type == "index": return _moe_permute_index_map.apply(inp, routing_map, num_out_tokens, max_token_num) if map_type == "mask": - output, row_id_map, _ = _moe_permute_mask_map.apply(inp, routing_map, num_out_tokens, None) + output, row_id_map, _ = _moe_permute_mask_map.apply(inp, routing_map, num_out_tokens, None, None) return output, row_id_map raise ValueError("map_type should be one of 'mask' or 'index'") @@ -570,11 +580,63 @@ def moe_permute_with_probs( By default, set to '-1', meaning no tokens are dropped. """ output, row_id_map, permuted_probs = _moe_permute_mask_map.apply( - inp, routing_map, num_out_tokens, probs + inp, routing_map, num_out_tokens, probs, None ) return output, permuted_probs, row_id_map +def moe_permute_and_pad_with_probs( + inp: torch.Tensor, + probs: torch.Tensor, + routing_map: torch.Tensor, + tokens_per_expert: torch.Tensor, + align_size: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor]: + """ + Permute the tokens and probs based on the routing_map. + Token with the same index will be grouped together. + Tokens with the same designated expert will be grouped together. + The routing_map indicates which experts were selected by each token. + + Parameters + ---------- + inp: torch.Tensor + Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied. + probs: torch.Tensor + The tensor of probabilities corresponding to the permuted tokens and is + of shape [num_tokens, num_experts]. It will be permuted with the tokens + according to the routing_map. + routing_map: torch.Tensor + The token to expert mapping tensor of shape [num_tokens, num_experts] and dtype 'int32'. + The values in it: 1 means the token is routed to this expert and 0 means not. + tokens_per_expert : torch.Tensor + Tensor of shape `[num_experts]` containing actual token counts per expert. + align_size : int + the alignment size for the input tensor. + """ + assert ( + tokens_per_expert is not None + ), "tokens_per_expert must be provided to the fused permute padding function." + + # Calculate aligned token counts per expert + target_tokens_per_expert = ( + torch.ceil(tokens_per_expert / align_size) * align_size + ).long() + + if torch.equal(tokens_per_expert, target_tokens_per_expert): + pad_offsets = None + else: + pad_lengths = target_tokens_per_expert - tokens_per_expert + cum_pad = torch.cumsum(pad_lengths, dim=0) + pad_offsets = torch.cat([torch.zeros(1, dtype=cum_pad.dtype), cum_pad[:-1]]) + pad_offsets = pad_offsets.to(inp.device) + + output, row_id_map, permuted_probs = _moe_permute_mask_map.apply( + inp, routing_map, target_tokens_per_expert.sum().item(), probs, pad_offsets + ) + return output, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert + + def moe_unpermute( inp: torch.Tensor, row_id_map: torch.Tensor, @@ -582,6 +644,7 @@ def moe_unpermute( restore_shape: Optional[torch.Size] = None, map_type: str = "mask", probs: Optional[torch.Tensor] = None, + pad_offsets: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Unpermute a tensor with permuted tokens, and optionally merge the tokens with their @@ -605,6 +668,10 @@ def moe_unpermute( Options are: 'mask', 'index'. probs : torch.Tensor, default = None Renamed to merging_probs. Keep for backward compatibility. + pad_offsets : torch.Tensor, default = None + Tensor of per-expert cumulative padding offsets used to remove padding added + during permutation. This is the fourth output of `moe_permute_and_pad_with_probs` + and is required when unpermuting padded outputs. """ if probs is not None: if merging_probs is not None: @@ -616,7 +683,9 @@ def moe_unpermute( if map_type == "index": return _moe_unpermute_index_map.apply(inp, row_id_map, merging_probs) if map_type == "mask": - return _moe_unpermute_mask_map.apply(inp, row_id_map, merging_probs, restore_shape) + return _moe_unpermute_mask_map.apply( + inp, row_id_map, merging_probs, restore_shape, pad_offsets + ) raise ValueError("map_type should be one of 'mask' or 'index'") diff --git a/transformer_engine/pytorch/triton/permutation.py b/transformer_engine/pytorch/triton/permutation.py index 8f953e9c31d..efe9bb19dfc 100644 --- a/transformer_engine/pytorch/triton/permutation.py +++ b/transformer_engine/pytorch/triton/permutation.py @@ -123,6 +123,7 @@ def permute_with_mask_map( row_id_map: torch.Tensor, probs: torch.Tensor, scale: torch.Tensor, + pad_offsets: torch.Tensor, num_tokens: int, num_experts: int, num_out_tokens: int, @@ -142,6 +143,9 @@ def permute_with_mask_map( The probabilities of the input tensor. If it is not None, it will be permuted. scale : torch.Tensor The scale of the input tensor. If it is not None, it will be permuted. + pad_offsets : torch.Tensor + The padding offsets for FP8 fused padding. If it is not None, it will be allocated output + buffers with aligned sizes. num_tokens : int Number of tokens in the input tensor. num_experts : int @@ -153,18 +157,18 @@ def permute_with_mask_map( scale_hidden_dim : int Hidden size of the scale tensor. """ - output = torch.empty((num_out_tokens, hidden_size), dtype=inp.dtype, device="cuda") - if probs is not None: - permuted_probs = torch.empty((num_out_tokens,), dtype=probs.dtype, device="cuda") - else: - permuted_probs = None - - if scale is not None: - permuted_scale = torch.empty( - (num_out_tokens, scale_hidden_dim), dtype=scale.dtype, device="cuda" - ) - else: - permuted_scale = None + alloc = torch.zeros if pad_offsets is not None else torch.empty + output = alloc((num_out_tokens, hidden_size), dtype=inp.dtype, device="cuda") + permuted_probs = ( + alloc((num_out_tokens,), dtype=probs.dtype, device="cuda") + if probs is not None + else None + ) + permuted_scale = ( + torch.empty((num_out_tokens, scale_hidden_dim), dtype=scale.dtype, device="cuda") + if scale is not None + else None + ) # pylint: disable=unnecessary-lambda-assignment grid = lambda META: (num_tokens, triton.cdiv(hidden_size, META["BLOCK_SIZE"])) _permute_kernel[grid]( @@ -173,6 +177,7 @@ def permute_with_mask_map( probs, scale, permuted_scale, + pad_offsets, scale_hidden_dim, row_id_map.stride(0), row_id_map.stride(1), @@ -193,6 +198,7 @@ def permute_with_mask_map( hidden_size, PERMUTE_PROBS=probs is not None, PERMUTE_SCALE=scale is not None, + FUSION_PAD=pad_offsets is not None, ) return output, permuted_scale, permuted_probs @@ -202,6 +208,7 @@ def unpermute_with_mask_map( row_id_map: torch.Tensor, merging_probs: Union[torch.Tensor, None], permuted_probs: Union[torch.Tensor, None], + pad_offsets: Union[torch.Tensor, None], num_tokens: int, num_experts: int, hidden_size: int, @@ -220,6 +227,9 @@ def unpermute_with_mask_map( to reduce the unpermuted tokens. permuted_probs : torch.Tensor The permuted probabilities of the input tensor. If it is not None, it will be unpermuted. + pad_offsets : torch.Tensor + The padding offsets used for FP8 fused unpadding. If it is not None, it will remove the + previously fused padding. num_tokens : int Number of tokens in the permuted tensor. num_experts : int @@ -241,6 +251,7 @@ def unpermute_with_mask_map( row_id_map, merging_probs, permuted_probs, + pad_offsets, row_id_map.stride(0), row_id_map.stride(1), inp.stride(0), @@ -259,6 +270,7 @@ def unpermute_with_mask_map( PROBS_LOAD_WIDTH=triton.next_power_of_2(num_experts), WITH_MERGING_PROBS=merging_probs is not None, PERMUTE_PROBS=permuted_probs is not None, + FUSION_UNPAD=pad_offsets is not None, ) return output, unpermuted_probs From a7de66c08c7c745d3c3b5614b6127e7acd0d603c Mon Sep 17 00:00:00 2001 From: xiaoxi-wangfj <690912414@qq.com> Date: Thu, 11 Dec 2025 06:31:06 +0000 Subject: [PATCH 02/19] [PyTorch/Common] Fuse permute+pad and unpermute+unpad support with_merging_probs Signed-off-by: xiaoxi-wangfj <690912414@qq.com> --- tests/pytorch/test_permutation.py | 29 +++++++++++++++---- .../common/triton/permutation.py | 5 ++++ transformer_engine/pytorch/permutation.py | 5 ++-- .../pytorch/triton/permutation.py | 24 +++++++++++---- 4 files changed, 49 insertions(+), 14 deletions(-) diff --git a/tests/pytorch/test_permutation.py b/tests/pytorch/test_permutation.py index 5a650a21680..967ec4d8966 100644 --- a/tests/pytorch/test_permutation.py +++ b/tests/pytorch/test_permutation.py @@ -662,18 +662,19 @@ def _test_permutation_and_padding_mask_map( hidden_size, topK, num_out_tokens, + with_merging_probs=False, align_size=16, BENCHMARK=False, ): if topK > num_expert: pytest.skip("topK should be smaller than the number of experts.") - if num_out_tokens == None: + if num_out_tokens is None: num_out_tokens = num_tokens * topK print( "permutation and padding:" - f" token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK} align_size:{align_size} {te_dtype}" + f" token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK} with_probs:{with_merging_probs} align_size:{align_size} {te_dtype}" ) # Convert TE dtypes to PyTorch dtypes @@ -743,8 +744,13 @@ def _test_permutation_and_padding_mask_map( fp8_unpadding = Fp8Unpadding(num_expert, align_size) unpaded_output = fp8_unpadding(unpermute_unpad_fwd_input, tokens_per_expert_list) + + probs_naive = probs unpermuted_unpaded_output = te_unpermute( - unpaded_output, row_id_map, restore_shape=restore_shape + unpaded_output, + row_id_map, + merging_probs=probs_naive if with_merging_probs else None, + restore_shape=restore_shape, ) unpermuted_unpaded_output.backward(unpermute_unpad_bwd_input, retain_graph=True) @@ -757,8 +763,8 @@ def _test_permutation_and_padding_mask_map( # fusion permute_and_pad fusion_permute_and_pad_fwd_input = permute_pad_fwd_input.detach() fusion_permute_and_pad_fwd_input.requires_grad_(True) - probs = probs.detach() - probs.requires_grad_(True) + probs_fusion = probs_naive.detach().clone() + probs_fusion.requires_grad_(True) ( fusion_permuted_padded_output, @@ -768,7 +774,7 @@ def _test_permutation_and_padding_mask_map( target_tokens_per_expert, ) = te_permute_and_pad_with_probs( fusion_permute_and_pad_fwd_input, - probs, + probs_fusion, routing_map, tokens_per_expert, align_size, @@ -787,6 +793,7 @@ def _test_permutation_and_padding_mask_map( fusion_unpermuted_unpaded_output = te_unpermute( fusion_unpermute_unpad_fwd_input, row_id_map, + merging_probs=probs_fusion if with_merging_probs else None, restore_shape=restore_shape, pad_offsets=pad_offsets, ) @@ -848,6 +855,13 @@ def _test_permutation_and_padding_mask_map( msg=f"Mismatch in te_permute_and_pad bwd", **tols, ) + if with_merging_probs: + torch.testing.assert_close( + probs_naive.grad.float(), + probs_fusion.grad.float(), + msg=f"Mismatch in te_unpermute bwd", + **tols, + ) ################################################################################################################################### # @@ -1492,6 +1506,7 @@ def test_permutation_mask_map( (4096, 512, 9216, 8), ], ) +@pytest.mark.parametrize("with_merging_probs", [True, False]) def test_permutation_and_padding_mask_map( te_dtype, num_tokens, @@ -1499,6 +1514,7 @@ def test_permutation_and_padding_mask_map( hidden_size, topK, num_out_tokens, + with_merging_probs, ): BENCHMARK = False @@ -1509,6 +1525,7 @@ def test_permutation_and_padding_mask_map( hidden_size=hidden_size, topK=topK, num_out_tokens=num_out_tokens, + with_merging_probs=with_merging_probs, BENCHMARK=BENCHMARK, ) diff --git a/transformer_engine/common/triton/permutation.py b/transformer_engine/common/triton/permutation.py index 8376f5ea92e..b15f1190325 100644 --- a/transformer_engine/common/triton/permutation.py +++ b/transformer_engine/common/triton/permutation.py @@ -427,6 +427,7 @@ def _unpermute_bwd_with_merging_probs_kernel( fwd_input_ptr, merging_probs_ptr, row_id_map_ptr, + pad_offsets_ptr, # strides stride_row_id_map_token, stride_row_id_map_expert, @@ -447,6 +448,7 @@ def _unpermute_bwd_with_merging_probs_kernel( num_experts: tl.constexpr, hidden_size: tl.constexpr, PROBS_LOAD_WIDTH: tl.constexpr, + FUSION_PAD: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): data_type = fwd_output_grad_ptr.dtype.element_ty @@ -470,6 +472,9 @@ def _unpermute_bwd_with_merging_probs_kernel( + pid * stride_row_id_map_token + (num_experts + idx) * stride_row_id_map_expert ) + if FUSION_PAD: + pad_off = tl.load(pad_offsets_ptr + expert_idx) + dst_row = dst_row + pad_off prob_grad_accum = tl.zeros((BLOCK_SIZE,), dtype=compute_type) current_start = 0 while current_start < hidden_size: diff --git a/transformer_engine/pytorch/permutation.py b/transformer_engine/pytorch/permutation.py index 7e8de84c58d..810c056373f 100644 --- a/transformer_engine/pytorch/permutation.py +++ b/transformer_engine/pytorch/permutation.py @@ -382,7 +382,7 @@ def forward( ) if with_probs: - ctx.save_for_backward(inp, row_id_map, merging_probs) + ctx.save_for_backward(inp, row_id_map, merging_probs, pad_offsets) else: ctx.save_for_backward(row_id_map, pad_offsets) ctx.num_experts = num_experts @@ -402,7 +402,7 @@ def backward(ctx, unpermuted_act_grad): probs_grad = None if ctx.needs_input_grad[0]: if ctx.with_probs: - fwd_input, row_id_map, merging_probs = ctx.saved_tensors + fwd_input, row_id_map, merging_probs, pad_offsets = ctx.saved_tensors else: row_id_map, pad_offsets = ctx.saved_tensors @@ -450,6 +450,7 @@ def backward(ctx, unpermuted_act_grad): row_id_map, fwd_input, merging_probs, + pad_offsets, ctx.num_tokens, ctx.num_experts, ctx.num_permuted_tokens, diff --git a/transformer_engine/pytorch/triton/permutation.py b/transformer_engine/pytorch/triton/permutation.py index efe9bb19dfc..bd8d3cd6012 100644 --- a/transformer_engine/pytorch/triton/permutation.py +++ b/transformer_engine/pytorch/triton/permutation.py @@ -144,8 +144,8 @@ def permute_with_mask_map( scale : torch.Tensor The scale of the input tensor. If it is not None, it will be permuted. pad_offsets : torch.Tensor - The padding offsets for FP8 fused padding. If it is not None, it will be allocated output - buffers with aligned sizes. + Per-expert padding offsets of shape `[num_experts]` for FP8 fused padding. + If it is not None, it will be allocated output buffers with aligned sizes. num_tokens : int Number of tokens in the input tensor. num_experts : int @@ -228,8 +228,8 @@ def unpermute_with_mask_map( permuted_probs : torch.Tensor The permuted probabilities of the input tensor. If it is not None, it will be unpermuted. pad_offsets : torch.Tensor - The padding offsets used for FP8 fused unpadding. If it is not None, it will remove the - previously fused padding. + Per-expert padding offsets of shape `[num_experts]` for FP8 fused unpadding. + If it is not None, it will remove the previously fused padding. num_tokens : int Number of tokens in the permuted tensor. num_experts : int @@ -280,6 +280,7 @@ def unpermute_with_mask_map_bwd_with_merging_probs( row_id_map: torch.Tensor, fwd_input: torch.Tensor, merging_probs: torch.Tensor, + pad_offsets: torch.Tensor, num_tokens: int, num_experts: int, num_out_tokens: int, @@ -298,6 +299,9 @@ def unpermute_with_mask_map_bwd_with_merging_probs( The input tensor of the forward pass of shape `[num_out_tokens, hidden_size]`. merging_probs : torch.Tensor The merging probabilities of the input tensor of shape `[num_tokens, num_experts]`. + pad_offsets : torch.Tensor + Per-expert padding offsets of shape `[num_experts]` for FP8 fused padding. + If it is not None, it will be allocated output buffers with aligned sizes. num_tokens : int Number of tokens in the permuted tensor. num_experts : int @@ -307,8 +311,14 @@ def unpermute_with_mask_map_bwd_with_merging_probs( hidden_size : int Hidden size of the output tensor. """ - act_grad = torch.empty( - (num_out_tokens, hidden_size), dtype=fwd_output_grad.dtype, device="cuda" + act_grad = ( + torch.empty( + (num_out_tokens, hidden_size), dtype=fwd_output_grad.dtype, device="cuda" + ) + if pad_offsets is None + else torch.zeros( + (num_out_tokens, hidden_size), dtype=fwd_output_grad.dtype, device="cuda" + ) ) merging_probs_grad = torch.empty( (num_tokens, num_experts), dtype=merging_probs.dtype, device="cuda" @@ -319,6 +329,7 @@ def unpermute_with_mask_map_bwd_with_merging_probs( fwd_input, merging_probs, row_id_map, + pad_offsets, row_id_map.stride(0), row_id_map.stride(1), fwd_output_grad.stride(0), @@ -336,6 +347,7 @@ def unpermute_with_mask_map_bwd_with_merging_probs( num_experts, hidden_size, PROBS_LOAD_WIDTH=triton.next_power_of_2(num_experts), + FUSION_PAD=pad_offsets is not None, ) return act_grad, merging_probs_grad From f550684e39980ac8daa4962798ab4e4b1f077760 Mon Sep 17 00:00:00 2001 From: xiaoxi-wangfj <690912414@qq.com> Date: Thu, 11 Dec 2025 06:40:11 +0000 Subject: [PATCH 03/19] [PyTorch]format code Signed-off-by: xiaoxi-wangfj <690912414@qq.com> --- tests/pytorch/test_permutation.py | 40 +++++-------------- transformer_engine/pytorch/permutation.py | 8 ++-- .../pytorch/triton/permutation.py | 12 ++---- 3 files changed, 17 insertions(+), 43 deletions(-) diff --git a/tests/pytorch/test_permutation.py b/tests/pytorch/test_permutation.py index 967ec4d8966..893906304f6 100644 --- a/tests/pytorch/test_permutation.py +++ b/tests/pytorch/test_permutation.py @@ -690,9 +690,7 @@ def _test_permutation_and_padding_mask_map( _tmp_tensor = torch.zeros((num_tokens * num_expert,)) _tmp_tensor[: int(num_out_tokens)] = 1.0 _tmp_idx = torch.randperm(num_tokens * num_expert) - routing_map = ( - torch.reshape(_tmp_tensor[_tmp_idx], (num_tokens, num_expert)).bool().cuda() - ) + routing_map = torch.reshape(_tmp_tensor[_tmp_idx], (num_tokens, num_expert)).bool().cuda() probs = torch.rand(num_tokens, num_expert).cuda() * routing_map row_sums = probs.sum(dim=1, keepdim=True) @@ -701,18 +699,14 @@ def _test_permutation_and_padding_mask_map( probs.requires_grad_(True) tokens_per_expert = routing_map.sum(dim=0).cpu() - target_tokens_per_expert = ( - torch.ceil(tokens_per_expert / align_size) * align_size - ).long() + target_tokens_per_expert = (torch.ceil(tokens_per_expert / align_size) * align_size).long() num_permute_pad_out_tokens = target_tokens_per_expert.sum().item() permute_pad_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() permute_pad_bwd_input = torch.rand( (num_permute_pad_out_tokens, hidden_size), dtype=dtype ).cuda() - unpermute_unpad_bwd_input = torch.rand( - (num_tokens, hidden_size), dtype=dtype - ).cuda() + unpermute_unpad_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() permute_pad_fwd_input.requires_grad_(True) restore_shape = permute_pad_fwd_input.shape @@ -731,9 +725,7 @@ def _test_permutation_and_padding_mask_map( tokens_per_expert_list = tokens_per_expert.tolist() fp8_padding = Fp8Padding(num_expert, align_size) permuted_paded_output, _ = fp8_padding(permuted_output, tokens_per_expert_list) - permuted_paded_probs, _ = fp8_padding( - permuted_probs.unsqueeze(-1), tokens_per_expert_list - ) + permuted_paded_probs, _ = fp8_padding(permuted_probs.unsqueeze(-1), tokens_per_expert_list) permuted_paded_output.backward(permute_pad_bwd_input, retain_graph=True) @@ -782,9 +774,7 @@ def _test_permutation_and_padding_mask_map( fusion_permuted_padded_probs = fusion_permuted_padded_probs.unsqueeze(-1) fusion_permute_pad_bwd_input = permute_pad_bwd_input.detach() - fusion_permuted_padded_output.backward( - fusion_permute_pad_bwd_input, retain_graph=True - ) + fusion_permuted_padded_output.backward(fusion_permute_pad_bwd_input, retain_graph=True) # fusion unpad and unpermute fusion_unpermute_unpad_fwd_input = fusion_permuted_padded_output.detach() @@ -799,9 +789,7 @@ def _test_permutation_and_padding_mask_map( ) fusion_unpermute_bwd_input = unpermute_unpad_bwd_input.detach() - fusion_unpermuted_unpaded_output.backward( - fusion_unpermute_bwd_input, retain_graph=True - ) + fusion_unpermuted_unpaded_output.backward(fusion_unpermute_bwd_input, retain_graph=True) ################################################################################################################################### # @@ -813,16 +801,12 @@ def _test_permutation_and_padding_mask_map( permuted_paded_output_ = permuted_paded_output.float() fusion_permuted_padded_output_ = fusion_permuted_padded_output.float() permute_pad_fwd_input_grad = permute_pad_fwd_input.grad.float() - fusion_permute_and_pad_fwd_input_grad = ( - fusion_permute_and_pad_fwd_input.grad.float() - ) + fusion_permute_and_pad_fwd_input_grad = fusion_permute_and_pad_fwd_input.grad.float() unpermuted_unpaded_output_ = unpermuted_unpaded_output.float() fusion_unpermuted_unpaded_output_ = fusion_unpermuted_unpaded_output.float() unpermute_unpad_fwd_input_grad = unpermute_unpad_fwd_input.grad.float() - fusion_unpermute_unpad_fwd_input_grad = ( - fusion_unpermute_unpad_fwd_input.grad.float() - ) + fusion_unpermute_unpad_fwd_input_grad = fusion_unpermute_unpad_fwd_input.grad.float() if not BENCHMARK: torch.testing.assert_close( @@ -923,16 +907,12 @@ def fusion_permute_and_pad(): print(f"permute_and_pad\t\tbwd: naive: {t1:.3f} ms, fusion: {t2:.3f} ms") def unpad_unpermute(): - unpaded_output = fp8_unpadding( - unpermute_unpad_fwd_input, tokens_per_expert_list - ) + unpaded_output = fp8_unpadding(unpermute_unpad_fwd_input, tokens_per_expert_list) unpermuted_unpaded_output = te_unpermute( unpaded_output, row_id_map, restore_shape=restore_shape ) - unpermuted_unpaded_output.backward( - unpermute_unpad_bwd_input, retain_graph=True - ) + unpermuted_unpaded_output.backward(unpermute_unpad_bwd_input, retain_graph=True) t1 = perf_test_cuda_kernel(lambda: unpad_unpermute()) t2 = perf_test_cuda_kernel( diff --git a/transformer_engine/pytorch/permutation.py b/transformer_engine/pytorch/permutation.py index 810c056373f..74353573c4f 100644 --- a/transformer_engine/pytorch/permutation.py +++ b/transformer_engine/pytorch/permutation.py @@ -548,7 +548,9 @@ def moe_permute( if map_type == "index": return _moe_permute_index_map.apply(inp, routing_map, num_out_tokens, max_token_num) if map_type == "mask": - output, row_id_map, _ = _moe_permute_mask_map.apply(inp, routing_map, num_out_tokens, None, None) + output, row_id_map, _ = _moe_permute_mask_map.apply( + inp, routing_map, num_out_tokens, None, None + ) return output, row_id_map raise ValueError("map_type should be one of 'mask' or 'index'") @@ -620,9 +622,7 @@ def moe_permute_and_pad_with_probs( ), "tokens_per_expert must be provided to the fused permute padding function." # Calculate aligned token counts per expert - target_tokens_per_expert = ( - torch.ceil(tokens_per_expert / align_size) * align_size - ).long() + target_tokens_per_expert = (torch.ceil(tokens_per_expert / align_size) * align_size).long() if torch.equal(tokens_per_expert, target_tokens_per_expert): pad_offsets = None diff --git a/transformer_engine/pytorch/triton/permutation.py b/transformer_engine/pytorch/triton/permutation.py index bd8d3cd6012..1621d0e9fcb 100644 --- a/transformer_engine/pytorch/triton/permutation.py +++ b/transformer_engine/pytorch/triton/permutation.py @@ -160,9 +160,7 @@ def permute_with_mask_map( alloc = torch.zeros if pad_offsets is not None else torch.empty output = alloc((num_out_tokens, hidden_size), dtype=inp.dtype, device="cuda") permuted_probs = ( - alloc((num_out_tokens,), dtype=probs.dtype, device="cuda") - if probs is not None - else None + alloc((num_out_tokens,), dtype=probs.dtype, device="cuda") if probs is not None else None ) permuted_scale = ( torch.empty((num_out_tokens, scale_hidden_dim), dtype=scale.dtype, device="cuda") @@ -312,13 +310,9 @@ def unpermute_with_mask_map_bwd_with_merging_probs( Hidden size of the output tensor. """ act_grad = ( - torch.empty( - (num_out_tokens, hidden_size), dtype=fwd_output_grad.dtype, device="cuda" - ) + torch.empty((num_out_tokens, hidden_size), dtype=fwd_output_grad.dtype, device="cuda") if pad_offsets is None - else torch.zeros( - (num_out_tokens, hidden_size), dtype=fwd_output_grad.dtype, device="cuda" - ) + else torch.zeros((num_out_tokens, hidden_size), dtype=fwd_output_grad.dtype, device="cuda") ) merging_probs_grad = torch.empty( (num_tokens, num_experts), dtype=merging_probs.dtype, device="cuda" From 6069277a9e426b972acfe4d581f06907078ba546 Mon Sep 17 00:00:00 2001 From: xiaoxi-wangfj <690912414@qq.com> Date: Thu, 11 Dec 2025 07:13:26 +0000 Subject: [PATCH 04/19] [Common]perf expert_idx loaded once Signed-off-by: xiaoxi-wangfj <690912414@qq.com> --- tests/pytorch/test_permutation.py | 4 +++- .../common/triton/permutation.py | 19 +++++++------------ 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/tests/pytorch/test_permutation.py b/tests/pytorch/test_permutation.py index 893906304f6..b0eac2adf7e 100644 --- a/tests/pytorch/test_permutation.py +++ b/tests/pytorch/test_permutation.py @@ -674,7 +674,8 @@ def _test_permutation_and_padding_mask_map( print( "permutation and padding:" - f" token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK} with_probs:{with_merging_probs} align_size:{align_size} {te_dtype}" + f" token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK}" + f" with_merging_probs:{with_merging_probs} align_size:{align_size} {te_dtype}" ) # Convert TE dtypes to PyTorch dtypes @@ -1478,6 +1479,7 @@ def test_permutation_mask_map( @pytest.mark.parametrize( "num_tokens, num_expert, hidden_size, topK", [ + (0, 8, 1280, 2), (4096, 64, 1280, 7), (4096, 64, 2048, 6), (4096, 160, 5120, 6), diff --git a/transformer_engine/common/triton/permutation.py b/transformer_engine/common/triton/permutation.py index b15f1190325..b3a383f4230 100644 --- a/transformer_engine/common/triton/permutation.py +++ b/transformer_engine/common/triton/permutation.py @@ -228,6 +228,8 @@ def _permute_kernel( FUSION_PAD: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): + expert_idx = 0 + pid_t = tl.program_id(0) pid_h = tl.program_id(1) cur_off = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) @@ -248,12 +250,13 @@ def _permute_kernel( dst_row = tl.load( row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert ).to(tl.int64) - if FUSION_PAD: + if FUSION_PAD or PERMUTE_PROBS: expert_idx = tl.load( row_id_map_ptr + pid_t * stride_row_id_map_token + (num_experts + idx) * stride_row_id_map_expert ) + if FUSION_PAD: pad_off = tl.load(pad_offsets_ptr + expert_idx) dst_row = dst_row + pad_off output_off = dst_row * stride_output_token + cur_off * stride_output_hidden @@ -263,11 +266,6 @@ def _permute_kernel( ) tl.store(permuted_scale_ptr + permuted_scale_off, scale, mask=mask_scale) if PERMUTE_PROBS: - expert_idx = tl.load( - row_id_map_ptr - + pid_t * stride_row_id_map_token - + (num_experts + idx) * stride_row_id_map_expert - ) prob_off = pid_t * stride_probs_token + expert_idx * stride_probs_expert prob = tl.load(probs_ptr + prob_off) if pid_h == 0: @@ -334,6 +332,7 @@ def _unpermute_kernel( ): data_type = input_ptr.dtype.element_ty compute_type = tl.float32 + expert_idx = 0 pid_t = tl.program_id(0) pid_h = tl.program_id(1) @@ -360,23 +359,19 @@ def _unpermute_kernel( src_row = tl.load( row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert ).to(tl.int64) - if FUSION_UNPAD: + if FUSION_UNPAD or WITH_MERGING_PROBS: expert_idx = tl.load( row_id_map_ptr + pid_t * stride_row_id_map_token + (num_experts + idx) * stride_row_id_map_expert ) + if FUSION_UNPAD: pad_off = tl.load(pad_offsets_ptr + expert_idx) src_row = src_row + pad_off input_off = src_row * stride_input_token + current_offset * stride_input_hidden inp = tl.load(input_ptr + input_off, mask=mask) inp = inp.to(compute_type) if WITH_MERGING_PROBS: - expert_idx = tl.load( - row_id_map_ptr - + pid_t * stride_row_id_map_token - + (num_experts + idx) * stride_row_id_map_expert - ) merging_prob_off = ( pid_t * stride_merging_probs_token + expert_idx * stride_merging_probs_expert ) From 1ea08f71b5e6c62d9e433a442f808d2f9085b898 Mon Sep 17 00:00:00 2001 From: xiaoxi-wangfj <690912414@qq.com> Date: Wed, 17 Dec 2025 13:15:10 +0800 Subject: [PATCH 05/19] fix: pad_offsets can be None Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: xiaoxi-wangfj <690912414@qq.com> --- transformer_engine/pytorch/permutation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/permutation.py b/transformer_engine/pytorch/permutation.py index 74353573c4f..f2809413a01 100644 --- a/transformer_engine/pytorch/permutation.py +++ b/transformer_engine/pytorch/permutation.py @@ -2,7 +2,7 @@ # # See LICENSE for license information. -"""MoE Permutaion API""" +"""MoE Permutation API""" import warnings from typing import Optional, Tuple import torch @@ -191,7 +191,7 @@ def forward( routing_map: torch.Tensor, num_out_tokens: int, probs: torch.Tensor, - pad_offsets: torch.Tensor, + pad_offsets: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # pylint: disable=missing-function-docstring if not inp.numel(): From 230939ce19dbedc60a3fecb76b7dd112c84b1e03 Mon Sep 17 00:00:00 2001 From: tdophung Date: Wed, 10 Dec 2025 17:18:46 -0800 Subject: [PATCH 06/19] add padding + merging probs bwd support. Not tested Signed-off-by: tdophung --- tests/pytorch/test_permutation.py | 286 ++++++++++++++++++ .../common/triton/permutation.py | 4 +- .../pytorch/triton/permutation.py | 4 +- 3 files changed, 290 insertions(+), 4 deletions(-) diff --git a/tests/pytorch/test_permutation.py b/tests/pytorch/test_permutation.py index b0eac2adf7e..9e440b6795c 100644 --- a/tests/pytorch/test_permutation.py +++ b/tests/pytorch/test_permutation.py @@ -947,6 +947,237 @@ def unpad_unpermute(): print(f"unpermute_and_unpad\tbwd: naive: {t1:.3f} ms, fusion: {t2:.3f} ms") +def _test_permutation_and_padding_with_merging_probs( + te_dtype, + num_tokens, + num_expert, + hidden_size, + topK, + num_out_tokens, + align_size=16, + BENCHMARK=False, +): + """ + Test the combination of merging_probs AND pad_offsets together in moe_unpermute. + This specifically tests the backward pass fix where pad_offsets must be used + when computing gradients with merging_probs. + """ + if topK > num_expert: + pytest.skip("topK should be smaller than the number of experts.") + + if num_out_tokens == None: + num_out_tokens = num_tokens * topK + + print( + "permutation and padding with merging probs:" + f" token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK} align_size:{align_size} {te_dtype}" + ) + + # Convert TE dtypes to PyTorch dtypes + if te_dtype == tex.DType.kFloat32: + dtype = torch.float32 + elif te_dtype == tex.DType.kFloat16: + dtype = torch.float16 + elif te_dtype == tex.DType.kBFloat16: + dtype = torch.bfloat16 + else: + pytest.skip("Invalid dtype.") + + _tmp_tensor = torch.zeros((num_tokens * num_expert,)) + _tmp_tensor[: int(num_out_tokens)] = 1.0 + _tmp_idx = torch.randperm(num_tokens * num_expert) + routing_map = ( + torch.reshape(_tmp_tensor[_tmp_idx], (num_tokens, num_expert)).bool().cuda() + ) + + probs = torch.rand(num_tokens, num_expert).cuda() * routing_map + row_sums = probs.sum(dim=1, keepdim=True) + probs = probs / row_sums + probs = probs.to(dtype) + probs.requires_grad_(True) + + tokens_per_expert = routing_map.sum(dim=0).cpu() + target_tokens_per_expert = ( + torch.ceil(tokens_per_expert / align_size) * align_size + ).long() + num_permute_pad_out_tokens = target_tokens_per_expert.sum().item() + + permute_pad_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() + permute_pad_bwd_input = torch.rand( + (num_permute_pad_out_tokens, hidden_size), dtype=dtype + ).cuda() + unpermute_unpad_bwd_input = torch.rand( + (num_tokens, hidden_size), dtype=dtype + ).cuda() + permute_pad_fwd_input.requires_grad_(True) + + restore_shape = permute_pad_fwd_input.shape + ################################################################################################################################### + # + # Reference: moe_permute_with_probs + Fp8Padding, then Fp8Unpadding + moe_unpermute with merging_probs + # + ################################################################################################################################### + # permute + padding + permuted_output, permuted_probs, row_id_map = te_permute_with_probs( + permute_pad_fwd_input, + probs, + routing_map, + num_out_tokens=num_out_tokens, + ) + tokens_per_expert_list = tokens_per_expert.tolist() + fp8_padding = Fp8Padding(num_expert, align_size) + permuted_paded_output, _ = fp8_padding(permuted_output, tokens_per_expert_list) + + permuted_paded_output.backward(permute_pad_bwd_input, retain_graph=True) + + # Reference: unpadding + unpermute WITH merging_probs + ref_unpermute_fwd_input = permuted_paded_output.detach() + ref_unpermute_fwd_input.requires_grad_(True) + + ref_probs = probs.detach() + ref_probs.requires_grad_(True) + + fp8_unpadding = Fp8Unpadding(num_expert, align_size) + unpaded_output = fp8_unpadding(ref_unpermute_fwd_input, tokens_per_expert_list) + ref_unpermuted_output = te_unpermute( + unpaded_output, row_id_map, ref_probs, restore_shape=restore_shape + ) + + ref_unpermuted_output.backward(unpermute_unpad_bwd_input, retain_graph=True) + + ################################################################################################################################### + # + # Fused: moe_permute_and_pad_with_probs, then moe_unpermute with BOTH merging_probs AND pad_offsets + # + ################################################################################################################################### + # fusion permute_and_pad + fusion_permute_fwd_input = permute_pad_fwd_input.detach() + fusion_permute_fwd_input.requires_grad_(True) + fusion_probs = probs.detach() + fusion_probs.requires_grad_(True) + + ( + fusion_permuted_padded_output, + fusion_permuted_padded_probs, + fused_row_id_map, + pad_offsets, + _, + ) = te_permute_and_pad_with_probs( + fusion_permute_fwd_input, + fusion_probs, + routing_map, + tokens_per_expert, + align_size, + ) + + fusion_permute_pad_bwd_input = permute_pad_bwd_input.detach() + fusion_permuted_padded_output.backward( + fusion_permute_pad_bwd_input, retain_graph=True + ) + + # Fused: unpermute with BOTH merging_probs AND pad_offsets + fusion_unpermute_fwd_input = fusion_permuted_padded_output.detach() + fusion_unpermute_fwd_input.requires_grad_(True) + + fusion_merging_probs = probs.detach() + fusion_merging_probs.requires_grad_(True) + + fusion_unpermuted_output = te_unpermute( + fusion_unpermute_fwd_input, + fused_row_id_map, + fusion_merging_probs, + restore_shape=restore_shape, + pad_offsets=pad_offsets, + ) + + fusion_unpermute_bwd_input = unpermute_unpad_bwd_input.detach() + fusion_unpermuted_output.backward(fusion_unpermute_bwd_input, retain_graph=True) + + ################################################################################################################################### + # + # Results Check + # + ################################################################################################################################### + tols = dtype_tols(te_dtype) + + # Check forward pass + ref_unpermuted_output_ = ref_unpermuted_output.float() + fusion_unpermuted_output_ = fusion_unpermuted_output.float() + + if not BENCHMARK: + torch.testing.assert_close( + ref_unpermuted_output_, + fusion_unpermuted_output_, + msg=f"Mismatch in te_unpermute with merging_probs and pad_offsets fwd", + **tols, + ) + + # Check backward pass - activation gradients + ref_unpermute_fwd_input_grad = ref_unpermute_fwd_input.grad.float() + fusion_unpermute_fwd_input_grad = fusion_unpermute_fwd_input.grad.float() + + torch.testing.assert_close( + ref_unpermute_fwd_input_grad, + fusion_unpermute_fwd_input_grad, + msg=f"Mismatch in te_unpermute with merging_probs and pad_offsets bwd (act_grad)", + **tols, + ) + + # Check backward pass - probs gradients + ref_probs_grad = ref_probs.grad.float() + fusion_probs_grad = fusion_merging_probs.grad.float() + + torch.testing.assert_close( + ref_probs_grad, + fusion_probs_grad, + msg=f"Mismatch in te_unpermute with merging_probs and pad_offsets bwd (probs_grad)", + **tols, + ) + + ################################################################################################################################### + # + # Benchmark + # + ################################################################################################################################### + if BENCHMARK: + def ref_unpad_unpermute(): + unpaded = fp8_unpadding(ref_unpermute_fwd_input, tokens_per_expert_list) + return te_unpermute(unpaded, row_id_map, ref_probs, restore_shape=restore_shape) + + def fused_unpermute(): + return te_unpermute( + fusion_unpermute_fwd_input, + fused_row_id_map, + fusion_merging_probs, + restore_shape=restore_shape, + pad_offsets=pad_offsets, + ) + + t1 = perf_test_cuda_kernel(lambda: ref_unpad_unpermute()) + t2 = perf_test_cuda_kernel(lambda: fused_unpermute()) + print(f"unpermute_unpad_with_probs\tfwd: naive: {t1:.3f} ms, fusion: {t2:.3f} ms") + + t1 = perf_test_cuda_kernel( + lambda: backward_wrapper( + ref_unpermuted_output, + unpermute_unpad_bwd_input, + forward_input=[ref_unpermute_fwd_input, ref_probs], + retain_graph=True, + accumulate_grad=False, + ) + ) + t2 = perf_test_cuda_kernel( + lambda: backward_wrapper( + fusion_unpermuted_output, + fusion_unpermute_bwd_input, + forward_input=[fusion_unpermute_fwd_input, fusion_merging_probs], + retain_graph=True, + accumulate_grad=False, + ) + ) + print(f"unpermute_unpad_with_probs\tbwd: naive: {t1:.3f} ms, fusion: {t2:.3f} ms") + + def _test_permutation_mask_map_fp8( te_dtype, num_tokens, @@ -1512,6 +1743,39 @@ def test_permutation_and_padding_mask_map( ) +@pytest.mark.parametrize("te_dtype", _te_dtypes) +@pytest.mark.parametrize("num_out_tokens", [None]) +@pytest.mark.parametrize( + "num_tokens, num_expert, hidden_size, topK", + [ + (4096, 64, 1280, 7), + (4096, 64, 2048, 6), + (4096, 160, 5120, 6), + (4096, 256, 7168, 8), + ], +) +def test_permutation_and_padding_with_merging_probs( + te_dtype, + num_tokens, + num_expert, + hidden_size, + topK, + num_out_tokens, +): + """Test moe_unpermute backward pass with BOTH merging_probs AND pad_offsets.""" + BENCHMARK = False + + _test_permutation_and_padding_with_merging_probs( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + BENCHMARK=BENCHMARK, + ) + + @pytest.mark.parametrize("te_dtype", _te_dtypes) def test_permutation_mask_map_empty_input(te_dtype): with_probs = True @@ -1755,6 +2019,16 @@ def test_permutation_single_case(): BENCHMARK=Benchmark, ) + _test_permutation_and_padding_with_merging_probs( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + BENCHMARK=Benchmark, + ) + _test_moe_chunk_sort( te_dtype=te_dtype, num_tokens=num_tokens, @@ -1833,6 +2107,18 @@ def benchmark_single_case( ) torch.cuda.nvtx.range_pop() + torch.cuda.nvtx.range_push("permutation_and_padding_with_merging_probs") + _test_permutation_and_padding_with_merging_probs( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + BENCHMARK=True, + ) + torch.cuda.nvtx.range_pop() + torch.cuda.nvtx.range_push("permutation_mask_map_alongside_probs") _test_permutation_mask_map_alongside_probs( te_dtype=te_dtype, diff --git a/transformer_engine/common/triton/permutation.py b/transformer_engine/common/triton/permutation.py index b3a383f4230..de30c7c532d 100644 --- a/transformer_engine/common/triton/permutation.py +++ b/transformer_engine/common/triton/permutation.py @@ -443,7 +443,7 @@ def _unpermute_bwd_with_merging_probs_kernel( num_experts: tl.constexpr, hidden_size: tl.constexpr, PROBS_LOAD_WIDTH: tl.constexpr, - FUSION_PAD: tl.constexpr, + FUSION_UNPAD: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): data_type = fwd_output_grad_ptr.dtype.element_ty @@ -467,7 +467,7 @@ def _unpermute_bwd_with_merging_probs_kernel( + pid * stride_row_id_map_token + (num_experts + idx) * stride_row_id_map_expert ) - if FUSION_PAD: + if FUSION_UNPAD: pad_off = tl.load(pad_offsets_ptr + expert_idx) dst_row = dst_row + pad_off prob_grad_accum = tl.zeros((BLOCK_SIZE,), dtype=compute_type) diff --git a/transformer_engine/pytorch/triton/permutation.py b/transformer_engine/pytorch/triton/permutation.py index 1621d0e9fcb..5fe4c4d8b40 100644 --- a/transformer_engine/pytorch/triton/permutation.py +++ b/transformer_engine/pytorch/triton/permutation.py @@ -278,7 +278,7 @@ def unpermute_with_mask_map_bwd_with_merging_probs( row_id_map: torch.Tensor, fwd_input: torch.Tensor, merging_probs: torch.Tensor, - pad_offsets: torch.Tensor, + pad_offsets: Union[torch.Tensor, None], num_tokens: int, num_experts: int, num_out_tokens: int, @@ -341,7 +341,7 @@ def unpermute_with_mask_map_bwd_with_merging_probs( num_experts, hidden_size, PROBS_LOAD_WIDTH=triton.next_power_of_2(num_experts), - FUSION_PAD=pad_offsets is not None, + FUSION_UNPAD=pad_offsets is not None, ) return act_grad, merging_probs_grad From f3014625d5393e653a4f32efa507f311d91d0cf4 Mon Sep 17 00:00:00 2001 From: tdophung Date: Wed, 10 Dec 2025 17:47:42 -0800 Subject: [PATCH 07/19] Fix garbage initialized act grad Signed-off-by: tdophung --- transformer_engine/pytorch/triton/permutation.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/triton/permutation.py b/transformer_engine/pytorch/triton/permutation.py index 5fe4c4d8b40..fc98b8da083 100644 --- a/transformer_engine/pytorch/triton/permutation.py +++ b/transformer_engine/pytorch/triton/permutation.py @@ -309,10 +309,12 @@ def unpermute_with_mask_map_bwd_with_merging_probs( hidden_size : int Hidden size of the output tensor. """ - act_grad = ( - torch.empty((num_out_tokens, hidden_size), dtype=fwd_output_grad.dtype, device="cuda") - if pad_offsets is None - else torch.zeros((num_out_tokens, hidden_size), dtype=fwd_output_grad.dtype, device="cuda") + # Use zeros when pad_offsets is used because padding slots won't be written to + # by the kernel. This matches the behavior of Fp8Unpadding.backward which zeros + # out the padding slots. + alloc = torch.zeros if pad_offsets is not None else torch.empty + act_grad = alloc( + (num_out_tokens, hidden_size), dtype=fwd_output_grad.dtype, device="cuda" ) merging_probs_grad = torch.empty( (num_tokens, num_experts), dtype=merging_probs.dtype, device="cuda" From 7ed584c5231f80e0eb5ea7d31cc4808c344b6262 Mon Sep 17 00:00:00 2001 From: tdophung Date: Wed, 17 Dec 2025 14:33:38 -0800 Subject: [PATCH 08/19] all test passing for jax permutation + pad Signed-off-by: tdophung --- tests/jax/test_permutation.py | 311 ++++++- transformer_engine/jax/permutation.py | 420 ++++++++-- .../jax/triton_extensions/permutation.py | 789 ++++++++++++++++-- .../jax/triton_extensions/utils.py | 18 +- 4 files changed, 1389 insertions(+), 149 deletions(-) diff --git a/tests/jax/test_permutation.py b/tests/jax/test_permutation.py index 23d9f506090..7f73920a1e8 100644 --- a/tests/jax/test_permutation.py +++ b/tests/jax/test_permutation.py @@ -443,7 +443,7 @@ def test_token_dispatch(self, num_tokens, num_experts, hidden_size, tokens_per_e # Define loss functions def loss_fn(x): - output, _, _ = token_dispatch(x, routing_map, num_out_tokens) + output, _, _, _, _ = token_dispatch(x, routing_map, num_out_tokens) return jnp.sum(output**2) def ref_loss_fn(x): @@ -454,7 +454,7 @@ def ref_loss_fn(x): ref_loss_val, ref_grad = jax.value_and_grad(ref_loss_fn)(inp) # Compare forward outputs - output, _, _ = token_dispatch(inp, routing_map, num_out_tokens) + output, _, _, _, _ = token_dispatch(inp, routing_map, num_out_tokens) ref_output, _, _ = reference_token_dispatch(inp, routing_map, num_out_tokens) assert_allclose(output, ref_output) @@ -496,7 +496,7 @@ def test_token_dispatch_with_probs( # Define loss function that uses token_dispatch with probs # We compute gradients w.r.t. both inp and probs def loss_fn(x, p): - output, permuted_probs, _ = token_dispatch(x, routing_map, num_out_tokens, probs=p) + output, permuted_probs, _, _, _ = token_dispatch(x, routing_map, num_out_tokens, probs=p) return jnp.sum(output**2) + jnp.sum(permuted_probs**2) def ref_loss_fn(x, p): @@ -510,7 +510,7 @@ def ref_loss_fn(x, p): ref_loss_fn, argnums=(0, 1) )(inp, probs) - output, permuted_probs, _ = token_dispatch(inp, routing_map, num_out_tokens, probs=probs) + output, permuted_probs, _, _, _ = token_dispatch(inp, routing_map, num_out_tokens, probs=probs) ref_output, ref_permuted_probs, _ = reference_token_dispatch( inp, routing_map, num_out_tokens, probs=probs @@ -684,11 +684,310 @@ def test_dispatch_combine_roundtrip( jnp.sum(routing_map, axis=1, keepdims=True), 1.0 ) - # Dispatch tokens to experts (returns output, permuted_probs, row_id_map) - dispatched, _, row_id_map = token_dispatch(inp, routing_map, num_out_tokens) + # Dispatch tokens to experts (returns output, permuted_probs, row_id_map, ...) + dispatched, _, row_id_map, _, _ = token_dispatch(inp, routing_map, num_out_tokens) # Combine tokens back (with uniform merging) (new signature) combined = token_combine(dispatched, row_id_map, merging_probs) # Compare with original input assert_allclose(combined, inp) + + # ========================================================================= + # token_dispatch with padding tests (using unified API) + # ========================================================================= + + @pytest.mark.parametrize( + "num_tokens,num_experts,hidden_size,tokens_per_expert,align_size", + [ + (32, 8, 256, 2, 16), + (64, 16, 512, 3, 32), + (128, 8, 128, 4, 64), + ], + ) + @pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) + def test_token_dispatch_with_padding( + self, num_tokens, num_experts, hidden_size, tokens_per_expert, align_size, dtype + ): + """Test token_dispatch with padding forward and backward pass""" + key = jax.random.PRNGKey(42) + + # Generate routing map + routing_map = self.generate_routing_map(num_tokens, num_experts, tokens_per_expert, key) + tokens_per_expert_arr = jnp.sum(routing_map, axis=0).astype(jnp.int32) + num_out_tokens = int(jnp.sum(routing_map)) # Ignored when using padding + + # Generate input data + key, inp_key = jax.random.split(key) + inp = jax.random.uniform( + inp_key, (num_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0 + ) + + # Test forward pass with padding (using unified API) + # Note: num_out_tokens is not needed when using padding - it's computed internally + output, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert = token_dispatch( + inp, + routing_map, + tokens_per_expert=tokens_per_expert_arr, + align_size=align_size, + ) + + # Check output shape - should be padded + expected_padded_tokens = int(jnp.sum(target_tokens_per_expert)) + assert output.shape == (expected_padded_tokens, hidden_size) + assert permuted_probs is None # No probs provided + + # Check that each expert's tokens are aligned + for expert_idx in range(num_experts): + expert_tokens = int(target_tokens_per_expert[expert_idx]) + assert expert_tokens % align_size == 0 or expert_tokens == 0 + + # Test backward pass + def loss_fn(x): + out, _, _, _, _ = token_dispatch( + x, + routing_map, + tokens_per_expert=tokens_per_expert_arr, + align_size=align_size, + ) + return jnp.sum(out**2) + + grad = jax.grad(loss_fn)(inp) + assert grad.shape == inp.shape + assert not jnp.any(jnp.isnan(grad)) + + @pytest.mark.parametrize( + "num_tokens,num_experts,hidden_size,tokens_per_expert,align_size", + [ + (32, 8, 256, 2, 16), + (64, 16, 512, 3, 32), + ], + ) + @pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) + def test_token_dispatch_with_padding_and_probs( + self, num_tokens, num_experts, hidden_size, tokens_per_expert, align_size, dtype + ): + """Test token_dispatch with padding and probs""" + key = jax.random.PRNGKey(42) + + # Generate routing map + routing_map = self.generate_routing_map(num_tokens, num_experts, tokens_per_expert, key) + tokens_per_expert_arr = jnp.sum(routing_map, axis=0).astype(jnp.int32) + num_out_tokens = int(jnp.sum(routing_map)) # Ignored when using padding + + # Generate input data and probs + key, inp_key, prob_key = jax.random.split(key, 3) + inp = jax.random.uniform( + inp_key, (num_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0 + ) + probs = jax.random.uniform( + prob_key, (num_tokens, num_experts), dtype=dtype, minval=0.0, maxval=1.0 + ) + + # Test forward pass with padding and probs + # Note: num_out_tokens is not needed when using padding - it's computed internally + output, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert = token_dispatch( + inp, + routing_map, + probs=probs, + tokens_per_expert=tokens_per_expert_arr, + align_size=align_size, + ) + + # Check output shape + expected_padded_tokens = int(jnp.sum(target_tokens_per_expert)) + assert output.shape == (expected_padded_tokens, hidden_size) + assert permuted_probs is not None + assert permuted_probs.shape == (expected_padded_tokens,) + + # Test backward pass + def loss_fn(x, p): + out, perm_probs, _, _, _ = token_dispatch( + x, + routing_map, + probs=p, + tokens_per_expert=tokens_per_expert_arr, + align_size=align_size, + ) + return jnp.sum(out**2) + jnp.sum(perm_probs**2) + + (inp_grad, probs_grad) = jax.grad(loss_fn, argnums=(0, 1))(inp, probs) + assert inp_grad.shape == inp.shape + assert probs_grad.shape == probs.shape + assert not jnp.any(jnp.isnan(inp_grad)) + assert not jnp.any(jnp.isnan(probs_grad)) + + # ========================================================================= + # token_combine with unpad tests (using unified API) + # ========================================================================= + + @pytest.mark.parametrize( + "num_tokens,num_experts,hidden_size,tokens_per_expert,align_size", + [ + (32, 8, 256, 2, 16), + (64, 16, 512, 3, 32), + ], + ) + @pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) + @pytest.mark.parametrize("with_merging_probs", [True, False]) + def test_token_combine_with_unpad( + self, + num_tokens, + num_experts, + hidden_size, + tokens_per_expert, + align_size, + dtype, + with_merging_probs, + ): + """Test token_combine with unpad forward and backward pass""" + key = jax.random.PRNGKey(42) + + # Generate routing map + routing_map = self.generate_routing_map(num_tokens, num_experts, tokens_per_expert, key) + tokens_per_expert_arr = jnp.sum(routing_map, axis=0).astype(jnp.int32) + num_out_tokens = int(jnp.sum(routing_map)) + + # Generate input and dispatch with padding to get row_id_map and pad_offsets + key, inp_key = jax.random.split(key) + inp = jax.random.uniform( + inp_key, (num_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0 + ) + + _, _, row_id_map, pad_offsets, target_tokens_per_expert = token_dispatch( + inp, + routing_map, + tokens_per_expert=tokens_per_expert_arr, + align_size=align_size, + ) + + # Generate expert output data (padded) + expected_padded_tokens = int(jnp.sum(target_tokens_per_expert)) + key, expert_key, merge_key = jax.random.split(key, 3) + expert_output = jax.random.uniform( + expert_key, (expected_padded_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0 + ) + + if with_merging_probs: + merging_probs = jax.random.uniform( + merge_key, (num_tokens, num_experts), dtype=dtype, minval=0.0, maxval=1.0 + ) + # Normalize per token + merging_probs = merging_probs / (jnp.sum(merging_probs, axis=1, keepdims=True) + 1e-8) + else: + merging_probs = None + + # Test forward pass with unpad (using unified API) + output = token_combine(expert_output, row_id_map, merging_probs, pad_offsets) + assert output.shape == (num_tokens, hidden_size) + + # Test backward pass + def loss_fn(x): + out = token_combine(x, row_id_map, merging_probs, pad_offsets) + return jnp.sum(out**2) + + grad = jax.grad(loss_fn)(expert_output) + assert grad.shape == expert_output.shape + assert not jnp.any(jnp.isnan(grad)) + + # ========================================================================= + # Round-trip tests with padding + # ========================================================================= + + @pytest.mark.parametrize( + "num_tokens,num_experts,hidden_size,tokens_per_expert,align_size", + [ + (32, 8, 256, 2, 16), + (64, 16, 512, 3, 32), + (128, 8, 128, 4, 64), + ], + ) + @pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) + def test_dispatch_combine_with_padding_roundtrip( + self, num_tokens, num_experts, hidden_size, tokens_per_expert, align_size, dtype + ): + """Test that token_dispatch with padding followed by token_combine with unpad recovers input""" + key = jax.random.PRNGKey(42) + + # Generate routing map + routing_map = self.generate_routing_map(num_tokens, num_experts, tokens_per_expert, key) + tokens_per_expert_arr = jnp.sum(routing_map, axis=0).astype(jnp.int32) + num_out_tokens = int(jnp.sum(routing_map)) + + # Generate input data + key, inp_key = jax.random.split(key) + inp = jax.random.uniform( + inp_key, (num_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0 + ) + + # Create uniform merging probs (equal weight for all routed experts) + merging_probs = routing_map.astype(dtype) / jnp.maximum( + jnp.sum(routing_map, axis=1, keepdims=True), 1.0 + ) + + # Dispatch tokens to experts with padding + # Note: num_out_tokens is not needed when using padding - it's computed internally + dispatched, _, row_id_map, pad_offsets, _ = token_dispatch( + inp, + routing_map, + tokens_per_expert=tokens_per_expert_arr, + align_size=align_size, + ) + + # Combine tokens back with unpadding + combined = token_combine(dispatched, row_id_map, merging_probs, pad_offsets) + + # Compare with original input + assert_allclose(combined, inp) + + @pytest.mark.parametrize( + "num_tokens,num_experts,hidden_size,tokens_per_expert,align_size", + [ + (32, 8, 256, 2, 16), + (64, 16, 512, 3, 32), + ], + ) + @pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) + def test_dispatch_combine_with_padding_gradient_flow( + self, num_tokens, num_experts, hidden_size, tokens_per_expert, align_size, dtype + ): + """Test gradient flow through dispatch with padding -> combine with unpad""" + key = jax.random.PRNGKey(42) + + # Generate routing map + routing_map = self.generate_routing_map(num_tokens, num_experts, tokens_per_expert, key) + tokens_per_expert_arr = jnp.sum(routing_map, axis=0).astype(jnp.int32) + num_out_tokens = int(jnp.sum(routing_map)) + + # Generate input data + key, inp_key = jax.random.split(key) + inp = jax.random.uniform( + inp_key, (num_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0 + ) + + # Create uniform merging probs + merging_probs = routing_map.astype(dtype) / jnp.maximum( + jnp.sum(routing_map, axis=1, keepdims=True), 1.0 + ) + + # Define end-to-end function + def forward(x): + dispatched, _, row_id_map, pad_offsets, _ = token_dispatch( + x, + routing_map, + tokens_per_expert=tokens_per_expert_arr, + align_size=align_size, + ) + # Simulate some expert processing (e.g., scaling) + processed = dispatched * 2.0 + combined = token_combine(processed, row_id_map, merging_probs, pad_offsets) + return jnp.sum(combined**2) + + # Test gradient computation + grad = jax.grad(forward)(inp) + assert grad.shape == inp.shape + assert not jnp.any(jnp.isnan(grad)) + + # Verify gradient is non-zero for inputs that are routed + routed_mask = jnp.any(routing_map > 0, axis=1) + assert jnp.any(grad[routed_mask] != 0) diff --git a/transformer_engine/jax/permutation.py b/transformer_engine/jax/permutation.py index 55a59a1650a..24b85cc03e3 100644 --- a/transformer_engine/jax/permutation.py +++ b/transformer_engine/jax/permutation.py @@ -16,6 +16,7 @@ - Backward: Permute gradients (scatter to experts) """ +import warnings from functools import partial from typing import Optional, Tuple @@ -25,8 +26,11 @@ from transformer_engine.jax.triton_extensions.permutation import ( make_row_id_map, permute_with_mask_map, + permute_with_mask_map_and_pad, unpermute_with_mask_map, + unpermute_with_mask_map_and_unpad, unpermute_bwd_with_merging_probs, + unpermute_bwd_with_merging_probs_and_unpad, make_chunk_sort_map, sort_chunks_by_map, ) @@ -41,9 +45,17 @@ def token_dispatch( inp: jnp.ndarray, routing_map: jnp.ndarray, - num_out_tokens: int, + num_out_tokens: Optional[int] = None, probs: Optional[jnp.ndarray] = None, -) -> Tuple[jnp.ndarray, Optional[jnp.ndarray], jnp.ndarray]: + tokens_per_expert: Optional[jnp.ndarray] = None, + align_size: Optional[int] = None, +) -> Tuple[ + jnp.ndarray, + Optional[jnp.ndarray], + jnp.ndarray, + Optional[jnp.ndarray], + Optional[jnp.ndarray], +]: """ Dispatch tokens to experts based on routing map. @@ -51,6 +63,10 @@ def token_dispatch( to their designated experts according to the routing map. The row_id_map is computed internally from the routing_map. + Optionally supports fused padding for alignment when both `tokens_per_expert` + and `align_size` are provided. This is useful for efficient matrix multiplications + that require aligned tensor dimensions. + Parameters ---------- inp : jnp.ndarray @@ -58,37 +74,120 @@ def token_dispatch( routing_map : jnp.ndarray Routing mask of shape [batch, sequence, num_experts] or [num_tokens, num_experts]. Values: 1 = routed, 0 = not routed. - num_out_tokens : int - The number of output tokens after permutation. This should equal the sum of - routing_map and must be provided explicitly for JIT compatibility. + num_out_tokens : Optional[int], default = None + The number of output tokens after permutation. For the dropless case, this should be equal to + the sum of routing_map and must be provided explicitly for JIT compatibility when NOT + using padding. + When using padding (tokens_per_expert and align_size provided), this value + is ignored and computed internally based on aligned sizes. If provided along + with padding parameters, a warning will be issued. probs : Optional[jnp.ndarray] Optional routing probabilities of shape [batch, sequence, num_experts] or [num_tokens, num_experts]. If provided, permuted_probs will be returned. + tokens_per_expert : Optional[jnp.ndarray] + Optional tensor of shape [num_experts] containing actual token counts per expert. + Required for fused padding. If provided along with align_size, outputs will be + padded to align each expert's tokens, and num_out_tokens will be computed internally. + align_size : Optional[int] + Optional alignment size for padding. Required for fused padding. + Each expert's tokens will be padded to a multiple of this size. Returns ------- output : jnp.ndarray - Permuted output tensor of shape [num_out_tokens, hidden_size]. + Permuted output tensor of shape [num_out_tokens, hidden_size] + (or [num_out_tokens_padded, hidden_size] when using padding fusion). permuted_probs : Optional[jnp.ndarray] Permuted probabilities of shape [num_out_tokens], or None if probs was not provided. row_id_map : jnp.ndarray Row ID map for use in token_combine (shape [num_tokens, num_experts * 2 + 1]). + pad_offsets : Optional[jnp.ndarray] + Per-expert cumulative padding offsets of shape [num_experts] when using padding, + None otherwise. Pass this to token_combine when unpadding is needed. + target_tokens_per_expert : Optional[jnp.ndarray] + Aligned token counts per expert of shape [num_experts] when using padding, + None otherwise. + + Note + ---- + **JIT Compatibility with Fused Padding:** + + When using fused padding (tokens_per_expert and align_size provided), the output + size is computed from `tokens_per_expert` values. This requires concrete (non-traced) + values at compile time because JAX needs to know output shapes during tracing. + + If `tokens_per_expert` contains traced values (e.g., computed from traced inputs + inside a JIT-compiled function), a ValueError will be raised with instructions. + + To ensure compatibility, compute `tokens_per_expert` outside the JIT boundary + and pass it as a concrete array argument to the JIT-compiled function. + + Without padding (only `num_out_tokens` provided), the function is fully JIT-compatible + since `num_out_tokens` is a Python int known at trace time. """ - return _token_dispatch(inp, routing_map, probs, num_out_tokens) + # Check that both or neither padding parameters are provided + use_padding = tokens_per_expert is not None and align_size is not None + if (tokens_per_expert is None) != (align_size is None): + raise ValueError( + "Both tokens_per_expert and align_size must be provided together for fused padding, " + "or both must be None." + ) + + # Validate num_out_tokens usage + if use_padding: + if num_out_tokens is not None: + warnings.warn( + "num_out_tokens is ignored when using fused padding (tokens_per_expert and " + "align_size are provided). The output token count will be computed internally " + "based on the aligned tokens_per_expert.", + UserWarning, + stacklevel=2, + ) + # Set a dummy value - will be recomputed in the forward rule + actual_num_out_tokens = -1 + else: + if num_out_tokens is None: + raise ValueError( + "num_out_tokens must be provided when not using fused padding. " + "Either provide num_out_tokens, or provide both tokens_per_expert and align_size " + "for fused padding." + ) + actual_num_out_tokens = num_out_tokens + + return _token_dispatch( + inp, routing_map, probs, actual_num_out_tokens, tokens_per_expert, align_size, use_padding + ) -@partial(jax.custom_vjp, nondiff_argnums=(1, 3)) +@partial(jax.custom_vjp, nondiff_argnums=(1, 3, 4, 5, 6)) def _token_dispatch( inp: jnp.ndarray, routing_map: jnp.ndarray, probs: Optional[jnp.ndarray], num_out_tokens: int, -) -> Tuple[jnp.ndarray, Optional[jnp.ndarray], jnp.ndarray]: + tokens_per_expert: Optional[jnp.ndarray], + align_size: Optional[int], + use_padding: bool, +) -> Tuple[ + jnp.ndarray, + Optional[jnp.ndarray], + jnp.ndarray, + Optional[jnp.ndarray], + Optional[jnp.ndarray], +]: """Internal token_dispatch with custom VJP.""" - (output, permuted_probs, row_id_map), _ = _token_dispatch_fwd_rule( - inp, routing_map, probs, num_out_tokens + (output, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert), _ = ( + _token_dispatch_fwd_rule( + inp, + routing_map, + probs, + num_out_tokens, + tokens_per_expert, + align_size, + use_padding, + ) ) - return output, permuted_probs, row_id_map + return output, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert def _token_dispatch_fwd_rule( @@ -96,9 +195,18 @@ def _token_dispatch_fwd_rule( routing_map: jnp.ndarray, probs: Optional[jnp.ndarray], num_out_tokens: int, + tokens_per_expert: Optional[jnp.ndarray], + align_size: Optional[int], + use_padding: bool, ) -> Tuple[ - Tuple[jnp.ndarray, Optional[jnp.ndarray], jnp.ndarray], - Tuple[jnp.ndarray, int, int, int, bool], + Tuple[ + jnp.ndarray, + Optional[jnp.ndarray], + jnp.ndarray, + Optional[jnp.ndarray], + Optional[jnp.ndarray], + ], + Tuple[jnp.ndarray, Optional[jnp.ndarray], int, int, int, bool], ]: """Forward pass rule for token_dispatch.""" # Validate input dimensions @@ -126,42 +234,116 @@ def _token_dispatch_fwd_rule( with_probs = probs is not None - output, permuted_probs = permute_with_mask_map( - inp, - row_id_map, - probs, - num_tokens, - num_experts, - num_out_tokens, - hidden_size, - ) + if use_padding: + # Ensure tokens_per_expert contains concrete values (not traced). + # This is required because the output shape depends on the sum of aligned token counts. + # Using jax.ensure_compile_time_eval will raise a clear ConcretizationTypeError + # if tokens_per_expert is a traced array. + try: + with jax.ensure_compile_time_eval(): + # Calculate aligned token counts per expert + target_tokens_per_expert = ( + jnp.ceil(tokens_per_expert / align_size) * align_size + ).astype(jnp.int32) + + # Always compute pad_offsets when use_padding=True + # This ensures deterministic control flow for JIT compilation. + # If no padding is actually needed (tokens already aligned), pad_offsets + # will be all zeros, and the kernel handles this correctly (adding 0 is a no-op). + pad_lengths = target_tokens_per_expert - tokens_per_expert + cum_pad = jnp.cumsum(pad_lengths) + pad_offsets = jnp.concatenate( + [jnp.array([0], dtype=cum_pad.dtype), cum_pad[:-1]] + ) + + actual_num_out_tokens = int(jnp.sum(target_tokens_per_expert)) + except jax.errors.ConcretizationTypeError as e: + raise ValueError( + "tokens_per_expert must contain concrete (non-traced) values when using " + "fused padding. The output shape depends on the sum of aligned token counts, " + "which must be known at compile time. " + "Ensure tokens_per_expert is computed outside the JIT boundary or passed as " + "a concrete array to the JIT-compiled function." + ) from e + + # Always use the padded kernel when use_padding=True (static branch) + output, permuted_probs = permute_with_mask_map_and_pad( + inp, + row_id_map, + probs, + pad_offsets, + num_tokens, + num_experts, + actual_num_out_tokens, + hidden_size, + ) + else: + # No padding + pad_offsets = None + target_tokens_per_expert = None + + output, permuted_probs = permute_with_mask_map( + inp, + row_id_map, + probs, + num_tokens, + num_experts, + num_out_tokens, + hidden_size, + ) # Return (primals, residuals) - # Include with_probs flag to know how to handle backward pass - residuals = (row_id_map, num_tokens, num_experts, hidden_size, with_probs) - return (output, permuted_probs, row_id_map), residuals + residuals = (row_id_map, pad_offsets, num_tokens, num_experts, hidden_size, with_probs) + return ( + output, + permuted_probs, + row_id_map, + pad_offsets, + target_tokens_per_expert, + ), residuals def _token_dispatch_bwd_rule( _routing_map: jnp.ndarray, _num_out_tokens: int, - residuals: Tuple[jnp.ndarray, int, int, int, bool], - g: Tuple[jnp.ndarray, Optional[jnp.ndarray], jnp.ndarray], + _tokens_per_expert: Optional[jnp.ndarray], + _align_size: Optional[int], + _use_padding: bool, + residuals: Tuple[jnp.ndarray, Optional[jnp.ndarray], int, int, int, bool], + g: Tuple[ + jnp.ndarray, + Optional[jnp.ndarray], + jnp.ndarray, + Optional[jnp.ndarray], + Optional[jnp.ndarray], + ], ) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]: """Backward pass rule for token_dispatch.""" - row_id_map, num_tokens, num_experts, hidden_size, with_probs = residuals - output_grad, permuted_probs_grad, _ = g # Ignore row_id_map gradient + row_id_map, pad_offsets, num_tokens, num_experts, hidden_size, with_probs = residuals + output_grad, permuted_probs_grad, _, _, _ = g # Ignore row_id_map, pad_offsets, target grads # Backward: unpermute gradients (gather from experts back to tokens) - inp_grad, probs_grad = unpermute_with_mask_map( - output_grad, - row_id_map, - None, # No merging probs - permuted_probs_grad if with_probs else None, - num_tokens, - num_experts, - hidden_size, - ) + if pad_offsets is not None: + inp_grad, probs_grad = unpermute_with_mask_map_and_unpad( + output_grad, + row_id_map, + None, # No merging probs + permuted_probs_grad if with_probs else None, + pad_offsets, + num_tokens, + num_experts, + hidden_size, + ) + else: + inp_grad, probs_grad = unpermute_with_mask_map( + output_grad, + row_id_map, + None, # No merging probs + permuted_probs_grad if with_probs else None, + num_tokens, + num_experts, + hidden_size, + ) return inp_grad, probs_grad if with_probs else None @@ -178,6 +360,7 @@ def token_combine( inp: jnp.ndarray, row_id_map: jnp.ndarray, merging_probs: Optional[jnp.ndarray] = None, + pad_offsets: Optional[jnp.ndarray] = None, ) -> jnp.ndarray: """ Combine tokens from experts back to original token positions. @@ -185,33 +368,42 @@ def token_combine( This is the forward pass of MoE unpermutation. Tokens are gathered from experts and merged (optionally weighted by merging_probs). + Optionally supports fused unpadding when `pad_offsets` is provided (from + token_dispatch with padding enabled). + Parameters ---------- inp : jnp.ndarray - Input tensor from experts of shape [num_out_tokens, hidden_size]. + Input tensor from experts of shape [num_out_tokens, hidden_size] + (or [num_out_tokens_padded, hidden_size] when using unpadding). row_id_map : jnp.ndarray Row ID map from token_dispatch of shape [num_tokens, num_experts * 2 + 1]. merging_probs : Optional[jnp.ndarray] Merging weights of shape [batch, sequence, num_experts] or [num_tokens, num_experts]. If provided, tokens from different experts are weighted-summed. If None, tokens are summed directly. + pad_offsets : Optional[jnp.ndarray] + Per-expert cumulative padding offsets of shape [num_experts] from token_dispatch. + If provided, fused unpadding will be performed. This should be the pad_offsets + returned by token_dispatch when using padding. Returns ------- output : jnp.ndarray Combined output tensor of shape [num_tokens, hidden_size]. """ - return _token_combine(inp, row_id_map, merging_probs) + return _token_combine(inp, row_id_map, merging_probs, pad_offsets) -@partial(jax.custom_vjp, nondiff_argnums=(1,)) +@jax.custom_vjp def _token_combine( inp: jnp.ndarray, row_id_map: jnp.ndarray, merging_probs: Optional[jnp.ndarray], + pad_offsets: Optional[jnp.ndarray], ) -> jnp.ndarray: """Internal token_combine with custom VJP.""" - output, _ = _token_combine_fwd_rule(inp, row_id_map, merging_probs) + output, _ = _token_combine_fwd_rule(inp, row_id_map, merging_probs, pad_offsets) return output @@ -219,7 +411,20 @@ def _token_combine_fwd_rule( inp: jnp.ndarray, row_id_map: jnp.ndarray, merging_probs: Optional[jnp.ndarray], -) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray, Optional[jnp.ndarray], int, int, int, int]]: + pad_offsets: Optional[jnp.ndarray], +) -> Tuple[ + jnp.ndarray, + Tuple[ + jnp.ndarray, + Optional[jnp.ndarray], + jnp.ndarray, + Optional[jnp.ndarray], + int, + int, + int, + int, + ], +]: """Forward pass rule for token_combine.""" # Infer dimensions from row_id_map shape: [num_tokens, num_experts * 2 + 1] num_tokens = row_id_map.shape[0] @@ -227,21 +432,34 @@ def _token_combine_fwd_rule( hidden_size = inp.shape[-1] num_out_tokens = inp.shape[0] - # Call triton extension - output, _ = unpermute_with_mask_map( - inp, - row_id_map, - merging_probs, - None, # No permuted probs to unpermute - num_tokens, - num_experts, - hidden_size, - ) + # Call triton extension with or without unpadding + if pad_offsets is not None: + output, _ = unpermute_with_mask_map_and_unpad( + inp, + row_id_map, + merging_probs, + None, # No permuted probs to unpermute + pad_offsets, + num_tokens, + num_experts, + hidden_size, + ) + else: + output, _ = unpermute_with_mask_map( + inp, + row_id_map, + merging_probs, + None, # No permuted probs to unpermute + num_tokens, + num_experts, + hidden_size, + ) # Return (primal, residuals) # Include inp in residuals for backward with merging_probs residuals = ( row_id_map, + pad_offsets, inp, merging_probs, num_tokens, @@ -253,13 +471,26 @@ def _token_combine_fwd_rule( def _token_combine_bwd_rule( - row_id_map: jnp.ndarray, - residuals: Tuple[jnp.ndarray, jnp.ndarray, Optional[jnp.ndarray], int, int, int, int], + residuals: Tuple[ + jnp.ndarray, + Optional[jnp.ndarray], + jnp.ndarray, + Optional[jnp.ndarray], + int, + int, + int, + int, + ], g: jnp.ndarray, -) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]: - """Backward pass rule for token_combine.""" +) -> Tuple[jnp.ndarray, None, Optional[jnp.ndarray], None]: + """Backward pass rule for token_combine. + + Returns gradients for: (inp, row_id_map, merging_probs, pad_offsets) + row_id_map and pad_offsets are integer arrays, so their gradients are None. + """ ( row_id_map, + pad_offsets, fwd_input, merging_probs, num_tokens, @@ -273,30 +504,63 @@ def _token_combine_bwd_rule( if with_merging_probs: # Use specialized backward kernel that properly scales by merging_probs - inp_grad, merging_probs_grad = unpermute_bwd_with_merging_probs( - output_grad, - row_id_map, - fwd_input, - merging_probs, - num_tokens, - num_experts, - num_out_tokens, - hidden_size, - ) + if pad_offsets is not None: + inp_grad, merging_probs_grad = unpermute_bwd_with_merging_probs_and_unpad( + output_grad, + row_id_map, + fwd_input, + merging_probs, + pad_offsets, + num_tokens, + num_experts, + num_out_tokens, + hidden_size, + ) + # The backward kernel only writes to positions that tokens map to. + # Padded positions may contain uninitialized (NaN) values - replace with zeros. + inp_grad = jnp.where(jnp.isnan(inp_grad), 0.0, inp_grad) + else: + inp_grad, merging_probs_grad = unpermute_bwd_with_merging_probs( + output_grad, + row_id_map, + fwd_input, + merging_probs, + num_tokens, + num_experts, + num_out_tokens, + hidden_size, + ) else: # Simple case: just permute gradients back - inp_grad, _ = permute_with_mask_map( - output_grad, - row_id_map, - None, - num_tokens, - num_experts, - num_out_tokens, - hidden_size, - ) + if pad_offsets is not None: + inp_grad, _ = permute_with_mask_map_and_pad( + output_grad, + row_id_map, + None, + pad_offsets, + num_tokens, + num_experts, + num_out_tokens, + hidden_size, + ) + # The permute kernel only writes to positions that tokens map to. + # Padded positions may contain uninitialized (NaN) values - replace with zeros. + inp_grad = jnp.where(jnp.isnan(inp_grad), 0.0, inp_grad) + else: + inp_grad, _ = permute_with_mask_map( + output_grad, + row_id_map, + None, + num_tokens, + num_experts, + num_out_tokens, + hidden_size, + ) merging_probs_grad = None - return inp_grad, merging_probs_grad + # Return gradients for: inp, row_id_map, merging_probs, pad_offsets + # row_id_map and pad_offsets are integer arrays, so their gradients are None + return inp_grad, None, merging_probs_grad, None _token_combine.defvjp(_token_combine_fwd_rule, _token_combine_bwd_rule) diff --git a/transformer_engine/jax/triton_extensions/permutation.py b/transformer_engine/jax/triton_extensions/permutation.py index 4f59f65a87e..206164665da 100644 --- a/transformer_engine/jax/triton_extensions/permutation.py +++ b/transformer_engine/jax/triton_extensions/permutation.py @@ -27,8 +27,11 @@ __all__ = [ "make_row_id_map", "permute_with_mask_map", + "permute_with_mask_map_and_pad", "unpermute_with_mask_map", + "unpermute_with_mask_map_and_unpad", "unpermute_bwd_with_merging_probs", + "unpermute_bwd_with_merging_probs_and_unpad", "make_chunk_sort_map", "sort_chunks_by_map", ] @@ -248,14 +251,14 @@ class PermuteWithMaskMapPrimitive(BasePrimitive): name = "te_permute_with_mask_map_triton" multiple_results = True - # scale and permuted_scale are dummy inputs (not used when PERMUTE_SCALE=False) + # scale, permuted_scale, and pad_offsets are dummy inputs (not used when PERMUTE_SCALE=False, FUSION_PAD=False) # but they need to be in the signature for the kernel call impl_static_args = ( - 5, 6, 7, 8, 9, + 10, ) # num_tokens, num_experts, num_out_tokens, hidden_size, with_probs inner_primitive = None outer_primitive = None @@ -267,6 +270,7 @@ def abstract( probs_aval, scale_aval, # dummy, same shape as inp permuted_scale_aval, # dummy, same shape as inp + pad_offsets_aval, # dummy, not used when FUSION_PAD=False *, num_tokens, num_experts, @@ -275,7 +279,7 @@ def abstract( with_probs, ): """Shape/dtype inference for permute.""" - del row_id_map_aval, scale_aval, permuted_scale_aval + del row_id_map_aval, scale_aval, permuted_scale_aval, pad_offsets_aval del num_tokens, num_experts output_shape = (num_out_tokens, hidden_size) @@ -295,6 +299,7 @@ def impl( probs, scale, permuted_scale, + pad_offsets, num_tokens, num_experts, num_out_tokens, @@ -309,6 +314,7 @@ def impl( probs, scale, permuted_scale, + pad_offsets, num_tokens=num_tokens, num_experts=num_experts, num_out_tokens=num_out_tokens, @@ -324,6 +330,7 @@ def lowering( probs, scale, permuted_scale, + pad_offsets, *, num_tokens, num_experts, @@ -359,6 +366,7 @@ def lowering( block_size = _get_min_block_size(_permute_kernel) grid = (num_tokens, triton.cdiv(hidden_size, block_size)) + # Pass all 6 inputs including pad_offsets (even though FUSION_PAD=False) return triton_call_lowering( ctx, _permute_kernel, @@ -367,6 +375,7 @@ def lowering( probs, scale, permuted_scale, + pad_offsets, grid=grid, constexprs={ "scale_hidden_dim": 0, @@ -387,6 +396,7 @@ def lowering( "hidden_size": hidden_size, "PERMUTE_PROBS": with_probs, "PERMUTE_SCALE": False, + "FUSION_PAD": False, "BLOCK_SIZE": block_size, }, ) @@ -395,6 +405,167 @@ def lowering( register_primitive(PermuteWithMaskMapPrimitive) +class PermuteWithMaskMapAndPadPrimitive(BasePrimitive): + """ + Permute the input tensor based on the row_id_map with fused padding. + """ + + name = "te_permute_with_mask_map_and_pad_triton" + multiple_results = True + # scale and permuted_scale are dummy inputs (not used when PERMUTE_SCALE=False) + # Order must match kernel: inp, row_id_map, probs, scale, permuted_scale, pad_offsets + impl_static_args = ( + 6, + 7, + 8, + 9, + 10, + ) # num_tokens, num_experts, num_out_tokens, hidden_size, with_probs + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + inp_aval, + row_id_map_aval, + probs_aval, + scale_aval, # dummy, same shape as inp + permuted_scale_aval, # dummy, same shape as inp + pad_offsets_aval, + *, + num_tokens, + num_experts, + num_out_tokens, + hidden_size, + with_probs, + ): + """Shape/dtype inference for permute with padding.""" + del row_id_map_aval, scale_aval, permuted_scale_aval, pad_offsets_aval + del num_tokens, num_experts + + output_shape = (num_out_tokens, hidden_size) + output_aval = jax.core.ShapedArray(output_shape, inp_aval.dtype) + + if with_probs: + permuted_probs_aval = jax.core.ShapedArray((num_out_tokens,), probs_aval.dtype) + else: + permuted_probs_aval = jax.core.ShapedArray((0,), inp_aval.dtype) + + return output_aval, permuted_probs_aval + + @staticmethod + def impl( + inp, + row_id_map, + probs, + scale, + permuted_scale, + pad_offsets, + num_tokens, + num_experts, + num_out_tokens, + hidden_size, + with_probs, + ): + """Forward to inner primitive.""" + assert PermuteWithMaskMapAndPadPrimitive.inner_primitive is not None + return PermuteWithMaskMapAndPadPrimitive.inner_primitive.bind( + inp, + row_id_map, + probs, + scale, + permuted_scale, + pad_offsets, + num_tokens=num_tokens, + num_experts=num_experts, + num_out_tokens=num_out_tokens, + hidden_size=hidden_size, + with_probs=with_probs, + ) + + @staticmethod + def lowering( + ctx, + inp, + row_id_map, + probs, + scale, + permuted_scale, + pad_offsets, + *, + num_tokens, + num_experts, + num_out_tokens, + hidden_size, + with_probs, + ): + """MLIR lowering using triton_call_lowering.""" + del num_out_tokens + inp_stride_token = hidden_size + inp_stride_hidden = 1 + output_stride_token = hidden_size + output_stride_hidden = 1 + row_id_stride_token = num_experts * 2 + 1 + row_id_stride_expert = 1 + permuted_probs_stride_token = 1 + + if with_probs: + # Check if probs is 2D [num_tokens, num_experts] or 1D [num_tokens] + probs_aval = ctx.avals_in[2] + if len(probs_aval.shape) > 1: + probs_stride_token = num_experts + probs_stride_expert = 1 + else: + probs_stride_token = 1 + probs_stride_expert = 1 + else: + probs_stride_token = 0 + probs_stride_expert = 0 + + # Grid function equivalent: (num_tokens, cdiv(hidden_size, BLOCK_SIZE)) + # Use minimum BLOCK_SIZE from autotune configs to ensure grid covers all elements + block_size = _get_min_block_size(_permute_kernel) + grid = (num_tokens, triton.cdiv(hidden_size, block_size)) + + # Args order must match kernel: inp, row_id_map, probs, scale, permuted_scale, pad_offsets + return triton_call_lowering( + ctx, + _permute_kernel, + inp, + row_id_map, + probs, + scale, + permuted_scale, + pad_offsets, + grid=grid, + constexprs={ + "scale_hidden_dim": 0, + "stride_row_id_map_token": row_id_stride_token, + "stride_row_id_map_expert": row_id_stride_expert, + "stride_input_token": inp_stride_token, + "stride_input_hidden": inp_stride_hidden, + "stride_output_token": output_stride_token, + "stride_output_hidden": output_stride_hidden, + "stride_probs_token": probs_stride_token, + "stride_probs_expert": probs_stride_expert, + "stride_scale_token": hidden_size, + "stride_scale_hidden": 1, + "stride_permuted_probs_token": permuted_probs_stride_token, + "stride_permuted_scale_token": hidden_size, + "stride_permuted_scale_hidden": 1, + "num_experts": num_experts, + "hidden_size": hidden_size, + "PERMUTE_PROBS": with_probs, + "PERMUTE_SCALE": False, + "FUSION_PAD": True, + "BLOCK_SIZE": block_size, + }, + ) + + +register_primitive(PermuteWithMaskMapAndPadPrimitive) + + class UnpermuteWithMaskMapPrimitive(BasePrimitive): """ Unpermute the input tensor based on the row_id_map. @@ -403,11 +574,11 @@ class UnpermuteWithMaskMapPrimitive(BasePrimitive): name = "te_unpermute_with_mask_map_triton" multiple_results = True impl_static_args = ( - 4, 5, 6, 7, 8, + 9, ) # num_tokens, num_experts, hidden_size, with_merging_probs, with_probs inner_primitive = None outer_primitive = None @@ -418,6 +589,7 @@ def abstract( row_id_map_aval, merging_probs_aval, permuted_probs_aval, + pad_offsets_aval, # dummy, not used when FUSION_UNPAD=False *, num_tokens, num_experts, @@ -426,7 +598,7 @@ def abstract( with_probs, ): """Shape/dtype inference for unpermute.""" - del row_id_map_aval, merging_probs_aval, with_merging_probs + del row_id_map_aval, merging_probs_aval, with_merging_probs, pad_offsets_aval output_shape = (num_tokens, hidden_size) output_aval = jax.core.ShapedArray(output_shape, inp_aval.dtype) @@ -447,6 +619,7 @@ def impl( row_id_map, merging_probs, permuted_probs, + pad_offsets, num_tokens, num_experts, hidden_size, @@ -459,95 +632,389 @@ def impl( inp, row_id_map, merging_probs, - permuted_probs, + permuted_probs, + pad_offsets, + num_tokens=num_tokens, + num_experts=num_experts, + hidden_size=hidden_size, + with_merging_probs=with_merging_probs, + with_probs=with_probs, + ) + + @staticmethod + def lowering( + ctx, + inp, + row_id_map, + merging_probs, + permuted_probs, + pad_offsets, + *, + num_tokens, + num_experts, + hidden_size, + with_merging_probs, + with_probs, + ): + """MLIR lowering using triton_call_lowering.""" + # Compute strides + inp_stride_token = hidden_size + inp_stride_hidden = 1 + output_stride_token = hidden_size + output_stride_hidden = 1 + row_id_stride_token = num_experts * 2 + 1 + row_id_stride_expert = 1 + + if with_merging_probs: + merging_probs_stride_token = num_experts + merging_probs_stride_expert = 1 + else: + merging_probs_stride_token = 0 + merging_probs_stride_expert = 0 + + permuted_probs_stride_token = 1 + unpermuted_probs_stride_token = num_experts + unpermuted_probs_stride_expert = 1 + + # Grid - use minimum BLOCK_SIZE from autotune configs + block_size = _get_min_block_size(_unpermute_kernel) + grid = (num_tokens, triton.cdiv(hidden_size, block_size)) + + # Pass all 5 inputs including pad_offsets (even though FUSION_UNPAD=False) + return triton_call_lowering( + ctx, + _unpermute_kernel, + inp, + row_id_map, + merging_probs, + permuted_probs, + pad_offsets, + grid=grid, + constexprs={ + "stride_row_id_map_token": row_id_stride_token, + "stride_row_id_map_expert": row_id_stride_expert, + "stride_input_token": inp_stride_token, + "stride_input_hidden": inp_stride_hidden, + "stride_output_token": output_stride_token, + "stride_output_hidden": output_stride_hidden, + "stride_merging_probs_token": merging_probs_stride_token, + "stride_merging_probs_expert": merging_probs_stride_expert, + "stride_permuted_probs_token": permuted_probs_stride_token, + "stride_unpermuted_probs_token": unpermuted_probs_stride_token, + "stride_unpermuted_probs_expert": unpermuted_probs_stride_expert, + "num_experts": num_experts, + "hidden_size": hidden_size, + "PROBS_LOAD_WIDTH": triton.next_power_of_2(num_experts), + "WITH_MERGING_PROBS": with_merging_probs, + "PERMUTE_PROBS": with_probs, + "FUSION_UNPAD": False, + "BLOCK_SIZE": block_size, + }, + ) + + +register_primitive(UnpermuteWithMaskMapPrimitive) + + +class UnpermuteWithMaskMapAndUnpadPrimitive(BasePrimitive): + """ + Unpermute the input tensor based on the row_id_map with fused unpadding. + """ + + name = "te_unpermute_with_mask_map_and_unpad_triton" + multiple_results = True + impl_static_args = ( + 5, + 6, + 7, + 8, + 9, + ) # num_tokens, num_experts, hidden_size, with_merging_probs, with_probs + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + inp_aval, + row_id_map_aval, + merging_probs_aval, + permuted_probs_aval, + pad_offsets_aval, + *, + num_tokens, + num_experts, + hidden_size, + with_merging_probs, + with_probs, + ): + """Shape/dtype inference for unpermute with unpadding.""" + del row_id_map_aval, merging_probs_aval, with_merging_probs, pad_offsets_aval + + output_shape = (num_tokens, hidden_size) + output_aval = jax.core.ShapedArray(output_shape, inp_aval.dtype) + + if with_probs: + unpermuted_probs_shape = (num_tokens, num_experts) + unpermuted_probs_aval = jax.core.ShapedArray( + unpermuted_probs_shape, permuted_probs_aval.dtype + ) + else: + unpermuted_probs_aval = jax.core.ShapedArray((0,), inp_aval.dtype) + + return output_aval, unpermuted_probs_aval + + @staticmethod + def impl( + inp, + row_id_map, + merging_probs, + permuted_probs, + pad_offsets, + num_tokens, + num_experts, + hidden_size, + with_merging_probs, + with_probs, + ): + """Forward to inner primitive.""" + assert UnpermuteWithMaskMapAndUnpadPrimitive.inner_primitive is not None + return UnpermuteWithMaskMapAndUnpadPrimitive.inner_primitive.bind( + inp, + row_id_map, + merging_probs, + permuted_probs, + pad_offsets, + num_tokens=num_tokens, + num_experts=num_experts, + hidden_size=hidden_size, + with_merging_probs=with_merging_probs, + with_probs=with_probs, + ) + + @staticmethod + def lowering( + ctx, + inp, + row_id_map, + merging_probs, + permuted_probs, + pad_offsets, + *, + num_tokens, + num_experts, + hidden_size, + with_merging_probs, + with_probs, + ): + """MLIR lowering using triton_call_lowering.""" + # Compute strides + inp_stride_token = hidden_size + inp_stride_hidden = 1 + output_stride_token = hidden_size + output_stride_hidden = 1 + row_id_stride_token = num_experts * 2 + 1 + row_id_stride_expert = 1 + + if with_merging_probs: + merging_probs_stride_token = num_experts + merging_probs_stride_expert = 1 + else: + merging_probs_stride_token = 0 + merging_probs_stride_expert = 0 + + permuted_probs_stride_token = 1 + unpermuted_probs_stride_token = num_experts + unpermuted_probs_stride_expert = 1 + + # Grid - use minimum BLOCK_SIZE from autotune configs + block_size = _get_min_block_size(_unpermute_kernel) + grid = (num_tokens, triton.cdiv(hidden_size, block_size)) + + return triton_call_lowering( + ctx, + _unpermute_kernel, + inp, + row_id_map, + merging_probs, + permuted_probs, + pad_offsets, + grid=grid, + constexprs={ + "stride_row_id_map_token": row_id_stride_token, + "stride_row_id_map_expert": row_id_stride_expert, + "stride_input_token": inp_stride_token, + "stride_input_hidden": inp_stride_hidden, + "stride_output_token": output_stride_token, + "stride_output_hidden": output_stride_hidden, + "stride_merging_probs_token": merging_probs_stride_token, + "stride_merging_probs_expert": merging_probs_stride_expert, + "stride_permuted_probs_token": permuted_probs_stride_token, + "stride_unpermuted_probs_token": unpermuted_probs_stride_token, + "stride_unpermuted_probs_expert": unpermuted_probs_stride_expert, + "num_experts": num_experts, + "hidden_size": hidden_size, + "PROBS_LOAD_WIDTH": triton.next_power_of_2(num_experts), + "WITH_MERGING_PROBS": with_merging_probs, + "PERMUTE_PROBS": with_probs, + "FUSION_UNPAD": True, + "BLOCK_SIZE": block_size, + }, + ) + + +register_primitive(UnpermuteWithMaskMapAndUnpadPrimitive) + + +class UnpermuteBwdWithMergingProbsPrimitive(BasePrimitive): + """ + Backward pass for unpermute with merging probabilities. + + This kernel computes gradients for both the input and merging_probs. + """ + + name = "te_unpermute_bwd_with_merging_probs_triton" + multiple_results = True + impl_static_args = (5, 6, 7, 8) # num_tokens, num_experts, num_out_tokens, hidden_size + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + fwd_output_grad_aval, + fwd_input_aval, + merging_probs_aval, + row_id_map_aval, + pad_offsets_aval, # dummy, not used when FUSION_UNPAD=False + *, + num_tokens, + num_experts, + num_out_tokens, + hidden_size, + ): + """Shape/dtype inference for unpermute backward with merging probs.""" + del fwd_input_aval, row_id_map_aval, pad_offsets_aval + + # fwd_input_grad has same shape as fwd_input + fwd_input_grad_shape = (num_out_tokens, hidden_size) + fwd_input_grad_aval = jax.core.ShapedArray(fwd_input_grad_shape, fwd_output_grad_aval.dtype) + + # merging_probs_grad has same shape as merging_probs + merging_probs_grad_shape = (num_tokens, num_experts) + merging_probs_grad_aval = jax.core.ShapedArray( + merging_probs_grad_shape, merging_probs_aval.dtype + ) + + return fwd_input_grad_aval, merging_probs_grad_aval + + @staticmethod + def impl( + fwd_output_grad, + fwd_input, + merging_probs, + row_id_map, + pad_offsets, + num_tokens, + num_experts, + num_out_tokens, + hidden_size, + ): + """Forward to inner primitive.""" + assert UnpermuteBwdWithMergingProbsPrimitive.inner_primitive is not None + return UnpermuteBwdWithMergingProbsPrimitive.inner_primitive.bind( + fwd_output_grad, + fwd_input, + merging_probs, + row_id_map, + pad_offsets, num_tokens=num_tokens, num_experts=num_experts, + num_out_tokens=num_out_tokens, hidden_size=hidden_size, - with_merging_probs=with_merging_probs, - with_probs=with_probs, ) @staticmethod def lowering( ctx, - inp, - row_id_map, + fwd_output_grad, + fwd_input, merging_probs, - permuted_probs, + row_id_map, + pad_offsets, *, num_tokens, num_experts, + num_out_tokens, hidden_size, - with_merging_probs, - with_probs, ): """MLIR lowering using triton_call_lowering.""" + del num_out_tokens + # Compute strides - inp_stride_token = hidden_size - inp_stride_hidden = 1 - output_stride_token = hidden_size - output_stride_hidden = 1 row_id_stride_token = num_experts * 2 + 1 row_id_stride_expert = 1 + fwd_output_grad_stride_token = hidden_size + fwd_output_grad_stride_hidden = 1 + fwd_input_grad_stride_token = hidden_size + fwd_input_grad_stride_hidden = 1 + fwd_input_stride_token = hidden_size + fwd_input_stride_hidden = 1 + merging_probs_stride_token = num_experts + merging_probs_stride_expert = 1 + merging_probs_grad_stride_token = num_experts + merging_probs_grad_stride_expert = 1 - if with_merging_probs: - merging_probs_stride_token = num_experts - merging_probs_stride_expert = 1 - else: - merging_probs_stride_token = 0 - merging_probs_stride_expert = 0 - - permuted_probs_stride_token = 1 - unpermuted_probs_stride_token = num_experts - unpermuted_probs_stride_expert = 1 + # Grid - one program per token + grid = (num_tokens,) - # Grid - use minimum BLOCK_SIZE from autotune configs - block_size = _get_min_block_size(_unpermute_kernel) - grid = (num_tokens, triton.cdiv(hidden_size, block_size)) + # Get min block size from autotune configs for consistency + block_size = _get_min_block_size(_unpermute_bwd_with_merging_probs_kernel) + # Pass all 5 inputs including pad_offsets (even though FUSION_UNPAD=False) return triton_call_lowering( ctx, - _unpermute_kernel, - inp, - row_id_map, + _unpermute_bwd_with_merging_probs_kernel, + fwd_output_grad, + fwd_input, merging_probs, - permuted_probs, + row_id_map, + pad_offsets, grid=grid, constexprs={ "stride_row_id_map_token": row_id_stride_token, "stride_row_id_map_expert": row_id_stride_expert, - "stride_input_token": inp_stride_token, - "stride_input_hidden": inp_stride_hidden, - "stride_output_token": output_stride_token, - "stride_output_hidden": output_stride_hidden, + "stride_fwd_output_grad_token": fwd_output_grad_stride_token, + "stride_fwd_output_grad_hidden": fwd_output_grad_stride_hidden, + "stride_fwd_input_grad_token": fwd_input_grad_stride_token, + "stride_fwd_input_grad_hidden": fwd_input_grad_stride_hidden, + "stride_fwd_input_token": fwd_input_stride_token, + "stride_fwd_input_hidden": fwd_input_stride_hidden, "stride_merging_probs_token": merging_probs_stride_token, "stride_merging_probs_expert": merging_probs_stride_expert, - "stride_permuted_probs_token": permuted_probs_stride_token, - "stride_unpermuted_probs_token": unpermuted_probs_stride_token, - "stride_unpermuted_probs_expert": unpermuted_probs_stride_expert, + "stride_merging_probs_grad_token": merging_probs_grad_stride_token, + "stride_merging_probs_grad_expert": merging_probs_grad_stride_expert, "num_experts": num_experts, "hidden_size": hidden_size, "PROBS_LOAD_WIDTH": triton.next_power_of_2(num_experts), - "WITH_MERGING_PROBS": with_merging_probs, - "PERMUTE_PROBS": with_probs, + "FUSION_UNPAD": False, "BLOCK_SIZE": block_size, }, ) -register_primitive(UnpermuteWithMaskMapPrimitive) +register_primitive(UnpermuteBwdWithMergingProbsPrimitive) -class UnpermuteBwdWithMergingProbsPrimitive(BasePrimitive): +class UnpermuteBwdWithMergingProbsAndUnpadPrimitive(BasePrimitive): """ - Backward pass for unpermute with merging probabilities. + Backward pass for unpermute with merging probabilities and fused unpadding. - This kernel computes gradients for both the input and merging_probs. + This kernel computes gradients for both the input and merging_probs, + while handling padded outputs. """ - name = "te_unpermute_bwd_with_merging_probs_triton" + name = "te_unpermute_bwd_with_merging_probs_and_unpad_triton" multiple_results = True - impl_static_args = (4, 5, 6, 7) # num_tokens, num_experts, num_out_tokens, hidden_size + impl_static_args = (5, 6, 7, 8) # num_tokens, num_experts, num_out_tokens, hidden_size inner_primitive = None outer_primitive = None @@ -557,14 +1024,15 @@ def abstract( fwd_input_aval, merging_probs_aval, row_id_map_aval, + pad_offsets_aval, *, num_tokens, num_experts, num_out_tokens, hidden_size, ): - """Shape/dtype inference for unpermute backward with merging probs.""" - del fwd_input_aval, row_id_map_aval + """Shape/dtype inference for unpermute backward with merging probs and unpadding.""" + del fwd_input_aval, row_id_map_aval, pad_offsets_aval # fwd_input_grad has same shape as fwd_input fwd_input_grad_shape = (num_out_tokens, hidden_size) @@ -584,18 +1052,20 @@ def impl( fwd_input, merging_probs, row_id_map, + pad_offsets, num_tokens, num_experts, num_out_tokens, hidden_size, ): """Forward to inner primitive.""" - assert UnpermuteBwdWithMergingProbsPrimitive.inner_primitive is not None - return UnpermuteBwdWithMergingProbsPrimitive.inner_primitive.bind( + assert UnpermuteBwdWithMergingProbsAndUnpadPrimitive.inner_primitive is not None + return UnpermuteBwdWithMergingProbsAndUnpadPrimitive.inner_primitive.bind( fwd_output_grad, fwd_input, merging_probs, row_id_map, + pad_offsets, num_tokens=num_tokens, num_experts=num_experts, num_out_tokens=num_out_tokens, @@ -609,6 +1079,7 @@ def lowering( fwd_input, merging_probs, row_id_map, + pad_offsets, *, num_tokens, num_experts, @@ -638,7 +1109,6 @@ def lowering( # Get min block size from autotune configs for consistency block_size = _get_min_block_size(_unpermute_bwd_with_merging_probs_kernel) - # Pass inputs in kernel argument order: fwd_output_grad, fwd_input, merging_probs, row_id_map return triton_call_lowering( ctx, _unpermute_bwd_with_merging_probs_kernel, @@ -646,6 +1116,7 @@ def lowering( fwd_input, merging_probs, row_id_map, + pad_offsets, grid=grid, constexprs={ "stride_row_id_map_token": row_id_stride_token, @@ -663,12 +1134,13 @@ def lowering( "num_experts": num_experts, "hidden_size": hidden_size, "PROBS_LOAD_WIDTH": triton.next_power_of_2(num_experts), + "FUSION_UNPAD": True, "BLOCK_SIZE": block_size, }, ) -register_primitive(UnpermuteBwdWithMergingProbsPrimitive) +register_primitive(UnpermuteBwdWithMergingProbsAndUnpadPrimitive) def unpermute_bwd_with_merging_probs( @@ -712,12 +1184,73 @@ def unpermute_bwd_with_merging_probs( merging_probs_grad : jnp.ndarray Gradient w.r.t. merging_probs of shape `[num_tokens, num_experts]`. """ - # Pass arguments in kernel order: fwd_output_grad, fwd_input, merging_probs, row_id_map + # Create dummy pad_offsets (not used when FUSION_UNPAD=False, but required by kernel signature) + dummy_pad_offsets = jnp.zeros((0,), dtype=jnp.int32) + # Pass arguments in kernel order: fwd_output_grad, fwd_input, merging_probs, row_id_map, pad_offsets return UnpermuteBwdWithMergingProbsPrimitive.outer_primitive.bind( fwd_output_grad, fwd_input, merging_probs, row_id_map, + dummy_pad_offsets, + num_tokens=num_tokens, + num_experts=num_experts, + num_out_tokens=num_out_tokens, + hidden_size=hidden_size, + ) + + +def unpermute_bwd_with_merging_probs_and_unpad( + fwd_output_grad: jnp.ndarray, + row_id_map: jnp.ndarray, + fwd_input: jnp.ndarray, + merging_probs: jnp.ndarray, + pad_offsets: jnp.ndarray, + num_tokens: int, + num_experts: int, + num_out_tokens: int, + hidden_size: int, +) -> Tuple[jnp.ndarray, jnp.ndarray]: + """ + Backward pass for unpermute with merging probabilities and fused unpadding. + + This computes gradients for both the input tensor and merging_probs, + while handling padded outputs. + + Parameters + ---------- + fwd_output_grad : jnp.ndarray + Gradient of the forward output of shape `[num_tokens, hidden_size]`. + row_id_map : jnp.ndarray + The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`. + fwd_input : jnp.ndarray + The input tensor from the forward pass of shape `[num_out_tokens, hidden_size]`. + merging_probs : jnp.ndarray + The merging probabilities of shape `[num_tokens, num_experts]`. + pad_offsets : jnp.ndarray + Per-expert cumulative padding offsets of shape `[num_experts]`. + num_tokens : int + Number of tokens in the unpermuted tensor. + num_experts : int + Number of experts. + num_out_tokens : int + Number of tokens in the permuted tensor (including padding). + hidden_size : int + Hidden size. + + Returns + ------- + fwd_input_grad : jnp.ndarray + Gradient w.r.t. the input tensor of shape `[num_out_tokens, hidden_size]`. + merging_probs_grad : jnp.ndarray + Gradient w.r.t. merging_probs of shape `[num_tokens, num_experts]`. + """ + return UnpermuteBwdWithMergingProbsAndUnpadPrimitive.outer_primitive.bind( + fwd_output_grad, + fwd_input, + merging_probs, + row_id_map, + pad_offsets, num_tokens=num_tokens, num_experts=num_experts, num_out_tokens=num_out_tokens, @@ -964,6 +1497,8 @@ def permute_with_mask_map( # Create dummy scale tensors (not used when PERMUTE_SCALE=False, but required by kernel signature) dummy_scale = inp dummy_permuted_scale = inp + # Create dummy pad_offsets (not used when FUSION_PAD=False, but required by kernel signature) + dummy_pad_offsets = jnp.zeros((0,), dtype=jnp.int32) output, permuted_probs = PermuteWithMaskMapPrimitive.outer_primitive.bind( inp, @@ -971,6 +1506,77 @@ def permute_with_mask_map( probs, dummy_scale, dummy_permuted_scale, + dummy_pad_offsets, + num_tokens=num_tokens, + num_experts=num_experts, + num_out_tokens=num_out_tokens, + hidden_size=hidden_size, + with_probs=with_probs, + ) + + if not with_probs: + permuted_probs = None + + return output, permuted_probs + + +def permute_with_mask_map_and_pad( + inp: jnp.ndarray, + row_id_map: jnp.ndarray, + probs: Optional[jnp.ndarray], + pad_offsets: jnp.ndarray, + num_tokens: int, + num_experts: int, + num_out_tokens: int, + hidden_size: int, +) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]: + """ + Permute the input tensor based on the row_id_map with fused padding. + + Parameters + ---------- + inp : jnp.ndarray + Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied. + row_id_map : jnp.ndarray + The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`. + probs : Optional[jnp.ndarray] + The probabilities of the input tensor. If it is not None, it will be permuted. + pad_offsets : jnp.ndarray + Per-expert cumulative padding offsets of shape `[num_experts]`. + num_tokens : int + Number of tokens in the input tensor. + num_experts : int + Number of experts in the input tensor. + num_out_tokens : int + Number of tokens in the permuted tensor (including padding). + hidden_size : int + Hidden size of the input tensor. + + Returns + ------- + output : jnp.ndarray + Permuted and padded output tensor of shape `[num_out_tokens, hidden_size]`. + permuted_probs : Optional[jnp.ndarray] + Permuted probabilities if probs was provided, None otherwise. + """ + with_probs = probs is not None + + # Handle None probs by creating dummy tensor + if not with_probs: + probs = jnp.zeros((0,), dtype=inp.dtype) + + # Create dummy scale tensors (not used when PERMUTE_SCALE=False, but required by kernel signature) + dummy_scale = inp + dummy_permuted_scale = inp + + # Args order must match kernel: inp, row_id_map, probs, scale, permuted_scale, pad_offsets + output, permuted_probs = PermuteWithMaskMapAndPadPrimitive.outer_primitive.bind( + inp, + row_id_map, + probs, + dummy_scale, + dummy_permuted_scale, + pad_offsets, num_tokens=num_tokens, num_experts=num_experts, num_out_tokens=num_out_tokens, @@ -1029,12 +1635,83 @@ def unpermute_with_mask_map( merging_probs = jnp.zeros((0,), dtype=inp.dtype) if not with_probs: permuted_probs = jnp.zeros((0,), dtype=inp.dtype) + # Create dummy pad_offsets (not used when FUSION_UNPAD=False, but required by kernel signature) + dummy_pad_offsets = jnp.zeros((0,), dtype=jnp.int32) output, unpermuted_probs = UnpermuteWithMaskMapPrimitive.outer_primitive.bind( inp, row_id_map, merging_probs, permuted_probs, + dummy_pad_offsets, + num_tokens=num_tokens, + num_experts=num_experts, + hidden_size=hidden_size, + with_merging_probs=with_merging_probs, + with_probs=with_probs, + ) + + if not with_probs: + unpermuted_probs = None + + return output, unpermuted_probs + + +def unpermute_with_mask_map_and_unpad( + inp: jnp.ndarray, + row_id_map: jnp.ndarray, + merging_probs: Optional[jnp.ndarray], + permuted_probs: Optional[jnp.ndarray], + pad_offsets: jnp.ndarray, + num_tokens: int, + num_experts: int, + hidden_size: int, +) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]: + """ + Unpermute the input tensor based on the row_id_map with fused unpadding. + + Parameters + ---------- + inp : jnp.ndarray + Input tensor of shape `[num_out_tokens, hidden_size]` (including padding). + row_id_map : jnp.ndarray + The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`. + merging_probs : Optional[jnp.ndarray] + The merging probabilities of the input tensor. If it is not None, it will be used as weights + to reduce the unpermuted tokens. + permuted_probs : Optional[jnp.ndarray] + The permuted probabilities of the input tensor. If it is not None, it will be unpermuted. + pad_offsets : jnp.ndarray + Per-expert cumulative padding offsets of shape `[num_experts]`. + num_tokens : int + Number of tokens in the unpermuted tensor. + num_experts : int + Number of experts. + hidden_size : int + Hidden size of the tensor. + + Returns + ------- + output : jnp.ndarray + Unpermuted output tensor of shape `[num_tokens, hidden_size]`. + unpermuted_probs : Optional[jnp.ndarray] + Unpermuted probabilities if permuted_probs was provided, None otherwise. + """ + with_merging_probs = merging_probs is not None + with_probs = permuted_probs is not None + + # Handle None inputs by creating dummy tensors + if not with_merging_probs: + merging_probs = jnp.zeros((0,), dtype=inp.dtype) + if not with_probs: + permuted_probs = jnp.zeros((0,), dtype=inp.dtype) + + output, unpermuted_probs = UnpermuteWithMaskMapAndUnpadPrimitive.outer_primitive.bind( + inp, + row_id_map, + merging_probs, + permuted_probs, + pad_offsets, num_tokens=num_tokens, num_experts=num_experts, hidden_size=hidden_size, diff --git a/transformer_engine/jax/triton_extensions/utils.py b/transformer_engine/jax/triton_extensions/utils.py index 12d6a9e3de4..f093d99c49e 100644 --- a/transformer_engine/jax/triton_extensions/utils.py +++ b/transformer_engine/jax/triton_extensions/utils.py @@ -142,16 +142,16 @@ def compile_triton( ) # Create kernel object for JAX + # From jax/jaxlib/gpu/triton_kernels.cc: + # Kernel::Kernel(kernel_name, num_warps, num_ctas, shared_mem_bytes, ptx, ttir, compute_capability) kernel = gpu_triton.TritonKernel( - compiled.name, - num_warps, - compiled.metadata.shared, - compiled.asm["ptx"], - "", # ttir - compute_capability, - 1, - 1, - 1, # cluster_dims + compiled.name, # arg0: kernel_name (str) + num_warps, # arg1: num_warps (int) + num_ctas, # arg2: num_ctas (int) + compiled.metadata.shared, # arg3: shared_mem_bytes (int) + compiled.asm["ptx"], # arg4: ptx (str) + "", # arg5: ttir (str) - empty + compute_capability, # arg6: compute_capability (int) ) _TRITON_KERNEL_CACHE[cache_key] = kernel From 7998ce8c828603114359ed71c6d54fedaa73babe Mon Sep 17 00:00:00 2001 From: tdophung Date: Wed, 17 Dec 2025 15:26:21 -0800 Subject: [PATCH 09/19] change tokens_per_experts APIs to num_out_tokens with conservative allocation of worst case padding for output buffer Signed-off-by: tdophung --- tests/jax/test_permutation.py | 86 +++++++------- transformer_engine/jax/permutation.py | 163 ++++++++++---------------- 2 files changed, 109 insertions(+), 140 deletions(-) diff --git a/tests/jax/test_permutation.py b/tests/jax/test_permutation.py index 7f73920a1e8..16e311387cd 100644 --- a/tests/jax/test_permutation.py +++ b/tests/jax/test_permutation.py @@ -698,7 +698,7 @@ def test_dispatch_combine_roundtrip( # ========================================================================= @pytest.mark.parametrize( - "num_tokens,num_experts,hidden_size,tokens_per_expert,align_size", + "num_tokens,num_experts,hidden_size,topk,align_size", [ (32, 8, 256, 2, 16), (64, 16, 512, 3, 32), @@ -707,15 +707,14 @@ def test_dispatch_combine_roundtrip( ) @pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) def test_token_dispatch_with_padding( - self, num_tokens, num_experts, hidden_size, tokens_per_expert, align_size, dtype + self, num_tokens, num_experts, hidden_size, topk, align_size, dtype ): """Test token_dispatch with padding forward and backward pass""" key = jax.random.PRNGKey(42) # Generate routing map - routing_map = self.generate_routing_map(num_tokens, num_experts, tokens_per_expert, key) - tokens_per_expert_arr = jnp.sum(routing_map, axis=0).astype(jnp.int32) - num_out_tokens = int(jnp.sum(routing_map)) # Ignored when using padding + routing_map = self.generate_routing_map(num_tokens, num_experts, topk, key) + num_out_tokens = int(jnp.sum(routing_map)) # Generate input data key, inp_key = jax.random.split(key) @@ -724,17 +723,19 @@ def test_token_dispatch_with_padding( ) # Test forward pass with padding (using unified API) - # Note: num_out_tokens is not needed when using padding - it's computed internally + # Now we just pass num_out_tokens and align_size - tokens_per_expert is computed internally output, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert = token_dispatch( inp, routing_map, - tokens_per_expert=tokens_per_expert_arr, + num_out_tokens, align_size=align_size, ) - # Check output shape - should be padded - expected_padded_tokens = int(jnp.sum(target_tokens_per_expert)) - assert output.shape == (expected_padded_tokens, hidden_size) + # Check output shape - should be worst-case padded size + worst_case_size = ( + (num_out_tokens + num_experts * (align_size - 1)) // align_size + ) * align_size + assert output.shape == (worst_case_size, hidden_size) assert permuted_probs is None # No probs provided # Check that each expert's tokens are aligned @@ -747,7 +748,7 @@ def loss_fn(x): out, _, _, _, _ = token_dispatch( x, routing_map, - tokens_per_expert=tokens_per_expert_arr, + num_out_tokens, align_size=align_size, ) return jnp.sum(out**2) @@ -757,7 +758,7 @@ def loss_fn(x): assert not jnp.any(jnp.isnan(grad)) @pytest.mark.parametrize( - "num_tokens,num_experts,hidden_size,tokens_per_expert,align_size", + "num_tokens,num_experts,hidden_size,topk,align_size", [ (32, 8, 256, 2, 16), (64, 16, 512, 3, 32), @@ -765,15 +766,14 @@ def loss_fn(x): ) @pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) def test_token_dispatch_with_padding_and_probs( - self, num_tokens, num_experts, hidden_size, tokens_per_expert, align_size, dtype + self, num_tokens, num_experts, hidden_size, topk, align_size, dtype ): """Test token_dispatch with padding and probs""" key = jax.random.PRNGKey(42) # Generate routing map - routing_map = self.generate_routing_map(num_tokens, num_experts, tokens_per_expert, key) - tokens_per_expert_arr = jnp.sum(routing_map, axis=0).astype(jnp.int32) - num_out_tokens = int(jnp.sum(routing_map)) # Ignored when using padding + routing_map = self.generate_routing_map(num_tokens, num_experts, topk, key) + num_out_tokens = int(jnp.sum(routing_map)) # Generate input data and probs key, inp_key, prob_key = jax.random.split(key, 3) @@ -785,28 +785,30 @@ def test_token_dispatch_with_padding_and_probs( ) # Test forward pass with padding and probs - # Note: num_out_tokens is not needed when using padding - it's computed internally + # tokens_per_expert is computed internally from routing_map output, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert = token_dispatch( inp, routing_map, + num_out_tokens, probs=probs, - tokens_per_expert=tokens_per_expert_arr, align_size=align_size, ) - # Check output shape - expected_padded_tokens = int(jnp.sum(target_tokens_per_expert)) - assert output.shape == (expected_padded_tokens, hidden_size) + # Check output shape - should be worst-case padded size + worst_case_size = ( + (num_out_tokens + num_experts * (align_size - 1)) // align_size + ) * align_size + assert output.shape == (worst_case_size, hidden_size) assert permuted_probs is not None - assert permuted_probs.shape == (expected_padded_tokens,) + assert permuted_probs.shape == (worst_case_size,) # Test backward pass def loss_fn(x, p): out, perm_probs, _, _, _ = token_dispatch( x, routing_map, + num_out_tokens, probs=p, - tokens_per_expert=tokens_per_expert_arr, align_size=align_size, ) return jnp.sum(out**2) + jnp.sum(perm_probs**2) @@ -822,7 +824,7 @@ def loss_fn(x, p): # ========================================================================= @pytest.mark.parametrize( - "num_tokens,num_experts,hidden_size,tokens_per_expert,align_size", + "num_tokens,num_experts,hidden_size,topk,align_size", [ (32, 8, 256, 2, 16), (64, 16, 512, 3, 32), @@ -835,7 +837,7 @@ def test_token_combine_with_unpad( num_tokens, num_experts, hidden_size, - tokens_per_expert, + topk, align_size, dtype, with_merging_probs, @@ -844,8 +846,7 @@ def test_token_combine_with_unpad( key = jax.random.PRNGKey(42) # Generate routing map - routing_map = self.generate_routing_map(num_tokens, num_experts, tokens_per_expert, key) - tokens_per_expert_arr = jnp.sum(routing_map, axis=0).astype(jnp.int32) + routing_map = self.generate_routing_map(num_tokens, num_experts, topk, key) num_out_tokens = int(jnp.sum(routing_map)) # Generate input and dispatch with padding to get row_id_map and pad_offsets @@ -854,18 +855,21 @@ def test_token_combine_with_unpad( inp_key, (num_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0 ) + # Dispatch with padding to get row_id_map and pad_offsets _, _, row_id_map, pad_offsets, target_tokens_per_expert = token_dispatch( inp, routing_map, - tokens_per_expert=tokens_per_expert_arr, + num_out_tokens, align_size=align_size, ) - # Generate expert output data (padded) - expected_padded_tokens = int(jnp.sum(target_tokens_per_expert)) + # Generate expert output data (worst-case padded size) + worst_case_size = ( + (num_out_tokens + num_experts * (align_size - 1)) // align_size + ) * align_size key, expert_key, merge_key = jax.random.split(key, 3) expert_output = jax.random.uniform( - expert_key, (expected_padded_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0 + expert_key, (worst_case_size, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0 ) if with_merging_probs: @@ -895,7 +899,7 @@ def loss_fn(x): # ========================================================================= @pytest.mark.parametrize( - "num_tokens,num_experts,hidden_size,tokens_per_expert,align_size", + "num_tokens,num_experts,hidden_size,topk,align_size", [ (32, 8, 256, 2, 16), (64, 16, 512, 3, 32), @@ -904,14 +908,13 @@ def loss_fn(x): ) @pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) def test_dispatch_combine_with_padding_roundtrip( - self, num_tokens, num_experts, hidden_size, tokens_per_expert, align_size, dtype + self, num_tokens, num_experts, hidden_size, topk, align_size, dtype ): """Test that token_dispatch with padding followed by token_combine with unpad recovers input""" key = jax.random.PRNGKey(42) # Generate routing map - routing_map = self.generate_routing_map(num_tokens, num_experts, tokens_per_expert, key) - tokens_per_expert_arr = jnp.sum(routing_map, axis=0).astype(jnp.int32) + routing_map = self.generate_routing_map(num_tokens, num_experts, topk, key) num_out_tokens = int(jnp.sum(routing_map)) # Generate input data @@ -926,11 +929,11 @@ def test_dispatch_combine_with_padding_roundtrip( ) # Dispatch tokens to experts with padding - # Note: num_out_tokens is not needed when using padding - it's computed internally + # tokens_per_expert is computed internally from routing_map dispatched, _, row_id_map, pad_offsets, _ = token_dispatch( inp, routing_map, - tokens_per_expert=tokens_per_expert_arr, + num_out_tokens, align_size=align_size, ) @@ -941,7 +944,7 @@ def test_dispatch_combine_with_padding_roundtrip( assert_allclose(combined, inp) @pytest.mark.parametrize( - "num_tokens,num_experts,hidden_size,tokens_per_expert,align_size", + "num_tokens,num_experts,hidden_size,topk,align_size", [ (32, 8, 256, 2, 16), (64, 16, 512, 3, 32), @@ -949,14 +952,13 @@ def test_dispatch_combine_with_padding_roundtrip( ) @pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) def test_dispatch_combine_with_padding_gradient_flow( - self, num_tokens, num_experts, hidden_size, tokens_per_expert, align_size, dtype + self, num_tokens, num_experts, hidden_size, topk, align_size, dtype ): """Test gradient flow through dispatch with padding -> combine with unpad""" key = jax.random.PRNGKey(42) # Generate routing map - routing_map = self.generate_routing_map(num_tokens, num_experts, tokens_per_expert, key) - tokens_per_expert_arr = jnp.sum(routing_map, axis=0).astype(jnp.int32) + routing_map = self.generate_routing_map(num_tokens, num_experts, topk, key) num_out_tokens = int(jnp.sum(routing_map)) # Generate input data @@ -975,7 +977,7 @@ def forward(x): dispatched, _, row_id_map, pad_offsets, _ = token_dispatch( x, routing_map, - tokens_per_expert=tokens_per_expert_arr, + num_out_tokens, align_size=align_size, ) # Simulate some expert processing (e.g., scaling) diff --git a/transformer_engine/jax/permutation.py b/transformer_engine/jax/permutation.py index 24b85cc03e3..0c64e6491ce 100644 --- a/transformer_engine/jax/permutation.py +++ b/transformer_engine/jax/permutation.py @@ -45,9 +45,8 @@ def token_dispatch( inp: jnp.ndarray, routing_map: jnp.ndarray, - num_out_tokens: Optional[int] = None, + num_out_tokens: int, probs: Optional[jnp.ndarray] = None, - tokens_per_expert: Optional[jnp.ndarray] = None, align_size: Optional[int] = None, ) -> Tuple[ jnp.ndarray, @@ -63,9 +62,9 @@ def token_dispatch( to their designated experts according to the routing map. The row_id_map is computed internally from the routing_map. - Optionally supports fused padding for alignment when both `tokens_per_expert` - and `align_size` are provided. This is useful for efficient matrix multiplications - that require aligned tensor dimensions. + Optionally supports fused padding for alignment when `align_size` is provided. + This is useful for efficient matrix multiplications that require aligned tensor + dimensions. The padding is computed internally from the routing_map. Parameters ---------- @@ -74,31 +73,30 @@ def token_dispatch( routing_map : jnp.ndarray Routing mask of shape [batch, sequence, num_experts] or [num_tokens, num_experts]. Values: 1 = routed, 0 = not routed. - num_out_tokens : Optional[int], default = None - The number of output tokens after permutation. For the dropless case, this should be equal to - the sum of routing_map and must be provided explicitly for JIT compatibility when NOT - using padding. - When using padding (tokens_per_expert and align_size provided), this value - is ignored and computed internally based on aligned sizes. If provided along - with padding parameters, a warning will be issued. + num_out_tokens : int + The number of output tokens after permutation (before padding). For the dropless + case, this should be equal to the sum of routing_map. Must be provided explicitly + for JIT compatibility since output shape must be known at compile time. probs : Optional[jnp.ndarray] Optional routing probabilities of shape [batch, sequence, num_experts] or [num_tokens, num_experts]. If provided, permuted_probs will be returned. - tokens_per_expert : Optional[jnp.ndarray] - Optional tensor of shape [num_experts] containing actual token counts per expert. - Required for fused padding. If provided along with align_size, outputs will be - padded to align each expert's tokens, and num_out_tokens will be computed internally. align_size : Optional[int] - Optional alignment size for padding. Required for fused padding. - Each expert's tokens will be padded to a multiple of this size. + Optional alignment size for padding. If provided, outputs will be padded to + align each expert's tokens to a multiple of this size. The output buffer is + allocated with worst-case size, rounded down to align_size: + ((num_out_tokens + num_experts * (align_size - 1)) // align_size) * align_size + This enables full JIT compatibility. Returns ------- output : jnp.ndarray - Permuted output tensor of shape [num_out_tokens, hidden_size] - (or [num_out_tokens_padded, hidden_size] when using padding fusion). + Permuted output tensor of shape [num_out_tokens, hidden_size] without padding, + or [worst_case_padded_size, hidden_size] when using padding fusion. + With padding, the actual used portion may be smaller than the buffer; check + actual_num_out_tokens (sum of target_tokens_per_expert) for the actual size. permuted_probs : Optional[jnp.ndarray] - Permuted probabilities of shape [num_out_tokens], or None if probs was not provided. + Permuted probabilities of shape [num_out_tokens] or [worst_case_padded_size], + or None if probs was not provided. row_id_map : jnp.ndarray Row ID map for use in token_combine (shape [num_tokens, num_experts * 2 + 1]). pad_offsets : Optional[jnp.ndarray] @@ -110,52 +108,33 @@ def token_dispatch( Note ---- - **JIT Compatibility with Fused Padding:** - - When using fused padding (tokens_per_expert and align_size provided), the output - size is computed from `tokens_per_expert` values. This requires concrete (non-traced) - values at compile time because JAX needs to know output shapes during tracing. - - If `tokens_per_expert` contains traced values (e.g., computed from traced inputs - inside a JIT-compiled function), a ValueError will be raised with instructions. - - To ensure compatibility, compute `tokens_per_expert` outside the JIT boundary - and pass it as a concrete array argument to the JIT-compiled function. - - Without padding (only `num_out_tokens` provided), the function is fully JIT-compatible - since `num_out_tokens` is a Python int known at trace time. + **JIT Compatibility:** + + This function is fully JIT-compatible. When using padding (align_size provided), + the output buffer is allocated with a fixed worst-case size that depends only on + compile-time constants (num_out_tokens, num_experts, align_size). The actual + padding offsets (pad_offsets) and aligned token counts (target_tokens_per_expert) + are computed internally from the routing_map and can be traced values. + + The worst-case output size is: + ((num_out_tokens + num_experts * (align_size - 1)) // align_size) * align_size + This accounts for the maximum possible padding when each expert needs (align_size - 1) + extra tokens to align, rounded down to align_size for buffer alignment. """ - # Check that both or neither padding parameters are provided - use_padding = tokens_per_expert is not None and align_size is not None - if (tokens_per_expert is None) != (align_size is None): - raise ValueError( - "Both tokens_per_expert and align_size must be provided together for fused padding, " - "or both must be None." - ) + use_padding = align_size is not None + num_experts = routing_map.shape[-1] - # Validate num_out_tokens usage if use_padding: - if num_out_tokens is not None: - warnings.warn( - "num_out_tokens is ignored when using fused padding (tokens_per_expert and " - "align_size are provided). The output token count will be computed internally " - "based on the aligned tokens_per_expert.", - UserWarning, - stacklevel=2, - ) - # Set a dummy value - will be recomputed in the forward rule - actual_num_out_tokens = -1 + # Compute worst-case output size (compile-time constant) + # This is the maximum possible size when each expert needs max padding + worst_case_out_tokens = ( + (num_out_tokens + num_experts * (align_size - 1)) // align_size + ) * align_size else: - if num_out_tokens is None: - raise ValueError( - "num_out_tokens must be provided when not using fused padding. " - "Either provide num_out_tokens, or provide both tokens_per_expert and align_size " - "for fused padding." - ) - actual_num_out_tokens = num_out_tokens + worst_case_out_tokens = num_out_tokens return _token_dispatch( - inp, routing_map, probs, actual_num_out_tokens, tokens_per_expert, align_size, use_padding + inp, routing_map, probs, num_out_tokens, worst_case_out_tokens, align_size, use_padding ) @@ -165,7 +144,7 @@ def _token_dispatch( routing_map: jnp.ndarray, probs: Optional[jnp.ndarray], num_out_tokens: int, - tokens_per_expert: Optional[jnp.ndarray], + worst_case_out_tokens: int, align_size: Optional[int], use_padding: bool, ) -> Tuple[ @@ -182,7 +161,7 @@ def _token_dispatch( routing_map, probs, num_out_tokens, - tokens_per_expert, + worst_case_out_tokens, align_size, use_padding, ) @@ -195,7 +174,7 @@ def _token_dispatch_fwd_rule( routing_map: jnp.ndarray, probs: Optional[jnp.ndarray], num_out_tokens: int, - tokens_per_expert: Optional[jnp.ndarray], + worst_case_out_tokens: int, align_size: Optional[int], use_padding: bool, ) -> Tuple[ @@ -235,38 +214,26 @@ def _token_dispatch_fwd_rule( with_probs = probs is not None if use_padding: - # Ensure tokens_per_expert contains concrete values (not traced). - # This is required because the output shape depends on the sum of aligned token counts. - # Using jax.ensure_compile_time_eval will raise a clear ConcretizationTypeError - # if tokens_per_expert is a traced array. - try: - with jax.ensure_compile_time_eval(): - # Calculate aligned token counts per expert - target_tokens_per_expert = ( - jnp.ceil(tokens_per_expert / align_size) * align_size - ).astype(jnp.int32) - - # Always compute pad_offsets when use_padding=True - # This ensures deterministic control flow for JIT compilation. - # If no padding is actually needed (tokens already aligned), pad_offsets - # will be all zeros, and the kernel handles this correctly (adding 0 is a no-op). - pad_lengths = target_tokens_per_expert - tokens_per_expert - cum_pad = jnp.cumsum(pad_lengths) - pad_offsets = jnp.concatenate( - [jnp.array([0], dtype=cum_pad.dtype), cum_pad[:-1]] - ) - - actual_num_out_tokens = int(jnp.sum(target_tokens_per_expert)) - except jax.errors.ConcretizationTypeError as e: - raise ValueError( - "tokens_per_expert must contain concrete (non-traced) values when using " - "fused padding. The output shape depends on the sum of aligned token counts, " - "which must be known at compile time. " - "Ensure tokens_per_expert is computed outside the JIT boundary or passed as " - "a concrete array to the JIT-compiled function." - ) from e - - # Always use the padded kernel when use_padding=True (static branch) + # Compute tokens_per_expert internally from routing_map + # This can be a traced value since output shape uses worst_case_out_tokens + tokens_per_expert = jnp.sum(routing_map, axis=0).astype(jnp.int32) + + # Calculate aligned token counts per expert + target_tokens_per_expert = ( + jnp.ceil(tokens_per_expert / align_size) * align_size + ).astype(jnp.int32) + + # Compute pad_offsets: cumulative padding for each expert + # pad_offsets[i] = sum of (target - actual) for experts 0..i-1 + pad_lengths = target_tokens_per_expert - tokens_per_expert + cum_pad = jnp.cumsum(pad_lengths) + pad_offsets = jnp.concatenate( + [jnp.array([0], dtype=cum_pad.dtype), cum_pad[:-1]] + ) + + # Use worst_case_out_tokens as the output buffer size (compile-time constant) + # The actual used size is sum(target_tokens_per_expert), which may be smaller. + # Unused positions will be zero-initialized by the kernel. output, permuted_probs = permute_with_mask_map_and_pad( inp, row_id_map, @@ -274,7 +241,7 @@ def _token_dispatch_fwd_rule( pad_offsets, num_tokens, num_experts, - actual_num_out_tokens, + worst_case_out_tokens, hidden_size, ) else: @@ -306,7 +273,7 @@ def _token_dispatch_fwd_rule( def _token_dispatch_bwd_rule( _routing_map: jnp.ndarray, _num_out_tokens: int, - _tokens_per_expert: Optional[jnp.ndarray], + _worst_case_out_tokens: int, _align_size: Optional[int], _use_padding: bool, residuals: Tuple[jnp.ndarray, Optional[jnp.ndarray], int, int, int, bool], From dd5c72a93312fb4dc03967d2dd962ebf9cc25558 Mon Sep 17 00:00:00 2001 From: tdophung Date: Fri, 19 Dec 2025 14:11:57 -0800 Subject: [PATCH 10/19] change test permutation to reduce test time Signed-off-by: tdophung --- tests/jax/test_permutation.py | 1038 ++++++++++++----------------- tests/pytorch/test_permutation.py | 28 +- 2 files changed, 441 insertions(+), 625 deletions(-) diff --git a/tests/jax/test_permutation.py b/tests/jax/test_permutation.py index 16e311387cd..6c84127d170 100644 --- a/tests/jax/test_permutation.py +++ b/tests/jax/test_permutation.py @@ -4,6 +4,8 @@ """Tests for permutation Triton kernels and high-level APIs""" +import functools + import jax import jax.numpy as jnp import pytest @@ -19,63 +21,60 @@ def reference_make_row_id_map( routing_map: jnp.ndarray, - num_tokens: int, - num_experts: int, ) -> jnp.ndarray: """ - Reference implementation of make_row_id_map using JAX primitives. + Vectorized reference implementation of make_row_id_map using JAX primitives. Parameters ---------- routing_map : jnp.ndarray Input tensor of shape [num_tokens, num_experts]. Mask indicating which experts are routed to which tokens (1 = routed, 0 = not routed). - num_tokens : int - Number of tokens in the input tensor. - num_experts : int - Number of experts in the input tensor. Returns ------- row_id_map : jnp.ndarray The row_id_map for the permutation of shape [num_tokens, num_experts * 2 + 1]. """ - row_id_map = jnp.full((num_tokens, num_experts * 2 + 1), -1, dtype=jnp.int32) + num_tokens, num_experts = routing_map.shape # For each expert, compute cumulative sum to get destination indices cumsum_per_expert = jnp.cumsum(routing_map, axis=0) - # Compute total tokens per expert + # Compute total tokens per expert and expert offsets tokens_per_expert = jnp.sum(routing_map, axis=0) expert_offsets = jnp.concatenate([jnp.array([0]), jnp.cumsum(tokens_per_expert)[:-1]]) - # Build the row_id_map - for token_idx in range(num_tokens): - routed_experts = jnp.where(routing_map[token_idx] == 1)[0] - n_routed = len(routed_experts) - - # Store number of routed experts in the last position - row_id_map = row_id_map.at[token_idx, -1].set(n_routed) - - # For each routed expert, compute destination row and store it - dest_rows = [] - expert_indices = [] - for expert_idx in routed_experts: - # Destination row = expert offset + (cumsum - 1) - dest_row = expert_offsets[expert_idx] + cumsum_per_expert[token_idx, expert_idx] - 1 - dest_rows.append(dest_row) - expert_indices.append(expert_idx) - - # Sort by destination row - if n_routed > 0: - sort_indices = jnp.argsort(-jnp.array(dest_rows)) # Negative for descending sort - sorted_dest_rows = jnp.array(dest_rows)[sort_indices] - sorted_expert_indices = jnp.array(expert_indices)[sort_indices] - - # Store sorted destination rows and expert indices - for i in range(n_routed): - row_id_map = row_id_map.at[token_idx, i].set(sorted_dest_rows[i]) - row_id_map = row_id_map.at[token_idx, num_experts + i].set(sorted_expert_indices[i]) + # Compute destination rows for all (token, expert) pairs + # dest_row[i, j] = expert_offsets[j] + cumsum_per_expert[i, j] - 1 if routed, else -1 + dest_rows_all = (expert_offsets[None, :] + cumsum_per_expert - 1) * routing_map + (-1) * ( + 1 - routing_map + ) + + # Count routed experts per token + n_routed_per_token = jnp.sum(routing_map, axis=1) + + # For each token, we need to sort by descending dest_row and pack into row_id_map + # Use a large negative value for non-routed experts so they sort to the end + sort_keys = jnp.where(routing_map == 1, -dest_rows_all, jnp.iinfo(jnp.int32).max) + sorted_expert_indices = jnp.argsort(sort_keys, axis=1) + + # Gather the sorted destination rows and expert indices using advanced indexing + # Create indices for gathering + token_idx = jnp.broadcast_to( + jnp.arange(num_tokens)[:, None], (num_tokens, num_experts) + ) + sorted_dest_rows = dest_rows_all[token_idx, sorted_expert_indices] + + # Build row_id_map: [dest_row_0, ..., dest_row_{E-1}, expert_idx_0, ..., expert_idx_{E-1}, n_routed] + row_id_map = jnp.concatenate( + [ + sorted_dest_rows.astype(jnp.int32), + sorted_expert_indices.astype(jnp.int32), + n_routed_per_token.astype(jnp.int32)[:, None], + ], + axis=1, + ) return row_id_map @@ -84,13 +83,10 @@ def _reference_permute_impl( inp: jnp.ndarray, row_id_map: jnp.ndarray, probs: jnp.ndarray, - num_tokens: int, - num_experts: int, num_out_tokens: int, - hidden_size: int, ) -> tuple: """ - Internal helper for reference permutation implementation. + Vectorized internal helper for reference permutation implementation. Parameters ---------- @@ -100,14 +96,8 @@ def _reference_permute_impl( The token to expert mapping tensor of shape [num_tokens, num_experts * 2 + 1]. probs : jnp.ndarray The probabilities of the input tensor. - num_tokens : int - Number of tokens in the input tensor. - num_experts : int - Number of experts. num_out_tokens : int Number of tokens in the permuted tensor. - hidden_size : int - Hidden size of the input tensor. Returns ------- @@ -116,33 +106,63 @@ def _reference_permute_impl( permuted_probs : jnp.ndarray Permuted probabilities if probs was provided, None otherwise. """ + num_tokens, hidden_size = inp.shape + num_experts = (row_id_map.shape[1] - 1) // 2 + + # Extract destination rows, expert indices, and n_routed from row_id_map + dest_rows = row_id_map[:, :num_experts] # [num_tokens, num_experts] + expert_indices = row_id_map[:, num_experts : 2 * num_experts] # [num_tokens, num_experts] + n_routed = row_id_map[:, 2 * num_experts] # [num_tokens] + + # Create mask for valid entries: slot_idx < n_routed[token] + # The kernel's row_id_map only guarantees valid data in the first n_routed slots + # (slots beyond n_routed may contain garbage, not -1) + slot_indices = jnp.arange(num_experts)[None, :] # [1, num_experts] + valid_mask = slot_indices < n_routed[:, None] # [num_tokens, num_experts] + + # Flatten for scatter operations + flat_dest_rows = dest_rows.flatten() # [num_tokens * num_experts] + flat_valid_mask = valid_mask.flatten() + flat_token_indices = jnp.repeat(jnp.arange(num_tokens), num_experts) + flat_expert_indices = expert_indices.flatten() + + # Set invalid dest_rows to num_out_tokens (out of bounds, will be dropped) + # This avoids overwriting valid entries at index 0 with zeros + flat_dest_rows_clamped = jnp.where(flat_valid_mask, flat_dest_rows, num_out_tokens) + + # Gather input tokens and scatter to output output = jnp.zeros((num_out_tokens, hidden_size), dtype=inp.dtype) - permuted_probs = None if probs is None else jnp.zeros((num_out_tokens,), dtype=probs.dtype) - - for token_idx in range(num_tokens): - n_routed = int(row_id_map[token_idx, -1]) # int() needed for Python range() - for i in range(n_routed): - # Don't use int() here - JAX can index with traced values, - # and int() breaks autodiff gradient tracking - dest_row = row_id_map[token_idx, i] - expert_idx = row_id_map[token_idx, num_experts + i] - - # Get probability for this expert - if probs is not None: - if probs.ndim == 1: - prob = probs[token_idx] - else: - prob = probs[token_idx, expert_idx] - - # Match kernel behavior: if prob == 0.0, zero out the output (padding indicator) - if prob == 0.0: - output = output.at[dest_row].set(0.0) - else: - output = output.at[dest_row].set(inp[token_idx]) - - permuted_probs = permuted_probs.at[dest_row].set(prob) - else: - output = output.at[dest_row].set(inp[token_idx]) + gathered_inp = inp[flat_token_indices] # [num_tokens * num_experts, hidden_size] + + # Use segment_sum-like operation via scatter + # For each valid (token, expert) pair, write inp[token] to output[dest_row] + # Invalid entries target num_out_tokens and get dropped by mode="drop" + output = output.at[flat_dest_rows_clamped].set( + gathered_inp, + mode="drop", + ) + + permuted_probs = None + if probs is not None: + permuted_probs = jnp.zeros((num_out_tokens,), dtype=probs.dtype) + + # Vectorized approach: gather probs and scatter to permuted_probs + if probs.ndim == 1: + flat_probs = probs[flat_token_indices] + else: + # Clamp invalid expert indices to 0 to avoid wraparound indexing with -1 + # The result for invalid entries will be ignored anyway since they target num_out_tokens + # Cast to int32 explicitly for consistent indexing behavior + flat_expert_indices_clamped = jnp.where( + flat_valid_mask, flat_expert_indices, 0 + ).astype(jnp.int32) + flat_probs = probs[flat_token_indices.astype(jnp.int32), flat_expert_indices_clamped] + + # Invalid entries target num_out_tokens and get dropped by mode="drop" + permuted_probs = permuted_probs.at[flat_dest_rows_clamped.astype(jnp.int32)].set( + flat_probs, + mode="drop", + ) return output, permuted_probs @@ -152,12 +172,9 @@ def _reference_unpermute_impl( row_id_map: jnp.ndarray, merging_probs: jnp.ndarray, permuted_probs: jnp.ndarray, - num_tokens: int, - num_experts: int, - hidden_size: int, ) -> tuple: """ - Internal helper for reference unpermutation implementation. + Vectorized internal helper for reference unpermutation implementation. Parameters ---------- @@ -169,12 +186,6 @@ def _reference_unpermute_impl( The merging probabilities for weighted reduction. permuted_probs : jnp.ndarray The permuted probabilities. - num_tokens : int - Number of tokens. - num_experts : int - Number of experts. - hidden_size : int - Hidden size. Returns ------- @@ -183,31 +194,48 @@ def _reference_unpermute_impl( unpermuted_probs : jnp.ndarray Unpermuted probabilities if permuted_probs was provided, None otherwise. """ - output = jnp.zeros((num_tokens, hidden_size), dtype=inp.dtype) - unpermuted_probs = ( - None - if permuted_probs is None - else jnp.zeros((num_tokens, num_experts), dtype=permuted_probs.dtype) - ) + num_tokens = row_id_map.shape[0] + num_experts = (row_id_map.shape[1] - 1) // 2 - for token_idx in range(num_tokens): - n_routed = int(row_id_map[token_idx, -1]) # int() needed for Python range() - for i in range(n_routed): - # Don't use int() here - JAX can index with traced values, - # and int() breaks autodiff gradient tracking - src_row = row_id_map[token_idx, i] - expert_idx = row_id_map[token_idx, num_experts + i] - - if merging_probs is not None: - weight = merging_probs[token_idx, expert_idx] - output = output.at[token_idx].add(inp[src_row] * weight) - else: - output = output.at[token_idx].add(inp[src_row]) - - if permuted_probs is not None: - unpermuted_probs = unpermuted_probs.at[token_idx, expert_idx].set( - permuted_probs[src_row] - ) + # Extract source rows, expert indices, and n_routed from row_id_map + src_rows = row_id_map[:, :num_experts] # [num_tokens, num_experts] + expert_indices = row_id_map[:, num_experts : 2 * num_experts] # [num_tokens, num_experts] + n_routed = row_id_map[:, 2 * num_experts] # [num_tokens] + + # Create mask for valid entries: slot_idx < n_routed[token] + # The kernel's row_id_map only guarantees valid data in the first n_routed slots + slot_indices = jnp.arange(num_experts)[None, :] # [1, num_experts] + valid_mask = slot_indices < n_routed[:, None] # [num_tokens, num_experts] + + # Clamp invalid src_rows to 0 (they won't be used due to masking) + src_rows_clamped = jnp.where(valid_mask, src_rows, 0) + + # Gather input from permuted positions + gathered_inp = inp[src_rows_clamped] # [num_tokens, num_experts, hidden_size] + + # Apply merging probs if provided + if merging_probs is not None: + # Gather the merging weights for each (token, expert) pair using advanced indexing + token_idx = jnp.broadcast_to( + jnp.arange(num_tokens)[:, None], (num_tokens, num_experts) + ) + weights = merging_probs[token_idx, expert_indices] # [num_tokens, num_experts] + gathered_inp = gathered_inp * weights[:, :, None] + + # Mask out invalid entries and sum across experts + gathered_inp = jnp.where(valid_mask[:, :, None], gathered_inp, 0.0) + output = jnp.sum(gathered_inp, axis=1) # [num_tokens, hidden_size] + + unpermuted_probs = None + if permuted_probs is not None: + gathered_probs = permuted_probs[src_rows_clamped] # [num_tokens, num_experts] + unpermuted_probs = jnp.zeros((num_tokens, num_experts), dtype=permuted_probs.dtype) + token_idx = jnp.broadcast_to( + jnp.arange(num_tokens)[:, None], (num_tokens, num_experts) + ) + unpermuted_probs = unpermuted_probs.at[token_idx, expert_indices].set( + jnp.where(valid_mask, gathered_probs, 0.0) + ) return output, unpermuted_probs @@ -241,13 +269,8 @@ def reference_token_dispatch( row_id_map : jnp.ndarray The row_id_map for the permutation. """ - num_tokens, num_experts = routing_map.shape - hidden_size = inp.shape[1] - - row_id_map = reference_make_row_id_map(routing_map, num_tokens, num_experts) - output, permuted_probs = _reference_permute_impl( - inp, row_id_map, probs, num_tokens, num_experts, num_out_tokens, hidden_size - ) + row_id_map = reference_make_row_id_map(routing_map) + output, permuted_probs = _reference_permute_impl(inp, row_id_map, probs, num_out_tokens) return output, permuted_probs, row_id_map @@ -274,13 +297,7 @@ def reference_token_combine( output : jnp.ndarray Unpermuted output tensor of shape [num_tokens, hidden_size]. """ - num_tokens = row_id_map.shape[0] - num_experts = (row_id_map.shape[1] - 1) // 2 - hidden_size = inp.shape[1] - - output, _ = _reference_unpermute_impl( - inp, row_id_map, merging_probs, None, num_tokens, num_experts, hidden_size - ) + output, _ = _reference_unpermute_impl(inp, row_id_map, merging_probs, None) return output @@ -289,10 +306,9 @@ def reference_make_chunk_sort_map( split_sizes: jnp.ndarray, sorted_indices: jnp.ndarray, num_tokens: int, - num_splits: int, ) -> jnp.ndarray: """ - Reference implementation of make_chunk_sort_map using JAX primitives. + Vectorized reference implementation of make_chunk_sort_map using JAX primitives. Parameters ---------- @@ -302,45 +318,48 @@ def reference_make_chunk_sort_map( The indices of the sorted chunks of shape [num_splits,]. num_tokens : int Number of tokens. - num_splits : int - Number of splits. Returns ------- row_id_map : jnp.ndarray Row ID map for chunk sorting of shape [num_tokens,]. """ - row_id_map = jnp.zeros((num_tokens,), dtype=jnp.int32) + # Compute source chunk boundaries (cumulative sum of original split_sizes) + src_cumsum = jnp.concatenate([jnp.array([0]), jnp.cumsum(split_sizes)]) - # Compute cumulative positions - cumsum_sizes = jnp.concatenate([jnp.array([0]), jnp.cumsum(split_sizes)]) + # Compute destination chunk boundaries based on sorted order + sorted_sizes = split_sizes[sorted_indices] + dest_cumsum = jnp.concatenate([jnp.array([0]), jnp.cumsum(sorted_sizes)]) - # For each chunk, compute the destination indices - dest_offset = 0 - for sorted_idx in sorted_indices: - chunk_start = cumsum_sizes[sorted_idx] - chunk_end = cumsum_sizes[sorted_idx + 1] - chunk_size = chunk_end - chunk_start + # For each source chunk, compute its destination offset + # inverse_indices[i] = position of chunk i in sorted order + inverse_indices = jnp.argsort(sorted_indices) + dest_offsets = dest_cumsum[inverse_indices] - # Map source positions to destination positions - for i in range(chunk_size): - row_id_map = row_id_map.at[chunk_start + i].set(dest_offset + i) + # Create row_id_map: for each token position, compute its destination + # First, figure out which chunk each position belongs to + position_indices = jnp.arange(num_tokens) - dest_offset += chunk_size + # chunk_ids[i] = which chunk position i belongs to + chunk_ids = jnp.searchsorted(src_cumsum[1:], position_indices, side="right") - return row_id_map + # within_chunk_offset[i] = position i's offset within its chunk + within_chunk_offset = position_indices - src_cumsum[chunk_ids] + + # destination[i] = dest_offsets[chunk_ids[i]] + within_chunk_offset[i] + row_id_map = dest_offsets[chunk_ids] + within_chunk_offset + + return row_id_map.astype(jnp.int32) def reference_sort_chunks_by_map( inp: jnp.ndarray, row_id_map: jnp.ndarray, probs: jnp.ndarray, - num_tokens: int, - hidden_size: int, is_forward: bool, ) -> tuple: """ - Reference implementation of sort_chunks_by_map using JAX primitives. + Vectorized reference implementation of sort_chunks_by_map using JAX primitives. Parameters ---------- @@ -350,10 +369,6 @@ def reference_sort_chunks_by_map( The token to destination mapping of shape [num_tokens,]. probs : jnp.ndarray The probabilities. - num_tokens : int - Number of tokens. - hidden_size : int - Hidden size. is_forward : bool Whether this is forward or backward. @@ -364,25 +379,25 @@ def reference_sort_chunks_by_map( permuted_probs : jnp.ndarray Sorted probabilities if probs was provided, None otherwise. """ - output = jnp.zeros((num_tokens, hidden_size), dtype=inp.dtype) - permuted_probs = None if probs is None else jnp.zeros((num_tokens,), dtype=probs.dtype) + num_tokens = inp.shape[0] + hidden_size = inp.shape[1] if is_forward: - # Forward: src -> dest - for src_idx in range(num_tokens): - # Don't use int() - JAX can index with traced values - dest_idx = row_id_map[src_idx] - output = output.at[dest_idx].set(inp[src_idx]) - if probs is not None: - permuted_probs = permuted_probs.at[dest_idx].set(probs[src_idx]) + # Forward: scatter inp[src] to output[dest] where dest = row_id_map[src] + output = jnp.zeros((num_tokens, hidden_size), dtype=inp.dtype) + output = output.at[row_id_map].set(inp) + if probs is not None: + permuted_probs = jnp.zeros((num_tokens,), dtype=probs.dtype) + permuted_probs = permuted_probs.at[row_id_map].set(probs) + else: + permuted_probs = None else: - # Backward: dest -> src - for dest_idx in range(num_tokens): - # Don't use int() - JAX can index with traced values - src_idx = row_id_map[dest_idx] - output = output.at[dest_idx].set(inp[src_idx]) - if probs is not None: - permuted_probs = permuted_probs.at[dest_idx].set(probs[src_idx]) + # Backward: gather output[dest] = inp[src] where src = row_id_map[dest] + output = inp[row_id_map] + if probs is not None: + permuted_probs = probs[row_id_map] + else: + permuted_probs = None return output, permuted_probs @@ -416,19 +431,33 @@ def generate_routing_map( return routing_map # ========================================================================= - # token_dispatch tests + # Consolidated dispatch + combine tests # ========================================================================= @pytest.mark.parametrize( "num_tokens,num_experts,hidden_size,tokens_per_expert", [ - (32, 8, 256, 2), - (64, 16, 512, 3), + (4096, 8, 1280, 2), + (4096, 256, 4096, 6), ], ) @pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) - def test_token_dispatch(self, num_tokens, num_experts, hidden_size, tokens_per_expert, dtype): - """Test token_dispatch forward and backward pass against reference""" + @pytest.mark.parametrize("with_probs", [True, False]) + def test_dispatch_and_combine( + self, num_tokens, num_experts, hidden_size, tokens_per_expert, dtype, with_probs + ): + """ + Comprehensive test for token_dispatch and token_combine. + + Tests: + 1. Dispatch forward pass against reference (element-by-element) + 2. Dispatch backward pass against reference + 3. Combine forward pass against reference (element-by-element) + 4. Combine backward pass against reference + 5. Roundtrip: dispatch + combine recovers original input + 6. row_id_map n_routed column validation + 7. Probs permutation (when with_probs=True) + """ key = jax.random.PRNGKey(42) # Generate routing map @@ -436,160 +465,128 @@ def test_token_dispatch(self, num_tokens, num_experts, hidden_size, tokens_per_e num_out_tokens = int(jnp.sum(routing_map)) # Generate input data - key, inp_key = jax.random.split(key) + key, inp_key, prob_key, merge_key = jax.random.split(key, 4) inp = jax.random.uniform( inp_key, (num_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0 ) - # Define loss functions - def loss_fn(x): - output, _, _, _, _ = token_dispatch(x, routing_map, num_out_tokens) - return jnp.sum(output**2) - - def ref_loss_fn(x): - output, _, _ = reference_token_dispatch(x, routing_map, num_out_tokens) - return jnp.sum(output**2) - - loss_val, computed_grad = jax.value_and_grad(loss_fn)(inp) - ref_loss_val, ref_grad = jax.value_and_grad(ref_loss_fn)(inp) - - # Compare forward outputs - output, _, _, _, _ = token_dispatch(inp, routing_map, num_out_tokens) - ref_output, _, _ = reference_token_dispatch(inp, routing_map, num_out_tokens) - assert_allclose(output, ref_output) - - # Compare loss and gradient - assert_allclose(loss_val, ref_loss_val) - assert_allclose(computed_grad, ref_grad) - - # ========================================================================= - # token_dispatch with probs tests - # ========================================================================= - - @pytest.mark.parametrize( - "num_tokens,num_experts,hidden_size,tokens_per_expert", - [ - (32, 8, 256, 2), - (64, 16, 512, 3), - ], - ) - @pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) - def test_token_dispatch_with_probs( - self, num_tokens, num_experts, hidden_size, tokens_per_expert, dtype - ): - """Test token_dispatch with probs forward and backward pass against reference""" - key = jax.random.PRNGKey(42) - - # Generate routing map - routing_map = self.generate_routing_map(num_tokens, num_experts, tokens_per_expert, key) - num_out_tokens = int(jnp.sum(routing_map)) + # Generate probs if needed (minval > 0 to avoid kernel's special prob==0 handling) + probs = None + if with_probs: + probs = jax.random.uniform( + prob_key, (num_tokens, num_experts), dtype=dtype, minval=0.1, maxval=1.0 + ) - # Generate input data and probs - key, inp_key, prob_key = jax.random.split(key, 3) - inp = jax.random.uniform( - inp_key, (num_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0 + # Generate merging probs (normalized per token) + merging_probs = jax.random.uniform( + merge_key, (num_tokens, num_experts), dtype=dtype, minval=0.1, maxval=1.0 ) - probs = jax.random.uniform( - prob_key, (num_tokens, num_experts), dtype=dtype, minval=0.0, maxval=1.0 + merging_probs = merging_probs * routing_map.astype(dtype) # Zero out non-routed + merging_probs = merging_probs / jnp.maximum( + jnp.sum(merging_probs, axis=1, keepdims=True), 1e-8 ) - # Define loss function that uses token_dispatch with probs - # We compute gradients w.r.t. both inp and probs - def loss_fn(x, p): - output, permuted_probs, _, _, _ = token_dispatch(x, routing_map, num_out_tokens, probs=p) - return jnp.sum(output**2) + jnp.sum(permuted_probs**2) - - def ref_loss_fn(x, p): - output, permuted_probs, _ = reference_token_dispatch( - x, routing_map, num_out_tokens, probs=p - ) - return jnp.sum(output**2) + jnp.sum(permuted_probs**2) - - loss_val, (inp_grad, probs_grad) = jax.value_and_grad(loss_fn, argnums=(0, 1))(inp, probs) - ref_loss_val, (ref_inp_grad, ref_probs_grad) = jax.value_and_grad( - ref_loss_fn, argnums=(0, 1) - )(inp, probs) - - output, permuted_probs, _, _, _ = token_dispatch(inp, routing_map, num_out_tokens, probs=probs) - - ref_output, ref_permuted_probs, _ = reference_token_dispatch( + # ===================================================================== + # Test 1: Dispatch forward pass + # ===================================================================== + output, permuted_probs, row_id_map, _, _ = token_dispatch( inp, routing_map, num_out_tokens, probs=probs ) + ref_output, ref_permuted_probs = _reference_permute_impl( + inp, row_id_map, probs, num_out_tokens + ) - # Compare forward outputs - assert_allclose(output, ref_output) - assert_allclose(permuted_probs, ref_permuted_probs) + # Validate row_id_map structure: n_routed column should match routing_map sum + n_routed_actual = row_id_map[:, -1] + n_routed_expected = jnp.sum(routing_map, axis=1) + assert jnp.array_equal(n_routed_actual, n_routed_expected), ( + "make_row_id_map n_routed column mismatch" + ) - # Compare loss and gradients - assert_allclose(loss_val, ref_loss_val) - assert_allclose(inp_grad, ref_inp_grad) - assert_allclose(probs_grad, ref_probs_grad) + # Compare dispatch output + assert_allclose(output, ref_output, dtype=dtype) + if with_probs: + assert_allclose(permuted_probs, ref_permuted_probs, dtype=dtype) - # ========================================================================= - # token_combine tests - # ========================================================================= + # ===================================================================== + # Test 2: Dispatch backward pass + # ===================================================================== + if with_probs: - @pytest.mark.parametrize( - "num_tokens,num_experts,hidden_size,tokens_per_expert", - [ - (32, 8, 256, 2), - (64, 16, 512, 3), - ], - ) - @pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) - @pytest.mark.parametrize("with_merging_probs", [True, False]) - def test_token_combine( - self, num_tokens, num_experts, hidden_size, tokens_per_expert, dtype, with_merging_probs - ): - """Test token_combine forward and backward pass against reference""" - key = jax.random.PRNGKey(42) - - # Generate routing map - routing_map = self.generate_routing_map(num_tokens, num_experts, tokens_per_expert, key) - num_out_tokens = int(jnp.sum(routing_map)) - - # Get row_id_map from reference_token_dispatch - key, dummy_key = jax.random.split(key) - dummy_inp = jax.random.uniform( - dummy_key, (num_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0 - ) - _, _, row_id_map = reference_token_dispatch(dummy_inp, routing_map, num_out_tokens) + @jax.jit + def dispatch_loss(x, p): + out, perm_probs, _, _, _ = token_dispatch( + x, routing_map, num_out_tokens, probs=p + ) + return jnp.sum(out**2) + jnp.sum(perm_probs**2) - # Generate input data (from expert outputs) - key, inp_key, merge_key = jax.random.split(key, 3) - inp = jax.random.uniform( - inp_key, (num_out_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0 - ) + @jax.jit + def ref_dispatch_loss(x, p): + out, perm_probs = _reference_permute_impl(x, row_id_map, p, num_out_tokens) + return jnp.sum(out**2) + jnp.sum(perm_probs**2) - if with_merging_probs: - merging_probs = jax.random.uniform( - merge_key, (num_tokens, num_experts), dtype=dtype, minval=0.0, maxval=1.0 + _, (inp_grad, probs_grad) = jax.value_and_grad(dispatch_loss, argnums=(0, 1))( + inp, probs ) - # Normalize per token - merging_probs = merging_probs / (jnp.sum(merging_probs, axis=1, keepdims=True) + 1e-8) + _, (ref_inp_grad, ref_probs_grad) = jax.value_and_grad( + ref_dispatch_loss, argnums=(0, 1) + )(inp, probs) + assert_allclose(inp_grad, ref_inp_grad, dtype=dtype) + assert_allclose(probs_grad, ref_probs_grad, dtype=dtype) else: - merging_probs = None - - # Define loss functions - def loss_fn(x): - output = token_combine(x, row_id_map, merging_probs) - return jnp.sum(output**2) - - def ref_loss_fn(x): - output = reference_token_combine(x, row_id_map, merging_probs) - return jnp.sum(output**2) - loss_val, computed_grad = jax.value_and_grad(loss_fn)(inp) - ref_loss_val, ref_grad = jax.value_and_grad(ref_loss_fn)(inp) + @jax.jit + def dispatch_loss_no_probs(x): + out, _, _, _, _ = token_dispatch(x, routing_map, num_out_tokens) + return jnp.sum(out**2) + + @jax.jit + def ref_dispatch_loss_no_probs(x): + out, _ = _reference_permute_impl(x, row_id_map, None, num_out_tokens) + return jnp.sum(out**2) + + _, inp_grad = jax.value_and_grad(dispatch_loss_no_probs)(inp) + _, ref_inp_grad = jax.value_and_grad(ref_dispatch_loss_no_probs)(inp) + assert_allclose(inp_grad, ref_inp_grad, dtype=dtype) + + # ===================================================================== + # Test 3: Combine forward pass + # ===================================================================== + combined = token_combine(output, row_id_map, merging_probs) + ref_combined = _reference_unpermute_impl(output, row_id_map, merging_probs, None)[0] + assert_allclose(combined, ref_combined, dtype=dtype) + + # ===================================================================== + # Test 4: Combine backward pass + # ===================================================================== + + @jax.jit + def combine_loss(x): + return jnp.sum(token_combine(x, row_id_map, merging_probs) ** 2) + + @jax.jit + def ref_combine_loss(x): + return jnp.sum(_reference_unpermute_impl(x, row_id_map, merging_probs, None)[0] ** 2) + + _, combine_grad = jax.value_and_grad(combine_loss)(output) + _, ref_combine_grad = jax.value_and_grad(ref_combine_loss)(output) + assert_allclose(combine_grad, ref_combine_grad, dtype=dtype) + + # ===================================================================== + # Test 5: Roundtrip (dispatch + combine = original) + # ===================================================================== + # Use uniform merging probs for perfect roundtrip + uniform_merging_probs = routing_map.astype(dtype) / jnp.maximum( + jnp.sum(routing_map, axis=1, keepdims=True), 1.0 + ) - # Compare forward outputs - output = token_combine(inp, row_id_map, merging_probs) - ref_output = reference_token_combine(inp, row_id_map, merging_probs) - assert_allclose(output, ref_output) + @jax.jit + def roundtrip(x): + dispatched, _, rid_map, _, _ = token_dispatch(x, routing_map, num_out_tokens) + return token_combine(dispatched, rid_map, uniform_merging_probs) - # Compare loss and gradient - assert_allclose(loss_val, ref_loss_val) - assert_allclose(computed_grad, ref_grad) + roundtrip_output = roundtrip(inp) + assert_allclose(roundtrip_output, inp, dtype=dtype) # ========================================================================= # sort_chunks_by_index tests @@ -598,8 +595,9 @@ def ref_loss_fn(x): @pytest.mark.parametrize( "num_splits,total_tokens,hidden_size", [ - (4, 128, 256), - (8, 256, 512), + (8, 4096, 1280), + (64, 4096, 4096), + (256, 4096, 9216), ], ) @pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) @@ -622,374 +620,186 @@ def test_sort_chunks_by_index(self, num_splits, total_tokens, hidden_size, dtype inp_key, (total_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0 ) - row_id_map = reference_make_chunk_sort_map( - split_sizes, sorted_indices, total_tokens, num_splits - ) + # Get reference row_id_map + row_id_map = reference_make_chunk_sort_map(split_sizes, sorted_indices, total_tokens) - # Define loss functions + # Define loss functions (JIT compiled for performance) + @jax.jit def loss_fn(x): output, _ = sort_chunks_by_index(x, split_sizes, sorted_indices) return jnp.sum(output**2) + @jax.jit def ref_loss_fn(x): - output, _ = reference_sort_chunks_by_map( - x, row_id_map, None, total_tokens, hidden_size, is_forward=True - ) + output, _ = reference_sort_chunks_by_map(x, row_id_map, None, is_forward=True) return jnp.sum(output**2) + # Test forward pass + output, _ = sort_chunks_by_index(inp, split_sizes, sorted_indices) + ref_output, _ = reference_sort_chunks_by_map(inp, row_id_map, None, is_forward=True) + + # Test backward pass with JIT loss_val, computed_grad = jax.value_and_grad(loss_fn)(inp) ref_loss_val, ref_grad = jax.value_and_grad(ref_loss_fn)(inp) - # Compare forward outputs - output, _ = sort_chunks_by_index(inp, split_sizes, sorted_indices) - ref_output, _ = reference_sort_chunks_by_map( - inp, row_id_map, None, total_tokens, hidden_size, is_forward=True - ) + # Compare forward and backward assert_allclose(output, ref_output) - - # Compare loss and gradient assert_allclose(loss_val, ref_loss_val) assert_allclose(computed_grad, ref_grad) # ========================================================================= - # Round-trip tests (token_dispatch -> expert processing -> token_combine) - # ========================================================================= - - @pytest.mark.parametrize( - "num_tokens,num_experts,hidden_size,tokens_per_expert", - [ - (32, 8, 256, 2), - (64, 16, 512, 3), - ], - ) - @pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) - def test_dispatch_combine_roundtrip( - self, num_tokens, num_experts, hidden_size, tokens_per_expert, dtype - ): - """Test that token_dispatch followed by token_combine recovers original input""" - key = jax.random.PRNGKey(42) - - # Generate routing map - routing_map = self.generate_routing_map(num_tokens, num_experts, tokens_per_expert, key) - num_out_tokens = int(jnp.sum(routing_map)) - - # Generate input data - key, inp_key = jax.random.split(key) - inp = jax.random.uniform( - inp_key, (num_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0 - ) - - # Create uniform merging probs (equal weight for all routed experts) - merging_probs = routing_map.astype(dtype) / jnp.maximum( - jnp.sum(routing_map, axis=1, keepdims=True), 1.0 - ) - - # Dispatch tokens to experts (returns output, permuted_probs, row_id_map, ...) - dispatched, _, row_id_map, _, _ = token_dispatch(inp, routing_map, num_out_tokens) - - # Combine tokens back (with uniform merging) (new signature) - combined = token_combine(dispatched, row_id_map, merging_probs) - - # Compare with original input - assert_allclose(combined, inp) - - # ========================================================================= - # token_dispatch with padding tests (using unified API) + # Consolidated dispatch + combine with padding tests # ========================================================================= @pytest.mark.parametrize( "num_tokens,num_experts,hidden_size,topk,align_size", [ - (32, 8, 256, 2, 16), - (64, 16, 512, 3, 32), - (128, 8, 128, 4, 64), + (4096, 8, 1280, 2, 16), + (4096, 256, 4096, 6, 16), ], ) @pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) - def test_token_dispatch_with_padding( - self, num_tokens, num_experts, hidden_size, topk, align_size, dtype + @pytest.mark.parametrize("with_probs", [True, False]) + def test_dispatch_and_combine_with_padding( + self, num_tokens, num_experts, hidden_size, topk, align_size, dtype, with_probs ): - """Test token_dispatch with padding forward and backward pass""" + """ + Comprehensive test for token_dispatch and token_combine with padding/unpadding. + + Tests: + 1. Dispatch with padding: output shape and alignment + 2. Dispatch backward pass with padding + 3. Combine with unpad: output shape + 4. Combine backward pass with unpad + 5. Roundtrip with padding: dispatch + combine recovers original + 6. Probs permutation with padding (when with_probs=True) + """ key = jax.random.PRNGKey(42) # Generate routing map routing_map = self.generate_routing_map(num_tokens, num_experts, topk, key) num_out_tokens = int(jnp.sum(routing_map)) - # Generate input data - key, inp_key = jax.random.split(key) - inp = jax.random.uniform( - inp_key, (num_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0 - ) - - # Test forward pass with padding (using unified API) - # Now we just pass num_out_tokens and align_size - tokens_per_expert is computed internally - output, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert = token_dispatch( - inp, - routing_map, - num_out_tokens, - align_size=align_size, - ) - - # Check output shape - should be worst-case padded size + # Compute worst-case padded size worst_case_size = ( (num_out_tokens + num_experts * (align_size - 1)) // align_size ) * align_size - assert output.shape == (worst_case_size, hidden_size) - assert permuted_probs is None # No probs provided - - # Check that each expert's tokens are aligned - for expert_idx in range(num_experts): - expert_tokens = int(target_tokens_per_expert[expert_idx]) - assert expert_tokens % align_size == 0 or expert_tokens == 0 - - # Test backward pass - def loss_fn(x): - out, _, _, _, _ = token_dispatch( - x, - routing_map, - num_out_tokens, - align_size=align_size, - ) - return jnp.sum(out**2) - - grad = jax.grad(loss_fn)(inp) - assert grad.shape == inp.shape - assert not jnp.any(jnp.isnan(grad)) - - @pytest.mark.parametrize( - "num_tokens,num_experts,hidden_size,topk,align_size", - [ - (32, 8, 256, 2, 16), - (64, 16, 512, 3, 32), - ], - ) - @pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) - def test_token_dispatch_with_padding_and_probs( - self, num_tokens, num_experts, hidden_size, topk, align_size, dtype - ): - """Test token_dispatch with padding and probs""" - key = jax.random.PRNGKey(42) - - # Generate routing map - routing_map = self.generate_routing_map(num_tokens, num_experts, topk, key) - num_out_tokens = int(jnp.sum(routing_map)) - # Generate input data and probs - key, inp_key, prob_key = jax.random.split(key, 3) + # Generate input data + key, inp_key, prob_key, merge_key = jax.random.split(key, 4) inp = jax.random.uniform( inp_key, (num_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0 ) - probs = jax.random.uniform( - prob_key, (num_tokens, num_experts), dtype=dtype, minval=0.0, maxval=1.0 - ) - - # Test forward pass with padding and probs - # tokens_per_expert is computed internally from routing_map - output, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert = token_dispatch( - inp, - routing_map, - num_out_tokens, - probs=probs, - align_size=align_size, - ) - # Check output shape - should be worst-case padded size - worst_case_size = ( - (num_out_tokens + num_experts * (align_size - 1)) // align_size - ) * align_size - assert output.shape == (worst_case_size, hidden_size) - assert permuted_probs is not None - assert permuted_probs.shape == (worst_case_size,) - - # Test backward pass - def loss_fn(x, p): - out, perm_probs, _, _, _ = token_dispatch( - x, - routing_map, - num_out_tokens, - probs=p, - align_size=align_size, + # Generate probs if needed (minval > 0 to avoid kernel's special prob==0 handling) + probs = None + if with_probs: + probs = jax.random.uniform( + prob_key, (num_tokens, num_experts), dtype=dtype, minval=0.1, maxval=1.0 ) - return jnp.sum(out**2) + jnp.sum(perm_probs**2) - (inp_grad, probs_grad) = jax.grad(loss_fn, argnums=(0, 1))(inp, probs) - assert inp_grad.shape == inp.shape - assert probs_grad.shape == probs.shape - assert not jnp.any(jnp.isnan(inp_grad)) - assert not jnp.any(jnp.isnan(probs_grad)) - - # ========================================================================= - # token_combine with unpad tests (using unified API) - # ========================================================================= - - @pytest.mark.parametrize( - "num_tokens,num_experts,hidden_size,topk,align_size", - [ - (32, 8, 256, 2, 16), - (64, 16, 512, 3, 32), - ], - ) - @pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) - @pytest.mark.parametrize("with_merging_probs", [True, False]) - def test_token_combine_with_unpad( - self, - num_tokens, - num_experts, - hidden_size, - topk, - align_size, - dtype, - with_merging_probs, - ): - """Test token_combine with unpad forward and backward pass""" - key = jax.random.PRNGKey(42) - - # Generate routing map - routing_map = self.generate_routing_map(num_tokens, num_experts, topk, key) - num_out_tokens = int(jnp.sum(routing_map)) - - # Generate input and dispatch with padding to get row_id_map and pad_offsets - key, inp_key = jax.random.split(key) - inp = jax.random.uniform( - inp_key, (num_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0 + # Generate merging probs (normalized per token) + merging_probs = jax.random.uniform( + merge_key, (num_tokens, num_experts), dtype=dtype, minval=0.1, maxval=1.0 ) - - # Dispatch with padding to get row_id_map and pad_offsets - _, _, row_id_map, pad_offsets, target_tokens_per_expert = token_dispatch( - inp, - routing_map, - num_out_tokens, - align_size=align_size, + merging_probs = merging_probs * routing_map.astype(dtype) # Zero out non-routed + merging_probs = merging_probs / jnp.maximum( + jnp.sum(merging_probs, axis=1, keepdims=True), 1e-8 ) - # Generate expert output data (worst-case padded size) - worst_case_size = ( - (num_out_tokens + num_experts * (align_size - 1)) // align_size - ) * align_size - key, expert_key, merge_key = jax.random.split(key, 3) - expert_output = jax.random.uniform( - expert_key, (worst_case_size, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0 + # ===================================================================== + # Test 1: Dispatch with padding - forward pass + # ===================================================================== + output, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert = ( + token_dispatch( + inp, routing_map, num_out_tokens, probs=probs, align_size=align_size + ) ) - if with_merging_probs: - merging_probs = jax.random.uniform( - merge_key, (num_tokens, num_experts), dtype=dtype, minval=0.0, maxval=1.0 - ) - # Normalize per token - merging_probs = merging_probs / (jnp.sum(merging_probs, axis=1, keepdims=True) + 1e-8) + # Check output shape + assert output.shape == (worst_case_size, hidden_size) + if with_probs: + assert permuted_probs is not None + assert permuted_probs.shape == (worst_case_size,) else: - merging_probs = None + assert permuted_probs is None - # Test forward pass with unpad (using unified API) - output = token_combine(expert_output, row_id_map, merging_probs, pad_offsets) - assert output.shape == (num_tokens, hidden_size) - - # Test backward pass - def loss_fn(x): - out = token_combine(x, row_id_map, merging_probs, pad_offsets) - return jnp.sum(out**2) - - grad = jax.grad(loss_fn)(expert_output) - assert grad.shape == expert_output.shape - assert not jnp.any(jnp.isnan(grad)) - - # ========================================================================= - # Round-trip tests with padding - # ========================================================================= + # Check alignment: each expert's tokens should be aligned + for expert_idx in range(num_experts): + expert_tokens = int(target_tokens_per_expert[expert_idx]) + assert expert_tokens % align_size == 0 or expert_tokens == 0 - @pytest.mark.parametrize( - "num_tokens,num_experts,hidden_size,topk,align_size", - [ - (32, 8, 256, 2, 16), - (64, 16, 512, 3, 32), - (128, 8, 128, 4, 64), - ], - ) - @pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) - def test_dispatch_combine_with_padding_roundtrip( - self, num_tokens, num_experts, hidden_size, topk, align_size, dtype - ): - """Test that token_dispatch with padding followed by token_combine with unpad recovers input""" - key = jax.random.PRNGKey(42) + # ===================================================================== + # Test 2: Dispatch with padding - backward pass + # ===================================================================== + if with_probs: - # Generate routing map - routing_map = self.generate_routing_map(num_tokens, num_experts, topk, key) - num_out_tokens = int(jnp.sum(routing_map)) + @jax.jit + def dispatch_loss(x, p): + out, perm_probs, _, _, _ = token_dispatch( + x, routing_map, num_out_tokens, probs=p, align_size=align_size + ) + return jnp.sum(out**2) + jnp.sum(perm_probs**2) - # Generate input data - key, inp_key = jax.random.split(key) - inp = jax.random.uniform( - inp_key, (num_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0 - ) + inp_grad, probs_grad = jax.grad(dispatch_loss, argnums=(0, 1))(inp, probs) + assert inp_grad.shape == inp.shape + assert probs_grad.shape == probs.shape + assert not jnp.any(jnp.isnan(inp_grad)) + assert not jnp.any(jnp.isnan(probs_grad)) + else: - # Create uniform merging probs (equal weight for all routed experts) - merging_probs = routing_map.astype(dtype) / jnp.maximum( + @jax.jit + def dispatch_loss_no_probs(x): + out, _, _, _, _ = token_dispatch( + x, routing_map, num_out_tokens, align_size=align_size + ) + return jnp.sum(out**2) + + inp_grad = jax.grad(dispatch_loss_no_probs)(inp) + assert inp_grad.shape == inp.shape + assert not jnp.any(jnp.isnan(inp_grad)) + + # ===================================================================== + # Test 3: Combine with unpad - forward pass + # ===================================================================== + combined = token_combine(output, row_id_map, merging_probs, pad_offsets) + assert combined.shape == (num_tokens, hidden_size) + + # ===================================================================== + # Test 4: Combine with unpad - backward pass + # ===================================================================== + + @jax.jit + def combine_loss(x): + return jnp.sum(token_combine(x, row_id_map, merging_probs, pad_offsets) ** 2) + + combine_grad = jax.grad(combine_loss)(output) + assert combine_grad.shape == output.shape + assert not jnp.any(jnp.isnan(combine_grad)) + + # ===================================================================== + # Test 5: Roundtrip with padding (dispatch + combine = original) + # ===================================================================== + # Use uniform merging probs for perfect roundtrip + uniform_merging_probs = routing_map.astype(dtype) / jnp.maximum( jnp.sum(routing_map, axis=1, keepdims=True), 1.0 ) - # Dispatch tokens to experts with padding - # tokens_per_expert is computed internally from routing_map - dispatched, _, row_id_map, pad_offsets, _ = token_dispatch( - inp, - routing_map, - num_out_tokens, - align_size=align_size, - ) - - # Combine tokens back with unpadding - combined = token_combine(dispatched, row_id_map, merging_probs, pad_offsets) - - # Compare with original input - assert_allclose(combined, inp) - - @pytest.mark.parametrize( - "num_tokens,num_experts,hidden_size,topk,align_size", - [ - (32, 8, 256, 2, 16), - (64, 16, 512, 3, 32), - ], - ) - @pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) - def test_dispatch_combine_with_padding_gradient_flow( - self, num_tokens, num_experts, hidden_size, topk, align_size, dtype - ): - """Test gradient flow through dispatch with padding -> combine with unpad""" - key = jax.random.PRNGKey(42) + @jax.jit + def roundtrip(x): + dispatched, _, rid_map, p_offsets, _ = token_dispatch( + x, routing_map, num_out_tokens, align_size=align_size + ) + return token_combine(dispatched, rid_map, uniform_merging_probs, p_offsets) - # Generate routing map - routing_map = self.generate_routing_map(num_tokens, num_experts, topk, key) - num_out_tokens = int(jnp.sum(routing_map)) + roundtrip_output = roundtrip(inp) + assert_allclose(roundtrip_output, inp, dtype=dtype) - # Generate input data - key, inp_key = jax.random.split(key) - inp = jax.random.uniform( - inp_key, (num_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0 - ) + # Test roundtrip gradient + @jax.jit + def roundtrip_loss(x): + return jnp.sum(roundtrip(x) ** 2) - # Create uniform merging probs - merging_probs = routing_map.astype(dtype) / jnp.maximum( - jnp.sum(routing_map, axis=1, keepdims=True), 1.0 - ) - - # Define end-to-end function - def forward(x): - dispatched, _, row_id_map, pad_offsets, _ = token_dispatch( - x, - routing_map, - num_out_tokens, - align_size=align_size, - ) - # Simulate some expert processing (e.g., scaling) - processed = dispatched * 2.0 - combined = token_combine(processed, row_id_map, merging_probs, pad_offsets) - return jnp.sum(combined**2) - - # Test gradient computation - grad = jax.grad(forward)(inp) - assert grad.shape == inp.shape - assert not jnp.any(jnp.isnan(grad)) - - # Verify gradient is non-zero for inputs that are routed - routed_mask = jnp.any(routing_map > 0, axis=1) - assert jnp.any(grad[routed_mask] != 0) + roundtrip_grad = jax.grad(roundtrip_loss)(inp) + assert roundtrip_grad.shape == inp.shape + assert not jnp.any(jnp.isnan(roundtrip_grad)) diff --git a/tests/pytorch/test_permutation.py b/tests/pytorch/test_permutation.py index 9e440b6795c..edef134ebfc 100644 --- a/tests/pytorch/test_permutation.py +++ b/tests/pytorch/test_permutation.py @@ -1651,7 +1651,7 @@ def perf_test_cuda_kernel(cuda_kernel_fn): @pytest.mark.parametrize("num_tokens", [4096]) @pytest.mark.parametrize("num_expert", [7, 16]) @pytest.mark.parametrize("hidden_size", [4096]) -@pytest.mark.parametrize("topK", [1, 2, 5]) +@pytest.mark.parametrize("topK", [2, 5]) @pytest.mark.parametrize("num_out_tokens", [None, 2039]) def test_permutation_index_map( te_dtype, @@ -1680,7 +1680,7 @@ def test_permutation_index_map( @pytest.mark.parametrize("num_tokens", [4096]) @pytest.mark.parametrize("num_expert", [7, 16]) @pytest.mark.parametrize("hidden_size", [4096]) -@pytest.mark.parametrize("topK", [1, 2, 5]) +@pytest.mark.parametrize("topK", [2, 5]) @pytest.mark.parametrize("num_out_tokens", [None, 2039]) def test_permutation_mask_map( te_dtype, @@ -1710,12 +1710,18 @@ def test_permutation_mask_map( @pytest.mark.parametrize( "num_tokens, num_expert, hidden_size, topK", [ +<<<<<<< HEAD (0, 8, 1280, 2), (4096, 64, 1280, 7), (4096, 64, 2048, 6), (4096, 160, 5120, 6), (4096, 256, 7168, 8), (4096, 384, 8192, 8), +======= + (4096, 8, 1280, 2), + (4096, 64, 4096, 6), + (4096, 256, 7168, 6), +>>>>>>> 5518c80f (change test permutation to reduce test time) (4096, 512, 9216, 8), ], ) @@ -1748,10 +1754,10 @@ def test_permutation_and_padding_mask_map( @pytest.mark.parametrize( "num_tokens, num_expert, hidden_size, topK", [ - (4096, 64, 1280, 7), - (4096, 64, 2048, 6), - (4096, 160, 5120, 6), - (4096, 256, 7168, 8), + (4096, 8, 1280, 2), + (4096, 64, 4096, 6), + (4096, 256, 7168, 6), + (4096, 512, 9216, 8), ], ) def test_permutation_and_padding_with_merging_probs( @@ -1797,9 +1803,9 @@ def test_permutation_mask_map_empty_input(te_dtype): @pytest.mark.parametrize("num_tokens", [4096]) @pytest.mark.parametrize("num_expert", [7, 16]) @pytest.mark.parametrize("hidden_size", [4096]) -@pytest.mark.parametrize("topK", [1, 2, 5]) +@pytest.mark.parametrize("topK", [2, 5]) @pytest.mark.parametrize("num_out_tokens", [None, 2039]) -@pytest.mark.parametrize("tp_size", [1, 2, 8]) +@pytest.mark.parametrize("tp_size", [1, 2]) def test_permutation_mask_map_alongside_probs( te_dtype, num_tokens, @@ -1849,10 +1855,10 @@ def test_permutation_mask_map_alongside_probs_empty_input(te_dtype): @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.parametrize("te_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) -@pytest.mark.parametrize("num_tokens", [2048]) +@pytest.mark.parametrize("num_tokens", [4096]) @pytest.mark.parametrize("num_expert", [7, 16]) @pytest.mark.parametrize("hidden_size", [4096]) -@pytest.mark.parametrize("topK", [1, 2, 5]) +@pytest.mark.parametrize("topK", [2, 5]) @pytest.mark.parametrize("num_out_tokens", [None, 2039]) @pytest.mark.parametrize("recipe", fp8_recipes) def test_permutation_mask_map_fp8( @@ -1937,7 +1943,7 @@ def test_permutation_mask_map_topk1_no_probs( @pytest.mark.parametrize("te_dtype", _te_dtypes) @pytest.mark.parametrize("num_tokens", [4096]) @pytest.mark.parametrize("num_expert", [7, 16]) -@pytest.mark.parametrize("tp_size", [1, 2, 8]) +@pytest.mark.parametrize("tp_size", [2, 8]) @pytest.mark.parametrize("hidden_size", [4096]) def test_chunk_permutation( te_dtype, From ce187b6434708365a0a93a86aabffdd076060b27 Mon Sep 17 00:00:00 2001 From: tdophung Date: Fri, 19 Dec 2025 14:37:06 -0800 Subject: [PATCH 11/19] triggering PR refresh Signed-off-by: tdophung From 7dc9ccb506a8442d9d0918ff0121984979c196be Mon Sep 17 00:00:00 2001 From: tdophung Date: Fri, 19 Dec 2025 16:22:59 -0800 Subject: [PATCH 12/19] format code Signed-off-by: tdophung --- tests/jax/test_permutation.py | 34 +++++++------------ tests/pytorch/test_permutation.py | 26 +++----------- transformer_engine/jax/permutation.py | 12 +++---- .../jax/triton_extensions/utils.py | 14 ++++---- transformer_engine/pytorch/permutation.py | 7 ++-- .../pytorch/triton/permutation.py | 6 ++-- 6 files changed, 36 insertions(+), 63 deletions(-) diff --git a/tests/jax/test_permutation.py b/tests/jax/test_permutation.py index 6c84127d170..95b9dc2db32 100644 --- a/tests/jax/test_permutation.py +++ b/tests/jax/test_permutation.py @@ -61,9 +61,7 @@ def reference_make_row_id_map( # Gather the sorted destination rows and expert indices using advanced indexing # Create indices for gathering - token_idx = jnp.broadcast_to( - jnp.arange(num_tokens)[:, None], (num_tokens, num_experts) - ) + token_idx = jnp.broadcast_to(jnp.arange(num_tokens)[:, None], (num_tokens, num_experts)) sorted_dest_rows = dest_rows_all[token_idx, sorted_expert_indices] # Build row_id_map: [dest_row_0, ..., dest_row_{E-1}, expert_idx_0, ..., expert_idx_{E-1}, n_routed] @@ -153,9 +151,9 @@ def _reference_permute_impl( # Clamp invalid expert indices to 0 to avoid wraparound indexing with -1 # The result for invalid entries will be ignored anyway since they target num_out_tokens # Cast to int32 explicitly for consistent indexing behavior - flat_expert_indices_clamped = jnp.where( - flat_valid_mask, flat_expert_indices, 0 - ).astype(jnp.int32) + flat_expert_indices_clamped = jnp.where(flat_valid_mask, flat_expert_indices, 0).astype( + jnp.int32 + ) flat_probs = probs[flat_token_indices.astype(jnp.int32), flat_expert_indices_clamped] # Invalid entries target num_out_tokens and get dropped by mode="drop" @@ -216,9 +214,7 @@ def _reference_unpermute_impl( # Apply merging probs if provided if merging_probs is not None: # Gather the merging weights for each (token, expert) pair using advanced indexing - token_idx = jnp.broadcast_to( - jnp.arange(num_tokens)[:, None], (num_tokens, num_experts) - ) + token_idx = jnp.broadcast_to(jnp.arange(num_tokens)[:, None], (num_tokens, num_experts)) weights = merging_probs[token_idx, expert_indices] # [num_tokens, num_experts] gathered_inp = gathered_inp * weights[:, :, None] @@ -230,9 +226,7 @@ def _reference_unpermute_impl( if permuted_probs is not None: gathered_probs = permuted_probs[src_rows_clamped] # [num_tokens, num_experts] unpermuted_probs = jnp.zeros((num_tokens, num_experts), dtype=permuted_probs.dtype) - token_idx = jnp.broadcast_to( - jnp.arange(num_tokens)[:, None], (num_tokens, num_experts) - ) + token_idx = jnp.broadcast_to(jnp.arange(num_tokens)[:, None], (num_tokens, num_experts)) unpermuted_probs = unpermuted_probs.at[token_idx, expert_indices].set( jnp.where(valid_mask, gathered_probs, 0.0) ) @@ -499,9 +493,9 @@ def test_dispatch_and_combine( # Validate row_id_map structure: n_routed column should match routing_map sum n_routed_actual = row_id_map[:, -1] n_routed_expected = jnp.sum(routing_map, axis=1) - assert jnp.array_equal(n_routed_actual, n_routed_expected), ( - "make_row_id_map n_routed column mismatch" - ) + assert jnp.array_equal( + n_routed_actual, n_routed_expected + ), "make_row_id_map n_routed column mismatch" # Compare dispatch output assert_allclose(output, ref_output, dtype=dtype) @@ -515,9 +509,7 @@ def test_dispatch_and_combine( @jax.jit def dispatch_loss(x, p): - out, perm_probs, _, _, _ = token_dispatch( - x, routing_map, num_out_tokens, probs=p - ) + out, perm_probs, _, _, _ = token_dispatch(x, routing_map, num_out_tokens, probs=p) return jnp.sum(out**2) + jnp.sum(perm_probs**2) @jax.jit @@ -710,10 +702,8 @@ def test_dispatch_and_combine_with_padding( # ===================================================================== # Test 1: Dispatch with padding - forward pass # ===================================================================== - output, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert = ( - token_dispatch( - inp, routing_map, num_out_tokens, probs=probs, align_size=align_size - ) + output, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert = token_dispatch( + inp, routing_map, num_out_tokens, probs=probs, align_size=align_size ) # Check output shape diff --git a/tests/pytorch/test_permutation.py b/tests/pytorch/test_permutation.py index edef134ebfc..f0a14b4b6b8 100644 --- a/tests/pytorch/test_permutation.py +++ b/tests/pytorch/test_permutation.py @@ -986,9 +986,7 @@ def _test_permutation_and_padding_with_merging_probs( _tmp_tensor = torch.zeros((num_tokens * num_expert,)) _tmp_tensor[: int(num_out_tokens)] = 1.0 _tmp_idx = torch.randperm(num_tokens * num_expert) - routing_map = ( - torch.reshape(_tmp_tensor[_tmp_idx], (num_tokens, num_expert)).bool().cuda() - ) + routing_map = torch.reshape(_tmp_tensor[_tmp_idx], (num_tokens, num_expert)).bool().cuda() probs = torch.rand(num_tokens, num_expert).cuda() * routing_map row_sums = probs.sum(dim=1, keepdim=True) @@ -997,18 +995,14 @@ def _test_permutation_and_padding_with_merging_probs( probs.requires_grad_(True) tokens_per_expert = routing_map.sum(dim=0).cpu() - target_tokens_per_expert = ( - torch.ceil(tokens_per_expert / align_size) * align_size - ).long() + target_tokens_per_expert = (torch.ceil(tokens_per_expert / align_size) * align_size).long() num_permute_pad_out_tokens = target_tokens_per_expert.sum().item() permute_pad_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() permute_pad_bwd_input = torch.rand( (num_permute_pad_out_tokens, hidden_size), dtype=dtype ).cuda() - unpermute_unpad_bwd_input = torch.rand( - (num_tokens, hidden_size), dtype=dtype - ).cuda() + unpermute_unpad_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() permute_pad_fwd_input.requires_grad_(True) restore_shape = permute_pad_fwd_input.shape @@ -1071,9 +1065,7 @@ def _test_permutation_and_padding_with_merging_probs( ) fusion_permute_pad_bwd_input = permute_pad_bwd_input.detach() - fusion_permuted_padded_output.backward( - fusion_permute_pad_bwd_input, retain_graph=True - ) + fusion_permuted_padded_output.backward(fusion_permute_pad_bwd_input, retain_graph=True) # Fused: unpermute with BOTH merging_probs AND pad_offsets fusion_unpermute_fwd_input = fusion_permuted_padded_output.detach() @@ -1140,6 +1132,7 @@ def _test_permutation_and_padding_with_merging_probs( # ################################################################################################################################### if BENCHMARK: + def ref_unpad_unpermute(): unpaded = fp8_unpadding(ref_unpermute_fwd_input, tokens_per_expert_list) return te_unpermute(unpaded, row_id_map, ref_probs, restore_shape=restore_shape) @@ -1710,18 +1703,9 @@ def test_permutation_mask_map( @pytest.mark.parametrize( "num_tokens, num_expert, hidden_size, topK", [ -<<<<<<< HEAD - (0, 8, 1280, 2), - (4096, 64, 1280, 7), - (4096, 64, 2048, 6), - (4096, 160, 5120, 6), - (4096, 256, 7168, 8), - (4096, 384, 8192, 8), -======= (4096, 8, 1280, 2), (4096, 64, 4096, 6), (4096, 256, 7168, 6), ->>>>>>> 5518c80f (change test permutation to reduce test time) (4096, 512, 9216, 8), ], ) diff --git a/transformer_engine/jax/permutation.py b/transformer_engine/jax/permutation.py index 0c64e6491ce..f4bca3c205c 100644 --- a/transformer_engine/jax/permutation.py +++ b/transformer_engine/jax/permutation.py @@ -219,17 +219,15 @@ def _token_dispatch_fwd_rule( tokens_per_expert = jnp.sum(routing_map, axis=0).astype(jnp.int32) # Calculate aligned token counts per expert - target_tokens_per_expert = ( - jnp.ceil(tokens_per_expert / align_size) * align_size - ).astype(jnp.int32) + target_tokens_per_expert = (jnp.ceil(tokens_per_expert / align_size) * align_size).astype( + jnp.int32 + ) # Compute pad_offsets: cumulative padding for each expert # pad_offsets[i] = sum of (target - actual) for experts 0..i-1 pad_lengths = target_tokens_per_expert - tokens_per_expert cum_pad = jnp.cumsum(pad_lengths) - pad_offsets = jnp.concatenate( - [jnp.array([0], dtype=cum_pad.dtype), cum_pad[:-1]] - ) + pad_offsets = jnp.concatenate([jnp.array([0], dtype=cum_pad.dtype), cum_pad[:-1]]) # Use worst_case_out_tokens as the output buffer size (compile-time constant) # The actual used size is sum(target_tokens_per_expert), which may be smaller. @@ -451,7 +449,7 @@ def _token_combine_bwd_rule( g: jnp.ndarray, ) -> Tuple[jnp.ndarray, None, Optional[jnp.ndarray], None]: """Backward pass rule for token_combine. - + Returns gradients for: (inp, row_id_map, merging_probs, pad_offsets) row_id_map and pad_offsets are integer arrays, so their gradients are None. """ diff --git a/transformer_engine/jax/triton_extensions/utils.py b/transformer_engine/jax/triton_extensions/utils.py index f093d99c49e..b7a62bb4d16 100644 --- a/transformer_engine/jax/triton_extensions/utils.py +++ b/transformer_engine/jax/triton_extensions/utils.py @@ -145,13 +145,13 @@ def compile_triton( # From jax/jaxlib/gpu/triton_kernels.cc: # Kernel::Kernel(kernel_name, num_warps, num_ctas, shared_mem_bytes, ptx, ttir, compute_capability) kernel = gpu_triton.TritonKernel( - compiled.name, # arg0: kernel_name (str) - num_warps, # arg1: num_warps (int) - num_ctas, # arg2: num_ctas (int) - compiled.metadata.shared, # arg3: shared_mem_bytes (int) - compiled.asm["ptx"], # arg4: ptx (str) - "", # arg5: ttir (str) - empty - compute_capability, # arg6: compute_capability (int) + compiled.name, # arg0: kernel_name (str) + num_warps, # arg1: num_warps (int) + num_ctas, # arg2: num_ctas (int) + compiled.metadata.shared, # arg3: shared_mem_bytes (int) + compiled.asm["ptx"], # arg4: ptx (str) + "", # arg5: ttir (str) - empty + compute_capability, # arg6: compute_capability (int) ) _TRITON_KERNEL_CACHE[cache_key] = kernel diff --git a/transformer_engine/pytorch/permutation.py b/transformer_engine/pytorch/permutation.py index f2809413a01..be6eb835f00 100644 --- a/transformer_engine/pytorch/permutation.py +++ b/transformer_engine/pytorch/permutation.py @@ -627,10 +627,11 @@ def moe_permute_and_pad_with_probs( if torch.equal(tokens_per_expert, target_tokens_per_expert): pad_offsets = None else: - pad_lengths = target_tokens_per_expert - tokens_per_expert + pad_lengths = (target_tokens_per_expert - tokens_per_expert).to(inp.device) cum_pad = torch.cumsum(pad_lengths, dim=0) - pad_offsets = torch.cat([torch.zeros(1, dtype=cum_pad.dtype), cum_pad[:-1]]) - pad_offsets = pad_offsets.to(inp.device) + pad_offsets = torch.cat( + [torch.zeros(1, dtype=cum_pad.dtype, device=inp.device), cum_pad[:-1]] + ) output, row_id_map, permuted_probs = _moe_permute_mask_map.apply( inp, routing_map, target_tokens_per_expert.sum().item(), probs, pad_offsets diff --git a/transformer_engine/pytorch/triton/permutation.py b/transformer_engine/pytorch/triton/permutation.py index fc98b8da083..27662e1b283 100644 --- a/transformer_engine/pytorch/triton/permutation.py +++ b/transformer_engine/pytorch/triton/permutation.py @@ -157,6 +157,8 @@ def permute_with_mask_map( scale_hidden_dim : int Hidden size of the scale tensor. """ + # Use torch.zeros when pad_offsets is provided to ensure padding regions are zeroed, + # since the kernel doesn't write to padding positions. alloc = torch.zeros if pad_offsets is not None else torch.empty output = alloc((num_out_tokens, hidden_size), dtype=inp.dtype, device="cuda") permuted_probs = ( @@ -313,9 +315,7 @@ def unpermute_with_mask_map_bwd_with_merging_probs( # by the kernel. This matches the behavior of Fp8Unpadding.backward which zeros # out the padding slots. alloc = torch.zeros if pad_offsets is not None else torch.empty - act_grad = alloc( - (num_out_tokens, hidden_size), dtype=fwd_output_grad.dtype, device="cuda" - ) + act_grad = alloc((num_out_tokens, hidden_size), dtype=fwd_output_grad.dtype, device="cuda") merging_probs_grad = torch.empty( (num_tokens, num_experts), dtype=merging_probs.dtype, device="cuda" ) From 1fbe99cd3232755b4748f45580834a87cc5c353a Mon Sep 17 00:00:00 2001 From: tdophung Date: Fri, 19 Dec 2025 17:44:39 -0800 Subject: [PATCH 13/19] Remove some tests cases from pytorch side. Add a separate toekn_dispatch test for sanity in case combine accidentally undo an error on dispatch in the roundtrip test. Add distinction between L0 and L2 in test cases in jax Signed-off-by: tdophung --- tests/jax/test_permutation.py | 177 ++++++++++++++++++++++++++---- tests/pytorch/test_permutation.py | 14 ++- 2 files changed, 167 insertions(+), 24 deletions(-) diff --git a/tests/jax/test_permutation.py b/tests/jax/test_permutation.py index 95b9dc2db32..a79fb9764f3 100644 --- a/tests/jax/test_permutation.py +++ b/tests/jax/test_permutation.py @@ -16,7 +16,61 @@ token_combine, sort_chunks_by_index, ) -from utils import assert_allclose +from utils import assert_allclose, pytest_parametrize_wrapper + + +# ============================================================================= +# Test parameter definitions with L0 (fast) and L2 (comprehensive) levels +# ============================================================================= + +# All dispatch/combine test cases +ALL_DISPATCH_COMBINE_CASES = [ + (128, 5, 128, 3), + (1024, 8, 128, 8), + (4096, 32, 1280, 2), + (4096, 256, 4096, 6), +] +DISPATCH_COMBINE_CASES = { + "L0": ALL_DISPATCH_COMBINE_CASES[0:2], + "L2": ALL_DISPATCH_COMBINE_CASES, +} + +# All sort chunks test cases +ALL_SORT_CHUNKS_CASES = [ + (8, 4096, 1280), + (64, 4096, 4096), + (256, 4096, 9216), +] +SORT_CHUNKS_CASES = { + "L0": ALL_SORT_CHUNKS_CASES[0:2], + "L2": ALL_SORT_CHUNKS_CASES, +} + +# All dispatch/combine with padding test cases +ALL_DISPATCH_COMBINE_PADDING_CASES = [ + (128, 5, 128, 3, 8), + (1024, 8, 128, 8, 16), + (4096, 32, 1280, 2, 128), + (4096, 256, 4096, 6, 16), +] +DISPATCH_COMBINE_PADDING_CASES = { + "L0": ALL_DISPATCH_COMBINE_PADDING_CASES[0:2], + "L2": ALL_DISPATCH_COMBINE_PADDING_CASES, +} + +# Dtypes for testing +ALL_DTYPES = [jnp.float32, jnp.bfloat16] +DTYPES = { + "L0": ALL_DTYPES, + "L2": ALL_DTYPES, +} + +# With probs options +ALL_WITH_PROBS = [True, False] +WITH_PROBS = { + "L0": [True], + "L2": ALL_WITH_PROBS, +} def reference_make_row_id_map( @@ -424,19 +478,105 @@ def generate_routing_map( return routing_map + + @pytest_parametrize_wrapper( + "num_tokens,num_experts,hidden_size,tokens_per_expert", + DISPATCH_COMBINE_CASES, + ) + @pytest_parametrize_wrapper("dtype", DTYPES) + @pytest_parametrize_wrapper("with_probs", WITH_PROBS) + def test_token_dispatch( + self, num_tokens, num_experts, hidden_size, tokens_per_expert, dtype, with_probs + ): + """ + Individual test for token_dispatch forward and backward passes. + + This test validates dispatch in isolation to catch errors that might be + masked when combined with token_combine in the roundtrip test. + + Uses value_and_grad to validate both forward (via loss comparison) and + backward (via gradient comparison) passes against reference implementation. + """ + key = jax.random.PRNGKey(42) + + # Generate routing map + routing_map = self.generate_routing_map(num_tokens, num_experts, tokens_per_expert, key) + num_out_tokens = int(jnp.sum(routing_map)) + + # Generate input data + key, inp_key, prob_key = jax.random.split(key, 3) + inp = jax.random.uniform( + inp_key, (num_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0 + ) + + # Generate probs if needed (minval > 0 to avoid kernel's special prob==0 handling) + probs = None + if with_probs: + probs = jax.random.uniform( + prob_key, (num_tokens, num_experts), dtype=dtype, minval=0.1, maxval=1.0 + ) + + # Generate reference row_id_map for comparison + ref_row_id_map = reference_make_row_id_map(routing_map) + + # ===================================================================== + # Test forward and backward pass using value_and_grad + # (value validates forward, grad validates backward) + # ===================================================================== + if with_probs: + + @jax.jit + def dispatch_loss(x, p): + out, perm_probs, _, _, _ = token_dispatch(x, routing_map, num_out_tokens, probs=p) + return jnp.sum(out**2) + jnp.sum(perm_probs**2) + + @jax.jit + def ref_dispatch_loss(x, p): + out, perm_probs = _reference_permute_impl(x, ref_row_id_map, p, num_out_tokens) + return jnp.sum(out**2) + jnp.sum(perm_probs**2) + + loss_val, (inp_grad, probs_grad) = jax.value_and_grad(dispatch_loss, argnums=(0, 1))( + inp, probs + ) + ref_loss_val, (ref_inp_grad, ref_probs_grad) = jax.value_and_grad( + ref_dispatch_loss, argnums=(0, 1) + )(inp, probs) + + # Validate forward loss matches + assert_allclose(loss_val, ref_loss_val, dtype=dtype) + # Validate gradients + assert_allclose(inp_grad, ref_inp_grad, dtype=dtype) + assert_allclose(probs_grad, ref_probs_grad, dtype=dtype) + else: + + @jax.jit + def dispatch_loss_no_probs(x): + out, _, _, _, _ = token_dispatch(x, routing_map, num_out_tokens) + return jnp.sum(out**2) + + @jax.jit + def ref_dispatch_loss_no_probs(x): + out, _ = _reference_permute_impl(x, ref_row_id_map, None, num_out_tokens) + return jnp.sum(out**2) + + loss_val, inp_grad = jax.value_and_grad(dispatch_loss_no_probs)(inp) + ref_loss_val, ref_inp_grad = jax.value_and_grad(ref_dispatch_loss_no_probs)(inp) + + # Validate forward loss matches + assert_allclose(loss_val, ref_loss_val, dtype=dtype) + # Validate gradients + assert_allclose(inp_grad, ref_inp_grad, dtype=dtype) + # ========================================================================= # Consolidated dispatch + combine tests # ========================================================================= - @pytest.mark.parametrize( + @pytest_parametrize_wrapper( "num_tokens,num_experts,hidden_size,tokens_per_expert", - [ - (4096, 8, 1280, 2), - (4096, 256, 4096, 6), - ], + DISPATCH_COMBINE_CASES, ) - @pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) - @pytest.mark.parametrize("with_probs", [True, False]) + @pytest_parametrize_wrapper("dtype", DTYPES) + @pytest_parametrize_wrapper("with_probs", WITH_PROBS) def test_dispatch_and_combine( self, num_tokens, num_experts, hidden_size, tokens_per_expert, dtype, with_probs ): @@ -584,15 +724,11 @@ def roundtrip(x): # sort_chunks_by_index tests # ========================================================================= - @pytest.mark.parametrize( + @pytest_parametrize_wrapper( "num_splits,total_tokens,hidden_size", - [ - (8, 4096, 1280), - (64, 4096, 4096), - (256, 4096, 9216), - ], + SORT_CHUNKS_CASES, ) - @pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) + @pytest_parametrize_wrapper("dtype", DTYPES) def test_sort_chunks_by_index(self, num_splits, total_tokens, hidden_size, dtype): """Test sort_chunks_by_index forward and backward pass against reference""" key = jax.random.PRNGKey(42) @@ -643,15 +779,12 @@ def ref_loss_fn(x): # Consolidated dispatch + combine with padding tests # ========================================================================= - @pytest.mark.parametrize( + @pytest_parametrize_wrapper( "num_tokens,num_experts,hidden_size,topk,align_size", - [ - (4096, 8, 1280, 2, 16), - (4096, 256, 4096, 6, 16), - ], + DISPATCH_COMBINE_PADDING_CASES, ) - @pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) - @pytest.mark.parametrize("with_probs", [True, False]) + @pytest_parametrize_wrapper("dtype", DTYPES) + @pytest_parametrize_wrapper("with_probs", WITH_PROBS) def test_dispatch_and_combine_with_padding( self, num_tokens, num_experts, hidden_size, topk, align_size, dtype, with_probs ): diff --git a/tests/pytorch/test_permutation.py b/tests/pytorch/test_permutation.py index f0a14b4b6b8..9a0cf6fb7cf 100644 --- a/tests/pytorch/test_permutation.py +++ b/tests/pytorch/test_permutation.py @@ -2,6 +2,7 @@ # # See LICENSE for license information. +import os import random import torch @@ -1962,6 +1963,10 @@ def test_chunk_permutation_empty_input(te_dtype): ) +@pytest.mark.skipif( + os.getenv("RUN_BENCHMARK_TESTS", "0") != "1", + reason="Benchmark test - run with: RUN_BENCHMARK_TESTS=1 pytest -k single_case", +) def test_permutation_single_case(): print("GPU:", torch.cuda.get_device_name(0)) @@ -2125,7 +2130,12 @@ def benchmark_single_case( torch.cuda.nvtx.range_pop() -def benchmark_multiple_cases(): +@pytest.mark.skipif( + os.getenv("RUN_BENCHMARK_TESTS", "0") != "1", + reason="Benchmark test - run with: RUN_BENCHMARK_TESTS=1 pytest -k benchmark", +) +def test_benchmark_multiple_cases(): + """Benchmark test - skipped by default. Run with: RUN_BENCHMARK_TESTS=1 pytest -k benchmark""" print("GPU:", torch.cuda.get_device_name(0)) # te_dtype = tex.DType.kFloat32 @@ -2167,4 +2177,4 @@ def benchmark_multiple_cases(): if __name__ == "__main__": - benchmark_multiple_cases() + test_benchmark_multiple_cases() From 592f67599ebb2e6dcb5891281226ebb1064c361f Mon Sep 17 00:00:00 2001 From: tdophung Date: Fri, 19 Dec 2025 17:46:54 -0800 Subject: [PATCH 14/19] format code Signed-off-by: tdophung --- tests/jax/test_permutation.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/tests/jax/test_permutation.py b/tests/jax/test_permutation.py index a79fb9764f3..9d1bcc820fa 100644 --- a/tests/jax/test_permutation.py +++ b/tests/jax/test_permutation.py @@ -25,9 +25,9 @@ # All dispatch/combine test cases ALL_DISPATCH_COMBINE_CASES = [ - (128, 5, 128, 3), - (1024, 8, 128, 8), - (4096, 32, 1280, 2), + (128, 5, 128, 3), + (1024, 8, 128, 8), + (4096, 32, 1280, 2), (4096, 256, 4096, 6), ] DISPATCH_COMBINE_CASES = { @@ -48,9 +48,9 @@ # All dispatch/combine with padding test cases ALL_DISPATCH_COMBINE_PADDING_CASES = [ - (128, 5, 128, 3, 8), - (1024, 8, 128, 8, 16), - (4096, 32, 1280, 2, 128), + (128, 5, 128, 3, 8), + (1024, 8, 128, 8, 16), + (4096, 32, 1280, 2, 128), (4096, 256, 4096, 6, 16), ] DISPATCH_COMBINE_PADDING_CASES = { @@ -478,7 +478,6 @@ def generate_routing_map( return routing_map - @pytest_parametrize_wrapper( "num_tokens,num_experts,hidden_size,tokens_per_expert", DISPATCH_COMBINE_CASES, From 1d43279df8ae126d0e7624abd33105801e4882b9 Mon Sep 17 00:00:00 2001 From: tdophung Date: Fri, 19 Dec 2025 18:20:48 -0800 Subject: [PATCH 15/19] remove chance for inefficiency in moving between CPU and GPU, remove redundant primitive using a new static bool for padding, add assert for align size Signed-off-by: tdophung --- .../jax/triton_extensions/permutation.py | 186 ++---------------- transformer_engine/pytorch/permutation.py | 7 +- 2 files changed, 21 insertions(+), 172 deletions(-) diff --git a/transformer_engine/jax/triton_extensions/permutation.py b/transformer_engine/jax/triton_extensions/permutation.py index 206164665da..01b15c5adc8 100644 --- a/transformer_engine/jax/triton_extensions/permutation.py +++ b/transformer_engine/jax/triton_extensions/permutation.py @@ -246,20 +246,21 @@ def lowering(ctx, row_id_map, *, num_tokens, num_experts): class PermuteWithMaskMapPrimitive(BasePrimitive): """ - Permute the input tensor based on the row_id_map. + Permute the input tensor based on the row_id_map, optionally with fused padding. """ name = "te_permute_with_mask_map_triton" multiple_results = True - # scale, permuted_scale, and pad_offsets are dummy inputs (not used when PERMUTE_SCALE=False, FUSION_PAD=False) - # but they need to be in the signature for the kernel call + # scale, permuted_scale are dummy inputs (not used when PERMUTE_SCALE=False) + # pad_offsets can be shape (0,) when not doing padding, or (num_experts,) when padding impl_static_args = ( 6, 7, 8, 9, 10, - ) # num_tokens, num_experts, num_out_tokens, hidden_size, with_probs + 11, + ) # num_tokens, num_experts, num_out_tokens, hidden_size, with_probs, with_pad inner_primitive = None outer_primitive = None @@ -270,17 +271,18 @@ def abstract( probs_aval, scale_aval, # dummy, same shape as inp permuted_scale_aval, # dummy, same shape as inp - pad_offsets_aval, # dummy, not used when FUSION_PAD=False + pad_offsets_aval, *, num_tokens, num_experts, num_out_tokens, hidden_size, with_probs, + with_pad, ): """Shape/dtype inference for permute.""" del row_id_map_aval, scale_aval, permuted_scale_aval, pad_offsets_aval - del num_tokens, num_experts + del num_tokens, num_experts, with_pad output_shape = (num_out_tokens, hidden_size) output_aval = jax.core.ShapedArray(output_shape, inp_aval.dtype) @@ -305,6 +307,7 @@ def impl( num_out_tokens, hidden_size, with_probs, + with_pad, ): """Forward to inner primitive.""" assert PermuteWithMaskMapPrimitive.inner_primitive is not None @@ -320,6 +323,7 @@ def impl( num_out_tokens=num_out_tokens, hidden_size=hidden_size, with_probs=with_probs, + with_pad=with_pad, ) @staticmethod @@ -337,6 +341,7 @@ def lowering( num_out_tokens, hidden_size, with_probs, + with_pad, ): """MLIR lowering using triton_call_lowering.""" del num_out_tokens @@ -366,7 +371,6 @@ def lowering( block_size = _get_min_block_size(_permute_kernel) grid = (num_tokens, triton.cdiv(hidden_size, block_size)) - # Pass all 6 inputs including pad_offsets (even though FUSION_PAD=False) return triton_call_lowering( ctx, _permute_kernel, @@ -396,7 +400,7 @@ def lowering( "hidden_size": hidden_size, "PERMUTE_PROBS": with_probs, "PERMUTE_SCALE": False, - "FUSION_PAD": False, + "FUSION_PAD": with_pad, "BLOCK_SIZE": block_size, }, ) @@ -405,167 +409,6 @@ def lowering( register_primitive(PermuteWithMaskMapPrimitive) -class PermuteWithMaskMapAndPadPrimitive(BasePrimitive): - """ - Permute the input tensor based on the row_id_map with fused padding. - """ - - name = "te_permute_with_mask_map_and_pad_triton" - multiple_results = True - # scale and permuted_scale are dummy inputs (not used when PERMUTE_SCALE=False) - # Order must match kernel: inp, row_id_map, probs, scale, permuted_scale, pad_offsets - impl_static_args = ( - 6, - 7, - 8, - 9, - 10, - ) # num_tokens, num_experts, num_out_tokens, hidden_size, with_probs - inner_primitive = None - outer_primitive = None - - @staticmethod - def abstract( - inp_aval, - row_id_map_aval, - probs_aval, - scale_aval, # dummy, same shape as inp - permuted_scale_aval, # dummy, same shape as inp - pad_offsets_aval, - *, - num_tokens, - num_experts, - num_out_tokens, - hidden_size, - with_probs, - ): - """Shape/dtype inference for permute with padding.""" - del row_id_map_aval, scale_aval, permuted_scale_aval, pad_offsets_aval - del num_tokens, num_experts - - output_shape = (num_out_tokens, hidden_size) - output_aval = jax.core.ShapedArray(output_shape, inp_aval.dtype) - - if with_probs: - permuted_probs_aval = jax.core.ShapedArray((num_out_tokens,), probs_aval.dtype) - else: - permuted_probs_aval = jax.core.ShapedArray((0,), inp_aval.dtype) - - return output_aval, permuted_probs_aval - - @staticmethod - def impl( - inp, - row_id_map, - probs, - scale, - permuted_scale, - pad_offsets, - num_tokens, - num_experts, - num_out_tokens, - hidden_size, - with_probs, - ): - """Forward to inner primitive.""" - assert PermuteWithMaskMapAndPadPrimitive.inner_primitive is not None - return PermuteWithMaskMapAndPadPrimitive.inner_primitive.bind( - inp, - row_id_map, - probs, - scale, - permuted_scale, - pad_offsets, - num_tokens=num_tokens, - num_experts=num_experts, - num_out_tokens=num_out_tokens, - hidden_size=hidden_size, - with_probs=with_probs, - ) - - @staticmethod - def lowering( - ctx, - inp, - row_id_map, - probs, - scale, - permuted_scale, - pad_offsets, - *, - num_tokens, - num_experts, - num_out_tokens, - hidden_size, - with_probs, - ): - """MLIR lowering using triton_call_lowering.""" - del num_out_tokens - inp_stride_token = hidden_size - inp_stride_hidden = 1 - output_stride_token = hidden_size - output_stride_hidden = 1 - row_id_stride_token = num_experts * 2 + 1 - row_id_stride_expert = 1 - permuted_probs_stride_token = 1 - - if with_probs: - # Check if probs is 2D [num_tokens, num_experts] or 1D [num_tokens] - probs_aval = ctx.avals_in[2] - if len(probs_aval.shape) > 1: - probs_stride_token = num_experts - probs_stride_expert = 1 - else: - probs_stride_token = 1 - probs_stride_expert = 1 - else: - probs_stride_token = 0 - probs_stride_expert = 0 - - # Grid function equivalent: (num_tokens, cdiv(hidden_size, BLOCK_SIZE)) - # Use minimum BLOCK_SIZE from autotune configs to ensure grid covers all elements - block_size = _get_min_block_size(_permute_kernel) - grid = (num_tokens, triton.cdiv(hidden_size, block_size)) - - # Args order must match kernel: inp, row_id_map, probs, scale, permuted_scale, pad_offsets - return triton_call_lowering( - ctx, - _permute_kernel, - inp, - row_id_map, - probs, - scale, - permuted_scale, - pad_offsets, - grid=grid, - constexprs={ - "scale_hidden_dim": 0, - "stride_row_id_map_token": row_id_stride_token, - "stride_row_id_map_expert": row_id_stride_expert, - "stride_input_token": inp_stride_token, - "stride_input_hidden": inp_stride_hidden, - "stride_output_token": output_stride_token, - "stride_output_hidden": output_stride_hidden, - "stride_probs_token": probs_stride_token, - "stride_probs_expert": probs_stride_expert, - "stride_scale_token": hidden_size, - "stride_scale_hidden": 1, - "stride_permuted_probs_token": permuted_probs_stride_token, - "stride_permuted_scale_token": hidden_size, - "stride_permuted_scale_hidden": 1, - "num_experts": num_experts, - "hidden_size": hidden_size, - "PERMUTE_PROBS": with_probs, - "PERMUTE_SCALE": False, - "FUSION_PAD": True, - "BLOCK_SIZE": block_size, - }, - ) - - -register_primitive(PermuteWithMaskMapAndPadPrimitive) - - class UnpermuteWithMaskMapPrimitive(BasePrimitive): """ Unpermute the input tensor based on the row_id_map. @@ -1512,6 +1355,7 @@ def permute_with_mask_map( num_out_tokens=num_out_tokens, hidden_size=hidden_size, with_probs=with_probs, + with_pad=False, ) if not with_probs: @@ -1569,8 +1413,7 @@ def permute_with_mask_map_and_pad( dummy_scale = inp dummy_permuted_scale = inp - # Args order must match kernel: inp, row_id_map, probs, scale, permuted_scale, pad_offsets - output, permuted_probs = PermuteWithMaskMapAndPadPrimitive.outer_primitive.bind( + output, permuted_probs = PermuteWithMaskMapPrimitive.outer_primitive.bind( inp, row_id_map, probs, @@ -1582,6 +1425,7 @@ def permute_with_mask_map_and_pad( num_out_tokens=num_out_tokens, hidden_size=hidden_size, with_probs=with_probs, + with_pad=True, ) if not with_probs: diff --git a/transformer_engine/pytorch/permutation.py b/transformer_engine/pytorch/permutation.py index be6eb835f00..d15814585ee 100644 --- a/transformer_engine/pytorch/permutation.py +++ b/transformer_engine/pytorch/permutation.py @@ -620,6 +620,11 @@ def moe_permute_and_pad_with_probs( assert ( tokens_per_expert is not None ), "tokens_per_expert must be provided to the fused permute padding function." + assert align_size > 0, f"align_size must be positive, got {align_size}" + + # Ensure tokens_per_expert is on the same device as input to avoid device transfers + if tokens_per_expert.device != inp.device: + tokens_per_expert = tokens_per_expert.to(inp.device) # Calculate aligned token counts per expert target_tokens_per_expert = (torch.ceil(tokens_per_expert / align_size) * align_size).long() @@ -627,7 +632,7 @@ def moe_permute_and_pad_with_probs( if torch.equal(tokens_per_expert, target_tokens_per_expert): pad_offsets = None else: - pad_lengths = (target_tokens_per_expert - tokens_per_expert).to(inp.device) + pad_lengths = target_tokens_per_expert - tokens_per_expert cum_pad = torch.cumsum(pad_lengths, dim=0) pad_offsets = torch.cat( [torch.zeros(1, dtype=cum_pad.dtype, device=inp.device), cum_pad[:-1]] From 4169a4eea5a331f999ac34ee9cdc63607459d58d Mon Sep 17 00:00:00 2001 From: tdophung Date: Mon, 22 Dec 2025 10:34:00 -0800 Subject: [PATCH 16/19] fix lint in jax Signed-off-by: tdophung --- transformer_engine/jax/permutation.py | 1 - 1 file changed, 1 deletion(-) diff --git a/transformer_engine/jax/permutation.py b/transformer_engine/jax/permutation.py index f4bca3c205c..32de0b1a3c8 100644 --- a/transformer_engine/jax/permutation.py +++ b/transformer_engine/jax/permutation.py @@ -16,7 +16,6 @@ - Backward: Permute gradients (scatter to experts) """ -import warnings from functools import partial from typing import Optional, Tuple From c619adf13ac89d6dc9925b5f81024462cb14b3bf Mon Sep 17 00:00:00 2001 From: tdophung Date: Mon, 22 Dec 2025 13:30:16 -0800 Subject: [PATCH 17/19] account for both jax newer and older than version 0.8.2. Adjusted gpu triton binding accordingly Signed-off-by: tdophung --- .../jax/triton_extensions/utils.py | 33 +++++++++++++------ 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/transformer_engine/jax/triton_extensions/utils.py b/transformer_engine/jax/triton_extensions/utils.py index b7a62bb4d16..e23095a2586 100644 --- a/transformer_engine/jax/triton_extensions/utils.py +++ b/transformer_engine/jax/triton_extensions/utils.py @@ -143,16 +143,29 @@ def compile_triton( # Create kernel object for JAX # From jax/jaxlib/gpu/triton_kernels.cc: - # Kernel::Kernel(kernel_name, num_warps, num_ctas, shared_mem_bytes, ptx, ttir, compute_capability) - kernel = gpu_triton.TritonKernel( - compiled.name, # arg0: kernel_name (str) - num_warps, # arg1: num_warps (int) - num_ctas, # arg2: num_ctas (int) - compiled.metadata.shared, # arg3: shared_mem_bytes (int) - compiled.asm["ptx"], # arg4: ptx (str) - "", # arg5: ttir (str) - empty - compute_capability, # arg6: compute_capability (int) - ) + from packaging import version + if version.parse(jax.__version__) >= version.parse("0.8.2"): + kernel = gpu_triton.TritonKernel( + compiled.name, # arg0: kernel_name (str) + num_warps, # arg1: num_warps (int) + num_ctas, # arg2: num_ctas (int) + compiled.metadata.shared, # arg3: shared_mem_bytes (int) + compiled.asm["ptx"], # arg4: ptx (str) + "", # arg5: ttir (str) - empty + compute_capability, # arg6: compute_capability (int) + ) + else: + kernel = gpu_triton.TritonKernel( + compile.name, + num_warps, + compiled.metadata.shared, + compiled.asm["ptx"], + "", # ttir + compute_capability, + 1, + 1, + 1, + ) _TRITON_KERNEL_CACHE[cache_key] = kernel return kernel From 405b34139a5ab7908077ae9ffe8382f9e2da92c6 Mon Sep 17 00:00:00 2001 From: tdophung Date: Mon, 22 Dec 2025 13:36:30 -0800 Subject: [PATCH 18/19] format code Signed-off-by: tdophung --- transformer_engine/jax/triton_extensions/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/transformer_engine/jax/triton_extensions/utils.py b/transformer_engine/jax/triton_extensions/utils.py index e23095a2586..139c4b3de65 100644 --- a/transformer_engine/jax/triton_extensions/utils.py +++ b/transformer_engine/jax/triton_extensions/utils.py @@ -144,6 +144,7 @@ def compile_triton( # Create kernel object for JAX # From jax/jaxlib/gpu/triton_kernels.cc: from packaging import version + if version.parse(jax.__version__) >= version.parse("0.8.2"): kernel = gpu_triton.TritonKernel( compiled.name, # arg0: kernel_name (str) @@ -160,7 +161,7 @@ def compile_triton( num_warps, compiled.metadata.shared, compiled.asm["ptx"], - "", # ttir + "", # ttir compute_capability, 1, 1, From 7cad5c5bf1b74e8d41a658454aa9a4a4f38418b6 Mon Sep 17 00:00:00 2001 From: tdophung Date: Mon, 22 Dec 2025 18:40:45 -0800 Subject: [PATCH 19/19] fix typo Signed-off-by: tdophung --- transformer_engine/jax/triton_extensions/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/jax/triton_extensions/utils.py b/transformer_engine/jax/triton_extensions/utils.py index 139c4b3de65..41ce15303c7 100644 --- a/transformer_engine/jax/triton_extensions/utils.py +++ b/transformer_engine/jax/triton_extensions/utils.py @@ -157,7 +157,7 @@ def compile_triton( ) else: kernel = gpu_triton.TritonKernel( - compile.name, + compiled.name, num_warps, compiled.metadata.shared, compiled.asm["ptx"],