From b002b928ccee403fea63364c901c6ab7f68f061b Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Tue, 16 Dec 2025 16:33:59 +0000 Subject: [PATCH 1/2] Get seqlens and offsets in O(N) space instead of O(N*N) space Signed-off-by: Kshitij Lakhani --- transformer_engine/jax/attention.py | 112 ++++++++++++---------------- 1 file changed, 48 insertions(+), 64 deletions(-) diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 0a32be96796..40454c3191f 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -506,6 +506,44 @@ def run_length_fill(segment_ids) -> jnp.ndarray: run_length_segment_id_shape = jax.vmap(run_length_fill_flattened, in_axes=0)(segment_ids_flat) return run_length_segment_id_shape.reshape(orig_shape) +def _get_seqlens_thd(segment_ids, max_segments_per_seq): + # Create mask for non-zero seg ids and get the non-zero indices associated with the same + non_zero_mask = segment_ids != 0 + max_size = segment_ids.shape[-1] + non_zero_indices = jax.vmap( + lambda mask_row: jnp.where(mask_row, size=max_size, fill_value=-1)[0] + )(non_zero_mask) + + # Pick non-zero seg ids and seg pos using take_along_axis to index within the seg ids and pos + # Clip -1 to 0 for safe indexing + clipped_indices = jnp.clip(non_zero_indices, 0, None) + valid_segment_ids = jnp.where( + non_zero_indices >= 0, jnp.take_along_axis(segment_ids, clipped_indices, axis=-1), 0 + ) + seqlens_all = jax.vmap( + lambda sp_row: jnp.bincount(sp_row, length=max_segments_per_seq + 1)[1:] + )(valid_segment_ids) + seqlens_all_pad_neg = jnp.where(seqlens_all == 0, -1, seqlens_all) + return seqlens_all_pad_neg + +def _get_seqoffsets_thd(segment_ids, segment_pos, max_segments_per_seq): + segment_changes = jnp.concatenate( + [ + jnp.full( + (segment_pos.shape[0], 1), True, dtype=bool + ), # First valid element starts a segment + (segment_pos[..., 1:] != segment_pos[..., :-1] + 1), # Segment pos changed + ], + axis=-1, + ) + # Remove any padded region segment changes + segment_changes_masked = jnp.where(segment_ids != 0, segment_changes, False) + # Get the indices for segment changes (these are the offsets) + seq_offsets = jax.vmap( + lambda scm_row: jnp.where(scm_row, size=max_segments_per_seq + 1, fill_value=-1)[0] + )(segment_changes_masked) + return seq_offsets + def _segment_ids_pos_to_seqlens_offsets( segment_ids_q, @@ -536,70 +574,16 @@ def _segment_ids_pos_to_seqlens_offsets( # It does not need to involve SW for this mask's creation # TODO(KshitijLakhani): Try exercising the fast path for BRCM as well - if (attn_mask_type.is_causal() and window_size is None) or ( - window_size == (-1, -1) and not attn_mask_type.is_bottom_right() - ): - return _segment_ids_pos_to_seqlens_offsets_fast_causal_path( - segment_ids_q, segment_ids_kv, segment_pos_q, segment_pos_kv, max_segments_per_seq - ) - - # (1 = attend, 0 = masked) - segment_mask = make_attention_mask( - segment_ids_q, - segment_ids_kv, - jnp.equal, - ) - segment_mask_with_id = make_attention_mask( - segment_ids_q, - segment_ids_kv, - lambda x, y: jnp.equal(x, y) * x, - ) - # TE JAX Attn expects the THD segments to have q_token <= kv_tokens so that a correct cross-attn type BRCM can be applied - attn_mask = segment_mask - if attn_mask_type.is_bottom_right(): - run_length_out_q = run_length_fill(segment_ids_q) - run_length_out_kv = run_length_fill(segment_ids_kv) - # Example for brcm: - # run_length_out_q: [3 3 3 0 4 4 4 4] - # segment_pos_q: [0 1 2 3 0 1 2 3] - # segment_ids_q: [1 1 1 0 2 2 2 2] - # run_length_out_kv: [4 4 4 4 0 0 10 10 10 10 10 10 10 10 10 10] - # segment_pos_kv: [0 1 2 3 4 5 0 1 2 3 4 5 6 7 8 9] - # segment_ids_kv: [1 1 1 1 0 0 2 2 2 2 2 2 2 2 2 2] - # brcm: [[[1 1 0 0 0 0 1 1 1 1 1 1 1 1 0 0] - # [1 1 1 0 0 0 1 1 1 1 1 1 1 1 1 0] - # [1 1 1 1 0 0 1 1 1 1 1 1 1 1 1 1] - # [1 1 1 1 0 0 1 1 1 1 1 1 1 1 1 1] - # [1 0 0 0 0 0 1 1 1 1 1 1 1 0 0 0] - # [1 1 0 0 0 0 1 1 1 1 1 1 1 1 0 0] - # [1 1 1 0 0 0 1 1 1 1 1 1 1 1 1 0] - # [1 1 1 1 0 0 1 1 1 1 1 1 1 1 1 1]]] - # attn_mask(noswa):[[[1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0] - # [1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0] - # [1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0] - # [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] - # [0 0 0 0 0 0 1 1 1 1 1 1 1 0 0 0] - # [0 0 0 0 0 0 1 1 1 1 1 1 1 1 0 0] - # [0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 0] - # [0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1]]] - bottom_right_causal_mask = make_attention_mask( - run_length_out_q - segment_pos_q, - run_length_out_kv - segment_pos_kv, - jnp.less_equal, - ) - attn_mask = jnp.logical_and(segment_mask, bottom_right_causal_mask) - elif attn_mask_type.is_causal(): - causal_mask = make_attention_mask( - segment_pos_q, - segment_pos_kv, - jnp.greater_equal, - ) - attn_mask = jnp.logical_and(segment_mask, causal_mask) - - attn_mask_with_id = jnp.where(attn_mask, segment_mask_with_id, 0) - q_seqlen, q_offset, kv_seqlen, kv_offset = _mask_to_seqlens_offset( - attn_mask_with_id, max_segments_per_seq - ) + # if (attn_mask_type.is_causal() and window_size is None) or ( + # window_size == (-1, -1) and not attn_mask_type.is_bottom_right() + # ): + # return _segment_ids_pos_to_seqlens_offsets_fast_causal_path( + # segment_ids_q, segment_ids_kv, segment_pos_q, segment_pos_kv, max_segments_per_seq + # ) + q_seqlen = _get_seqlens_thd(segment_ids=segment_ids_q, max_segments_per_seq=max_segments_per_seq) + kv_seqlen = _get_seqlens_thd(segment_ids=segment_ids_kv, max_segments_per_seq=max_segments_per_seq) + q_offset = _get_seqoffsets_thd(segment_ids=segment_ids_q, segment_pos=segment_pos_q, max_segments_per_seq=max_segments_per_seq) + kv_offset = _get_seqoffsets_thd(segment_ids=segment_ids_kv, segment_pos=segment_pos_kv, max_segments_per_seq=max_segments_per_seq) return q_seqlen, kv_seqlen, q_offset, kv_offset From dbc634ae2e3e843690291152d86c9e8e32e95326 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 17 Dec 2025 22:59:39 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/attention.py | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 40454c3191f..f8382fb73e8 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -506,6 +506,7 @@ def run_length_fill(segment_ids) -> jnp.ndarray: run_length_segment_id_shape = jax.vmap(run_length_fill_flattened, in_axes=0)(segment_ids_flat) return run_length_segment_id_shape.reshape(orig_shape) + def _get_seqlens_thd(segment_ids, max_segments_per_seq): # Create mask for non-zero seg ids and get the non-zero indices associated with the same non_zero_mask = segment_ids != 0 @@ -522,10 +523,11 @@ def _get_seqlens_thd(segment_ids, max_segments_per_seq): ) seqlens_all = jax.vmap( lambda sp_row: jnp.bincount(sp_row, length=max_segments_per_seq + 1)[1:] - )(valid_segment_ids) + )(valid_segment_ids) seqlens_all_pad_neg = jnp.where(seqlens_all == 0, -1, seqlens_all) return seqlens_all_pad_neg + def _get_seqoffsets_thd(segment_ids, segment_pos, max_segments_per_seq): segment_changes = jnp.concatenate( [ @@ -543,7 +545,7 @@ def _get_seqoffsets_thd(segment_ids, segment_pos, max_segments_per_seq): lambda scm_row: jnp.where(scm_row, size=max_segments_per_seq + 1, fill_value=-1)[0] )(segment_changes_masked) return seq_offsets - + def _segment_ids_pos_to_seqlens_offsets( segment_ids_q, @@ -580,10 +582,22 @@ def _segment_ids_pos_to_seqlens_offsets( # return _segment_ids_pos_to_seqlens_offsets_fast_causal_path( # segment_ids_q, segment_ids_kv, segment_pos_q, segment_pos_kv, max_segments_per_seq # ) - q_seqlen = _get_seqlens_thd(segment_ids=segment_ids_q, max_segments_per_seq=max_segments_per_seq) - kv_seqlen = _get_seqlens_thd(segment_ids=segment_ids_kv, max_segments_per_seq=max_segments_per_seq) - q_offset = _get_seqoffsets_thd(segment_ids=segment_ids_q, segment_pos=segment_pos_q, max_segments_per_seq=max_segments_per_seq) - kv_offset = _get_seqoffsets_thd(segment_ids=segment_ids_kv, segment_pos=segment_pos_kv, max_segments_per_seq=max_segments_per_seq) + q_seqlen = _get_seqlens_thd( + segment_ids=segment_ids_q, max_segments_per_seq=max_segments_per_seq + ) + kv_seqlen = _get_seqlens_thd( + segment_ids=segment_ids_kv, max_segments_per_seq=max_segments_per_seq + ) + q_offset = _get_seqoffsets_thd( + segment_ids=segment_ids_q, + segment_pos=segment_pos_q, + max_segments_per_seq=max_segments_per_seq, + ) + kv_offset = _get_seqoffsets_thd( + segment_ids=segment_ids_kv, + segment_pos=segment_pos_kv, + max_segments_per_seq=max_segments_per_seq, + ) return q_seqlen, kv_seqlen, q_offset, kv_offset