diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 49372fda1d4..f7267af5b8a 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -668,14 +668,24 @@ def generate_random_segment_ids( (self.offsets_q, self.offsets_kv), ) case SeqDescFormat.SegmentIDs: + # Exercise the path to generate the segment_pos in from_segment_ids_and_pos() + # if no CP and load balancing, else explicitly pass the segment_pos self.sequence_desciptor = SequenceDescriptor.from_segment_ids_and_pos( ( self.cp_reorder_fn(self.segment_ids_q), self.cp_reorder_fn(self.segment_ids_kv), ), ( - self.cp_reorder_fn(self.segment_pos_q), - self.cp_reorder_fn(self.segment_pos_kv), + ( + self.cp_reorder_fn(self.segment_pos_q), + self.cp_reorder_fn(self.segment_pos_kv), + ) + if self.cp_size > 1 and self.cp_load_balanced + else None + ), + is_thd=self.qkv_layout.is_thd(), + is_segment_ids_reordered=( + True if self.cp_size > 1 and self.cp_load_balanced else False ), ) case _: @@ -704,6 +714,8 @@ def generate_random_segment_ids( self.sequence_desciptor = SequenceDescriptor.from_segment_ids_and_pos( (self.segment_ids_q, self.segment_ids_kv), None, + is_thd=self.qkv_layout.is_thd(), + is_segment_ids_reordered=False, ) case _: raise ValueError(f"Unknown {self.seq_desc_format=}") diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 21680dc8057..09a29f4cb83 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -658,7 +658,7 @@ class SequenceDescriptor: - SequenceDescriptor.from_seqlens_and_offsets For THD (packed) cases, where each batch may have not only 1 sequence. - SequenceDescriptor.from_segment_ids_and_pos - Experimental feature for THD (packed) cases with context parallelism. + Experimental feature for BSHD (with and without reordering) and THD (packed) cases without reordering """ seqlens: Optional[Tuple[jnp.ndarray, jnp.ndarray]] @@ -796,9 +796,14 @@ def from_segment_ids_and_pos( cls, segment_ids: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]], segment_pos: Optional[Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]] = None, + *, + is_thd: bool, + is_segment_ids_reordered: bool, ) -> SequenceDescriptor: """ - Experimental factory method for inputs with segment IDs and optional positions. (THD) + Experimental factory method for inputs with segment IDs and optional positions. + segment_pos = None to be used only for: BSHD with or without load balancing and, + THD without load balancing Args: segment_ids(Tuple(jnp.ndarray, jnp.ndarray)) = (q_segment_ids, kv_segment_ids): - q_segment_ids (jnp.ndarray): @@ -812,22 +817,84 @@ def from_segment_ids_and_pos( The position inside each segment for query, with shape [batch, max_seqlen]. - kv_segment_pos (jnp.ndarray): The position inside each segment for key, value, with shape [batch, max_seqlen]. + is_thd(bool): If True, QKVLayout is of type THD, else it is BSHD + is_segment_ids_reordered(bool): If True, the segment ids have been reordered for load balancing. + Only THD with load balancing is expected to have this flag set to True Return: A SequenceDescriptor with segment_ids/segment_pos initialized. """ q_seg_ids, kv_seg_ids = cls._expand_to_pair(segment_ids) - if segment_pos is not None: - segment_pos = cls._expand_to_pair(segment_pos) - else: - - def generate_default_pos(segment_ids): - seqlen = segment_ids.shape[-1] - return jnp.broadcast_to(jnp.arange(seqlen), segment_ids.shape) + # Using defaults : segment pos has to be generated. + if segment_pos is None: + # THD + load balanced segment_ids are not supported in this function + # BSHD + load balanced segment_ids are incorrect as BSHD handles reordering within the primitive itself + if is_segment_ids_reordered: + assert not is_thd, ( + f"{segment_pos=} default arg is not supported for load balanced reordered" + " (Striped) THD inputs. Please pass the load balanced reordered segment_pos" + " and segment_ids explicitly to {from_segment_ids_and_pos.__qualname__}" + " using convenience function reorder_causal_load_balancing()" + ) + assert is_thd, ( + f"{segment_pos=} default arg is not supported for load balanced reordered (Dual" + " Chunk) BSHD inputs. BSHD segment_pos and segment_ids do not need to be load" + " balanced reordered. The reordering for these is performed within the" + " primitive" + ) + + # Generate the default pos for THD and BSHD non-reordered segment_ids + def generate_default_pos(seg_ids): + if is_thd: + batch_size, seq_size = seg_ids.shape + # Assume that the first token belongs to a segment and is not a padded token + first_is_segment = jnp.full((batch_size, 1), True, dtype=bool) + # Get segment start positions + segment_start = jnp.concatenate( + [ + first_is_segment, + (seg_ids[..., 1:] != seg_ids[..., :-1]) & (seg_ids[..., 1:] != 0), + ], + axis=-1, + ) + # Get offset for location where new segment starts + segment_start_idx = jax.vmap(lambda row: jnp.arange(row.size) * row)( + segment_start + ) + segment_start_offsets = jax.vmap(jnp.maximum.accumulate)(segment_start_idx) + + # Get the last non-zero index - after this everything is padding + # (B,) + last_nonzero_idx = jax.vmap( + lambda segids_row: jnp.max( + jnp.where(segids_row != 0, jnp.arange(seq_size), -1) + ) + )(seg_ids) + seg_pos_no_thd = jnp.arange(seq_size) + # Get a mask which can be used to zero out all the padding at the end (after the non-zero index) + mask = seg_pos_no_thd <= last_nonzero_idx[:, None] + + # Get the unmasked seg_pos for the THD sequence + seg_pos = ( + jnp.broadcast_to(jnp.arange(seq_size), seg_ids.shape) + - segment_start_offsets + ) + + # Use the mask to zero out the padding at the end (after the non-zero index) + segment_pos = jax.vmap( + lambda pos_row, mask_row: jnp.where(mask_row, pos_row, 0) + )(seg_pos, mask) + return segment_pos + + seqlen = seg_ids.shape[-1] + return jnp.broadcast_to(jnp.arange(seqlen), seg_ids.shape) q_seg_pos = generate_default_pos(q_seg_ids) kv_seg_pos = generate_default_pos(kv_seg_ids) segment_pos = (q_seg_pos, kv_seg_pos) + # Explicitly passed segment_pos + else: + segment_pos = cls._expand_to_pair(segment_pos) return cls( segment_ids=(q_seg_ids, kv_seg_ids),