diff --git a/tests/jax/test_permutation.py b/tests/jax/test_permutation.py index 23d9f506090..9d1bcc820fa 100644 --- a/tests/jax/test_permutation.py +++ b/tests/jax/test_permutation.py @@ -4,6 +4,8 @@ """Tests for permutation Triton kernels and high-level APIs""" +import functools + import jax import jax.numpy as jnp import pytest @@ -14,68 +16,117 @@ token_combine, sort_chunks_by_index, ) -from utils import assert_allclose +from utils import assert_allclose, pytest_parametrize_wrapper + + +# ============================================================================= +# Test parameter definitions with L0 (fast) and L2 (comprehensive) levels +# ============================================================================= + +# All dispatch/combine test cases +ALL_DISPATCH_COMBINE_CASES = [ + (128, 5, 128, 3), + (1024, 8, 128, 8), + (4096, 32, 1280, 2), + (4096, 256, 4096, 6), +] +DISPATCH_COMBINE_CASES = { + "L0": ALL_DISPATCH_COMBINE_CASES[0:2], + "L2": ALL_DISPATCH_COMBINE_CASES, +} + +# All sort chunks test cases +ALL_SORT_CHUNKS_CASES = [ + (8, 4096, 1280), + (64, 4096, 4096), + (256, 4096, 9216), +] +SORT_CHUNKS_CASES = { + "L0": ALL_SORT_CHUNKS_CASES[0:2], + "L2": ALL_SORT_CHUNKS_CASES, +} + +# All dispatch/combine with padding test cases +ALL_DISPATCH_COMBINE_PADDING_CASES = [ + (128, 5, 128, 3, 8), + (1024, 8, 128, 8, 16), + (4096, 32, 1280, 2, 128), + (4096, 256, 4096, 6, 16), +] +DISPATCH_COMBINE_PADDING_CASES = { + "L0": ALL_DISPATCH_COMBINE_PADDING_CASES[0:2], + "L2": ALL_DISPATCH_COMBINE_PADDING_CASES, +} + +# Dtypes for testing +ALL_DTYPES = [jnp.float32, jnp.bfloat16] +DTYPES = { + "L0": ALL_DTYPES, + "L2": ALL_DTYPES, +} + +# With probs options +ALL_WITH_PROBS = [True, False] +WITH_PROBS = { + "L0": [True], + "L2": ALL_WITH_PROBS, +} def reference_make_row_id_map( routing_map: jnp.ndarray, - num_tokens: int, - num_experts: int, ) -> jnp.ndarray: """ - Reference implementation of make_row_id_map using JAX primitives. + Vectorized reference implementation of make_row_id_map using JAX primitives. Parameters ---------- routing_map : jnp.ndarray Input tensor of shape [num_tokens, num_experts]. Mask indicating which experts are routed to which tokens (1 = routed, 0 = not routed). - num_tokens : int - Number of tokens in the input tensor. - num_experts : int - Number of experts in the input tensor. Returns ------- row_id_map : jnp.ndarray The row_id_map for the permutation of shape [num_tokens, num_experts * 2 + 1]. """ - row_id_map = jnp.full((num_tokens, num_experts * 2 + 1), -1, dtype=jnp.int32) + num_tokens, num_experts = routing_map.shape # For each expert, compute cumulative sum to get destination indices cumsum_per_expert = jnp.cumsum(routing_map, axis=0) - # Compute total tokens per expert + # Compute total tokens per expert and expert offsets tokens_per_expert = jnp.sum(routing_map, axis=0) expert_offsets = jnp.concatenate([jnp.array([0]), jnp.cumsum(tokens_per_expert)[:-1]]) - # Build the row_id_map - for token_idx in range(num_tokens): - routed_experts = jnp.where(routing_map[token_idx] == 1)[0] - n_routed = len(routed_experts) - - # Store number of routed experts in the last position - row_id_map = row_id_map.at[token_idx, -1].set(n_routed) - - # For each routed expert, compute destination row and store it - dest_rows = [] - expert_indices = [] - for expert_idx in routed_experts: - # Destination row = expert offset + (cumsum - 1) - dest_row = expert_offsets[expert_idx] + cumsum_per_expert[token_idx, expert_idx] - 1 - dest_rows.append(dest_row) - expert_indices.append(expert_idx) - - # Sort by destination row - if n_routed > 0: - sort_indices = jnp.argsort(-jnp.array(dest_rows)) # Negative for descending sort - sorted_dest_rows = jnp.array(dest_rows)[sort_indices] - sorted_expert_indices = jnp.array(expert_indices)[sort_indices] - - # Store sorted destination rows and expert indices - for i in range(n_routed): - row_id_map = row_id_map.at[token_idx, i].set(sorted_dest_rows[i]) - row_id_map = row_id_map.at[token_idx, num_experts + i].set(sorted_expert_indices[i]) + # Compute destination rows for all (token, expert) pairs + # dest_row[i, j] = expert_offsets[j] + cumsum_per_expert[i, j] - 1 if routed, else -1 + dest_rows_all = (expert_offsets[None, :] + cumsum_per_expert - 1) * routing_map + (-1) * ( + 1 - routing_map + ) + + # Count routed experts per token + n_routed_per_token = jnp.sum(routing_map, axis=1) + + # For each token, we need to sort by descending dest_row and pack into row_id_map + # Use a large negative value for non-routed experts so they sort to the end + sort_keys = jnp.where(routing_map == 1, -dest_rows_all, jnp.iinfo(jnp.int32).max) + sorted_expert_indices = jnp.argsort(sort_keys, axis=1) + + # Gather the sorted destination rows and expert indices using advanced indexing + # Create indices for gathering + token_idx = jnp.broadcast_to(jnp.arange(num_tokens)[:, None], (num_tokens, num_experts)) + sorted_dest_rows = dest_rows_all[token_idx, sorted_expert_indices] + + # Build row_id_map: [dest_row_0, ..., dest_row_{E-1}, expert_idx_0, ..., expert_idx_{E-1}, n_routed] + row_id_map = jnp.concatenate( + [ + sorted_dest_rows.astype(jnp.int32), + sorted_expert_indices.astype(jnp.int32), + n_routed_per_token.astype(jnp.int32)[:, None], + ], + axis=1, + ) return row_id_map @@ -84,13 +135,10 @@ def _reference_permute_impl( inp: jnp.ndarray, row_id_map: jnp.ndarray, probs: jnp.ndarray, - num_tokens: int, - num_experts: int, num_out_tokens: int, - hidden_size: int, ) -> tuple: """ - Internal helper for reference permutation implementation. + Vectorized internal helper for reference permutation implementation. Parameters ---------- @@ -100,14 +148,8 @@ def _reference_permute_impl( The token to expert mapping tensor of shape [num_tokens, num_experts * 2 + 1]. probs : jnp.ndarray The probabilities of the input tensor. - num_tokens : int - Number of tokens in the input tensor. - num_experts : int - Number of experts. num_out_tokens : int Number of tokens in the permuted tensor. - hidden_size : int - Hidden size of the input tensor. Returns ------- @@ -116,33 +158,63 @@ def _reference_permute_impl( permuted_probs : jnp.ndarray Permuted probabilities if probs was provided, None otherwise. """ + num_tokens, hidden_size = inp.shape + num_experts = (row_id_map.shape[1] - 1) // 2 + + # Extract destination rows, expert indices, and n_routed from row_id_map + dest_rows = row_id_map[:, :num_experts] # [num_tokens, num_experts] + expert_indices = row_id_map[:, num_experts : 2 * num_experts] # [num_tokens, num_experts] + n_routed = row_id_map[:, 2 * num_experts] # [num_tokens] + + # Create mask for valid entries: slot_idx < n_routed[token] + # The kernel's row_id_map only guarantees valid data in the first n_routed slots + # (slots beyond n_routed may contain garbage, not -1) + slot_indices = jnp.arange(num_experts)[None, :] # [1, num_experts] + valid_mask = slot_indices < n_routed[:, None] # [num_tokens, num_experts] + + # Flatten for scatter operations + flat_dest_rows = dest_rows.flatten() # [num_tokens * num_experts] + flat_valid_mask = valid_mask.flatten() + flat_token_indices = jnp.repeat(jnp.arange(num_tokens), num_experts) + flat_expert_indices = expert_indices.flatten() + + # Set invalid dest_rows to num_out_tokens (out of bounds, will be dropped) + # This avoids overwriting valid entries at index 0 with zeros + flat_dest_rows_clamped = jnp.where(flat_valid_mask, flat_dest_rows, num_out_tokens) + + # Gather input tokens and scatter to output output = jnp.zeros((num_out_tokens, hidden_size), dtype=inp.dtype) - permuted_probs = None if probs is None else jnp.zeros((num_out_tokens,), dtype=probs.dtype) - - for token_idx in range(num_tokens): - n_routed = int(row_id_map[token_idx, -1]) # int() needed for Python range() - for i in range(n_routed): - # Don't use int() here - JAX can index with traced values, - # and int() breaks autodiff gradient tracking - dest_row = row_id_map[token_idx, i] - expert_idx = row_id_map[token_idx, num_experts + i] - - # Get probability for this expert - if probs is not None: - if probs.ndim == 1: - prob = probs[token_idx] - else: - prob = probs[token_idx, expert_idx] - - # Match kernel behavior: if prob == 0.0, zero out the output (padding indicator) - if prob == 0.0: - output = output.at[dest_row].set(0.0) - else: - output = output.at[dest_row].set(inp[token_idx]) - - permuted_probs = permuted_probs.at[dest_row].set(prob) - else: - output = output.at[dest_row].set(inp[token_idx]) + gathered_inp = inp[flat_token_indices] # [num_tokens * num_experts, hidden_size] + + # Use segment_sum-like operation via scatter + # For each valid (token, expert) pair, write inp[token] to output[dest_row] + # Invalid entries target num_out_tokens and get dropped by mode="drop" + output = output.at[flat_dest_rows_clamped].set( + gathered_inp, + mode="drop", + ) + + permuted_probs = None + if probs is not None: + permuted_probs = jnp.zeros((num_out_tokens,), dtype=probs.dtype) + + # Vectorized approach: gather probs and scatter to permuted_probs + if probs.ndim == 1: + flat_probs = probs[flat_token_indices] + else: + # Clamp invalid expert indices to 0 to avoid wraparound indexing with -1 + # The result for invalid entries will be ignored anyway since they target num_out_tokens + # Cast to int32 explicitly for consistent indexing behavior + flat_expert_indices_clamped = jnp.where(flat_valid_mask, flat_expert_indices, 0).astype( + jnp.int32 + ) + flat_probs = probs[flat_token_indices.astype(jnp.int32), flat_expert_indices_clamped] + + # Invalid entries target num_out_tokens and get dropped by mode="drop" + permuted_probs = permuted_probs.at[flat_dest_rows_clamped.astype(jnp.int32)].set( + flat_probs, + mode="drop", + ) return output, permuted_probs @@ -152,12 +224,9 @@ def _reference_unpermute_impl( row_id_map: jnp.ndarray, merging_probs: jnp.ndarray, permuted_probs: jnp.ndarray, - num_tokens: int, - num_experts: int, - hidden_size: int, ) -> tuple: """ - Internal helper for reference unpermutation implementation. + Vectorized internal helper for reference unpermutation implementation. Parameters ---------- @@ -169,12 +238,6 @@ def _reference_unpermute_impl( The merging probabilities for weighted reduction. permuted_probs : jnp.ndarray The permuted probabilities. - num_tokens : int - Number of tokens. - num_experts : int - Number of experts. - hidden_size : int - Hidden size. Returns ------- @@ -183,31 +246,44 @@ def _reference_unpermute_impl( unpermuted_probs : jnp.ndarray Unpermuted probabilities if permuted_probs was provided, None otherwise. """ - output = jnp.zeros((num_tokens, hidden_size), dtype=inp.dtype) - unpermuted_probs = ( - None - if permuted_probs is None - else jnp.zeros((num_tokens, num_experts), dtype=permuted_probs.dtype) - ) + num_tokens = row_id_map.shape[0] + num_experts = (row_id_map.shape[1] - 1) // 2 - for token_idx in range(num_tokens): - n_routed = int(row_id_map[token_idx, -1]) # int() needed for Python range() - for i in range(n_routed): - # Don't use int() here - JAX can index with traced values, - # and int() breaks autodiff gradient tracking - src_row = row_id_map[token_idx, i] - expert_idx = row_id_map[token_idx, num_experts + i] - - if merging_probs is not None: - weight = merging_probs[token_idx, expert_idx] - output = output.at[token_idx].add(inp[src_row] * weight) - else: - output = output.at[token_idx].add(inp[src_row]) - - if permuted_probs is not None: - unpermuted_probs = unpermuted_probs.at[token_idx, expert_idx].set( - permuted_probs[src_row] - ) + # Extract source rows, expert indices, and n_routed from row_id_map + src_rows = row_id_map[:, :num_experts] # [num_tokens, num_experts] + expert_indices = row_id_map[:, num_experts : 2 * num_experts] # [num_tokens, num_experts] + n_routed = row_id_map[:, 2 * num_experts] # [num_tokens] + + # Create mask for valid entries: slot_idx < n_routed[token] + # The kernel's row_id_map only guarantees valid data in the first n_routed slots + slot_indices = jnp.arange(num_experts)[None, :] # [1, num_experts] + valid_mask = slot_indices < n_routed[:, None] # [num_tokens, num_experts] + + # Clamp invalid src_rows to 0 (they won't be used due to masking) + src_rows_clamped = jnp.where(valid_mask, src_rows, 0) + + # Gather input from permuted positions + gathered_inp = inp[src_rows_clamped] # [num_tokens, num_experts, hidden_size] + + # Apply merging probs if provided + if merging_probs is not None: + # Gather the merging weights for each (token, expert) pair using advanced indexing + token_idx = jnp.broadcast_to(jnp.arange(num_tokens)[:, None], (num_tokens, num_experts)) + weights = merging_probs[token_idx, expert_indices] # [num_tokens, num_experts] + gathered_inp = gathered_inp * weights[:, :, None] + + # Mask out invalid entries and sum across experts + gathered_inp = jnp.where(valid_mask[:, :, None], gathered_inp, 0.0) + output = jnp.sum(gathered_inp, axis=1) # [num_tokens, hidden_size] + + unpermuted_probs = None + if permuted_probs is not None: + gathered_probs = permuted_probs[src_rows_clamped] # [num_tokens, num_experts] + unpermuted_probs = jnp.zeros((num_tokens, num_experts), dtype=permuted_probs.dtype) + token_idx = jnp.broadcast_to(jnp.arange(num_tokens)[:, None], (num_tokens, num_experts)) + unpermuted_probs = unpermuted_probs.at[token_idx, expert_indices].set( + jnp.where(valid_mask, gathered_probs, 0.0) + ) return output, unpermuted_probs @@ -241,13 +317,8 @@ def reference_token_dispatch( row_id_map : jnp.ndarray The row_id_map for the permutation. """ - num_tokens, num_experts = routing_map.shape - hidden_size = inp.shape[1] - - row_id_map = reference_make_row_id_map(routing_map, num_tokens, num_experts) - output, permuted_probs = _reference_permute_impl( - inp, row_id_map, probs, num_tokens, num_experts, num_out_tokens, hidden_size - ) + row_id_map = reference_make_row_id_map(routing_map) + output, permuted_probs = _reference_permute_impl(inp, row_id_map, probs, num_out_tokens) return output, permuted_probs, row_id_map @@ -274,13 +345,7 @@ def reference_token_combine( output : jnp.ndarray Unpermuted output tensor of shape [num_tokens, hidden_size]. """ - num_tokens = row_id_map.shape[0] - num_experts = (row_id_map.shape[1] - 1) // 2 - hidden_size = inp.shape[1] - - output, _ = _reference_unpermute_impl( - inp, row_id_map, merging_probs, None, num_tokens, num_experts, hidden_size - ) + output, _ = _reference_unpermute_impl(inp, row_id_map, merging_probs, None) return output @@ -289,10 +354,9 @@ def reference_make_chunk_sort_map( split_sizes: jnp.ndarray, sorted_indices: jnp.ndarray, num_tokens: int, - num_splits: int, ) -> jnp.ndarray: """ - Reference implementation of make_chunk_sort_map using JAX primitives. + Vectorized reference implementation of make_chunk_sort_map using JAX primitives. Parameters ---------- @@ -302,45 +366,48 @@ def reference_make_chunk_sort_map( The indices of the sorted chunks of shape [num_splits,]. num_tokens : int Number of tokens. - num_splits : int - Number of splits. Returns ------- row_id_map : jnp.ndarray Row ID map for chunk sorting of shape [num_tokens,]. """ - row_id_map = jnp.zeros((num_tokens,), dtype=jnp.int32) + # Compute source chunk boundaries (cumulative sum of original split_sizes) + src_cumsum = jnp.concatenate([jnp.array([0]), jnp.cumsum(split_sizes)]) - # Compute cumulative positions - cumsum_sizes = jnp.concatenate([jnp.array([0]), jnp.cumsum(split_sizes)]) + # Compute destination chunk boundaries based on sorted order + sorted_sizes = split_sizes[sorted_indices] + dest_cumsum = jnp.concatenate([jnp.array([0]), jnp.cumsum(sorted_sizes)]) - # For each chunk, compute the destination indices - dest_offset = 0 - for sorted_idx in sorted_indices: - chunk_start = cumsum_sizes[sorted_idx] - chunk_end = cumsum_sizes[sorted_idx + 1] - chunk_size = chunk_end - chunk_start + # For each source chunk, compute its destination offset + # inverse_indices[i] = position of chunk i in sorted order + inverse_indices = jnp.argsort(sorted_indices) + dest_offsets = dest_cumsum[inverse_indices] - # Map source positions to destination positions - for i in range(chunk_size): - row_id_map = row_id_map.at[chunk_start + i].set(dest_offset + i) + # Create row_id_map: for each token position, compute its destination + # First, figure out which chunk each position belongs to + position_indices = jnp.arange(num_tokens) - dest_offset += chunk_size + # chunk_ids[i] = which chunk position i belongs to + chunk_ids = jnp.searchsorted(src_cumsum[1:], position_indices, side="right") - return row_id_map + # within_chunk_offset[i] = position i's offset within its chunk + within_chunk_offset = position_indices - src_cumsum[chunk_ids] + + # destination[i] = dest_offsets[chunk_ids[i]] + within_chunk_offset[i] + row_id_map = dest_offsets[chunk_ids] + within_chunk_offset + + return row_id_map.astype(jnp.int32) def reference_sort_chunks_by_map( inp: jnp.ndarray, row_id_map: jnp.ndarray, probs: jnp.ndarray, - num_tokens: int, - hidden_size: int, is_forward: bool, ) -> tuple: """ - Reference implementation of sort_chunks_by_map using JAX primitives. + Vectorized reference implementation of sort_chunks_by_map using JAX primitives. Parameters ---------- @@ -350,10 +417,6 @@ def reference_sort_chunks_by_map( The token to destination mapping of shape [num_tokens,]. probs : jnp.ndarray The probabilities. - num_tokens : int - Number of tokens. - hidden_size : int - Hidden size. is_forward : bool Whether this is forward or backward. @@ -364,25 +427,25 @@ def reference_sort_chunks_by_map( permuted_probs : jnp.ndarray Sorted probabilities if probs was provided, None otherwise. """ - output = jnp.zeros((num_tokens, hidden_size), dtype=inp.dtype) - permuted_probs = None if probs is None else jnp.zeros((num_tokens,), dtype=probs.dtype) + num_tokens = inp.shape[0] + hidden_size = inp.shape[1] if is_forward: - # Forward: src -> dest - for src_idx in range(num_tokens): - # Don't use int() - JAX can index with traced values - dest_idx = row_id_map[src_idx] - output = output.at[dest_idx].set(inp[src_idx]) - if probs is not None: - permuted_probs = permuted_probs.at[dest_idx].set(probs[src_idx]) + # Forward: scatter inp[src] to output[dest] where dest = row_id_map[src] + output = jnp.zeros((num_tokens, hidden_size), dtype=inp.dtype) + output = output.at[row_id_map].set(inp) + if probs is not None: + permuted_probs = jnp.zeros((num_tokens,), dtype=probs.dtype) + permuted_probs = permuted_probs.at[row_id_map].set(probs) + else: + permuted_probs = None else: - # Backward: dest -> src - for dest_idx in range(num_tokens): - # Don't use int() - JAX can index with traced values - src_idx = row_id_map[dest_idx] - output = output.at[dest_idx].set(inp[src_idx]) - if probs is not None: - permuted_probs = permuted_probs.at[dest_idx].set(probs[src_idx]) + # Backward: gather output[dest] = inp[src] where src = row_id_map[dest] + output = inp[row_id_map] + if probs is not None: + permuted_probs = probs[row_id_map] + else: + permuted_probs = None return output, permuted_probs @@ -415,20 +478,24 @@ def generate_routing_map( return routing_map - # ========================================================================= - # token_dispatch tests - # ========================================================================= - - @pytest.mark.parametrize( + @pytest_parametrize_wrapper( "num_tokens,num_experts,hidden_size,tokens_per_expert", - [ - (32, 8, 256, 2), - (64, 16, 512, 3), - ], + DISPATCH_COMBINE_CASES, ) - @pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) - def test_token_dispatch(self, num_tokens, num_experts, hidden_size, tokens_per_expert, dtype): - """Test token_dispatch forward and backward pass against reference""" + @pytest_parametrize_wrapper("dtype", DTYPES) + @pytest_parametrize_wrapper("with_probs", WITH_PROBS) + def test_token_dispatch( + self, num_tokens, num_experts, hidden_size, tokens_per_expert, dtype, with_probs + ): + """ + Individual test for token_dispatch forward and backward passes. + + This test validates dispatch in isolation to catch errors that might be + masked when combined with token_combine in the roundtrip test. + + Uses value_and_grad to validate both forward (via loss comparison) and + backward (via gradient comparison) passes against reference implementation. + """ key = jax.random.PRNGKey(42) # Generate routing map @@ -436,173 +503,231 @@ def test_token_dispatch(self, num_tokens, num_experts, hidden_size, tokens_per_e num_out_tokens = int(jnp.sum(routing_map)) # Generate input data - key, inp_key = jax.random.split(key) + key, inp_key, prob_key = jax.random.split(key, 3) inp = jax.random.uniform( inp_key, (num_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0 ) - # Define loss functions - def loss_fn(x): - output, _, _ = token_dispatch(x, routing_map, num_out_tokens) - return jnp.sum(output**2) + # Generate probs if needed (minval > 0 to avoid kernel's special prob==0 handling) + probs = None + if with_probs: + probs = jax.random.uniform( + prob_key, (num_tokens, num_experts), dtype=dtype, minval=0.1, maxval=1.0 + ) - def ref_loss_fn(x): - output, _, _ = reference_token_dispatch(x, routing_map, num_out_tokens) - return jnp.sum(output**2) + # Generate reference row_id_map for comparison + ref_row_id_map = reference_make_row_id_map(routing_map) - loss_val, computed_grad = jax.value_and_grad(loss_fn)(inp) - ref_loss_val, ref_grad = jax.value_and_grad(ref_loss_fn)(inp) + # ===================================================================== + # Test forward and backward pass using value_and_grad + # (value validates forward, grad validates backward) + # ===================================================================== + if with_probs: - # Compare forward outputs - output, _, _ = token_dispatch(inp, routing_map, num_out_tokens) - ref_output, _, _ = reference_token_dispatch(inp, routing_map, num_out_tokens) - assert_allclose(output, ref_output) + @jax.jit + def dispatch_loss(x, p): + out, perm_probs, _, _, _ = token_dispatch(x, routing_map, num_out_tokens, probs=p) + return jnp.sum(out**2) + jnp.sum(perm_probs**2) - # Compare loss and gradient - assert_allclose(loss_val, ref_loss_val) - assert_allclose(computed_grad, ref_grad) + @jax.jit + def ref_dispatch_loss(x, p): + out, perm_probs = _reference_permute_impl(x, ref_row_id_map, p, num_out_tokens) + return jnp.sum(out**2) + jnp.sum(perm_probs**2) + + loss_val, (inp_grad, probs_grad) = jax.value_and_grad(dispatch_loss, argnums=(0, 1))( + inp, probs + ) + ref_loss_val, (ref_inp_grad, ref_probs_grad) = jax.value_and_grad( + ref_dispatch_loss, argnums=(0, 1) + )(inp, probs) + + # Validate forward loss matches + assert_allclose(loss_val, ref_loss_val, dtype=dtype) + # Validate gradients + assert_allclose(inp_grad, ref_inp_grad, dtype=dtype) + assert_allclose(probs_grad, ref_probs_grad, dtype=dtype) + else: + + @jax.jit + def dispatch_loss_no_probs(x): + out, _, _, _, _ = token_dispatch(x, routing_map, num_out_tokens) + return jnp.sum(out**2) + + @jax.jit + def ref_dispatch_loss_no_probs(x): + out, _ = _reference_permute_impl(x, ref_row_id_map, None, num_out_tokens) + return jnp.sum(out**2) + + loss_val, inp_grad = jax.value_and_grad(dispatch_loss_no_probs)(inp) + ref_loss_val, ref_inp_grad = jax.value_and_grad(ref_dispatch_loss_no_probs)(inp) + + # Validate forward loss matches + assert_allclose(loss_val, ref_loss_val, dtype=dtype) + # Validate gradients + assert_allclose(inp_grad, ref_inp_grad, dtype=dtype) # ========================================================================= - # token_dispatch with probs tests + # Consolidated dispatch + combine tests # ========================================================================= - @pytest.mark.parametrize( + @pytest_parametrize_wrapper( "num_tokens,num_experts,hidden_size,tokens_per_expert", - [ - (32, 8, 256, 2), - (64, 16, 512, 3), - ], + DISPATCH_COMBINE_CASES, ) - @pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) - def test_token_dispatch_with_probs( - self, num_tokens, num_experts, hidden_size, tokens_per_expert, dtype + @pytest_parametrize_wrapper("dtype", DTYPES) + @pytest_parametrize_wrapper("with_probs", WITH_PROBS) + def test_dispatch_and_combine( + self, num_tokens, num_experts, hidden_size, tokens_per_expert, dtype, with_probs ): - """Test token_dispatch with probs forward and backward pass against reference""" + """ + Comprehensive test for token_dispatch and token_combine. + + Tests: + 1. Dispatch forward pass against reference (element-by-element) + 2. Dispatch backward pass against reference + 3. Combine forward pass against reference (element-by-element) + 4. Combine backward pass against reference + 5. Roundtrip: dispatch + combine recovers original input + 6. row_id_map n_routed column validation + 7. Probs permutation (when with_probs=True) + """ key = jax.random.PRNGKey(42) # Generate routing map routing_map = self.generate_routing_map(num_tokens, num_experts, tokens_per_expert, key) num_out_tokens = int(jnp.sum(routing_map)) - # Generate input data and probs - key, inp_key, prob_key = jax.random.split(key, 3) + # Generate input data + key, inp_key, prob_key, merge_key = jax.random.split(key, 4) inp = jax.random.uniform( inp_key, (num_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0 ) - probs = jax.random.uniform( - prob_key, (num_tokens, num_experts), dtype=dtype, minval=0.0, maxval=1.0 - ) - # Define loss function that uses token_dispatch with probs - # We compute gradients w.r.t. both inp and probs - def loss_fn(x, p): - output, permuted_probs, _ = token_dispatch(x, routing_map, num_out_tokens, probs=p) - return jnp.sum(output**2) + jnp.sum(permuted_probs**2) - - def ref_loss_fn(x, p): - output, permuted_probs, _ = reference_token_dispatch( - x, routing_map, num_out_tokens, probs=p + # Generate probs if needed (minval > 0 to avoid kernel's special prob==0 handling) + probs = None + if with_probs: + probs = jax.random.uniform( + prob_key, (num_tokens, num_experts), dtype=dtype, minval=0.1, maxval=1.0 ) - return jnp.sum(output**2) + jnp.sum(permuted_probs**2) - - loss_val, (inp_grad, probs_grad) = jax.value_and_grad(loss_fn, argnums=(0, 1))(inp, probs) - ref_loss_val, (ref_inp_grad, ref_probs_grad) = jax.value_and_grad( - ref_loss_fn, argnums=(0, 1) - )(inp, probs) - output, permuted_probs, _ = token_dispatch(inp, routing_map, num_out_tokens, probs=probs) - - ref_output, ref_permuted_probs, _ = reference_token_dispatch( - inp, routing_map, num_out_tokens, probs=probs + # Generate merging probs (normalized per token) + merging_probs = jax.random.uniform( + merge_key, (num_tokens, num_experts), dtype=dtype, minval=0.1, maxval=1.0 ) - - # Compare forward outputs - assert_allclose(output, ref_output) - assert_allclose(permuted_probs, ref_permuted_probs) - - # Compare loss and gradients - assert_allclose(loss_val, ref_loss_val) - assert_allclose(inp_grad, ref_inp_grad) - assert_allclose(probs_grad, ref_probs_grad) - - # ========================================================================= - # token_combine tests - # ========================================================================= - - @pytest.mark.parametrize( - "num_tokens,num_experts,hidden_size,tokens_per_expert", - [ - (32, 8, 256, 2), - (64, 16, 512, 3), - ], - ) - @pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) - @pytest.mark.parametrize("with_merging_probs", [True, False]) - def test_token_combine( - self, num_tokens, num_experts, hidden_size, tokens_per_expert, dtype, with_merging_probs - ): - """Test token_combine forward and backward pass against reference""" - key = jax.random.PRNGKey(42) - - # Generate routing map - routing_map = self.generate_routing_map(num_tokens, num_experts, tokens_per_expert, key) - num_out_tokens = int(jnp.sum(routing_map)) - - # Get row_id_map from reference_token_dispatch - key, dummy_key = jax.random.split(key) - dummy_inp = jax.random.uniform( - dummy_key, (num_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0 + merging_probs = merging_probs * routing_map.astype(dtype) # Zero out non-routed + merging_probs = merging_probs / jnp.maximum( + jnp.sum(merging_probs, axis=1, keepdims=True), 1e-8 ) - _, _, row_id_map = reference_token_dispatch(dummy_inp, routing_map, num_out_tokens) - # Generate input data (from expert outputs) - key, inp_key, merge_key = jax.random.split(key, 3) - inp = jax.random.uniform( - inp_key, (num_out_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0 + # ===================================================================== + # Test 1: Dispatch forward pass + # ===================================================================== + output, permuted_probs, row_id_map, _, _ = token_dispatch( + inp, routing_map, num_out_tokens, probs=probs + ) + ref_output, ref_permuted_probs = _reference_permute_impl( + inp, row_id_map, probs, num_out_tokens ) - if with_merging_probs: - merging_probs = jax.random.uniform( - merge_key, (num_tokens, num_experts), dtype=dtype, minval=0.0, maxval=1.0 + # Validate row_id_map structure: n_routed column should match routing_map sum + n_routed_actual = row_id_map[:, -1] + n_routed_expected = jnp.sum(routing_map, axis=1) + assert jnp.array_equal( + n_routed_actual, n_routed_expected + ), "make_row_id_map n_routed column mismatch" + + # Compare dispatch output + assert_allclose(output, ref_output, dtype=dtype) + if with_probs: + assert_allclose(permuted_probs, ref_permuted_probs, dtype=dtype) + + # ===================================================================== + # Test 2: Dispatch backward pass + # ===================================================================== + if with_probs: + + @jax.jit + def dispatch_loss(x, p): + out, perm_probs, _, _, _ = token_dispatch(x, routing_map, num_out_tokens, probs=p) + return jnp.sum(out**2) + jnp.sum(perm_probs**2) + + @jax.jit + def ref_dispatch_loss(x, p): + out, perm_probs = _reference_permute_impl(x, row_id_map, p, num_out_tokens) + return jnp.sum(out**2) + jnp.sum(perm_probs**2) + + _, (inp_grad, probs_grad) = jax.value_and_grad(dispatch_loss, argnums=(0, 1))( + inp, probs ) - # Normalize per token - merging_probs = merging_probs / (jnp.sum(merging_probs, axis=1, keepdims=True) + 1e-8) + _, (ref_inp_grad, ref_probs_grad) = jax.value_and_grad( + ref_dispatch_loss, argnums=(0, 1) + )(inp, probs) + assert_allclose(inp_grad, ref_inp_grad, dtype=dtype) + assert_allclose(probs_grad, ref_probs_grad, dtype=dtype) else: - merging_probs = None - # Define loss functions - def loss_fn(x): - output = token_combine(x, row_id_map, merging_probs) - return jnp.sum(output**2) - - def ref_loss_fn(x): - output = reference_token_combine(x, row_id_map, merging_probs) - return jnp.sum(output**2) - - loss_val, computed_grad = jax.value_and_grad(loss_fn)(inp) - ref_loss_val, ref_grad = jax.value_and_grad(ref_loss_fn)(inp) + @jax.jit + def dispatch_loss_no_probs(x): + out, _, _, _, _ = token_dispatch(x, routing_map, num_out_tokens) + return jnp.sum(out**2) + + @jax.jit + def ref_dispatch_loss_no_probs(x): + out, _ = _reference_permute_impl(x, row_id_map, None, num_out_tokens) + return jnp.sum(out**2) + + _, inp_grad = jax.value_and_grad(dispatch_loss_no_probs)(inp) + _, ref_inp_grad = jax.value_and_grad(ref_dispatch_loss_no_probs)(inp) + assert_allclose(inp_grad, ref_inp_grad, dtype=dtype) + + # ===================================================================== + # Test 3: Combine forward pass + # ===================================================================== + combined = token_combine(output, row_id_map, merging_probs) + ref_combined = _reference_unpermute_impl(output, row_id_map, merging_probs, None)[0] + assert_allclose(combined, ref_combined, dtype=dtype) + + # ===================================================================== + # Test 4: Combine backward pass + # ===================================================================== + + @jax.jit + def combine_loss(x): + return jnp.sum(token_combine(x, row_id_map, merging_probs) ** 2) + + @jax.jit + def ref_combine_loss(x): + return jnp.sum(_reference_unpermute_impl(x, row_id_map, merging_probs, None)[0] ** 2) + + _, combine_grad = jax.value_and_grad(combine_loss)(output) + _, ref_combine_grad = jax.value_and_grad(ref_combine_loss)(output) + assert_allclose(combine_grad, ref_combine_grad, dtype=dtype) + + # ===================================================================== + # Test 5: Roundtrip (dispatch + combine = original) + # ===================================================================== + # Use uniform merging probs for perfect roundtrip + uniform_merging_probs = routing_map.astype(dtype) / jnp.maximum( + jnp.sum(routing_map, axis=1, keepdims=True), 1.0 + ) - # Compare forward outputs - output = token_combine(inp, row_id_map, merging_probs) - ref_output = reference_token_combine(inp, row_id_map, merging_probs) - assert_allclose(output, ref_output) + @jax.jit + def roundtrip(x): + dispatched, _, rid_map, _, _ = token_dispatch(x, routing_map, num_out_tokens) + return token_combine(dispatched, rid_map, uniform_merging_probs) - # Compare loss and gradient - assert_allclose(loss_val, ref_loss_val) - assert_allclose(computed_grad, ref_grad) + roundtrip_output = roundtrip(inp) + assert_allclose(roundtrip_output, inp, dtype=dtype) # ========================================================================= # sort_chunks_by_index tests # ========================================================================= - @pytest.mark.parametrize( + @pytest_parametrize_wrapper( "num_splits,total_tokens,hidden_size", - [ - (4, 128, 256), - (8, 256, 512), - ], + SORT_CHUNKS_CASES, ) - @pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) + @pytest_parametrize_wrapper("dtype", DTYPES) def test_sort_chunks_by_index(self, num_splits, total_tokens, hidden_size, dtype): """Test sort_chunks_by_index forward and backward pass against reference""" key = jax.random.PRNGKey(42) @@ -622,73 +747,181 @@ def test_sort_chunks_by_index(self, num_splits, total_tokens, hidden_size, dtype inp_key, (total_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0 ) - row_id_map = reference_make_chunk_sort_map( - split_sizes, sorted_indices, total_tokens, num_splits - ) + # Get reference row_id_map + row_id_map = reference_make_chunk_sort_map(split_sizes, sorted_indices, total_tokens) - # Define loss functions + # Define loss functions (JIT compiled for performance) + @jax.jit def loss_fn(x): output, _ = sort_chunks_by_index(x, split_sizes, sorted_indices) return jnp.sum(output**2) + @jax.jit def ref_loss_fn(x): - output, _ = reference_sort_chunks_by_map( - x, row_id_map, None, total_tokens, hidden_size, is_forward=True - ) + output, _ = reference_sort_chunks_by_map(x, row_id_map, None, is_forward=True) return jnp.sum(output**2) + # Test forward pass + output, _ = sort_chunks_by_index(inp, split_sizes, sorted_indices) + ref_output, _ = reference_sort_chunks_by_map(inp, row_id_map, None, is_forward=True) + + # Test backward pass with JIT loss_val, computed_grad = jax.value_and_grad(loss_fn)(inp) ref_loss_val, ref_grad = jax.value_and_grad(ref_loss_fn)(inp) - # Compare forward outputs - output, _ = sort_chunks_by_index(inp, split_sizes, sorted_indices) - ref_output, _ = reference_sort_chunks_by_map( - inp, row_id_map, None, total_tokens, hidden_size, is_forward=True - ) + # Compare forward and backward assert_allclose(output, ref_output) - - # Compare loss and gradient assert_allclose(loss_val, ref_loss_val) assert_allclose(computed_grad, ref_grad) # ========================================================================= - # Round-trip tests (token_dispatch -> expert processing -> token_combine) + # Consolidated dispatch + combine with padding tests # ========================================================================= - @pytest.mark.parametrize( - "num_tokens,num_experts,hidden_size,tokens_per_expert", - [ - (32, 8, 256, 2), - (64, 16, 512, 3), - ], + @pytest_parametrize_wrapper( + "num_tokens,num_experts,hidden_size,topk,align_size", + DISPATCH_COMBINE_PADDING_CASES, ) - @pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) - def test_dispatch_combine_roundtrip( - self, num_tokens, num_experts, hidden_size, tokens_per_expert, dtype + @pytest_parametrize_wrapper("dtype", DTYPES) + @pytest_parametrize_wrapper("with_probs", WITH_PROBS) + def test_dispatch_and_combine_with_padding( + self, num_tokens, num_experts, hidden_size, topk, align_size, dtype, with_probs ): - """Test that token_dispatch followed by token_combine recovers original input""" + """ + Comprehensive test for token_dispatch and token_combine with padding/unpadding. + + Tests: + 1. Dispatch with padding: output shape and alignment + 2. Dispatch backward pass with padding + 3. Combine with unpad: output shape + 4. Combine backward pass with unpad + 5. Roundtrip with padding: dispatch + combine recovers original + 6. Probs permutation with padding (when with_probs=True) + """ key = jax.random.PRNGKey(42) # Generate routing map - routing_map = self.generate_routing_map(num_tokens, num_experts, tokens_per_expert, key) + routing_map = self.generate_routing_map(num_tokens, num_experts, topk, key) num_out_tokens = int(jnp.sum(routing_map)) + # Compute worst-case padded size + worst_case_size = ( + (num_out_tokens + num_experts * (align_size - 1)) // align_size + ) * align_size + # Generate input data - key, inp_key = jax.random.split(key) + key, inp_key, prob_key, merge_key = jax.random.split(key, 4) inp = jax.random.uniform( inp_key, (num_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0 ) - # Create uniform merging probs (equal weight for all routed experts) - merging_probs = routing_map.astype(dtype) / jnp.maximum( + # Generate probs if needed (minval > 0 to avoid kernel's special prob==0 handling) + probs = None + if with_probs: + probs = jax.random.uniform( + prob_key, (num_tokens, num_experts), dtype=dtype, minval=0.1, maxval=1.0 + ) + + # Generate merging probs (normalized per token) + merging_probs = jax.random.uniform( + merge_key, (num_tokens, num_experts), dtype=dtype, minval=0.1, maxval=1.0 + ) + merging_probs = merging_probs * routing_map.astype(dtype) # Zero out non-routed + merging_probs = merging_probs / jnp.maximum( + jnp.sum(merging_probs, axis=1, keepdims=True), 1e-8 + ) + + # ===================================================================== + # Test 1: Dispatch with padding - forward pass + # ===================================================================== + output, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert = token_dispatch( + inp, routing_map, num_out_tokens, probs=probs, align_size=align_size + ) + + # Check output shape + assert output.shape == (worst_case_size, hidden_size) + if with_probs: + assert permuted_probs is not None + assert permuted_probs.shape == (worst_case_size,) + else: + assert permuted_probs is None + + # Check alignment: each expert's tokens should be aligned + for expert_idx in range(num_experts): + expert_tokens = int(target_tokens_per_expert[expert_idx]) + assert expert_tokens % align_size == 0 or expert_tokens == 0 + + # ===================================================================== + # Test 2: Dispatch with padding - backward pass + # ===================================================================== + if with_probs: + + @jax.jit + def dispatch_loss(x, p): + out, perm_probs, _, _, _ = token_dispatch( + x, routing_map, num_out_tokens, probs=p, align_size=align_size + ) + return jnp.sum(out**2) + jnp.sum(perm_probs**2) + + inp_grad, probs_grad = jax.grad(dispatch_loss, argnums=(0, 1))(inp, probs) + assert inp_grad.shape == inp.shape + assert probs_grad.shape == probs.shape + assert not jnp.any(jnp.isnan(inp_grad)) + assert not jnp.any(jnp.isnan(probs_grad)) + else: + + @jax.jit + def dispatch_loss_no_probs(x): + out, _, _, _, _ = token_dispatch( + x, routing_map, num_out_tokens, align_size=align_size + ) + return jnp.sum(out**2) + + inp_grad = jax.grad(dispatch_loss_no_probs)(inp) + assert inp_grad.shape == inp.shape + assert not jnp.any(jnp.isnan(inp_grad)) + + # ===================================================================== + # Test 3: Combine with unpad - forward pass + # ===================================================================== + combined = token_combine(output, row_id_map, merging_probs, pad_offsets) + assert combined.shape == (num_tokens, hidden_size) + + # ===================================================================== + # Test 4: Combine with unpad - backward pass + # ===================================================================== + + @jax.jit + def combine_loss(x): + return jnp.sum(token_combine(x, row_id_map, merging_probs, pad_offsets) ** 2) + + combine_grad = jax.grad(combine_loss)(output) + assert combine_grad.shape == output.shape + assert not jnp.any(jnp.isnan(combine_grad)) + + # ===================================================================== + # Test 5: Roundtrip with padding (dispatch + combine = original) + # ===================================================================== + # Use uniform merging probs for perfect roundtrip + uniform_merging_probs = routing_map.astype(dtype) / jnp.maximum( jnp.sum(routing_map, axis=1, keepdims=True), 1.0 ) - # Dispatch tokens to experts (returns output, permuted_probs, row_id_map) - dispatched, _, row_id_map = token_dispatch(inp, routing_map, num_out_tokens) + @jax.jit + def roundtrip(x): + dispatched, _, rid_map, p_offsets, _ = token_dispatch( + x, routing_map, num_out_tokens, align_size=align_size + ) + return token_combine(dispatched, rid_map, uniform_merging_probs, p_offsets) + + roundtrip_output = roundtrip(inp) + assert_allclose(roundtrip_output, inp, dtype=dtype) - # Combine tokens back (with uniform merging) (new signature) - combined = token_combine(dispatched, row_id_map, merging_probs) + # Test roundtrip gradient + @jax.jit + def roundtrip_loss(x): + return jnp.sum(roundtrip(x) ** 2) - # Compare with original input - assert_allclose(combined, inp) + roundtrip_grad = jax.grad(roundtrip_loss)(inp) + assert roundtrip_grad.shape == inp.shape + assert not jnp.any(jnp.isnan(roundtrip_grad)) diff --git a/tests/pytorch/test_permutation.py b/tests/pytorch/test_permutation.py index e8a7bedc873..9a0cf6fb7cf 100644 --- a/tests/pytorch/test_permutation.py +++ b/tests/pytorch/test_permutation.py @@ -2,6 +2,7 @@ # # See LICENSE for license information. +import os import random import torch @@ -13,6 +14,7 @@ from transformer_engine.pytorch import ( moe_permute as te_permute, moe_permute_with_probs as te_permute_with_probs, + moe_permute_and_pad_with_probs as te_permute_and_pad_with_probs, moe_unpermute as te_unpermute, moe_sort_chunks_by_index as te_sort_chunks_by_index, moe_sort_chunks_by_index_with_probs as te_sort_chunks_by_index_with_probs, @@ -24,6 +26,7 @@ MXFP8Quantizer, ) import transformer_engine_torch as tex +from transformer_engine.pytorch import Fp8Padding, Fp8Unpadding import copy seed = 1234 @@ -653,6 +656,522 @@ def _test_permutation_mask_map( print(f"unpermute\tbwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms") +def _test_permutation_and_padding_mask_map( + te_dtype, + num_tokens, + num_expert, + hidden_size, + topK, + num_out_tokens, + with_merging_probs=False, + align_size=16, + BENCHMARK=False, +): + if topK > num_expert: + pytest.skip("topK should be smaller than the number of experts.") + + if num_out_tokens is None: + num_out_tokens = num_tokens * topK + + print( + "permutation and padding:" + f" token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK}" + f" with_merging_probs:{with_merging_probs} align_size:{align_size} {te_dtype}" + ) + + # Convert TE dtypes to PyTorch dtypes + if te_dtype == tex.DType.kFloat32: + dtype = torch.float32 + elif te_dtype == tex.DType.kFloat16: + dtype = torch.float16 + elif te_dtype == tex.DType.kBFloat16: + dtype = torch.bfloat16 + else: + pytest.skip("Invalid dtype.") + + _tmp_tensor = torch.zeros((num_tokens * num_expert,)) + _tmp_tensor[: int(num_out_tokens)] = 1.0 + _tmp_idx = torch.randperm(num_tokens * num_expert) + routing_map = torch.reshape(_tmp_tensor[_tmp_idx], (num_tokens, num_expert)).bool().cuda() + + probs = torch.rand(num_tokens, num_expert).cuda() * routing_map + row_sums = probs.sum(dim=1, keepdim=True) + probs = probs / row_sums + probs = probs.to(dtype) + probs.requires_grad_(True) + + tokens_per_expert = routing_map.sum(dim=0).cpu() + target_tokens_per_expert = (torch.ceil(tokens_per_expert / align_size) * align_size).long() + num_permute_pad_out_tokens = target_tokens_per_expert.sum().item() + + permute_pad_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() + permute_pad_bwd_input = torch.rand( + (num_permute_pad_out_tokens, hidden_size), dtype=dtype + ).cuda() + unpermute_unpad_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() + permute_pad_fwd_input.requires_grad_(True) + + restore_shape = permute_pad_fwd_input.shape + ################################################################################################################################### + # + # moe_permute_with_probs and Fp8Padding, moe_unpermute and Fp8Unpadding + # + ################################################################################################################################### + # permute + padding + permuted_output, permuted_probs, row_id_map = te_permute_with_probs( + permute_pad_fwd_input, + probs, + routing_map, + num_out_tokens=num_out_tokens, + ) + tokens_per_expert_list = tokens_per_expert.tolist() + fp8_padding = Fp8Padding(num_expert, align_size) + permuted_paded_output, _ = fp8_padding(permuted_output, tokens_per_expert_list) + permuted_paded_probs, _ = fp8_padding(permuted_probs.unsqueeze(-1), tokens_per_expert_list) + + permuted_paded_output.backward(permute_pad_bwd_input, retain_graph=True) + + # unpadding + unpermute + + unpermute_unpad_fwd_input = permuted_paded_output.detach() + unpermute_unpad_fwd_input.requires_grad_(True) + + fp8_unpadding = Fp8Unpadding(num_expert, align_size) + unpaded_output = fp8_unpadding(unpermute_unpad_fwd_input, tokens_per_expert_list) + + probs_naive = probs + unpermuted_unpaded_output = te_unpermute( + unpaded_output, + row_id_map, + merging_probs=probs_naive if with_merging_probs else None, + restore_shape=restore_shape, + ) + + unpermuted_unpaded_output.backward(unpermute_unpad_bwd_input, retain_graph=True) + + ################################################################################################################################### + # + # fusion moe_permute_with_probs and Fp8Padding, fusion fusion moe_unpermute and Fp8Unpadding + # + ################################################################################################################################### + # fusion permute_and_pad + fusion_permute_and_pad_fwd_input = permute_pad_fwd_input.detach() + fusion_permute_and_pad_fwd_input.requires_grad_(True) + probs_fusion = probs_naive.detach().clone() + probs_fusion.requires_grad_(True) + + ( + fusion_permuted_padded_output, + fusion_permuted_padded_probs, + row_id_map, + pad_offsets, + target_tokens_per_expert, + ) = te_permute_and_pad_with_probs( + fusion_permute_and_pad_fwd_input, + probs_fusion, + routing_map, + tokens_per_expert, + align_size, + ) + fusion_permuted_padded_probs = fusion_permuted_padded_probs.unsqueeze(-1) + + fusion_permute_pad_bwd_input = permute_pad_bwd_input.detach() + fusion_permuted_padded_output.backward(fusion_permute_pad_bwd_input, retain_graph=True) + + # fusion unpad and unpermute + fusion_unpermute_unpad_fwd_input = fusion_permuted_padded_output.detach() + fusion_unpermute_unpad_fwd_input.requires_grad_(True) + + fusion_unpermuted_unpaded_output = te_unpermute( + fusion_unpermute_unpad_fwd_input, + row_id_map, + merging_probs=probs_fusion if with_merging_probs else None, + restore_shape=restore_shape, + pad_offsets=pad_offsets, + ) + + fusion_unpermute_bwd_input = unpermute_unpad_bwd_input.detach() + fusion_unpermuted_unpaded_output.backward(fusion_unpermute_bwd_input, retain_graph=True) + + ################################################################################################################################### + # + # Results Check + # + ################################################################################################################################### + tols = dtype_tols(te_dtype) + + permuted_paded_output_ = permuted_paded_output.float() + fusion_permuted_padded_output_ = fusion_permuted_padded_output.float() + permute_pad_fwd_input_grad = permute_pad_fwd_input.grad.float() + fusion_permute_and_pad_fwd_input_grad = fusion_permute_and_pad_fwd_input.grad.float() + + unpermuted_unpaded_output_ = unpermuted_unpaded_output.float() + fusion_unpermuted_unpaded_output_ = fusion_unpermuted_unpaded_output.float() + unpermute_unpad_fwd_input_grad = unpermute_unpad_fwd_input.grad.float() + fusion_unpermute_unpad_fwd_input_grad = fusion_unpermute_unpad_fwd_input.grad.float() + + if not BENCHMARK: + torch.testing.assert_close( + permuted_paded_output_, + fusion_permuted_padded_output_, + msg=f"Mismatch in te_permute_and_pad fwd", + **tols, + ) + torch.testing.assert_close( + permute_pad_fwd_input_grad, + fusion_permute_and_pad_fwd_input_grad, + msg=f"Mismatch in te_permute_and_pad bwd", + **tols, + ) + torch.testing.assert_close( + unpermuted_unpaded_output_, + fusion_unpermuted_unpaded_output_, + msg=f"Mismatch in te_unpermute fwd", + **tols, + ) + torch.testing.assert_close( + unpermute_unpad_fwd_input_grad, + fusion_unpermute_unpad_fwd_input_grad, + msg=f"Mismatch in te_unpermute bwd", + **tols, + ) + torch.testing.assert_close( + permuted_paded_probs.float(), + fusion_permuted_padded_probs.float(), + msg=f"Mismatch in te_permute_and_pad bwd", + **tols, + ) + if with_merging_probs: + torch.testing.assert_close( + probs_naive.grad.float(), + probs_fusion.grad.float(), + msg=f"Mismatch in te_unpermute bwd", + **tols, + ) + + ################################################################################################################################### + # + # Benchmark + # + ################################################################################################################################### + if BENCHMARK: + + def permute_and_pad(): + permuted_output, permuted_probs, row_id_map = te_permute_with_probs( + permute_pad_fwd_input, + probs, + routing_map, + num_out_tokens=num_out_tokens, + ) + fp8_padding(permuted_output, tokens_per_expert_list) + fp8_padding(permuted_probs.unsqueeze(-1), tokens_per_expert_list) + + def fusion_permute_and_pad(): + ( + fusion_permuted_padded_output, + fusion_permuted_padded_probs, + row_id_map, + pad_offsets, + target_tokens_per_expert, + ) = te_permute_and_pad_with_probs( + fusion_permute_and_pad_fwd_input, + probs, + routing_map, + tokens_per_expert, + align_size, + ) + fusion_permuted_padded_probs = fusion_permuted_padded_probs.unsqueeze(-1) + + t1 = perf_test_cuda_kernel(lambda: permute_and_pad()) + + t2 = perf_test_cuda_kernel(lambda: fusion_permute_and_pad()) + + print(f"permute_and_pad\t\tfwd: naive: {t1:.3f} ms, fusion: {t2:.3f} ms") + + t1 = perf_test_cuda_kernel( + lambda: backward_wrapper( + permuted_paded_output, + permute_pad_bwd_input, + forward_input=[permute_pad_fwd_input], + retain_graph=True, + accumulate_grad=False, + ) + ) + t2 = perf_test_cuda_kernel( + lambda: backward_wrapper( + fusion_permuted_padded_output, + fusion_permute_pad_bwd_input, + forward_input=[fusion_permute_and_pad_fwd_input], + retain_graph=True, + accumulate_grad=False, + ) + ) + print(f"permute_and_pad\t\tbwd: naive: {t1:.3f} ms, fusion: {t2:.3f} ms") + + def unpad_unpermute(): + unpaded_output = fp8_unpadding(unpermute_unpad_fwd_input, tokens_per_expert_list) + unpermuted_unpaded_output = te_unpermute( + unpaded_output, row_id_map, restore_shape=restore_shape + ) + + unpermuted_unpaded_output.backward(unpermute_unpad_bwd_input, retain_graph=True) + + t1 = perf_test_cuda_kernel(lambda: unpad_unpermute()) + t2 = perf_test_cuda_kernel( + lambda: te_unpermute( + fusion_unpermute_unpad_fwd_input, + row_id_map, + restore_shape=restore_shape, + pad_offsets=pad_offsets, + ) + ) + print(f"unpermute_and_unpad\tfwd: naive: {t1:.3f} ms, fusion: {t2:.3f} ms") + + t1 = perf_test_cuda_kernel( + lambda: backward_wrapper( + unpermuted_unpaded_output, + unpermute_unpad_bwd_input, + forward_input=([unpermute_unpad_fwd_input, probs]), + retain_graph=True, + accumulate_grad=False, + ) + ) + t2 = perf_test_cuda_kernel( + lambda: backward_wrapper( + fusion_unpermuted_unpaded_output, + fusion_unpermute_bwd_input, + forward_input=([fusion_unpermute_unpad_fwd_input, probs]), + retain_graph=True, + accumulate_grad=False, + ) + ) + print(f"unpermute_and_unpad\tbwd: naive: {t1:.3f} ms, fusion: {t2:.3f} ms") + + +def _test_permutation_and_padding_with_merging_probs( + te_dtype, + num_tokens, + num_expert, + hidden_size, + topK, + num_out_tokens, + align_size=16, + BENCHMARK=False, +): + """ + Test the combination of merging_probs AND pad_offsets together in moe_unpermute. + This specifically tests the backward pass fix where pad_offsets must be used + when computing gradients with merging_probs. + """ + if topK > num_expert: + pytest.skip("topK should be smaller than the number of experts.") + + if num_out_tokens == None: + num_out_tokens = num_tokens * topK + + print( + "permutation and padding with merging probs:" + f" token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK} align_size:{align_size} {te_dtype}" + ) + + # Convert TE dtypes to PyTorch dtypes + if te_dtype == tex.DType.kFloat32: + dtype = torch.float32 + elif te_dtype == tex.DType.kFloat16: + dtype = torch.float16 + elif te_dtype == tex.DType.kBFloat16: + dtype = torch.bfloat16 + else: + pytest.skip("Invalid dtype.") + + _tmp_tensor = torch.zeros((num_tokens * num_expert,)) + _tmp_tensor[: int(num_out_tokens)] = 1.0 + _tmp_idx = torch.randperm(num_tokens * num_expert) + routing_map = torch.reshape(_tmp_tensor[_tmp_idx], (num_tokens, num_expert)).bool().cuda() + + probs = torch.rand(num_tokens, num_expert).cuda() * routing_map + row_sums = probs.sum(dim=1, keepdim=True) + probs = probs / row_sums + probs = probs.to(dtype) + probs.requires_grad_(True) + + tokens_per_expert = routing_map.sum(dim=0).cpu() + target_tokens_per_expert = (torch.ceil(tokens_per_expert / align_size) * align_size).long() + num_permute_pad_out_tokens = target_tokens_per_expert.sum().item() + + permute_pad_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() + permute_pad_bwd_input = torch.rand( + (num_permute_pad_out_tokens, hidden_size), dtype=dtype + ).cuda() + unpermute_unpad_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() + permute_pad_fwd_input.requires_grad_(True) + + restore_shape = permute_pad_fwd_input.shape + ################################################################################################################################### + # + # Reference: moe_permute_with_probs + Fp8Padding, then Fp8Unpadding + moe_unpermute with merging_probs + # + ################################################################################################################################### + # permute + padding + permuted_output, permuted_probs, row_id_map = te_permute_with_probs( + permute_pad_fwd_input, + probs, + routing_map, + num_out_tokens=num_out_tokens, + ) + tokens_per_expert_list = tokens_per_expert.tolist() + fp8_padding = Fp8Padding(num_expert, align_size) + permuted_paded_output, _ = fp8_padding(permuted_output, tokens_per_expert_list) + + permuted_paded_output.backward(permute_pad_bwd_input, retain_graph=True) + + # Reference: unpadding + unpermute WITH merging_probs + ref_unpermute_fwd_input = permuted_paded_output.detach() + ref_unpermute_fwd_input.requires_grad_(True) + + ref_probs = probs.detach() + ref_probs.requires_grad_(True) + + fp8_unpadding = Fp8Unpadding(num_expert, align_size) + unpaded_output = fp8_unpadding(ref_unpermute_fwd_input, tokens_per_expert_list) + ref_unpermuted_output = te_unpermute( + unpaded_output, row_id_map, ref_probs, restore_shape=restore_shape + ) + + ref_unpermuted_output.backward(unpermute_unpad_bwd_input, retain_graph=True) + + ################################################################################################################################### + # + # Fused: moe_permute_and_pad_with_probs, then moe_unpermute with BOTH merging_probs AND pad_offsets + # + ################################################################################################################################### + # fusion permute_and_pad + fusion_permute_fwd_input = permute_pad_fwd_input.detach() + fusion_permute_fwd_input.requires_grad_(True) + fusion_probs = probs.detach() + fusion_probs.requires_grad_(True) + + ( + fusion_permuted_padded_output, + fusion_permuted_padded_probs, + fused_row_id_map, + pad_offsets, + _, + ) = te_permute_and_pad_with_probs( + fusion_permute_fwd_input, + fusion_probs, + routing_map, + tokens_per_expert, + align_size, + ) + + fusion_permute_pad_bwd_input = permute_pad_bwd_input.detach() + fusion_permuted_padded_output.backward(fusion_permute_pad_bwd_input, retain_graph=True) + + # Fused: unpermute with BOTH merging_probs AND pad_offsets + fusion_unpermute_fwd_input = fusion_permuted_padded_output.detach() + fusion_unpermute_fwd_input.requires_grad_(True) + + fusion_merging_probs = probs.detach() + fusion_merging_probs.requires_grad_(True) + + fusion_unpermuted_output = te_unpermute( + fusion_unpermute_fwd_input, + fused_row_id_map, + fusion_merging_probs, + restore_shape=restore_shape, + pad_offsets=pad_offsets, + ) + + fusion_unpermute_bwd_input = unpermute_unpad_bwd_input.detach() + fusion_unpermuted_output.backward(fusion_unpermute_bwd_input, retain_graph=True) + + ################################################################################################################################### + # + # Results Check + # + ################################################################################################################################### + tols = dtype_tols(te_dtype) + + # Check forward pass + ref_unpermuted_output_ = ref_unpermuted_output.float() + fusion_unpermuted_output_ = fusion_unpermuted_output.float() + + if not BENCHMARK: + torch.testing.assert_close( + ref_unpermuted_output_, + fusion_unpermuted_output_, + msg=f"Mismatch in te_unpermute with merging_probs and pad_offsets fwd", + **tols, + ) + + # Check backward pass - activation gradients + ref_unpermute_fwd_input_grad = ref_unpermute_fwd_input.grad.float() + fusion_unpermute_fwd_input_grad = fusion_unpermute_fwd_input.grad.float() + + torch.testing.assert_close( + ref_unpermute_fwd_input_grad, + fusion_unpermute_fwd_input_grad, + msg=f"Mismatch in te_unpermute with merging_probs and pad_offsets bwd (act_grad)", + **tols, + ) + + # Check backward pass - probs gradients + ref_probs_grad = ref_probs.grad.float() + fusion_probs_grad = fusion_merging_probs.grad.float() + + torch.testing.assert_close( + ref_probs_grad, + fusion_probs_grad, + msg=f"Mismatch in te_unpermute with merging_probs and pad_offsets bwd (probs_grad)", + **tols, + ) + + ################################################################################################################################### + # + # Benchmark + # + ################################################################################################################################### + if BENCHMARK: + + def ref_unpad_unpermute(): + unpaded = fp8_unpadding(ref_unpermute_fwd_input, tokens_per_expert_list) + return te_unpermute(unpaded, row_id_map, ref_probs, restore_shape=restore_shape) + + def fused_unpermute(): + return te_unpermute( + fusion_unpermute_fwd_input, + fused_row_id_map, + fusion_merging_probs, + restore_shape=restore_shape, + pad_offsets=pad_offsets, + ) + + t1 = perf_test_cuda_kernel(lambda: ref_unpad_unpermute()) + t2 = perf_test_cuda_kernel(lambda: fused_unpermute()) + print(f"unpermute_unpad_with_probs\tfwd: naive: {t1:.3f} ms, fusion: {t2:.3f} ms") + + t1 = perf_test_cuda_kernel( + lambda: backward_wrapper( + ref_unpermuted_output, + unpermute_unpad_bwd_input, + forward_input=[ref_unpermute_fwd_input, ref_probs], + retain_graph=True, + accumulate_grad=False, + ) + ) + t2 = perf_test_cuda_kernel( + lambda: backward_wrapper( + fusion_unpermuted_output, + fusion_unpermute_bwd_input, + forward_input=[fusion_unpermute_fwd_input, fusion_merging_probs], + retain_graph=True, + accumulate_grad=False, + ) + ) + print(f"unpermute_unpad_with_probs\tbwd: naive: {t1:.3f} ms, fusion: {t2:.3f} ms") + + def _test_permutation_mask_map_fp8( te_dtype, num_tokens, @@ -1126,7 +1645,7 @@ def perf_test_cuda_kernel(cuda_kernel_fn): @pytest.mark.parametrize("num_tokens", [4096]) @pytest.mark.parametrize("num_expert", [7, 16]) @pytest.mark.parametrize("hidden_size", [4096]) -@pytest.mark.parametrize("topK", [1, 2, 5]) +@pytest.mark.parametrize("topK", [2, 5]) @pytest.mark.parametrize("num_out_tokens", [None, 2039]) def test_permutation_index_map( te_dtype, @@ -1155,7 +1674,7 @@ def test_permutation_index_map( @pytest.mark.parametrize("num_tokens", [4096]) @pytest.mark.parametrize("num_expert", [7, 16]) @pytest.mark.parametrize("hidden_size", [4096]) -@pytest.mark.parametrize("topK", [1, 2, 5]) +@pytest.mark.parametrize("topK", [2, 5]) @pytest.mark.parametrize("num_out_tokens", [None, 2039]) def test_permutation_mask_map( te_dtype, @@ -1180,6 +1699,74 @@ def test_permutation_mask_map( ) +@pytest.mark.parametrize("te_dtype", _te_dtypes) +@pytest.mark.parametrize("num_out_tokens", [None]) +@pytest.mark.parametrize( + "num_tokens, num_expert, hidden_size, topK", + [ + (4096, 8, 1280, 2), + (4096, 64, 4096, 6), + (4096, 256, 7168, 6), + (4096, 512, 9216, 8), + ], +) +@pytest.mark.parametrize("with_merging_probs", [True, False]) +def test_permutation_and_padding_mask_map( + te_dtype, + num_tokens, + num_expert, + hidden_size, + topK, + num_out_tokens, + with_merging_probs, +): + BENCHMARK = False + + _test_permutation_and_padding_mask_map( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + with_merging_probs=with_merging_probs, + BENCHMARK=BENCHMARK, + ) + + +@pytest.mark.parametrize("te_dtype", _te_dtypes) +@pytest.mark.parametrize("num_out_tokens", [None]) +@pytest.mark.parametrize( + "num_tokens, num_expert, hidden_size, topK", + [ + (4096, 8, 1280, 2), + (4096, 64, 4096, 6), + (4096, 256, 7168, 6), + (4096, 512, 9216, 8), + ], +) +def test_permutation_and_padding_with_merging_probs( + te_dtype, + num_tokens, + num_expert, + hidden_size, + topK, + num_out_tokens, +): + """Test moe_unpermute backward pass with BOTH merging_probs AND pad_offsets.""" + BENCHMARK = False + + _test_permutation_and_padding_with_merging_probs( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + BENCHMARK=BENCHMARK, + ) + + @pytest.mark.parametrize("te_dtype", _te_dtypes) def test_permutation_mask_map_empty_input(te_dtype): with_probs = True @@ -1201,9 +1788,9 @@ def test_permutation_mask_map_empty_input(te_dtype): @pytest.mark.parametrize("num_tokens", [4096]) @pytest.mark.parametrize("num_expert", [7, 16]) @pytest.mark.parametrize("hidden_size", [4096]) -@pytest.mark.parametrize("topK", [1, 2, 5]) +@pytest.mark.parametrize("topK", [2, 5]) @pytest.mark.parametrize("num_out_tokens", [None, 2039]) -@pytest.mark.parametrize("tp_size", [1, 2, 8]) +@pytest.mark.parametrize("tp_size", [1, 2]) def test_permutation_mask_map_alongside_probs( te_dtype, num_tokens, @@ -1253,10 +1840,10 @@ def test_permutation_mask_map_alongside_probs_empty_input(te_dtype): @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.parametrize("te_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) -@pytest.mark.parametrize("num_tokens", [2048]) +@pytest.mark.parametrize("num_tokens", [4096]) @pytest.mark.parametrize("num_expert", [7, 16]) @pytest.mark.parametrize("hidden_size", [4096]) -@pytest.mark.parametrize("topK", [1, 2, 5]) +@pytest.mark.parametrize("topK", [2, 5]) @pytest.mark.parametrize("num_out_tokens", [None, 2039]) @pytest.mark.parametrize("recipe", fp8_recipes) def test_permutation_mask_map_fp8( @@ -1341,7 +1928,7 @@ def test_permutation_mask_map_topk1_no_probs( @pytest.mark.parametrize("te_dtype", _te_dtypes) @pytest.mark.parametrize("num_tokens", [4096]) @pytest.mark.parametrize("num_expert", [7, 16]) -@pytest.mark.parametrize("tp_size", [1, 2, 8]) +@pytest.mark.parametrize("tp_size", [2, 8]) @pytest.mark.parametrize("hidden_size", [4096]) def test_chunk_permutation( te_dtype, @@ -1376,6 +1963,10 @@ def test_chunk_permutation_empty_input(te_dtype): ) +@pytest.mark.skipif( + os.getenv("RUN_BENCHMARK_TESTS", "0") != "1", + reason="Benchmark test - run with: RUN_BENCHMARK_TESTS=1 pytest -k single_case", +) def test_permutation_single_case(): print("GPU:", torch.cuda.get_device_name(0)) @@ -1413,6 +2004,26 @@ def test_permutation_single_case(): BENCHMARK=Benchmark, ) + _test_permutation_and_padding_mask_map( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + BENCHMARK=Benchmark, + ) + + _test_permutation_and_padding_with_merging_probs( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + BENCHMARK=Benchmark, + ) + _test_moe_chunk_sort( te_dtype=te_dtype, num_tokens=num_tokens, @@ -1479,6 +2090,30 @@ def benchmark_single_case( ) torch.cuda.nvtx.range_pop() + torch.cuda.nvtx.range_push("permutation_and_padding_mask_map") + _test_permutation_and_padding_mask_map( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + BENCHMARK=True, + ) + torch.cuda.nvtx.range_pop() + + torch.cuda.nvtx.range_push("permutation_and_padding_with_merging_probs") + _test_permutation_and_padding_with_merging_probs( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + BENCHMARK=True, + ) + torch.cuda.nvtx.range_pop() + torch.cuda.nvtx.range_push("permutation_mask_map_alongside_probs") _test_permutation_mask_map_alongside_probs( te_dtype=te_dtype, @@ -1495,7 +2130,12 @@ def benchmark_single_case( torch.cuda.nvtx.range_pop() -def benchmark_multiple_cases(): +@pytest.mark.skipif( + os.getenv("RUN_BENCHMARK_TESTS", "0") != "1", + reason="Benchmark test - run with: RUN_BENCHMARK_TESTS=1 pytest -k benchmark", +) +def test_benchmark_multiple_cases(): + """Benchmark test - skipped by default. Run with: RUN_BENCHMARK_TESTS=1 pytest -k benchmark""" print("GPU:", torch.cuda.get_device_name(0)) # te_dtype = tex.DType.kFloat32 @@ -1537,4 +2177,4 @@ def benchmark_multiple_cases(): if __name__ == "__main__": - benchmark_multiple_cases() + test_benchmark_multiple_cases() diff --git a/transformer_engine/common/triton/permutation.py b/transformer_engine/common/triton/permutation.py index 87a9c245334..de30c7c532d 100644 --- a/transformer_engine/common/triton/permutation.py +++ b/transformer_engine/common/triton/permutation.py @@ -200,6 +200,7 @@ def _permute_kernel( probs_ptr, scale_ptr, permuted_scale_ptr, + pad_offsets_ptr, # sizes scale_hidden_dim, # strides @@ -224,8 +225,11 @@ def _permute_kernel( hidden_size: tl.constexpr, PERMUTE_PROBS: tl.constexpr, PERMUTE_SCALE: tl.constexpr, + FUSION_PAD: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): + expert_idx = 0 + pid_t = tl.program_id(0) pid_h = tl.program_id(1) cur_off = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) @@ -246,6 +250,15 @@ def _permute_kernel( dst_row = tl.load( row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert ).to(tl.int64) + if FUSION_PAD or PERMUTE_PROBS: + expert_idx = tl.load( + row_id_map_ptr + + pid_t * stride_row_id_map_token + + (num_experts + idx) * stride_row_id_map_expert + ) + if FUSION_PAD: + pad_off = tl.load(pad_offsets_ptr + expert_idx) + dst_row = dst_row + pad_off output_off = dst_row * stride_output_token + cur_off * stride_output_hidden if PERMUTE_SCALE: permuted_scale_off = ( @@ -253,11 +266,6 @@ def _permute_kernel( ) tl.store(permuted_scale_ptr + permuted_scale_off, scale, mask=mask_scale) if PERMUTE_PROBS: - expert_idx = tl.load( - row_id_map_ptr - + pid_t * stride_row_id_map_token - + (num_experts + idx) * stride_row_id_map_expert - ) prob_off = pid_t * stride_probs_token + expert_idx * stride_probs_expert prob = tl.load(probs_ptr + prob_off) if pid_h == 0: @@ -297,6 +305,7 @@ def _unpermute_kernel( row_id_map_ptr, merging_probs_ptr, permuted_probs_ptr, + pad_offsets_ptr, # strides stride_row_id_map_token, stride_row_id_map_expert, @@ -318,10 +327,12 @@ def _unpermute_kernel( PROBS_LOAD_WIDTH: tl.constexpr, WITH_MERGING_PROBS: tl.constexpr, PERMUTE_PROBS: tl.constexpr, + FUSION_UNPAD: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): data_type = input_ptr.dtype.element_ty compute_type = tl.float32 + expert_idx = 0 pid_t = tl.program_id(0) pid_h = tl.program_id(1) @@ -348,15 +359,19 @@ def _unpermute_kernel( src_row = tl.load( row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert ).to(tl.int64) - input_off = src_row * stride_input_token + current_offset * stride_input_hidden - inp = tl.load(input_ptr + input_off, mask=mask) - inp = inp.to(compute_type) - if WITH_MERGING_PROBS: + if FUSION_UNPAD or WITH_MERGING_PROBS: expert_idx = tl.load( row_id_map_ptr + pid_t * stride_row_id_map_token + (num_experts + idx) * stride_row_id_map_expert ) + if FUSION_UNPAD: + pad_off = tl.load(pad_offsets_ptr + expert_idx) + src_row = src_row + pad_off + input_off = src_row * stride_input_token + current_offset * stride_input_hidden + inp = tl.load(input_ptr + input_off, mask=mask) + inp = inp.to(compute_type) + if WITH_MERGING_PROBS: merging_prob_off = ( pid_t * stride_merging_probs_token + expert_idx * stride_merging_probs_expert ) @@ -407,6 +422,7 @@ def _unpermute_bwd_with_merging_probs_kernel( fwd_input_ptr, merging_probs_ptr, row_id_map_ptr, + pad_offsets_ptr, # strides stride_row_id_map_token, stride_row_id_map_expert, @@ -427,6 +443,7 @@ def _unpermute_bwd_with_merging_probs_kernel( num_experts: tl.constexpr, hidden_size: tl.constexpr, PROBS_LOAD_WIDTH: tl.constexpr, + FUSION_UNPAD: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): data_type = fwd_output_grad_ptr.dtype.element_ty @@ -450,6 +467,9 @@ def _unpermute_bwd_with_merging_probs_kernel( + pid * stride_row_id_map_token + (num_experts + idx) * stride_row_id_map_expert ) + if FUSION_UNPAD: + pad_off = tl.load(pad_offsets_ptr + expert_idx) + dst_row = dst_row + pad_off prob_grad_accum = tl.zeros((BLOCK_SIZE,), dtype=compute_type) current_start = 0 while current_start < hidden_size: diff --git a/transformer_engine/jax/permutation.py b/transformer_engine/jax/permutation.py index 55a59a1650a..32de0b1a3c8 100644 --- a/transformer_engine/jax/permutation.py +++ b/transformer_engine/jax/permutation.py @@ -25,8 +25,11 @@ from transformer_engine.jax.triton_extensions.permutation import ( make_row_id_map, permute_with_mask_map, + permute_with_mask_map_and_pad, unpermute_with_mask_map, + unpermute_with_mask_map_and_unpad, unpermute_bwd_with_merging_probs, + unpermute_bwd_with_merging_probs_and_unpad, make_chunk_sort_map, sort_chunks_by_map, ) @@ -43,7 +46,14 @@ def token_dispatch( routing_map: jnp.ndarray, num_out_tokens: int, probs: Optional[jnp.ndarray] = None, -) -> Tuple[jnp.ndarray, Optional[jnp.ndarray], jnp.ndarray]: + align_size: Optional[int] = None, +) -> Tuple[ + jnp.ndarray, + Optional[jnp.ndarray], + jnp.ndarray, + Optional[jnp.ndarray], + Optional[jnp.ndarray], +]: """ Dispatch tokens to experts based on routing map. @@ -51,6 +61,10 @@ def token_dispatch( to their designated experts according to the routing map. The row_id_map is computed internally from the routing_map. + Optionally supports fused padding for alignment when `align_size` is provided. + This is useful for efficient matrix multiplications that require aligned tensor + dimensions. The padding is computed internally from the routing_map. + Parameters ---------- inp : jnp.ndarray @@ -59,36 +73,99 @@ def token_dispatch( Routing mask of shape [batch, sequence, num_experts] or [num_tokens, num_experts]. Values: 1 = routed, 0 = not routed. num_out_tokens : int - The number of output tokens after permutation. This should equal the sum of - routing_map and must be provided explicitly for JIT compatibility. + The number of output tokens after permutation (before padding). For the dropless + case, this should be equal to the sum of routing_map. Must be provided explicitly + for JIT compatibility since output shape must be known at compile time. probs : Optional[jnp.ndarray] Optional routing probabilities of shape [batch, sequence, num_experts] or [num_tokens, num_experts]. If provided, permuted_probs will be returned. + align_size : Optional[int] + Optional alignment size for padding. If provided, outputs will be padded to + align each expert's tokens to a multiple of this size. The output buffer is + allocated with worst-case size, rounded down to align_size: + ((num_out_tokens + num_experts * (align_size - 1)) // align_size) * align_size + This enables full JIT compatibility. Returns ------- output : jnp.ndarray - Permuted output tensor of shape [num_out_tokens, hidden_size]. + Permuted output tensor of shape [num_out_tokens, hidden_size] without padding, + or [worst_case_padded_size, hidden_size] when using padding fusion. + With padding, the actual used portion may be smaller than the buffer; check + actual_num_out_tokens (sum of target_tokens_per_expert) for the actual size. permuted_probs : Optional[jnp.ndarray] - Permuted probabilities of shape [num_out_tokens], or None if probs was not provided. + Permuted probabilities of shape [num_out_tokens] or [worst_case_padded_size], + or None if probs was not provided. row_id_map : jnp.ndarray Row ID map for use in token_combine (shape [num_tokens, num_experts * 2 + 1]). + pad_offsets : Optional[jnp.ndarray] + Per-expert cumulative padding offsets of shape [num_experts] when using padding, + None otherwise. Pass this to token_combine when unpadding is needed. + target_tokens_per_expert : Optional[jnp.ndarray] + Aligned token counts per expert of shape [num_experts] when using padding, + None otherwise. + + Note + ---- + **JIT Compatibility:** + + This function is fully JIT-compatible. When using padding (align_size provided), + the output buffer is allocated with a fixed worst-case size that depends only on + compile-time constants (num_out_tokens, num_experts, align_size). The actual + padding offsets (pad_offsets) and aligned token counts (target_tokens_per_expert) + are computed internally from the routing_map and can be traced values. + + The worst-case output size is: + ((num_out_tokens + num_experts * (align_size - 1)) // align_size) * align_size + This accounts for the maximum possible padding when each expert needs (align_size - 1) + extra tokens to align, rounded down to align_size for buffer alignment. """ - return _token_dispatch(inp, routing_map, probs, num_out_tokens) + use_padding = align_size is not None + num_experts = routing_map.shape[-1] + if use_padding: + # Compute worst-case output size (compile-time constant) + # This is the maximum possible size when each expert needs max padding + worst_case_out_tokens = ( + (num_out_tokens + num_experts * (align_size - 1)) // align_size + ) * align_size + else: + worst_case_out_tokens = num_out_tokens + + return _token_dispatch( + inp, routing_map, probs, num_out_tokens, worst_case_out_tokens, align_size, use_padding + ) -@partial(jax.custom_vjp, nondiff_argnums=(1, 3)) + +@partial(jax.custom_vjp, nondiff_argnums=(1, 3, 4, 5, 6)) def _token_dispatch( inp: jnp.ndarray, routing_map: jnp.ndarray, probs: Optional[jnp.ndarray], num_out_tokens: int, -) -> Tuple[jnp.ndarray, Optional[jnp.ndarray], jnp.ndarray]: + worst_case_out_tokens: int, + align_size: Optional[int], + use_padding: bool, +) -> Tuple[ + jnp.ndarray, + Optional[jnp.ndarray], + jnp.ndarray, + Optional[jnp.ndarray], + Optional[jnp.ndarray], +]: """Internal token_dispatch with custom VJP.""" - (output, permuted_probs, row_id_map), _ = _token_dispatch_fwd_rule( - inp, routing_map, probs, num_out_tokens + (output, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert), _ = ( + _token_dispatch_fwd_rule( + inp, + routing_map, + probs, + num_out_tokens, + worst_case_out_tokens, + align_size, + use_padding, + ) ) - return output, permuted_probs, row_id_map + return output, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert def _token_dispatch_fwd_rule( @@ -96,9 +173,18 @@ def _token_dispatch_fwd_rule( routing_map: jnp.ndarray, probs: Optional[jnp.ndarray], num_out_tokens: int, + worst_case_out_tokens: int, + align_size: Optional[int], + use_padding: bool, ) -> Tuple[ - Tuple[jnp.ndarray, Optional[jnp.ndarray], jnp.ndarray], - Tuple[jnp.ndarray, int, int, int, bool], + Tuple[ + jnp.ndarray, + Optional[jnp.ndarray], + jnp.ndarray, + Optional[jnp.ndarray], + Optional[jnp.ndarray], + ], + Tuple[jnp.ndarray, Optional[jnp.ndarray], int, int, int, bool], ]: """Forward pass rule for token_dispatch.""" # Validate input dimensions @@ -126,42 +212,102 @@ def _token_dispatch_fwd_rule( with_probs = probs is not None - output, permuted_probs = permute_with_mask_map( - inp, - row_id_map, - probs, - num_tokens, - num_experts, - num_out_tokens, - hidden_size, - ) + if use_padding: + # Compute tokens_per_expert internally from routing_map + # This can be a traced value since output shape uses worst_case_out_tokens + tokens_per_expert = jnp.sum(routing_map, axis=0).astype(jnp.int32) + + # Calculate aligned token counts per expert + target_tokens_per_expert = (jnp.ceil(tokens_per_expert / align_size) * align_size).astype( + jnp.int32 + ) + + # Compute pad_offsets: cumulative padding for each expert + # pad_offsets[i] = sum of (target - actual) for experts 0..i-1 + pad_lengths = target_tokens_per_expert - tokens_per_expert + cum_pad = jnp.cumsum(pad_lengths) + pad_offsets = jnp.concatenate([jnp.array([0], dtype=cum_pad.dtype), cum_pad[:-1]]) + + # Use worst_case_out_tokens as the output buffer size (compile-time constant) + # The actual used size is sum(target_tokens_per_expert), which may be smaller. + # Unused positions will be zero-initialized by the kernel. + output, permuted_probs = permute_with_mask_map_and_pad( + inp, + row_id_map, + probs, + pad_offsets, + num_tokens, + num_experts, + worst_case_out_tokens, + hidden_size, + ) + else: + # No padding + pad_offsets = None + target_tokens_per_expert = None + + output, permuted_probs = permute_with_mask_map( + inp, + row_id_map, + probs, + num_tokens, + num_experts, + num_out_tokens, + hidden_size, + ) # Return (primals, residuals) - # Include with_probs flag to know how to handle backward pass - residuals = (row_id_map, num_tokens, num_experts, hidden_size, with_probs) - return (output, permuted_probs, row_id_map), residuals + residuals = (row_id_map, pad_offsets, num_tokens, num_experts, hidden_size, with_probs) + return ( + output, + permuted_probs, + row_id_map, + pad_offsets, + target_tokens_per_expert, + ), residuals def _token_dispatch_bwd_rule( _routing_map: jnp.ndarray, _num_out_tokens: int, - residuals: Tuple[jnp.ndarray, int, int, int, bool], - g: Tuple[jnp.ndarray, Optional[jnp.ndarray], jnp.ndarray], + _worst_case_out_tokens: int, + _align_size: Optional[int], + _use_padding: bool, + residuals: Tuple[jnp.ndarray, Optional[jnp.ndarray], int, int, int, bool], + g: Tuple[ + jnp.ndarray, + Optional[jnp.ndarray], + jnp.ndarray, + Optional[jnp.ndarray], + Optional[jnp.ndarray], + ], ) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]: """Backward pass rule for token_dispatch.""" - row_id_map, num_tokens, num_experts, hidden_size, with_probs = residuals - output_grad, permuted_probs_grad, _ = g # Ignore row_id_map gradient + row_id_map, pad_offsets, num_tokens, num_experts, hidden_size, with_probs = residuals + output_grad, permuted_probs_grad, _, _, _ = g # Ignore row_id_map, pad_offsets, target grads # Backward: unpermute gradients (gather from experts back to tokens) - inp_grad, probs_grad = unpermute_with_mask_map( - output_grad, - row_id_map, - None, # No merging probs - permuted_probs_grad if with_probs else None, - num_tokens, - num_experts, - hidden_size, - ) + if pad_offsets is not None: + inp_grad, probs_grad = unpermute_with_mask_map_and_unpad( + output_grad, + row_id_map, + None, # No merging probs + permuted_probs_grad if with_probs else None, + pad_offsets, + num_tokens, + num_experts, + hidden_size, + ) + else: + inp_grad, probs_grad = unpermute_with_mask_map( + output_grad, + row_id_map, + None, # No merging probs + permuted_probs_grad if with_probs else None, + num_tokens, + num_experts, + hidden_size, + ) return inp_grad, probs_grad if with_probs else None @@ -178,6 +324,7 @@ def token_combine( inp: jnp.ndarray, row_id_map: jnp.ndarray, merging_probs: Optional[jnp.ndarray] = None, + pad_offsets: Optional[jnp.ndarray] = None, ) -> jnp.ndarray: """ Combine tokens from experts back to original token positions. @@ -185,33 +332,42 @@ def token_combine( This is the forward pass of MoE unpermutation. Tokens are gathered from experts and merged (optionally weighted by merging_probs). + Optionally supports fused unpadding when `pad_offsets` is provided (from + token_dispatch with padding enabled). + Parameters ---------- inp : jnp.ndarray - Input tensor from experts of shape [num_out_tokens, hidden_size]. + Input tensor from experts of shape [num_out_tokens, hidden_size] + (or [num_out_tokens_padded, hidden_size] when using unpadding). row_id_map : jnp.ndarray Row ID map from token_dispatch of shape [num_tokens, num_experts * 2 + 1]. merging_probs : Optional[jnp.ndarray] Merging weights of shape [batch, sequence, num_experts] or [num_tokens, num_experts]. If provided, tokens from different experts are weighted-summed. If None, tokens are summed directly. + pad_offsets : Optional[jnp.ndarray] + Per-expert cumulative padding offsets of shape [num_experts] from token_dispatch. + If provided, fused unpadding will be performed. This should be the pad_offsets + returned by token_dispatch when using padding. Returns ------- output : jnp.ndarray Combined output tensor of shape [num_tokens, hidden_size]. """ - return _token_combine(inp, row_id_map, merging_probs) + return _token_combine(inp, row_id_map, merging_probs, pad_offsets) -@partial(jax.custom_vjp, nondiff_argnums=(1,)) +@jax.custom_vjp def _token_combine( inp: jnp.ndarray, row_id_map: jnp.ndarray, merging_probs: Optional[jnp.ndarray], + pad_offsets: Optional[jnp.ndarray], ) -> jnp.ndarray: """Internal token_combine with custom VJP.""" - output, _ = _token_combine_fwd_rule(inp, row_id_map, merging_probs) + output, _ = _token_combine_fwd_rule(inp, row_id_map, merging_probs, pad_offsets) return output @@ -219,7 +375,20 @@ def _token_combine_fwd_rule( inp: jnp.ndarray, row_id_map: jnp.ndarray, merging_probs: Optional[jnp.ndarray], -) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray, Optional[jnp.ndarray], int, int, int, int]]: + pad_offsets: Optional[jnp.ndarray], +) -> Tuple[ + jnp.ndarray, + Tuple[ + jnp.ndarray, + Optional[jnp.ndarray], + jnp.ndarray, + Optional[jnp.ndarray], + int, + int, + int, + int, + ], +]: """Forward pass rule for token_combine.""" # Infer dimensions from row_id_map shape: [num_tokens, num_experts * 2 + 1] num_tokens = row_id_map.shape[0] @@ -227,21 +396,34 @@ def _token_combine_fwd_rule( hidden_size = inp.shape[-1] num_out_tokens = inp.shape[0] - # Call triton extension - output, _ = unpermute_with_mask_map( - inp, - row_id_map, - merging_probs, - None, # No permuted probs to unpermute - num_tokens, - num_experts, - hidden_size, - ) + # Call triton extension with or without unpadding + if pad_offsets is not None: + output, _ = unpermute_with_mask_map_and_unpad( + inp, + row_id_map, + merging_probs, + None, # No permuted probs to unpermute + pad_offsets, + num_tokens, + num_experts, + hidden_size, + ) + else: + output, _ = unpermute_with_mask_map( + inp, + row_id_map, + merging_probs, + None, # No permuted probs to unpermute + num_tokens, + num_experts, + hidden_size, + ) # Return (primal, residuals) # Include inp in residuals for backward with merging_probs residuals = ( row_id_map, + pad_offsets, inp, merging_probs, num_tokens, @@ -253,13 +435,26 @@ def _token_combine_fwd_rule( def _token_combine_bwd_rule( - row_id_map: jnp.ndarray, - residuals: Tuple[jnp.ndarray, jnp.ndarray, Optional[jnp.ndarray], int, int, int, int], + residuals: Tuple[ + jnp.ndarray, + Optional[jnp.ndarray], + jnp.ndarray, + Optional[jnp.ndarray], + int, + int, + int, + int, + ], g: jnp.ndarray, -) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]: - """Backward pass rule for token_combine.""" +) -> Tuple[jnp.ndarray, None, Optional[jnp.ndarray], None]: + """Backward pass rule for token_combine. + + Returns gradients for: (inp, row_id_map, merging_probs, pad_offsets) + row_id_map and pad_offsets are integer arrays, so their gradients are None. + """ ( row_id_map, + pad_offsets, fwd_input, merging_probs, num_tokens, @@ -273,30 +468,63 @@ def _token_combine_bwd_rule( if with_merging_probs: # Use specialized backward kernel that properly scales by merging_probs - inp_grad, merging_probs_grad = unpermute_bwd_with_merging_probs( - output_grad, - row_id_map, - fwd_input, - merging_probs, - num_tokens, - num_experts, - num_out_tokens, - hidden_size, - ) + if pad_offsets is not None: + inp_grad, merging_probs_grad = unpermute_bwd_with_merging_probs_and_unpad( + output_grad, + row_id_map, + fwd_input, + merging_probs, + pad_offsets, + num_tokens, + num_experts, + num_out_tokens, + hidden_size, + ) + # The backward kernel only writes to positions that tokens map to. + # Padded positions may contain uninitialized (NaN) values - replace with zeros. + inp_grad = jnp.where(jnp.isnan(inp_grad), 0.0, inp_grad) + else: + inp_grad, merging_probs_grad = unpermute_bwd_with_merging_probs( + output_grad, + row_id_map, + fwd_input, + merging_probs, + num_tokens, + num_experts, + num_out_tokens, + hidden_size, + ) else: # Simple case: just permute gradients back - inp_grad, _ = permute_with_mask_map( - output_grad, - row_id_map, - None, - num_tokens, - num_experts, - num_out_tokens, - hidden_size, - ) + if pad_offsets is not None: + inp_grad, _ = permute_with_mask_map_and_pad( + output_grad, + row_id_map, + None, + pad_offsets, + num_tokens, + num_experts, + num_out_tokens, + hidden_size, + ) + # The permute kernel only writes to positions that tokens map to. + # Padded positions may contain uninitialized (NaN) values - replace with zeros. + inp_grad = jnp.where(jnp.isnan(inp_grad), 0.0, inp_grad) + else: + inp_grad, _ = permute_with_mask_map( + output_grad, + row_id_map, + None, + num_tokens, + num_experts, + num_out_tokens, + hidden_size, + ) merging_probs_grad = None - return inp_grad, merging_probs_grad + # Return gradients for: inp, row_id_map, merging_probs, pad_offsets + # row_id_map and pad_offsets are integer arrays, so their gradients are None + return inp_grad, None, merging_probs_grad, None _token_combine.defvjp(_token_combine_fwd_rule, _token_combine_bwd_rule) diff --git a/transformer_engine/jax/triton_extensions/permutation.py b/transformer_engine/jax/triton_extensions/permutation.py index 4f59f65a87e..01b15c5adc8 100644 --- a/transformer_engine/jax/triton_extensions/permutation.py +++ b/transformer_engine/jax/triton_extensions/permutation.py @@ -27,8 +27,11 @@ __all__ = [ "make_row_id_map", "permute_with_mask_map", + "permute_with_mask_map_and_pad", "unpermute_with_mask_map", + "unpermute_with_mask_map_and_unpad", "unpermute_bwd_with_merging_probs", + "unpermute_bwd_with_merging_probs_and_unpad", "make_chunk_sort_map", "sort_chunks_by_map", ] @@ -243,20 +246,21 @@ def lowering(ctx, row_id_map, *, num_tokens, num_experts): class PermuteWithMaskMapPrimitive(BasePrimitive): """ - Permute the input tensor based on the row_id_map. + Permute the input tensor based on the row_id_map, optionally with fused padding. """ name = "te_permute_with_mask_map_triton" multiple_results = True - # scale and permuted_scale are dummy inputs (not used when PERMUTE_SCALE=False) - # but they need to be in the signature for the kernel call + # scale, permuted_scale are dummy inputs (not used when PERMUTE_SCALE=False) + # pad_offsets can be shape (0,) when not doing padding, or (num_experts,) when padding impl_static_args = ( - 5, 6, 7, 8, 9, - ) # num_tokens, num_experts, num_out_tokens, hidden_size, with_probs + 10, + 11, + ) # num_tokens, num_experts, num_out_tokens, hidden_size, with_probs, with_pad inner_primitive = None outer_primitive = None @@ -267,16 +271,18 @@ def abstract( probs_aval, scale_aval, # dummy, same shape as inp permuted_scale_aval, # dummy, same shape as inp + pad_offsets_aval, *, num_tokens, num_experts, num_out_tokens, hidden_size, with_probs, + with_pad, ): """Shape/dtype inference for permute.""" - del row_id_map_aval, scale_aval, permuted_scale_aval - del num_tokens, num_experts + del row_id_map_aval, scale_aval, permuted_scale_aval, pad_offsets_aval + del num_tokens, num_experts, with_pad output_shape = (num_out_tokens, hidden_size) output_aval = jax.core.ShapedArray(output_shape, inp_aval.dtype) @@ -295,11 +301,13 @@ def impl( probs, scale, permuted_scale, + pad_offsets, num_tokens, num_experts, num_out_tokens, hidden_size, with_probs, + with_pad, ): """Forward to inner primitive.""" assert PermuteWithMaskMapPrimitive.inner_primitive is not None @@ -309,11 +317,13 @@ def impl( probs, scale, permuted_scale, + pad_offsets, num_tokens=num_tokens, num_experts=num_experts, num_out_tokens=num_out_tokens, hidden_size=hidden_size, with_probs=with_probs, + with_pad=with_pad, ) @staticmethod @@ -324,12 +334,14 @@ def lowering( probs, scale, permuted_scale, + pad_offsets, *, num_tokens, num_experts, num_out_tokens, hidden_size, with_probs, + with_pad, ): """MLIR lowering using triton_call_lowering.""" del num_out_tokens @@ -367,6 +379,7 @@ def lowering( probs, scale, permuted_scale, + pad_offsets, grid=grid, constexprs={ "scale_hidden_dim": 0, @@ -387,6 +400,7 @@ def lowering( "hidden_size": hidden_size, "PERMUTE_PROBS": with_probs, "PERMUTE_SCALE": False, + "FUSION_PAD": with_pad, "BLOCK_SIZE": block_size, }, ) @@ -403,11 +417,11 @@ class UnpermuteWithMaskMapPrimitive(BasePrimitive): name = "te_unpermute_with_mask_map_triton" multiple_results = True impl_static_args = ( - 4, 5, 6, 7, 8, + 9, ) # num_tokens, num_experts, hidden_size, with_merging_probs, with_probs inner_primitive = None outer_primitive = None @@ -418,6 +432,7 @@ def abstract( row_id_map_aval, merging_probs_aval, permuted_probs_aval, + pad_offsets_aval, # dummy, not used when FUSION_UNPAD=False *, num_tokens, num_experts, @@ -426,7 +441,7 @@ def abstract( with_probs, ): """Shape/dtype inference for unpermute.""" - del row_id_map_aval, merging_probs_aval, with_merging_probs + del row_id_map_aval, merging_probs_aval, with_merging_probs, pad_offsets_aval output_shape = (num_tokens, hidden_size) output_aval = jax.core.ShapedArray(output_shape, inp_aval.dtype) @@ -447,6 +462,7 @@ def impl( row_id_map, merging_probs, permuted_probs, + pad_offsets, num_tokens, num_experts, hidden_size, @@ -460,6 +476,7 @@ def impl( row_id_map, merging_probs, permuted_probs, + pad_offsets, num_tokens=num_tokens, num_experts=num_experts, hidden_size=hidden_size, @@ -474,6 +491,7 @@ def lowering( row_id_map, merging_probs, permuted_probs, + pad_offsets, *, num_tokens, num_experts, @@ -505,6 +523,7 @@ def lowering( block_size = _get_min_block_size(_unpermute_kernel) grid = (num_tokens, triton.cdiv(hidden_size, block_size)) + # Pass all 5 inputs including pad_offsets (even though FUSION_UNPAD=False) return triton_call_lowering( ctx, _unpermute_kernel, @@ -512,6 +531,7 @@ def lowering( row_id_map, merging_probs, permuted_probs, + pad_offsets, grid=grid, constexprs={ "stride_row_id_map_token": row_id_stride_token, @@ -530,6 +550,7 @@ def lowering( "PROBS_LOAD_WIDTH": triton.next_power_of_2(num_experts), "WITH_MERGING_PROBS": with_merging_probs, "PERMUTE_PROBS": with_probs, + "FUSION_UNPAD": False, "BLOCK_SIZE": block_size, }, ) @@ -538,6 +559,155 @@ def lowering( register_primitive(UnpermuteWithMaskMapPrimitive) +class UnpermuteWithMaskMapAndUnpadPrimitive(BasePrimitive): + """ + Unpermute the input tensor based on the row_id_map with fused unpadding. + """ + + name = "te_unpermute_with_mask_map_and_unpad_triton" + multiple_results = True + impl_static_args = ( + 5, + 6, + 7, + 8, + 9, + ) # num_tokens, num_experts, hidden_size, with_merging_probs, with_probs + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + inp_aval, + row_id_map_aval, + merging_probs_aval, + permuted_probs_aval, + pad_offsets_aval, + *, + num_tokens, + num_experts, + hidden_size, + with_merging_probs, + with_probs, + ): + """Shape/dtype inference for unpermute with unpadding.""" + del row_id_map_aval, merging_probs_aval, with_merging_probs, pad_offsets_aval + + output_shape = (num_tokens, hidden_size) + output_aval = jax.core.ShapedArray(output_shape, inp_aval.dtype) + + if with_probs: + unpermuted_probs_shape = (num_tokens, num_experts) + unpermuted_probs_aval = jax.core.ShapedArray( + unpermuted_probs_shape, permuted_probs_aval.dtype + ) + else: + unpermuted_probs_aval = jax.core.ShapedArray((0,), inp_aval.dtype) + + return output_aval, unpermuted_probs_aval + + @staticmethod + def impl( + inp, + row_id_map, + merging_probs, + permuted_probs, + pad_offsets, + num_tokens, + num_experts, + hidden_size, + with_merging_probs, + with_probs, + ): + """Forward to inner primitive.""" + assert UnpermuteWithMaskMapAndUnpadPrimitive.inner_primitive is not None + return UnpermuteWithMaskMapAndUnpadPrimitive.inner_primitive.bind( + inp, + row_id_map, + merging_probs, + permuted_probs, + pad_offsets, + num_tokens=num_tokens, + num_experts=num_experts, + hidden_size=hidden_size, + with_merging_probs=with_merging_probs, + with_probs=with_probs, + ) + + @staticmethod + def lowering( + ctx, + inp, + row_id_map, + merging_probs, + permuted_probs, + pad_offsets, + *, + num_tokens, + num_experts, + hidden_size, + with_merging_probs, + with_probs, + ): + """MLIR lowering using triton_call_lowering.""" + # Compute strides + inp_stride_token = hidden_size + inp_stride_hidden = 1 + output_stride_token = hidden_size + output_stride_hidden = 1 + row_id_stride_token = num_experts * 2 + 1 + row_id_stride_expert = 1 + + if with_merging_probs: + merging_probs_stride_token = num_experts + merging_probs_stride_expert = 1 + else: + merging_probs_stride_token = 0 + merging_probs_stride_expert = 0 + + permuted_probs_stride_token = 1 + unpermuted_probs_stride_token = num_experts + unpermuted_probs_stride_expert = 1 + + # Grid - use minimum BLOCK_SIZE from autotune configs + block_size = _get_min_block_size(_unpermute_kernel) + grid = (num_tokens, triton.cdiv(hidden_size, block_size)) + + return triton_call_lowering( + ctx, + _unpermute_kernel, + inp, + row_id_map, + merging_probs, + permuted_probs, + pad_offsets, + grid=grid, + constexprs={ + "stride_row_id_map_token": row_id_stride_token, + "stride_row_id_map_expert": row_id_stride_expert, + "stride_input_token": inp_stride_token, + "stride_input_hidden": inp_stride_hidden, + "stride_output_token": output_stride_token, + "stride_output_hidden": output_stride_hidden, + "stride_merging_probs_token": merging_probs_stride_token, + "stride_merging_probs_expert": merging_probs_stride_expert, + "stride_permuted_probs_token": permuted_probs_stride_token, + "stride_unpermuted_probs_token": unpermuted_probs_stride_token, + "stride_unpermuted_probs_expert": unpermuted_probs_stride_expert, + "num_experts": num_experts, + "hidden_size": hidden_size, + "PROBS_LOAD_WIDTH": triton.next_power_of_2(num_experts), + "WITH_MERGING_PROBS": with_merging_probs, + "PERMUTE_PROBS": with_probs, + "FUSION_UNPAD": True, + "BLOCK_SIZE": block_size, + }, + ) + + +register_primitive(UnpermuteWithMaskMapAndUnpadPrimitive) + + class UnpermuteBwdWithMergingProbsPrimitive(BasePrimitive): """ Backward pass for unpermute with merging probabilities. @@ -547,7 +717,7 @@ class UnpermuteBwdWithMergingProbsPrimitive(BasePrimitive): name = "te_unpermute_bwd_with_merging_probs_triton" multiple_results = True - impl_static_args = (4, 5, 6, 7) # num_tokens, num_experts, num_out_tokens, hidden_size + impl_static_args = (5, 6, 7, 8) # num_tokens, num_experts, num_out_tokens, hidden_size inner_primitive = None outer_primitive = None @@ -557,6 +727,7 @@ def abstract( fwd_input_aval, merging_probs_aval, row_id_map_aval, + pad_offsets_aval, # dummy, not used when FUSION_UNPAD=False *, num_tokens, num_experts, @@ -564,7 +735,7 @@ def abstract( hidden_size, ): """Shape/dtype inference for unpermute backward with merging probs.""" - del fwd_input_aval, row_id_map_aval + del fwd_input_aval, row_id_map_aval, pad_offsets_aval # fwd_input_grad has same shape as fwd_input fwd_input_grad_shape = (num_out_tokens, hidden_size) @@ -584,6 +755,7 @@ def impl( fwd_input, merging_probs, row_id_map, + pad_offsets, num_tokens, num_experts, num_out_tokens, @@ -596,6 +768,7 @@ def impl( fwd_input, merging_probs, row_id_map, + pad_offsets, num_tokens=num_tokens, num_experts=num_experts, num_out_tokens=num_out_tokens, @@ -609,6 +782,7 @@ def lowering( fwd_input, merging_probs, row_id_map, + pad_offsets, *, num_tokens, num_experts, @@ -638,7 +812,7 @@ def lowering( # Get min block size from autotune configs for consistency block_size = _get_min_block_size(_unpermute_bwd_with_merging_probs_kernel) - # Pass inputs in kernel argument order: fwd_output_grad, fwd_input, merging_probs, row_id_map + # Pass all 5 inputs including pad_offsets (even though FUSION_UNPAD=False) return triton_call_lowering( ctx, _unpermute_bwd_with_merging_probs_kernel, @@ -646,6 +820,7 @@ def lowering( fwd_input, merging_probs, row_id_map, + pad_offsets, grid=grid, constexprs={ "stride_row_id_map_token": row_id_stride_token, @@ -663,6 +838,7 @@ def lowering( "num_experts": num_experts, "hidden_size": hidden_size, "PROBS_LOAD_WIDTH": triton.next_power_of_2(num_experts), + "FUSION_UNPAD": False, "BLOCK_SIZE": block_size, }, ) @@ -671,6 +847,145 @@ def lowering( register_primitive(UnpermuteBwdWithMergingProbsPrimitive) +class UnpermuteBwdWithMergingProbsAndUnpadPrimitive(BasePrimitive): + """ + Backward pass for unpermute with merging probabilities and fused unpadding. + + This kernel computes gradients for both the input and merging_probs, + while handling padded outputs. + """ + + name = "te_unpermute_bwd_with_merging_probs_and_unpad_triton" + multiple_results = True + impl_static_args = (5, 6, 7, 8) # num_tokens, num_experts, num_out_tokens, hidden_size + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + fwd_output_grad_aval, + fwd_input_aval, + merging_probs_aval, + row_id_map_aval, + pad_offsets_aval, + *, + num_tokens, + num_experts, + num_out_tokens, + hidden_size, + ): + """Shape/dtype inference for unpermute backward with merging probs and unpadding.""" + del fwd_input_aval, row_id_map_aval, pad_offsets_aval + + # fwd_input_grad has same shape as fwd_input + fwd_input_grad_shape = (num_out_tokens, hidden_size) + fwd_input_grad_aval = jax.core.ShapedArray(fwd_input_grad_shape, fwd_output_grad_aval.dtype) + + # merging_probs_grad has same shape as merging_probs + merging_probs_grad_shape = (num_tokens, num_experts) + merging_probs_grad_aval = jax.core.ShapedArray( + merging_probs_grad_shape, merging_probs_aval.dtype + ) + + return fwd_input_grad_aval, merging_probs_grad_aval + + @staticmethod + def impl( + fwd_output_grad, + fwd_input, + merging_probs, + row_id_map, + pad_offsets, + num_tokens, + num_experts, + num_out_tokens, + hidden_size, + ): + """Forward to inner primitive.""" + assert UnpermuteBwdWithMergingProbsAndUnpadPrimitive.inner_primitive is not None + return UnpermuteBwdWithMergingProbsAndUnpadPrimitive.inner_primitive.bind( + fwd_output_grad, + fwd_input, + merging_probs, + row_id_map, + pad_offsets, + num_tokens=num_tokens, + num_experts=num_experts, + num_out_tokens=num_out_tokens, + hidden_size=hidden_size, + ) + + @staticmethod + def lowering( + ctx, + fwd_output_grad, + fwd_input, + merging_probs, + row_id_map, + pad_offsets, + *, + num_tokens, + num_experts, + num_out_tokens, + hidden_size, + ): + """MLIR lowering using triton_call_lowering.""" + del num_out_tokens + + # Compute strides + row_id_stride_token = num_experts * 2 + 1 + row_id_stride_expert = 1 + fwd_output_grad_stride_token = hidden_size + fwd_output_grad_stride_hidden = 1 + fwd_input_grad_stride_token = hidden_size + fwd_input_grad_stride_hidden = 1 + fwd_input_stride_token = hidden_size + fwd_input_stride_hidden = 1 + merging_probs_stride_token = num_experts + merging_probs_stride_expert = 1 + merging_probs_grad_stride_token = num_experts + merging_probs_grad_stride_expert = 1 + + # Grid - one program per token + grid = (num_tokens,) + + # Get min block size from autotune configs for consistency + block_size = _get_min_block_size(_unpermute_bwd_with_merging_probs_kernel) + + return triton_call_lowering( + ctx, + _unpermute_bwd_with_merging_probs_kernel, + fwd_output_grad, + fwd_input, + merging_probs, + row_id_map, + pad_offsets, + grid=grid, + constexprs={ + "stride_row_id_map_token": row_id_stride_token, + "stride_row_id_map_expert": row_id_stride_expert, + "stride_fwd_output_grad_token": fwd_output_grad_stride_token, + "stride_fwd_output_grad_hidden": fwd_output_grad_stride_hidden, + "stride_fwd_input_grad_token": fwd_input_grad_stride_token, + "stride_fwd_input_grad_hidden": fwd_input_grad_stride_hidden, + "stride_fwd_input_token": fwd_input_stride_token, + "stride_fwd_input_hidden": fwd_input_stride_hidden, + "stride_merging_probs_token": merging_probs_stride_token, + "stride_merging_probs_expert": merging_probs_stride_expert, + "stride_merging_probs_grad_token": merging_probs_grad_stride_token, + "stride_merging_probs_grad_expert": merging_probs_grad_stride_expert, + "num_experts": num_experts, + "hidden_size": hidden_size, + "PROBS_LOAD_WIDTH": triton.next_power_of_2(num_experts), + "FUSION_UNPAD": True, + "BLOCK_SIZE": block_size, + }, + ) + + +register_primitive(UnpermuteBwdWithMergingProbsAndUnpadPrimitive) + + def unpermute_bwd_with_merging_probs( fwd_output_grad: jnp.ndarray, row_id_map: jnp.ndarray, @@ -712,12 +1027,73 @@ def unpermute_bwd_with_merging_probs( merging_probs_grad : jnp.ndarray Gradient w.r.t. merging_probs of shape `[num_tokens, num_experts]`. """ - # Pass arguments in kernel order: fwd_output_grad, fwd_input, merging_probs, row_id_map + # Create dummy pad_offsets (not used when FUSION_UNPAD=False, but required by kernel signature) + dummy_pad_offsets = jnp.zeros((0,), dtype=jnp.int32) + # Pass arguments in kernel order: fwd_output_grad, fwd_input, merging_probs, row_id_map, pad_offsets return UnpermuteBwdWithMergingProbsPrimitive.outer_primitive.bind( fwd_output_grad, fwd_input, merging_probs, row_id_map, + dummy_pad_offsets, + num_tokens=num_tokens, + num_experts=num_experts, + num_out_tokens=num_out_tokens, + hidden_size=hidden_size, + ) + + +def unpermute_bwd_with_merging_probs_and_unpad( + fwd_output_grad: jnp.ndarray, + row_id_map: jnp.ndarray, + fwd_input: jnp.ndarray, + merging_probs: jnp.ndarray, + pad_offsets: jnp.ndarray, + num_tokens: int, + num_experts: int, + num_out_tokens: int, + hidden_size: int, +) -> Tuple[jnp.ndarray, jnp.ndarray]: + """ + Backward pass for unpermute with merging probabilities and fused unpadding. + + This computes gradients for both the input tensor and merging_probs, + while handling padded outputs. + + Parameters + ---------- + fwd_output_grad : jnp.ndarray + Gradient of the forward output of shape `[num_tokens, hidden_size]`. + row_id_map : jnp.ndarray + The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`. + fwd_input : jnp.ndarray + The input tensor from the forward pass of shape `[num_out_tokens, hidden_size]`. + merging_probs : jnp.ndarray + The merging probabilities of shape `[num_tokens, num_experts]`. + pad_offsets : jnp.ndarray + Per-expert cumulative padding offsets of shape `[num_experts]`. + num_tokens : int + Number of tokens in the unpermuted tensor. + num_experts : int + Number of experts. + num_out_tokens : int + Number of tokens in the permuted tensor (including padding). + hidden_size : int + Hidden size. + + Returns + ------- + fwd_input_grad : jnp.ndarray + Gradient w.r.t. the input tensor of shape `[num_out_tokens, hidden_size]`. + merging_probs_grad : jnp.ndarray + Gradient w.r.t. merging_probs of shape `[num_tokens, num_experts]`. + """ + return UnpermuteBwdWithMergingProbsAndUnpadPrimitive.outer_primitive.bind( + fwd_output_grad, + fwd_input, + merging_probs, + row_id_map, + pad_offsets, num_tokens=num_tokens, num_experts=num_experts, num_out_tokens=num_out_tokens, @@ -957,6 +1333,78 @@ def permute_with_mask_map( """ with_probs = probs is not None + # Handle None probs by creating dummy tensor + if not with_probs: + probs = jnp.zeros((0,), dtype=inp.dtype) + + # Create dummy scale tensors (not used when PERMUTE_SCALE=False, but required by kernel signature) + dummy_scale = inp + dummy_permuted_scale = inp + # Create dummy pad_offsets (not used when FUSION_PAD=False, but required by kernel signature) + dummy_pad_offsets = jnp.zeros((0,), dtype=jnp.int32) + + output, permuted_probs = PermuteWithMaskMapPrimitive.outer_primitive.bind( + inp, + row_id_map, + probs, + dummy_scale, + dummy_permuted_scale, + dummy_pad_offsets, + num_tokens=num_tokens, + num_experts=num_experts, + num_out_tokens=num_out_tokens, + hidden_size=hidden_size, + with_probs=with_probs, + with_pad=False, + ) + + if not with_probs: + permuted_probs = None + + return output, permuted_probs + + +def permute_with_mask_map_and_pad( + inp: jnp.ndarray, + row_id_map: jnp.ndarray, + probs: Optional[jnp.ndarray], + pad_offsets: jnp.ndarray, + num_tokens: int, + num_experts: int, + num_out_tokens: int, + hidden_size: int, +) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]: + """ + Permute the input tensor based on the row_id_map with fused padding. + + Parameters + ---------- + inp : jnp.ndarray + Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied. + row_id_map : jnp.ndarray + The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`. + probs : Optional[jnp.ndarray] + The probabilities of the input tensor. If it is not None, it will be permuted. + pad_offsets : jnp.ndarray + Per-expert cumulative padding offsets of shape `[num_experts]`. + num_tokens : int + Number of tokens in the input tensor. + num_experts : int + Number of experts in the input tensor. + num_out_tokens : int + Number of tokens in the permuted tensor (including padding). + hidden_size : int + Hidden size of the input tensor. + + Returns + ------- + output : jnp.ndarray + Permuted and padded output tensor of shape `[num_out_tokens, hidden_size]`. + permuted_probs : Optional[jnp.ndarray] + Permuted probabilities if probs was provided, None otherwise. + """ + with_probs = probs is not None + # Handle None probs by creating dummy tensor if not with_probs: probs = jnp.zeros((0,), dtype=inp.dtype) @@ -971,11 +1419,13 @@ def permute_with_mask_map( probs, dummy_scale, dummy_permuted_scale, + pad_offsets, num_tokens=num_tokens, num_experts=num_experts, num_out_tokens=num_out_tokens, hidden_size=hidden_size, with_probs=with_probs, + with_pad=True, ) if not with_probs: @@ -1029,12 +1479,83 @@ def unpermute_with_mask_map( merging_probs = jnp.zeros((0,), dtype=inp.dtype) if not with_probs: permuted_probs = jnp.zeros((0,), dtype=inp.dtype) + # Create dummy pad_offsets (not used when FUSION_UNPAD=False, but required by kernel signature) + dummy_pad_offsets = jnp.zeros((0,), dtype=jnp.int32) output, unpermuted_probs = UnpermuteWithMaskMapPrimitive.outer_primitive.bind( inp, row_id_map, merging_probs, permuted_probs, + dummy_pad_offsets, + num_tokens=num_tokens, + num_experts=num_experts, + hidden_size=hidden_size, + with_merging_probs=with_merging_probs, + with_probs=with_probs, + ) + + if not with_probs: + unpermuted_probs = None + + return output, unpermuted_probs + + +def unpermute_with_mask_map_and_unpad( + inp: jnp.ndarray, + row_id_map: jnp.ndarray, + merging_probs: Optional[jnp.ndarray], + permuted_probs: Optional[jnp.ndarray], + pad_offsets: jnp.ndarray, + num_tokens: int, + num_experts: int, + hidden_size: int, +) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]: + """ + Unpermute the input tensor based on the row_id_map with fused unpadding. + + Parameters + ---------- + inp : jnp.ndarray + Input tensor of shape `[num_out_tokens, hidden_size]` (including padding). + row_id_map : jnp.ndarray + The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`. + merging_probs : Optional[jnp.ndarray] + The merging probabilities of the input tensor. If it is not None, it will be used as weights + to reduce the unpermuted tokens. + permuted_probs : Optional[jnp.ndarray] + The permuted probabilities of the input tensor. If it is not None, it will be unpermuted. + pad_offsets : jnp.ndarray + Per-expert cumulative padding offsets of shape `[num_experts]`. + num_tokens : int + Number of tokens in the unpermuted tensor. + num_experts : int + Number of experts. + hidden_size : int + Hidden size of the tensor. + + Returns + ------- + output : jnp.ndarray + Unpermuted output tensor of shape `[num_tokens, hidden_size]`. + unpermuted_probs : Optional[jnp.ndarray] + Unpermuted probabilities if permuted_probs was provided, None otherwise. + """ + with_merging_probs = merging_probs is not None + with_probs = permuted_probs is not None + + # Handle None inputs by creating dummy tensors + if not with_merging_probs: + merging_probs = jnp.zeros((0,), dtype=inp.dtype) + if not with_probs: + permuted_probs = jnp.zeros((0,), dtype=inp.dtype) + + output, unpermuted_probs = UnpermuteWithMaskMapAndUnpadPrimitive.outer_primitive.bind( + inp, + row_id_map, + merging_probs, + permuted_probs, + pad_offsets, num_tokens=num_tokens, num_experts=num_experts, hidden_size=hidden_size, diff --git a/transformer_engine/jax/triton_extensions/utils.py b/transformer_engine/jax/triton_extensions/utils.py index 12d6a9e3de4..41ce15303c7 100644 --- a/transformer_engine/jax/triton_extensions/utils.py +++ b/transformer_engine/jax/triton_extensions/utils.py @@ -142,17 +142,31 @@ def compile_triton( ) # Create kernel object for JAX - kernel = gpu_triton.TritonKernel( - compiled.name, - num_warps, - compiled.metadata.shared, - compiled.asm["ptx"], - "", # ttir - compute_capability, - 1, - 1, - 1, # cluster_dims - ) + # From jax/jaxlib/gpu/triton_kernels.cc: + from packaging import version + + if version.parse(jax.__version__) >= version.parse("0.8.2"): + kernel = gpu_triton.TritonKernel( + compiled.name, # arg0: kernel_name (str) + num_warps, # arg1: num_warps (int) + num_ctas, # arg2: num_ctas (int) + compiled.metadata.shared, # arg3: shared_mem_bytes (int) + compiled.asm["ptx"], # arg4: ptx (str) + "", # arg5: ttir (str) - empty + compute_capability, # arg6: compute_capability (int) + ) + else: + kernel = gpu_triton.TritonKernel( + compiled.name, + num_warps, + compiled.metadata.shared, + compiled.asm["ptx"], + "", # ttir + compute_capability, + 1, + 1, + 1, + ) _TRITON_KERNEL_CACHE[cache_key] = kernel return kernel diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 5341af3d742..9f4a9678eb9 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -34,6 +34,7 @@ from transformer_engine.pytorch.permutation import ( moe_permute, moe_permute_with_probs, + moe_permute_and_pad_with_probs, moe_unpermute, moe_sort_chunks_by_index, moe_sort_chunks_by_index_with_probs, diff --git a/transformer_engine/pytorch/permutation.py b/transformer_engine/pytorch/permutation.py index 9fce9cefcf7..d15814585ee 100644 --- a/transformer_engine/pytorch/permutation.py +++ b/transformer_engine/pytorch/permutation.py @@ -2,7 +2,7 @@ # # See LICENSE for license information. -"""MoE Permutaion API""" +"""MoE Permutation API""" import warnings from typing import Optional, Tuple import torch @@ -191,6 +191,7 @@ def forward( routing_map: torch.Tensor, num_out_tokens: int, probs: torch.Tensor, + pad_offsets: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # pylint: disable=missing-function-docstring if not inp.numel(): @@ -201,6 +202,8 @@ def forward( assert routing_map.is_cuda, "TransformerEngine needs CUDA." if probs is not None: assert probs.is_cuda, "TransformerEngine needs CUDA." + if pad_offsets is not None: + assert pad_offsets.is_cuda, "TransformerEngine needs CUDA." assert inp.size(0) == routing_map.size(0), "Permute not possible" num_tokens, hidden_size = inp.size() @@ -250,6 +253,7 @@ def forward( row_id_map, probs, fp8_scale, + pad_offsets, num_tokens, num_experts, num_out_tokens, @@ -292,7 +296,7 @@ def forward( requires_grad=output.requires_grad, ) - ctx.save_for_backward(row_id_map) + ctx.save_for_backward(row_id_map, pad_offsets) ctx.num_experts = num_experts ctx.num_tokens = num_tokens ctx.hidden_size = hidden_size @@ -307,12 +311,12 @@ def backward( ) -> Tuple[torch.Tensor, ...]: # pylint: disable=missing-function-docstring if not permuted_act_grad.numel(): - return permuted_act_grad, None, None, ctx.probs + return permuted_act_grad, None, None, ctx.probs, None act_grad = None probs_grad = None if ctx.needs_input_grad[0]: - (row_id_map,) = ctx.saved_tensors + row_id_map, pad_offsets = ctx.saved_tensors assert not isinstance( permuted_act_grad, QuantizedTensor ), "The backward of moe_permute does not support FP8." @@ -321,13 +325,14 @@ def backward( row_id_map, None, permuted_probs_grad, + pad_offsets, ctx.num_tokens, ctx.num_experts, ctx.hidden_size, ) if not ctx.needs_input_grad[3]: probs_grad = None - return act_grad, None, None, probs_grad + return act_grad, None, None, probs_grad, None class _moe_unpermute_mask_map(torch.autograd.Function): @@ -340,6 +345,7 @@ def forward( row_id_map: torch.Tensor, merging_probs: Optional[torch.Tensor], restore_shape: Optional[torch.Size], + pad_offsets: Optional[torch.Tensor], ) -> torch.Tensor: # pylint: disable=missing-function-docstring if not inp.numel(): @@ -358,6 +364,8 @@ def forward( # Device check assert inp.is_cuda, "TransformerEngine needs CUDA." assert row_id_map.is_cuda, "TransformerEngine needs CUDA." + if pad_offsets is not None: + assert pad_offsets.is_cuda, "TransformerEngine needs CUDA." assert not isinstance( inp, QuantizedTensor @@ -367,15 +375,16 @@ def forward( row_id_map, merging_probs, None, + pad_offsets, num_tokens, num_experts, hidden_size, ) if with_probs: - ctx.save_for_backward(inp, row_id_map, merging_probs) + ctx.save_for_backward(inp, row_id_map, merging_probs, pad_offsets) else: - ctx.save_for_backward(row_id_map) + ctx.save_for_backward(row_id_map, pad_offsets) ctx.num_experts = num_experts ctx.num_tokens = num_tokens ctx.num_permuted_tokens = inp.size(0) @@ -387,15 +396,15 @@ def forward( def backward(ctx, unpermuted_act_grad): # pylint: disable=missing-function-docstring if not unpermuted_act_grad.numel(): - return unpermuted_act_grad, None, ctx.merging_probs, None + return unpermuted_act_grad, None, ctx.merging_probs, None, None act_grad = None probs_grad = None if ctx.needs_input_grad[0]: if ctx.with_probs: - fwd_input, row_id_map, merging_probs = ctx.saved_tensors + fwd_input, row_id_map, merging_probs, pad_offsets = ctx.saved_tensors else: - (row_id_map,) = ctx.saved_tensors + row_id_map, pad_offsets = ctx.saved_tensors fp8 = isinstance(unpermuted_act_grad, QuantizedTensor) per_tensor_recipe = isinstance(unpermuted_act_grad, Float8Tensor) @@ -441,6 +450,7 @@ def backward(ctx, unpermuted_act_grad): row_id_map, fwd_input, merging_probs, + pad_offsets, ctx.num_tokens, ctx.num_experts, ctx.num_permuted_tokens, @@ -453,6 +463,7 @@ def backward(ctx, unpermuted_act_grad): row_id_map, None, fp8_scale, + pad_offsets, ctx.num_tokens, ctx.num_experts, ctx.num_permuted_tokens, @@ -497,7 +508,7 @@ def backward(ctx, unpermuted_act_grad): if not ctx.needs_input_grad[2]: probs_grad = None - return act_grad, None, probs_grad, None + return act_grad, None, probs_grad, None, None def moe_permute( @@ -537,7 +548,9 @@ def moe_permute( if map_type == "index": return _moe_permute_index_map.apply(inp, routing_map, num_out_tokens, max_token_num) if map_type == "mask": - output, row_id_map, _ = _moe_permute_mask_map.apply(inp, routing_map, num_out_tokens, None) + output, row_id_map, _ = _moe_permute_mask_map.apply( + inp, routing_map, num_out_tokens, None, None + ) return output, row_id_map raise ValueError("map_type should be one of 'mask' or 'index'") @@ -570,11 +583,67 @@ def moe_permute_with_probs( By default, set to '-1', meaning no tokens are dropped. """ output, row_id_map, permuted_probs = _moe_permute_mask_map.apply( - inp, routing_map, num_out_tokens, probs + inp, routing_map, num_out_tokens, probs, None ) return output, permuted_probs, row_id_map +def moe_permute_and_pad_with_probs( + inp: torch.Tensor, + probs: torch.Tensor, + routing_map: torch.Tensor, + tokens_per_expert: torch.Tensor, + align_size: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor]: + """ + Permute the tokens and probs based on the routing_map. + Token with the same index will be grouped together. + Tokens with the same designated expert will be grouped together. + The routing_map indicates which experts were selected by each token. + + Parameters + ---------- + inp: torch.Tensor + Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied. + probs: torch.Tensor + The tensor of probabilities corresponding to the permuted tokens and is + of shape [num_tokens, num_experts]. It will be permuted with the tokens + according to the routing_map. + routing_map: torch.Tensor + The token to expert mapping tensor of shape [num_tokens, num_experts] and dtype 'int32'. + The values in it: 1 means the token is routed to this expert and 0 means not. + tokens_per_expert : torch.Tensor + Tensor of shape `[num_experts]` containing actual token counts per expert. + align_size : int + the alignment size for the input tensor. + """ + assert ( + tokens_per_expert is not None + ), "tokens_per_expert must be provided to the fused permute padding function." + assert align_size > 0, f"align_size must be positive, got {align_size}" + + # Ensure tokens_per_expert is on the same device as input to avoid device transfers + if tokens_per_expert.device != inp.device: + tokens_per_expert = tokens_per_expert.to(inp.device) + + # Calculate aligned token counts per expert + target_tokens_per_expert = (torch.ceil(tokens_per_expert / align_size) * align_size).long() + + if torch.equal(tokens_per_expert, target_tokens_per_expert): + pad_offsets = None + else: + pad_lengths = target_tokens_per_expert - tokens_per_expert + cum_pad = torch.cumsum(pad_lengths, dim=0) + pad_offsets = torch.cat( + [torch.zeros(1, dtype=cum_pad.dtype, device=inp.device), cum_pad[:-1]] + ) + + output, row_id_map, permuted_probs = _moe_permute_mask_map.apply( + inp, routing_map, target_tokens_per_expert.sum().item(), probs, pad_offsets + ) + return output, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert + + def moe_unpermute( inp: torch.Tensor, row_id_map: torch.Tensor, @@ -582,6 +651,7 @@ def moe_unpermute( restore_shape: Optional[torch.Size] = None, map_type: str = "mask", probs: Optional[torch.Tensor] = None, + pad_offsets: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Unpermute a tensor with permuted tokens, and optionally merge the tokens with their @@ -605,6 +675,10 @@ def moe_unpermute( Options are: 'mask', 'index'. probs : torch.Tensor, default = None Renamed to merging_probs. Keep for backward compatibility. + pad_offsets : torch.Tensor, default = None + Tensor of per-expert cumulative padding offsets used to remove padding added + during permutation. This is the fourth output of `moe_permute_and_pad_with_probs` + and is required when unpermuting padded outputs. """ if probs is not None: if merging_probs is not None: @@ -616,7 +690,9 @@ def moe_unpermute( if map_type == "index": return _moe_unpermute_index_map.apply(inp, row_id_map, merging_probs) if map_type == "mask": - return _moe_unpermute_mask_map.apply(inp, row_id_map, merging_probs, restore_shape) + return _moe_unpermute_mask_map.apply( + inp, row_id_map, merging_probs, restore_shape, pad_offsets + ) raise ValueError("map_type should be one of 'mask' or 'index'") diff --git a/transformer_engine/pytorch/triton/permutation.py b/transformer_engine/pytorch/triton/permutation.py index 8f953e9c31d..27662e1b283 100644 --- a/transformer_engine/pytorch/triton/permutation.py +++ b/transformer_engine/pytorch/triton/permutation.py @@ -123,6 +123,7 @@ def permute_with_mask_map( row_id_map: torch.Tensor, probs: torch.Tensor, scale: torch.Tensor, + pad_offsets: torch.Tensor, num_tokens: int, num_experts: int, num_out_tokens: int, @@ -142,6 +143,9 @@ def permute_with_mask_map( The probabilities of the input tensor. If it is not None, it will be permuted. scale : torch.Tensor The scale of the input tensor. If it is not None, it will be permuted. + pad_offsets : torch.Tensor + Per-expert padding offsets of shape `[num_experts]` for FP8 fused padding. + If it is not None, it will be allocated output buffers with aligned sizes. num_tokens : int Number of tokens in the input tensor. num_experts : int @@ -153,18 +157,18 @@ def permute_with_mask_map( scale_hidden_dim : int Hidden size of the scale tensor. """ - output = torch.empty((num_out_tokens, hidden_size), dtype=inp.dtype, device="cuda") - if probs is not None: - permuted_probs = torch.empty((num_out_tokens,), dtype=probs.dtype, device="cuda") - else: - permuted_probs = None - - if scale is not None: - permuted_scale = torch.empty( - (num_out_tokens, scale_hidden_dim), dtype=scale.dtype, device="cuda" - ) - else: - permuted_scale = None + # Use torch.zeros when pad_offsets is provided to ensure padding regions are zeroed, + # since the kernel doesn't write to padding positions. + alloc = torch.zeros if pad_offsets is not None else torch.empty + output = alloc((num_out_tokens, hidden_size), dtype=inp.dtype, device="cuda") + permuted_probs = ( + alloc((num_out_tokens,), dtype=probs.dtype, device="cuda") if probs is not None else None + ) + permuted_scale = ( + torch.empty((num_out_tokens, scale_hidden_dim), dtype=scale.dtype, device="cuda") + if scale is not None + else None + ) # pylint: disable=unnecessary-lambda-assignment grid = lambda META: (num_tokens, triton.cdiv(hidden_size, META["BLOCK_SIZE"])) _permute_kernel[grid]( @@ -173,6 +177,7 @@ def permute_with_mask_map( probs, scale, permuted_scale, + pad_offsets, scale_hidden_dim, row_id_map.stride(0), row_id_map.stride(1), @@ -193,6 +198,7 @@ def permute_with_mask_map( hidden_size, PERMUTE_PROBS=probs is not None, PERMUTE_SCALE=scale is not None, + FUSION_PAD=pad_offsets is not None, ) return output, permuted_scale, permuted_probs @@ -202,6 +208,7 @@ def unpermute_with_mask_map( row_id_map: torch.Tensor, merging_probs: Union[torch.Tensor, None], permuted_probs: Union[torch.Tensor, None], + pad_offsets: Union[torch.Tensor, None], num_tokens: int, num_experts: int, hidden_size: int, @@ -220,6 +227,9 @@ def unpermute_with_mask_map( to reduce the unpermuted tokens. permuted_probs : torch.Tensor The permuted probabilities of the input tensor. If it is not None, it will be unpermuted. + pad_offsets : torch.Tensor + Per-expert padding offsets of shape `[num_experts]` for FP8 fused unpadding. + If it is not None, it will remove the previously fused padding. num_tokens : int Number of tokens in the permuted tensor. num_experts : int @@ -241,6 +251,7 @@ def unpermute_with_mask_map( row_id_map, merging_probs, permuted_probs, + pad_offsets, row_id_map.stride(0), row_id_map.stride(1), inp.stride(0), @@ -259,6 +270,7 @@ def unpermute_with_mask_map( PROBS_LOAD_WIDTH=triton.next_power_of_2(num_experts), WITH_MERGING_PROBS=merging_probs is not None, PERMUTE_PROBS=permuted_probs is not None, + FUSION_UNPAD=pad_offsets is not None, ) return output, unpermuted_probs @@ -268,6 +280,7 @@ def unpermute_with_mask_map_bwd_with_merging_probs( row_id_map: torch.Tensor, fwd_input: torch.Tensor, merging_probs: torch.Tensor, + pad_offsets: Union[torch.Tensor, None], num_tokens: int, num_experts: int, num_out_tokens: int, @@ -286,6 +299,9 @@ def unpermute_with_mask_map_bwd_with_merging_probs( The input tensor of the forward pass of shape `[num_out_tokens, hidden_size]`. merging_probs : torch.Tensor The merging probabilities of the input tensor of shape `[num_tokens, num_experts]`. + pad_offsets : torch.Tensor + Per-expert padding offsets of shape `[num_experts]` for FP8 fused padding. + If it is not None, it will be allocated output buffers with aligned sizes. num_tokens : int Number of tokens in the permuted tensor. num_experts : int @@ -295,9 +311,11 @@ def unpermute_with_mask_map_bwd_with_merging_probs( hidden_size : int Hidden size of the output tensor. """ - act_grad = torch.empty( - (num_out_tokens, hidden_size), dtype=fwd_output_grad.dtype, device="cuda" - ) + # Use zeros when pad_offsets is used because padding slots won't be written to + # by the kernel. This matches the behavior of Fp8Unpadding.backward which zeros + # out the padding slots. + alloc = torch.zeros if pad_offsets is not None else torch.empty + act_grad = alloc((num_out_tokens, hidden_size), dtype=fwd_output_grad.dtype, device="cuda") merging_probs_grad = torch.empty( (num_tokens, num_experts), dtype=merging_probs.dtype, device="cuda" ) @@ -307,6 +325,7 @@ def unpermute_with_mask_map_bwd_with_merging_probs( fwd_input, merging_probs, row_id_map, + pad_offsets, row_id_map.stride(0), row_id_map.stride(1), fwd_output_grad.stride(0), @@ -324,6 +343,7 @@ def unpermute_with_mask_map_bwd_with_merging_probs( num_experts, hidden_size, PROBS_LOAD_WIDTH=triton.next_power_of_2(num_experts), + FUSION_UNPAD=pad_offsets is not None, ) return act_grad, merging_probs_grad