Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 59 additions & 61 deletions transformer_engine/jax/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,46 @@ def run_length_fill(segment_ids) -> jnp.ndarray:
return run_length_segment_id_shape.reshape(orig_shape)


def _get_seqlens_thd(segment_ids, max_segments_per_seq):
# Create mask for non-zero seg ids and get the non-zero indices associated with the same
non_zero_mask = segment_ids != 0
max_size = segment_ids.shape[-1]
non_zero_indices = jax.vmap(
lambda mask_row: jnp.where(mask_row, size=max_size, fill_value=-1)[0]
)(non_zero_mask)

# Pick non-zero seg ids and seg pos using take_along_axis to index within the seg ids and pos
# Clip -1 to 0 for safe indexing
clipped_indices = jnp.clip(non_zero_indices, 0, None)
valid_segment_ids = jnp.where(
non_zero_indices >= 0, jnp.take_along_axis(segment_ids, clipped_indices, axis=-1), 0
)
seqlens_all = jax.vmap(
lambda sp_row: jnp.bincount(sp_row, length=max_segments_per_seq + 1)[1:]
)(valid_segment_ids)
seqlens_all_pad_neg = jnp.where(seqlens_all == 0, -1, seqlens_all)
return seqlens_all_pad_neg


def _get_seqoffsets_thd(segment_ids, segment_pos, max_segments_per_seq):
segment_changes = jnp.concatenate(
[
jnp.full(
(segment_pos.shape[0], 1), True, dtype=bool
), # First valid element starts a segment
(segment_pos[..., 1:] != segment_pos[..., :-1] + 1), # Segment pos changed
],
axis=-1,
)
# Remove any padded region segment changes
segment_changes_masked = jnp.where(segment_ids != 0, segment_changes, False)
# Get the indices for segment changes (these are the offsets)
seq_offsets = jax.vmap(
lambda scm_row: jnp.where(scm_row, size=max_segments_per_seq + 1, fill_value=-1)[0]
)(segment_changes_masked)
return seq_offsets


def _segment_ids_pos_to_seqlens_offsets(
segment_ids_q,
segment_ids_kv,
Expand Down Expand Up @@ -570,69 +610,27 @@ def _segment_ids_pos_to_seqlens_offsets(
# It does not need to involve SW for this mask's creation

# TODO(KshitijLakhani): Try exercising the fast path for BRCM as well
if (attn_mask_type.is_causal() and window_size is None) or (
window_size == (-1, -1) and not attn_mask_type.is_bottom_right()
):
return _segment_ids_pos_to_seqlens_offsets_fast_causal_path(
segment_ids_q, segment_ids_kv, segment_pos_q, segment_pos_kv, max_segments_per_seq
)

# (1 = attend, 0 = masked)
segment_mask = make_attention_mask(
segment_ids_q,
segment_ids_kv,
jnp.equal,
# if (attn_mask_type.is_causal() and window_size is None) or (
# window_size == (-1, -1) and not attn_mask_type.is_bottom_right()
# ):
# return _segment_ids_pos_to_seqlens_offsets_fast_causal_path(
# segment_ids_q, segment_ids_kv, segment_pos_q, segment_pos_kv, max_segments_per_seq
# )
q_seqlen = _get_seqlens_thd(
segment_ids=segment_ids_q, max_segments_per_seq=max_segments_per_seq
)
segment_mask_with_id = make_attention_mask(
segment_ids_q,
segment_ids_kv,
lambda x, y: jnp.equal(x, y) * x,
kv_seqlen = _get_seqlens_thd(
segment_ids=segment_ids_kv, max_segments_per_seq=max_segments_per_seq
)
# TE JAX Attn expects the THD segments to have q_token <= kv_tokens so that a correct cross-attn type BRCM can be applied
attn_mask = segment_mask
if attn_mask_type.is_bottom_right():
run_length_out_q = run_length_fill(segment_ids_q)
run_length_out_kv = run_length_fill(segment_ids_kv)
# Example for brcm:
# run_length_out_q: [3 3 3 0 4 4 4 4]
# segment_pos_q: [0 1 2 3 0 1 2 3]
# segment_ids_q: [1 1 1 0 2 2 2 2]
# run_length_out_kv: [4 4 4 4 0 0 10 10 10 10 10 10 10 10 10 10]
# segment_pos_kv: [0 1 2 3 4 5 0 1 2 3 4 5 6 7 8 9]
# segment_ids_kv: [1 1 1 1 0 0 2 2 2 2 2 2 2 2 2 2]
# brcm: [[[1 1 0 0 0 0 1 1 1 1 1 1 1 1 0 0]
# [1 1 1 0 0 0 1 1 1 1 1 1 1 1 1 0]
# [1 1 1 1 0 0 1 1 1 1 1 1 1 1 1 1]
# [1 1 1 1 0 0 1 1 1 1 1 1 1 1 1 1]
# [1 0 0 0 0 0 1 1 1 1 1 1 1 0 0 0]
# [1 1 0 0 0 0 1 1 1 1 1 1 1 1 0 0]
# [1 1 1 0 0 0 1 1 1 1 1 1 1 1 1 0]
# [1 1 1 1 0 0 1 1 1 1 1 1 1 1 1 1]]]
# attn_mask(noswa):[[[1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
# [1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0]
# [1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0]
# [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
# [0 0 0 0 0 0 1 1 1 1 1 1 1 0 0 0]
# [0 0 0 0 0 0 1 1 1 1 1 1 1 1 0 0]
# [0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 0]
# [0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1]]]
bottom_right_causal_mask = make_attention_mask(
run_length_out_q - segment_pos_q,
run_length_out_kv - segment_pos_kv,
jnp.less_equal,
)
attn_mask = jnp.logical_and(segment_mask, bottom_right_causal_mask)
elif attn_mask_type.is_causal():
causal_mask = make_attention_mask(
segment_pos_q,
segment_pos_kv,
jnp.greater_equal,
)
attn_mask = jnp.logical_and(segment_mask, causal_mask)

attn_mask_with_id = jnp.where(attn_mask, segment_mask_with_id, 0)
q_seqlen, q_offset, kv_seqlen, kv_offset = _mask_to_seqlens_offset(
attn_mask_with_id, max_segments_per_seq
q_offset = _get_seqoffsets_thd(
segment_ids=segment_ids_q,
segment_pos=segment_pos_q,
max_segments_per_seq=max_segments_per_seq,
)
kv_offset = _get_seqoffsets_thd(
segment_ids=segment_ids_kv,
segment_pos=segment_pos_kv,
max_segments_per_seq=max_segments_per_seq,
)
return q_seqlen, kv_seqlen, q_offset, kv_offset

Expand Down