diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 0a32be96796..f8382fb73e8 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -507,6 +507,46 @@ def run_length_fill(segment_ids) -> jnp.ndarray: return run_length_segment_id_shape.reshape(orig_shape) +def _get_seqlens_thd(segment_ids, max_segments_per_seq): + # Create mask for non-zero seg ids and get the non-zero indices associated with the same + non_zero_mask = segment_ids != 0 + max_size = segment_ids.shape[-1] + non_zero_indices = jax.vmap( + lambda mask_row: jnp.where(mask_row, size=max_size, fill_value=-1)[0] + )(non_zero_mask) + + # Pick non-zero seg ids and seg pos using take_along_axis to index within the seg ids and pos + # Clip -1 to 0 for safe indexing + clipped_indices = jnp.clip(non_zero_indices, 0, None) + valid_segment_ids = jnp.where( + non_zero_indices >= 0, jnp.take_along_axis(segment_ids, clipped_indices, axis=-1), 0 + ) + seqlens_all = jax.vmap( + lambda sp_row: jnp.bincount(sp_row, length=max_segments_per_seq + 1)[1:] + )(valid_segment_ids) + seqlens_all_pad_neg = jnp.where(seqlens_all == 0, -1, seqlens_all) + return seqlens_all_pad_neg + + +def _get_seqoffsets_thd(segment_ids, segment_pos, max_segments_per_seq): + segment_changes = jnp.concatenate( + [ + jnp.full( + (segment_pos.shape[0], 1), True, dtype=bool + ), # First valid element starts a segment + (segment_pos[..., 1:] != segment_pos[..., :-1] + 1), # Segment pos changed + ], + axis=-1, + ) + # Remove any padded region segment changes + segment_changes_masked = jnp.where(segment_ids != 0, segment_changes, False) + # Get the indices for segment changes (these are the offsets) + seq_offsets = jax.vmap( + lambda scm_row: jnp.where(scm_row, size=max_segments_per_seq + 1, fill_value=-1)[0] + )(segment_changes_masked) + return seq_offsets + + def _segment_ids_pos_to_seqlens_offsets( segment_ids_q, segment_ids_kv, @@ -536,69 +576,27 @@ def _segment_ids_pos_to_seqlens_offsets( # It does not need to involve SW for this mask's creation # TODO(KshitijLakhani): Try exercising the fast path for BRCM as well - if (attn_mask_type.is_causal() and window_size is None) or ( - window_size == (-1, -1) and not attn_mask_type.is_bottom_right() - ): - return _segment_ids_pos_to_seqlens_offsets_fast_causal_path( - segment_ids_q, segment_ids_kv, segment_pos_q, segment_pos_kv, max_segments_per_seq - ) - - # (1 = attend, 0 = masked) - segment_mask = make_attention_mask( - segment_ids_q, - segment_ids_kv, - jnp.equal, + # if (attn_mask_type.is_causal() and window_size is None) or ( + # window_size == (-1, -1) and not attn_mask_type.is_bottom_right() + # ): + # return _segment_ids_pos_to_seqlens_offsets_fast_causal_path( + # segment_ids_q, segment_ids_kv, segment_pos_q, segment_pos_kv, max_segments_per_seq + # ) + q_seqlen = _get_seqlens_thd( + segment_ids=segment_ids_q, max_segments_per_seq=max_segments_per_seq ) - segment_mask_with_id = make_attention_mask( - segment_ids_q, - segment_ids_kv, - lambda x, y: jnp.equal(x, y) * x, + kv_seqlen = _get_seqlens_thd( + segment_ids=segment_ids_kv, max_segments_per_seq=max_segments_per_seq ) - # TE JAX Attn expects the THD segments to have q_token <= kv_tokens so that a correct cross-attn type BRCM can be applied - attn_mask = segment_mask - if attn_mask_type.is_bottom_right(): - run_length_out_q = run_length_fill(segment_ids_q) - run_length_out_kv = run_length_fill(segment_ids_kv) - # Example for brcm: - # run_length_out_q: [3 3 3 0 4 4 4 4] - # segment_pos_q: [0 1 2 3 0 1 2 3] - # segment_ids_q: [1 1 1 0 2 2 2 2] - # run_length_out_kv: [4 4 4 4 0 0 10 10 10 10 10 10 10 10 10 10] - # segment_pos_kv: [0 1 2 3 4 5 0 1 2 3 4 5 6 7 8 9] - # segment_ids_kv: [1 1 1 1 0 0 2 2 2 2 2 2 2 2 2 2] - # brcm: [[[1 1 0 0 0 0 1 1 1 1 1 1 1 1 0 0] - # [1 1 1 0 0 0 1 1 1 1 1 1 1 1 1 0] - # [1 1 1 1 0 0 1 1 1 1 1 1 1 1 1 1] - # [1 1 1 1 0 0 1 1 1 1 1 1 1 1 1 1] - # [1 0 0 0 0 0 1 1 1 1 1 1 1 0 0 0] - # [1 1 0 0 0 0 1 1 1 1 1 1 1 1 0 0] - # [1 1 1 0 0 0 1 1 1 1 1 1 1 1 1 0] - # [1 1 1 1 0 0 1 1 1 1 1 1 1 1 1 1]]] - # attn_mask(noswa):[[[1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0] - # [1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0] - # [1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0] - # [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] - # [0 0 0 0 0 0 1 1 1 1 1 1 1 0 0 0] - # [0 0 0 0 0 0 1 1 1 1 1 1 1 1 0 0] - # [0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 0] - # [0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1]]] - bottom_right_causal_mask = make_attention_mask( - run_length_out_q - segment_pos_q, - run_length_out_kv - segment_pos_kv, - jnp.less_equal, - ) - attn_mask = jnp.logical_and(segment_mask, bottom_right_causal_mask) - elif attn_mask_type.is_causal(): - causal_mask = make_attention_mask( - segment_pos_q, - segment_pos_kv, - jnp.greater_equal, - ) - attn_mask = jnp.logical_and(segment_mask, causal_mask) - - attn_mask_with_id = jnp.where(attn_mask, segment_mask_with_id, 0) - q_seqlen, q_offset, kv_seqlen, kv_offset = _mask_to_seqlens_offset( - attn_mask_with_id, max_segments_per_seq + q_offset = _get_seqoffsets_thd( + segment_ids=segment_ids_q, + segment_pos=segment_pos_q, + max_segments_per_seq=max_segments_per_seq, + ) + kv_offset = _get_seqoffsets_thd( + segment_ids=segment_ids_kv, + segment_pos=segment_pos_kv, + max_segments_per_seq=max_segments_per_seq, ) return q_seqlen, kv_seqlen, q_offset, kv_offset