From 2d2e33a24136dc9fe9f938f207df7c9628cce439 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Tue, 16 Dec 2025 23:10:56 +0000 Subject: [PATCH 01/18] Fix incorrect calculation of segment pos from segment ids for thd cases and load balanced cases in from_segment_ids_and_pos. Enforce passing of segment_pos for THD cases and lod balanced cases Signed-off-by: Kshitij Lakhani --- transformer_engine/jax/attention.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 21680dc8057..8c97545a4a5 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -795,10 +795,11 @@ def from_seqlens_and_offsets( def from_segment_ids_and_pos( cls, segment_ids: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]], - segment_pos: Optional[Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]] = None, + segment_pos: Optional[Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]], ) -> SequenceDescriptor: """ - Experimental factory method for inputs with segment IDs and optional positions. (THD) + Experimental factory method for inputs with segment IDs and optional positions. + segment_pos = None to be used only for : BSHD without load balancing Args: segment_ids(Tuple(jnp.ndarray, jnp.ndarray)) = (q_segment_ids, kv_segment_ids): - q_segment_ids (jnp.ndarray): @@ -817,10 +818,16 @@ def from_segment_ids_and_pos( """ q_seg_ids, kv_seg_ids = cls._expand_to_pair(segment_ids) - if segment_pos is not None: - segment_pos = cls._expand_to_pair(segment_pos) - else: - + if segment_pos is None: + warnings.warn( + "segment_pos no longer defaults to None and must be explicitly passed", + DeprecationWarning, + ) + warnings.warn( + "segment_pos = None is only acceptable if using BSHD and no load balancing. For all other cases, " \ + " please explicitly pass the segment_pos", + UserWarning, + ) def generate_default_pos(segment_ids): seqlen = segment_ids.shape[-1] return jnp.broadcast_to(jnp.arange(seqlen), segment_ids.shape) @@ -828,6 +835,8 @@ def generate_default_pos(segment_ids): q_seg_pos = generate_default_pos(q_seg_ids) kv_seg_pos = generate_default_pos(kv_seg_ids) segment_pos = (q_seg_pos, kv_seg_pos) + else: + segment_pos = cls._expand_to_pair(segment_pos) return cls( segment_ids=(q_seg_ids, kv_seg_ids), From 65e6b4b05d4fa703c477d44ffd42fcfde0643e7a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 16 Dec 2025 23:12:55 +0000 Subject: [PATCH 02/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/attention.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 8c97545a4a5..e756e9a4e8b 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -824,10 +824,11 @@ def from_segment_ids_and_pos( DeprecationWarning, ) warnings.warn( - "segment_pos = None is only acceptable if using BSHD and no load balancing. For all other cases, " \ - " please explicitly pass the segment_pos", + "segment_pos = None is only acceptable if using BSHD and no load balancing. For all" + " other cases, please explicitly pass the segment_pos", UserWarning, ) + def generate_default_pos(segment_ids): seqlen = segment_ids.shape[-1] return jnp.broadcast_to(jnp.arange(seqlen), segment_ids.shape) From 98575779e42c5530dff917139ead6cfadb4c0113 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Wed, 17 Dec 2025 02:02:31 +0000 Subject: [PATCH 03/18] Correct the assert condition Signed-off-by: Kshitij Lakhani --- transformer_engine/jax/attention.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index e756e9a4e8b..90937f3890a 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -791,11 +791,14 @@ def from_seqlens_and_offsets( q_offsets, kv_offsets = cls._expand_to_pair(seq_offsets) return cls(seqlens=(q_seqlens, kv_seqlens), seq_offsets=(q_offsets, kv_offsets)) + #TODO(KshitijLakhani), TODO(mgoldfarb-nvidia): Consider adding support for THD layout (non load balanced). @classmethod def from_segment_ids_and_pos( cls, segment_ids: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]], - segment_pos: Optional[Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]], + segment_pos: Optional[Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]] = None, + is_thd: bool = False, + is_load_balanced: bool = False, ) -> SequenceDescriptor: """ Experimental factory method for inputs with segment IDs and optional positions. @@ -813,19 +816,19 @@ def from_segment_ids_and_pos( The position inside each segment for query, with shape [batch, max_seqlen]. - kv_segment_pos (jnp.ndarray): The position inside each segment for key, value, with shape [batch, max_seqlen]. + is_thd(bool): If True, QKVLayout is of type THD, else it is BSHD + is_load_balanced(bool): If True, CP is being used and the inputs have been load balanced. Return: A SequenceDescriptor with segment_ids/segment_pos initialized. """ - q_seg_ids, kv_seg_ids = cls._expand_to_pair(segment_ids) - + # If using defaults if segment_pos is None: + # Segment pos is not calculated implicitly for THD cases and Load balancing cases + assert not is_load_balanced, f"segment_pos = None default arg is not supported for load balanced inputs" + assert not is_thd, f"segment_pos = None default arg is not supported for THD layouts" warnings.warn( - "segment_pos no longer defaults to None and must be explicitly passed", - DeprecationWarning, - ) - warnings.warn( - "segment_pos = None is only acceptable if using BSHD and no load balancing. For all" - " other cases, please explicitly pass the segment_pos", + "segment_pos = None is only acceptable if using BSHD and no load balancing. For all other cases, " \ + "segment_pos must be passed explicitly", UserWarning, ) @@ -836,9 +839,11 @@ def generate_default_pos(segment_ids): q_seg_pos = generate_default_pos(q_seg_ids) kv_seg_pos = generate_default_pos(kv_seg_ids) segment_pos = (q_seg_pos, kv_seg_pos) - else: + else: # Explicitly passed segment_pos segment_pos = cls._expand_to_pair(segment_pos) + q_seg_ids, kv_seg_ids = cls._expand_to_pair(segment_ids) + return cls( segment_ids=(q_seg_ids, kv_seg_ids), segment_pos=segment_pos, From b20ac22e3b2152c9851ff563e9b194b8c5eba52a Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Wed, 17 Dec 2025 02:03:52 +0000 Subject: [PATCH 04/18] Modify fused attn tests to pass new args to from_segment_ids_and_pos() Signed-off-by: Kshitij Lakhani --- tests/jax/test_fused_attn.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 49372fda1d4..1f7bee50d50 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -677,6 +677,8 @@ def generate_random_segment_ids( self.cp_reorder_fn(self.segment_pos_q), self.cp_reorder_fn(self.segment_pos_kv), ), + self.qkv_layout.is_thd(), + self.cp_size > 1 and self.cp_load_balanced, ) case _: raise ValueError(f"Unknown {self.seq_desc_format=}") @@ -704,6 +706,8 @@ def generate_random_segment_ids( self.sequence_desciptor = SequenceDescriptor.from_segment_ids_and_pos( (self.segment_ids_q, self.segment_ids_kv), None, + self.qkv_layout.is_thd(), + self.cp_size > 1 and self.cp_load_balanced, ) case _: raise ValueError(f"Unknown {self.seq_desc_format=}") From 03398a4c2d01b0054245539b21590d66293b9a33 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 17 Dec 2025 02:08:23 +0000 Subject: [PATCH 05/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/attention.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 90937f3890a..a1669919e94 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -791,7 +791,7 @@ def from_seqlens_and_offsets( q_offsets, kv_offsets = cls._expand_to_pair(seq_offsets) return cls(seqlens=(q_seqlens, kv_seqlens), seq_offsets=(q_offsets, kv_offsets)) - #TODO(KshitijLakhani), TODO(mgoldfarb-nvidia): Consider adding support for THD layout (non load balanced). + # TODO(KshitijLakhani), TODO(mgoldfarb-nvidia): Consider adding support for THD layout (non load balanced). @classmethod def from_segment_ids_and_pos( cls, @@ -823,12 +823,14 @@ def from_segment_ids_and_pos( """ # If using defaults if segment_pos is None: - # Segment pos is not calculated implicitly for THD cases and Load balancing cases - assert not is_load_balanced, f"segment_pos = None default arg is not supported for load balanced inputs" + # Segment pos is not calculated implicitly for THD cases and Load balancing cases + assert ( + not is_load_balanced + ), f"segment_pos = None default arg is not supported for load balanced inputs" assert not is_thd, f"segment_pos = None default arg is not supported for THD layouts" warnings.warn( - "segment_pos = None is only acceptable if using BSHD and no load balancing. For all other cases, " \ - "segment_pos must be passed explicitly", + "segment_pos = None is only acceptable if using BSHD and no load balancing. For all" + " other cases, segment_pos must be passed explicitly", UserWarning, ) @@ -839,7 +841,7 @@ def generate_default_pos(segment_ids): q_seg_pos = generate_default_pos(q_seg_ids) kv_seg_pos = generate_default_pos(kv_seg_ids) segment_pos = (q_seg_pos, kv_seg_pos) - else: # Explicitly passed segment_pos + else: # Explicitly passed segment_pos segment_pos = cls._expand_to_pair(segment_pos) q_seg_ids, kv_seg_ids = cls._expand_to_pair(segment_ids) From 0a47eb6abaf4919f0d40c82ae33eeb23f6e8b93d Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Wed, 17 Dec 2025 02:47:25 +0000 Subject: [PATCH 06/18] Calculate seg ids before pos Signed-off-by: Kshitij Lakhani --- transformer_engine/jax/attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index a1669919e94..6519dd3ed0e 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -821,6 +821,8 @@ def from_segment_ids_and_pos( Return: A SequenceDescriptor with segment_ids/segment_pos initialized. """ + q_seg_ids, kv_seg_ids = cls._expand_to_pair(segment_ids) + # If using defaults if segment_pos is None: # Segment pos is not calculated implicitly for THD cases and Load balancing cases @@ -844,8 +846,6 @@ def generate_default_pos(segment_ids): else: # Explicitly passed segment_pos segment_pos = cls._expand_to_pair(segment_pos) - q_seg_ids, kv_seg_ids = cls._expand_to_pair(segment_ids) - return cls( segment_ids=(q_seg_ids, kv_seg_ids), segment_pos=segment_pos, From 217ea588c79e62d1218423955e556b0e7941883d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 17 Dec 2025 02:48:47 +0000 Subject: [PATCH 07/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 6519dd3ed0e..95ab90f4f73 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -822,7 +822,7 @@ def from_segment_ids_and_pos( A SequenceDescriptor with segment_ids/segment_pos initialized. """ q_seg_ids, kv_seg_ids = cls._expand_to_pair(segment_ids) - + # If using defaults if segment_pos is None: # Segment pos is not calculated implicitly for THD cases and Load balancing cases From ca9d3bc04af0d3235d86f62175248adb865fe6d9 Mon Sep 17 00:00:00 2001 From: Kshitij Janardan Lakhani Date: Mon, 22 Dec 2025 16:42:26 -0800 Subject: [PATCH 08/18] 1. Change the signature for from_segment_ids_and_pos() 2. Add support for THD in from_segment_ids_and_pos() 3. Assert if load balanced segment_ids is passed to generate a segment_pos Signed-off-by: Kshitij Janardan Lakhani --- transformer_engine/jax/attention.py | 60 ++++++++++++++++++++--------- 1 file changed, 41 insertions(+), 19 deletions(-) diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 95ab90f4f73..73f55b86adb 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -797,8 +797,9 @@ def from_segment_ids_and_pos( cls, segment_ids: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]], segment_pos: Optional[Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]] = None, - is_thd: bool = False, - is_load_balanced: bool = False, + *, + is_thd: bool, + is_load_balanced: bool, ) -> SequenceDescriptor: """ Experimental factory method for inputs with segment IDs and optional positions. @@ -823,27 +824,48 @@ def from_segment_ids_and_pos( """ q_seg_ids, kv_seg_ids = cls._expand_to_pair(segment_ids) - # If using defaults + # Using defaults. segment pos has to be generated. if segment_pos is None: - # Segment pos is not calculated implicitly for THD cases and Load balancing cases - assert ( - not is_load_balanced - ), f"segment_pos = None default arg is not supported for load balanced inputs" - assert not is_thd, f"segment_pos = None default arg is not supported for THD layouts" - warnings.warn( - "segment_pos = None is only acceptable if using BSHD and no load balancing. For all" - " other cases, segment_pos must be passed explicitly", - UserWarning, - ) - - def generate_default_pos(segment_ids): - seqlen = segment_ids.shape[-1] - return jnp.broadcast_to(jnp.arange(seqlen), segment_ids.shape) - + # Segment pos is not calculated implicitly Load balancing cases + assert not is_load_balanced, (f"{segment_pos=} default arg is not supported for load balanced inputs. " + "Please pass the load balanced segment_pos and segment_ids using helper function reorder_causal_load_balancing()") + + def generate_default_pos(seg_ids): + if is_thd: + batch_size, seq_size = seg_ids.shape + # Assume that the first token belongs to a segment and is not a padded token + first_is_segment = jnp.full((batch_size, 1), True, dtype=bool) + # Get segment start positions + segment_start = jnp.concatenate([ + first_is_segment, # First valid element starts a segment + (seg_ids[..., 1:] != seg_ids[..., :-1]) & (seg_ids[..., 1:] != 0) + ], axis=-1) + # Get offset for location where new segment starts + segment_start_idx = jax.vmap(lambda row: jnp.arange(row.size) * row)(segment_start) + segment_start_offsets = jax.vmap(lambda row: jnp.maximum.accumulate(row))(segment_start_idx) + + # Get the last non-zero index - after this everything is padding + # (B,) + last_nonzero_idx = jax.vmap(lambda segids_row: jnp.max(jnp.where(segids_row != 0, jnp.arange(seq_size), -1)))(seg_ids) + seg_pos_no_thd = jnp.arange(seq_size) + # Get a mask which can be used to zero out all the padding at the end (after the non-zero index) + mask = seg_pos_no_thd <= last_nonzero_idx[:, None] + + # Get the unmasked seg_pos for the THD sequence + seg_pos = jnp.broadcast_to(jnp.arange(seq_size), seg_ids.shape) - segment_start_offsets + + # Use the mask to zero out the padding at the end (after the non-zero index) + segment_pos = jax.vmap(lambda pos_row, mask_row: jnp.where(mask_row, pos_row, 0))(seg_pos, mask) + return segment_pos + else: + seqlen = segment_ids.shape[-1] + return jnp.broadcast_to(jnp.arange(seqlen), segment_ids.shape) + q_seg_pos = generate_default_pos(q_seg_ids) kv_seg_pos = generate_default_pos(kv_seg_ids) segment_pos = (q_seg_pos, kv_seg_pos) - else: # Explicitly passed segment_pos + # Explicitly passed segment_pos + else: segment_pos = cls._expand_to_pair(segment_pos) return cls( From 0ee40a5ba4e567d886fb61c29862c07b6387b697 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 23 Dec 2025 00:45:13 +0000 Subject: [PATCH 09/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/attention.py | 41 +++++++++++++++++++++-------- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 73f55b86adb..730518e0cb9 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -827,8 +827,11 @@ def from_segment_ids_and_pos( # Using defaults. segment pos has to be generated. if segment_pos is None: # Segment pos is not calculated implicitly Load balancing cases - assert not is_load_balanced, (f"{segment_pos=} default arg is not supported for load balanced inputs. " - "Please pass the load balanced segment_pos and segment_ids using helper function reorder_causal_load_balancing()") + assert not is_load_balanced, ( + f"{segment_pos=} default arg is not supported for load balanced inputs. Please pass" + " the load balanced segment_pos and segment_ids using helper function" + " reorder_causal_load_balancing()" + ) def generate_default_pos(seg_ids): if is_thd: @@ -836,31 +839,47 @@ def generate_default_pos(seg_ids): # Assume that the first token belongs to a segment and is not a padded token first_is_segment = jnp.full((batch_size, 1), True, dtype=bool) # Get segment start positions - segment_start = jnp.concatenate([ + segment_start = jnp.concatenate( + [ first_is_segment, # First valid element starts a segment - (seg_ids[..., 1:] != seg_ids[..., :-1]) & (seg_ids[..., 1:] != 0) - ], axis=-1) + (seg_ids[..., 1:] != seg_ids[..., :-1]) & (seg_ids[..., 1:] != 0), + ], + axis=-1, + ) # Get offset for location where new segment starts - segment_start_idx = jax.vmap(lambda row: jnp.arange(row.size) * row)(segment_start) - segment_start_offsets = jax.vmap(lambda row: jnp.maximum.accumulate(row))(segment_start_idx) + segment_start_idx = jax.vmap(lambda row: jnp.arange(row.size) * row)( + segment_start + ) + segment_start_offsets = jax.vmap(lambda row: jnp.maximum.accumulate(row))( + segment_start_idx + ) # Get the last non-zero index - after this everything is padding # (B,) - last_nonzero_idx = jax.vmap(lambda segids_row: jnp.max(jnp.where(segids_row != 0, jnp.arange(seq_size), -1)))(seg_ids) + last_nonzero_idx = jax.vmap( + lambda segids_row: jnp.max( + jnp.where(segids_row != 0, jnp.arange(seq_size), -1) + ) + )(seg_ids) seg_pos_no_thd = jnp.arange(seq_size) # Get a mask which can be used to zero out all the padding at the end (after the non-zero index) mask = seg_pos_no_thd <= last_nonzero_idx[:, None] # Get the unmasked seg_pos for the THD sequence - seg_pos = jnp.broadcast_to(jnp.arange(seq_size), seg_ids.shape) - segment_start_offsets + seg_pos = ( + jnp.broadcast_to(jnp.arange(seq_size), seg_ids.shape) + - segment_start_offsets + ) # Use the mask to zero out the padding at the end (after the non-zero index) - segment_pos = jax.vmap(lambda pos_row, mask_row: jnp.where(mask_row, pos_row, 0))(seg_pos, mask) + segment_pos = jax.vmap( + lambda pos_row, mask_row: jnp.where(mask_row, pos_row, 0) + )(seg_pos, mask) return segment_pos else: seqlen = segment_ids.shape[-1] return jnp.broadcast_to(jnp.arange(seqlen), segment_ids.shape) - + q_seg_pos = generate_default_pos(q_seg_ids) kv_seg_pos = generate_default_pos(kv_seg_ids) segment_pos = (q_seg_pos, kv_seg_pos) From ceec1ead66022581bd19bb39a409efaa6dd34f11 Mon Sep 17 00:00:00 2001 From: Kshitij Janardan Lakhani Date: Mon, 22 Dec 2025 16:55:04 -0800 Subject: [PATCH 10/18] Pass keyword-only args by name Signed-off-by: Kshitij Janardan Lakhani --- tests/jax/test_fused_attn.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 1f7bee50d50..f4ee2b3cc5f 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -677,8 +677,8 @@ def generate_random_segment_ids( self.cp_reorder_fn(self.segment_pos_q), self.cp_reorder_fn(self.segment_pos_kv), ), - self.qkv_layout.is_thd(), - self.cp_size > 1 and self.cp_load_balanced, + is_thd=self.qkv_layout.is_thd(), + is_load_balanced=self.cp_size > 1 and self.cp_load_balanced, ) case _: raise ValueError(f"Unknown {self.seq_desc_format=}") @@ -706,8 +706,8 @@ def generate_random_segment_ids( self.sequence_desciptor = SequenceDescriptor.from_segment_ids_and_pos( (self.segment_ids_q, self.segment_ids_kv), None, - self.qkv_layout.is_thd(), - self.cp_size > 1 and self.cp_load_balanced, + is_thd=self.qkv_layout.is_thd(), + is_load_balanced=self.cp_size > 1 and self.cp_load_balanced, ) case _: raise ValueError(f"Unknown {self.seq_desc_format=}") From ab00bb081476584890a075fa74e21c23ec83cc48 Mon Sep 17 00:00:00 2001 From: Kshitij Janardan Lakhani Date: Mon, 22 Dec 2025 16:57:12 -0800 Subject: [PATCH 11/18] nit: Fix typo to use seg_ids instead of segment_ids Signed-off-by: Kshitij Janardan Lakhani --- transformer_engine/jax/attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 730518e0cb9..6a884a8c484 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -877,8 +877,8 @@ def generate_default_pos(seg_ids): )(seg_pos, mask) return segment_pos else: - seqlen = segment_ids.shape[-1] - return jnp.broadcast_to(jnp.arange(seqlen), segment_ids.shape) + seqlen = seg_ids.shape[-1] + return jnp.broadcast_to(jnp.arange(seqlen), seg_ids.shape) q_seg_pos = generate_default_pos(q_seg_ids) kv_seg_pos = generate_default_pos(kv_seg_ids) From 059c48dc5ee89c7921a1df20bc9fc21ae9a3cd49 Mon Sep 17 00:00:00 2001 From: Kshitij Janardan Lakhani Date: Mon, 22 Dec 2025 17:15:32 -0800 Subject: [PATCH 12/18] nit: Fix comments Signed-off-by: Kshitij Janardan Lakhani --- transformer_engine/jax/attention.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 6a884a8c484..7f9acf8568b 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -791,7 +791,6 @@ def from_seqlens_and_offsets( q_offsets, kv_offsets = cls._expand_to_pair(seq_offsets) return cls(seqlens=(q_seqlens, kv_seqlens), seq_offsets=(q_offsets, kv_offsets)) - # TODO(KshitijLakhani), TODO(mgoldfarb-nvidia): Consider adding support for THD layout (non load balanced). @classmethod def from_segment_ids_and_pos( cls, @@ -803,7 +802,7 @@ def from_segment_ids_and_pos( ) -> SequenceDescriptor: """ Experimental factory method for inputs with segment IDs and optional positions. - segment_pos = None to be used only for : BSHD without load balancing + segment_pos = None to be used only for : BSHD and THD without load balancing Args: segment_ids(Tuple(jnp.ndarray, jnp.ndarray)) = (q_segment_ids, kv_segment_ids): - q_segment_ids (jnp.ndarray): From d524ad676028b72ba53e08e488068fff01e0172f Mon Sep 17 00:00:00 2001 From: Kshitij Janardan Lakhani Date: Tue, 23 Dec 2025 14:17:38 -0800 Subject: [PATCH 13/18] Modify the function call to differentiate between load balancing and actually reordered segment_ids and segment_pos Signed-off-by: Kshitij Janardan Lakhani --- tests/jax/test_fused_attn.py | 4 ++-- transformer_engine/jax/attention.py | 27 +++++++++++++++++---------- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index f4ee2b3cc5f..4fc1b4681bd 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -678,7 +678,7 @@ def generate_random_segment_ids( self.cp_reorder_fn(self.segment_pos_kv), ), is_thd=self.qkv_layout.is_thd(), - is_load_balanced=self.cp_size > 1 and self.cp_load_balanced, + is_segment_ids_reordered=True, ) case _: raise ValueError(f"Unknown {self.seq_desc_format=}") @@ -707,7 +707,7 @@ def generate_random_segment_ids( (self.segment_ids_q, self.segment_ids_kv), None, is_thd=self.qkv_layout.is_thd(), - is_load_balanced=self.cp_size > 1 and self.cp_load_balanced, + is_segment_ids_reordered=False ) case _: raise ValueError(f"Unknown {self.seq_desc_format=}") diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 7f9acf8568b..0fd9439f59a 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -798,7 +798,7 @@ def from_segment_ids_and_pos( segment_pos: Optional[Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]] = None, *, is_thd: bool, - is_load_balanced: bool, + is_segment_ids_reordered: bool, ) -> SequenceDescriptor: """ Experimental factory method for inputs with segment IDs and optional positions. @@ -817,21 +817,28 @@ def from_segment_ids_and_pos( - kv_segment_pos (jnp.ndarray): The position inside each segment for key, value, with shape [batch, max_seqlen]. is_thd(bool): If True, QKVLayout is of type THD, else it is BSHD - is_load_balanced(bool): If True, CP is being used and the inputs have been load balanced. + is_segment_ids_reordered(bool): If True, the segment ids have been reordered for load balancing Return: A SequenceDescriptor with segment_ids/segment_pos initialized. """ q_seg_ids, kv_seg_ids = cls._expand_to_pair(segment_ids) - # Using defaults. segment pos has to be generated. + # Using defaults : segment pos has to be generated. if segment_pos is None: - # Segment pos is not calculated implicitly Load balancing cases - assert not is_load_balanced, ( - f"{segment_pos=} default arg is not supported for load balanced inputs. Please pass" - " the load balanced segment_pos and segment_ids using helper function" - " reorder_causal_load_balancing()" - ) - + # THD + load balanced segment_ids are not supported in this function + # BSHD + load balanced segment_ids are incorrect as BSHD handles reordering within the primitive itself + if is_segment_ids_reordered: + assert not is_thd, ( + f"{segment_pos=} default arg is not supported for load balanced reordered (Striped) THD inputs." + " Please pass the load balanced reordered segment_pos and segment_ids explicitly to" + " {from_segment_ids_and_pos.__qualname__} using convenience function reorder_causal_load_balancing()" + ) + assert is_thd, ( + f"{segment_pos=} default arg is not supported for load balanced reordered (Dual Chunk) BSHD inputs." + " BSHD segment_pos and segment_ids do not need to be load balanced reordered. The reordering for these" + " is performed within the primitive" + ) + # Generate the default pos for THD and BSHD non-reordered segment_ids def generate_default_pos(seg_ids): if is_thd: batch_size, seq_size = seg_ids.shape From d419f98ff4b559a7f57be367d1b46004f3ccd076 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 23 Dec 2025 22:18:50 +0000 Subject: [PATCH 14/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/jax/test_fused_attn.py | 2 +- transformer_engine/jax/attention.py | 15 +++++++++------ 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 4fc1b4681bd..46e55a5eb95 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -707,7 +707,7 @@ def generate_random_segment_ids( (self.segment_ids_q, self.segment_ids_kv), None, is_thd=self.qkv_layout.is_thd(), - is_segment_ids_reordered=False + is_segment_ids_reordered=False, ) case _: raise ValueError(f"Unknown {self.seq_desc_format=}") diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 0fd9439f59a..fb5cba5e01a 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -829,15 +829,18 @@ def from_segment_ids_and_pos( # BSHD + load balanced segment_ids are incorrect as BSHD handles reordering within the primitive itself if is_segment_ids_reordered: assert not is_thd, ( - f"{segment_pos=} default arg is not supported for load balanced reordered (Striped) THD inputs." - " Please pass the load balanced reordered segment_pos and segment_ids explicitly to" - " {from_segment_ids_and_pos.__qualname__} using convenience function reorder_causal_load_balancing()" + f"{segment_pos=} default arg is not supported for load balanced reordered" + " (Striped) THD inputs. Please pass the load balanced reordered segment_pos" + " and segment_ids explicitly to {from_segment_ids_and_pos.__qualname__}" + " using convenience function reorder_causal_load_balancing()" ) assert is_thd, ( - f"{segment_pos=} default arg is not supported for load balanced reordered (Dual Chunk) BSHD inputs." - " BSHD segment_pos and segment_ids do not need to be load balanced reordered. The reordering for these" - " is performed within the primitive" + f"{segment_pos=} default arg is not supported for load balanced reordered (Dual" + " Chunk) BSHD inputs. BSHD segment_pos and segment_ids do not need to be load" + " balanced reordered. The reordering for these is performed within the" + " primitive" ) + # Generate the default pos for THD and BSHD non-reordered segment_ids def generate_default_pos(seg_ids): if is_thd: From 3efa504198ecc2ca156794a844f3e8adfe4453b4 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Wed, 24 Dec 2025 00:14:58 +0000 Subject: [PATCH 15/18] Fix the is_segment_ids_reordered to be set only when CP and load balancing Signed-off-by: Kshitij Lakhani --- tests/jax/test_fused_attn.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 46e55a5eb95..5ed1eeaca8b 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -668,6 +668,8 @@ def generate_random_segment_ids( (self.offsets_q, self.offsets_kv), ) case SeqDescFormat.SegmentIDs: + # Exercise the path to generate the in segment_pos in from_segment_ids_and_pos() + # if no CP, else explicitly pass the segment_pos self.sequence_desciptor = SequenceDescriptor.from_segment_ids_and_pos( ( self.cp_reorder_fn(self.segment_ids_q), @@ -676,9 +678,9 @@ def generate_random_segment_ids( ( self.cp_reorder_fn(self.segment_pos_q), self.cp_reorder_fn(self.segment_pos_kv), - ), + ) if self.cp_size > 1 and self.cp_load_balanced else None, is_thd=self.qkv_layout.is_thd(), - is_segment_ids_reordered=True, + is_segment_ids_reordered=True if self.cp_size > 1 and self.cp_load_balanced else False, ) case _: raise ValueError(f"Unknown {self.seq_desc_format=}") From e65062cedc435acc6d28d720e72df8de7c0dddbf Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Wed, 24 Dec 2025 00:15:39 +0000 Subject: [PATCH 16/18] Fix comments for from_segment_ids_and_pos() Signed-off-by: Kshitij Lakhani --- transformer_engine/jax/attention.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index fb5cba5e01a..591e8606109 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -658,7 +658,7 @@ class SequenceDescriptor: - SequenceDescriptor.from_seqlens_and_offsets For THD (packed) cases, where each batch may have not only 1 sequence. - SequenceDescriptor.from_segment_ids_and_pos - Experimental feature for THD (packed) cases with context parallelism. + Experimental feature for BSHD (with and without reordering) and THD (packed) cases without reordering """ seqlens: Optional[Tuple[jnp.ndarray, jnp.ndarray]] @@ -802,7 +802,8 @@ def from_segment_ids_and_pos( ) -> SequenceDescriptor: """ Experimental factory method for inputs with segment IDs and optional positions. - segment_pos = None to be used only for : BSHD and THD without load balancing + segment_pos = None to be used only for: BSHD with or without load balancing and, + THD without load balancing Args: segment_ids(Tuple(jnp.ndarray, jnp.ndarray)) = (q_segment_ids, kv_segment_ids): - q_segment_ids (jnp.ndarray): @@ -817,7 +818,8 @@ def from_segment_ids_and_pos( - kv_segment_pos (jnp.ndarray): The position inside each segment for key, value, with shape [batch, max_seqlen]. is_thd(bool): If True, QKVLayout is of type THD, else it is BSHD - is_segment_ids_reordered(bool): If True, the segment ids have been reordered for load balancing + is_segment_ids_reordered(bool): If True, the segment ids have been reordered for load balancing. + Only THD with load balancing is expected to have this flag set to True Return: A SequenceDescriptor with segment_ids/segment_pos initialized. """ From 74a352e11ae2c4980ed318d11ee48a5d27c43bc2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 24 Dec 2025 00:17:04 +0000 Subject: [PATCH 17/18] Code clean up for more information, see https://pre-commit.ci Fix lint errors Signed-off-by: Kshitij Lakhani --- tests/jax/test_fused_attn.py | 18 ++++++++++++------ transformer_engine/jax/attention.py | 10 +++++----- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 5ed1eeaca8b..f7267af5b8a 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -668,19 +668,25 @@ def generate_random_segment_ids( (self.offsets_q, self.offsets_kv), ) case SeqDescFormat.SegmentIDs: - # Exercise the path to generate the in segment_pos in from_segment_ids_and_pos() - # if no CP, else explicitly pass the segment_pos + # Exercise the path to generate the segment_pos in from_segment_ids_and_pos() + # if no CP and load balancing, else explicitly pass the segment_pos self.sequence_desciptor = SequenceDescriptor.from_segment_ids_and_pos( ( self.cp_reorder_fn(self.segment_ids_q), self.cp_reorder_fn(self.segment_ids_kv), ), ( - self.cp_reorder_fn(self.segment_pos_q), - self.cp_reorder_fn(self.segment_pos_kv), - ) if self.cp_size > 1 and self.cp_load_balanced else None, + ( + self.cp_reorder_fn(self.segment_pos_q), + self.cp_reorder_fn(self.segment_pos_kv), + ) + if self.cp_size > 1 and self.cp_load_balanced + else None + ), is_thd=self.qkv_layout.is_thd(), - is_segment_ids_reordered=True if self.cp_size > 1 and self.cp_load_balanced else False, + is_segment_ids_reordered=( + True if self.cp_size > 1 and self.cp_load_balanced else False + ), ) case _: raise ValueError(f"Unknown {self.seq_desc_format=}") diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 591e8606109..3a7a90f034a 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -852,7 +852,7 @@ def generate_default_pos(seg_ids): # Get segment start positions segment_start = jnp.concatenate( [ - first_is_segment, # First valid element starts a segment + first_is_segment, (seg_ids[..., 1:] != seg_ids[..., :-1]) & (seg_ids[..., 1:] != 0), ], axis=-1, @@ -861,7 +861,7 @@ def generate_default_pos(seg_ids): segment_start_idx = jax.vmap(lambda row: jnp.arange(row.size) * row)( segment_start ) - segment_start_offsets = jax.vmap(lambda row: jnp.maximum.accumulate(row))( + segment_start_offsets = jax.vmap(jnp.maximum.accumulate)( segment_start_idx ) @@ -887,9 +887,9 @@ def generate_default_pos(seg_ids): lambda pos_row, mask_row: jnp.where(mask_row, pos_row, 0) )(seg_pos, mask) return segment_pos - else: - seqlen = seg_ids.shape[-1] - return jnp.broadcast_to(jnp.arange(seqlen), seg_ids.shape) + + seqlen = seg_ids.shape[-1] + return jnp.broadcast_to(jnp.arange(seqlen), seg_ids.shape) q_seg_pos = generate_default_pos(q_seg_ids) kv_seg_pos = generate_default_pos(kv_seg_ids) From e5381fbee290dd0ee354f36cbfb8a3d8f362d54b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 24 Dec 2025 01:46:50 +0000 Subject: [PATCH 18/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/attention.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 3a7a90f034a..09a29f4cb83 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -861,9 +861,7 @@ def generate_default_pos(seg_ids): segment_start_idx = jax.vmap(lambda row: jnp.arange(row.size) * row)( segment_start ) - segment_start_offsets = jax.vmap(jnp.maximum.accumulate)( - segment_start_idx - ) + segment_start_offsets = jax.vmap(jnp.maximum.accumulate)(segment_start_idx) # Get the last non-zero index - after this everything is padding # (B,) @@ -887,7 +885,7 @@ def generate_default_pos(seg_ids): lambda pos_row, mask_row: jnp.where(mask_row, pos_row, 0) )(seg_pos, mask) return segment_pos - + seqlen = seg_ids.shape[-1] return jnp.broadcast_to(jnp.arange(seqlen), seg_ids.shape)