diff --git a/examples/benchmark/configs/deepseek_v3_te_deepep.yaml b/examples/benchmark/configs/deepseek_v3_te_deepep.yaml index d5e56a70e..10ab5c34a 100644 --- a/examples/benchmark/configs/deepseek_v3_te_deepep.yaml +++ b/examples/benchmark/configs/deepseek_v3_te_deepep.yaml @@ -72,6 +72,7 @@ model: attn: te linear: te rms_norm: te + rope_fusion: true enable_deepep: true fake_balanced_gate: true enable_hf_state_dict_adapter: false diff --git a/examples/benchmark/configs/deepseek_v3_te_deepep_1024.yaml b/examples/benchmark/configs/deepseek_v3_te_deepep_1024.yaml index 1414ec852..168560858 100644 --- a/examples/benchmark/configs/deepseek_v3_te_deepep_1024.yaml +++ b/examples/benchmark/configs/deepseek_v3_te_deepep_1024.yaml @@ -53,6 +53,7 @@ model: attn: te linear: te rms_norm: te + rope_fusion: true enable_deepep: true fake_balanced_gate: true enable_hf_state_dict_adapter: false diff --git a/examples/benchmark/configs/glm_4.5_air_te_deepep.yaml b/examples/benchmark/configs/glm_4.5_air_te_deepep.yaml index 566fe5a4c..9150c0ed0 100644 --- a/examples/benchmark/configs/glm_4.5_air_te_deepep.yaml +++ b/examples/benchmark/configs/glm_4.5_air_te_deepep.yaml @@ -71,6 +71,7 @@ model: attn: te linear: te rms_norm: te + rope_fusion: true enable_deepep: true fake_balanced_gate: true enable_hf_state_dict_adapter: true diff --git a/examples/benchmark/configs/gptoss_120b_te_deepep.yaml b/examples/benchmark/configs/gptoss_120b_te_deepep.yaml index 014432e8e..219c4b0ae 100644 --- a/examples/benchmark/configs/gptoss_120b_te_deepep.yaml +++ b/examples/benchmark/configs/gptoss_120b_te_deepep.yaml @@ -71,6 +71,7 @@ model: attn: flex linear: te rms_norm: te + rope_fusion: true enable_deepep: true fake_balanced_gate: true enable_hf_state_dict_adapter: false diff --git a/examples/benchmark/configs/gptoss_20b_te_deepep.yaml b/examples/benchmark/configs/gptoss_20b_te_deepep.yaml index 24e9087f2..21b4a697a 100644 --- a/examples/benchmark/configs/gptoss_20b_te_deepep.yaml +++ b/examples/benchmark/configs/gptoss_20b_te_deepep.yaml @@ -70,6 +70,7 @@ model: attn: flex linear: te rms_norm: te + rope_fusion: true enable_deepep: true fake_balanced_gate: true enable_hf_state_dict_adapter: false diff --git a/examples/benchmark/configs/kimi_k2_te_deepep.yaml b/examples/benchmark/configs/kimi_k2_te_deepep.yaml index 73cfdcde8..72746408d 100644 --- a/examples/benchmark/configs/kimi_k2_te_deepep.yaml +++ b/examples/benchmark/configs/kimi_k2_te_deepep.yaml @@ -70,6 +70,7 @@ model: attn: te linear: te rms_norm: te + rope_fusion: true enable_deepep: true fake_balanced_gate: true enable_hf_state_dict_adapter: false diff --git a/examples/benchmark/configs/moonlight_16b_te_deepep.yaml b/examples/benchmark/configs/moonlight_16b_te_deepep.yaml index 8c8562a34..533db1f74 100644 --- a/examples/benchmark/configs/moonlight_16b_te_deepep.yaml +++ b/examples/benchmark/configs/moonlight_16b_te_deepep.yaml @@ -71,6 +71,7 @@ model: attn: te linear: te rms_norm: te + rope_fusion: true enable_deepep: true fake_balanced_gate: true enable_hf_state_dict_adapter: true diff --git a/examples/benchmark/configs/qwen3_moe_235b_te_deepep.yaml b/examples/benchmark/configs/qwen3_moe_235b_te_deepep.yaml index 266368f34..f177acb6d 100644 --- a/examples/benchmark/configs/qwen3_moe_235b_te_deepep.yaml +++ b/examples/benchmark/configs/qwen3_moe_235b_te_deepep.yaml @@ -72,6 +72,7 @@ model: attn: te linear: te rms_norm: te + rope_fusion: true enable_deepep: true fake_balanced_gate: true enable_hf_state_dict_adapter: true diff --git a/examples/benchmark/configs/qwen3_moe_30b_te_deepep.yaml b/examples/benchmark/configs/qwen3_moe_30b_te_deepep.yaml index 273b962b1..62b9fdf5b 100644 --- a/examples/benchmark/configs/qwen3_moe_30b_te_deepep.yaml +++ b/examples/benchmark/configs/qwen3_moe_30b_te_deepep.yaml @@ -70,6 +70,7 @@ model: attn: te linear: te rms_norm: te + rope_fusion: true enable_deepep: true fake_balanced_gate: true enable_hf_state_dict_adapter: true diff --git a/examples/benchmark/configs/qwen3_next_te_deepep.yaml b/examples/benchmark/configs/qwen3_next_te_deepep.yaml index 8fd4756c3..f91458ba2 100644 --- a/examples/benchmark/configs/qwen3_next_te_deepep.yaml +++ b/examples/benchmark/configs/qwen3_next_te_deepep.yaml @@ -71,6 +71,7 @@ model: attn: te linear: te rms_norm: te + rope_fusion: true enable_deepep: true fake_balanced_gate: true enable_hf_state_dict_adapter: true diff --git a/nemo_automodel/components/distributed/cp_utils.py b/nemo_automodel/components/distributed/cp_utils.py index f84929e5a..c12ff9f39 100644 --- a/nemo_automodel/components/distributed/cp_utils.py +++ b/nemo_automodel/components/distributed/cp_utils.py @@ -288,6 +288,8 @@ def make_cp_batch_for_te( "max_seqlen": torch.stack([chunk["max_seqlen"] for chunk in chunks]), "qkv_format": qkv_format, "padding_mask": torch.stack([chunk["padding_mask"] for chunk in chunks]), + "cp_size": cp_mesh.size() if cp_mesh is not None else 1, + "cp_rank": torch.distributed.get_rank(group=cp_mesh.get_group()) if cp_mesh is not None else 0, } @@ -329,5 +331,7 @@ def _shard_thd_chunk_for_te( "max_seqlen": torch.tensor(max_seqlen).to(torch.int32).to(device=cu_seqlens_padded.device), "qkv_format": qkv_format, "padding_mask": (batch["input_ids"] == padding_token_id).bool().contiguous(), + "cp_size": cp_size, + "cp_rank": cp_rank, } return output_batch diff --git a/nemo_automodel/components/models/deepseek_v3/layers.py b/nemo_automodel/components/models/deepseek_v3/layers.py index 260004f74..7544d9b94 100644 --- a/nemo_automodel/components/models/deepseek_v3/layers.py +++ b/nemo_automodel/components/models/deepseek_v3/layers.py @@ -23,7 +23,10 @@ postprocess_output_for_attn, preprocess_args_and_kwargs_for_attn, ) -from nemo_automodel.components.models.deepseek_v3.rope_utils import apply_rotary_emb, yarn_get_mscale +from nemo_automodel.components.models.deepseek_v3.rope_utils import ( + apply_rotary_emb_qk, + yarn_get_mscale, +) from nemo_automodel.components.moe.utils import ( BackendConfig, initialize_linear_module, @@ -46,6 +49,7 @@ def __init__(self, config: DeepseekV3Config, backend: BackendConfig): self.v_head_dim = config.v_head_dim self.backend = backend + self.rope_fusion = backend.rope_fusion attn_impl = backend.attn linear_impl = backend.linear rms_norm_impl = backend.rms_norm @@ -135,14 +139,34 @@ def forward( q = q.view(bsz, local_seq_len, self.n_heads, self.qk_head_dim) q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - q_pe = apply_rotary_emb(q_pe, freqs_cis, qkv_format, unsqueeze_dim=None) - - q = torch.cat([q_nope, q_pe], dim=-1) kv = self.kv_a_proj_with_mqa(x) kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) kv = self.kv_a_layernorm(kv) - k_pe = apply_rotary_emb(k_pe, freqs_cis, qkv_format, unsqueeze_dim=2) + + # For MLA, k_pe needs an extra head dimension for apply_rotary_emb_qk + # k_pe shape: (B, S, rope_dim) -> (B, S, 1, rope_dim) for bshd + # k_pe shape: (T, rope_dim) -> (T, 1, rope_dim) for thd + head_unsqueeze_dim = 2 if qkv_format == "bshd" else 1 + k_pe = k_pe.unsqueeze(head_unsqueeze_dim) + + # Apply rotary embeddings to q_pe and k_pe + cu_seqlens = attn_kwargs.get("cu_seqlens", None) + q_pe, k_pe = apply_rotary_emb_qk( + q_pe, + k_pe, + freqs_cis, + format=qkv_format, + rope_fusion=self.rope_fusion, + cu_seqlens=cu_seqlens, + cp_size=attn_kwargs.get("cp_size", 1), + cp_rank=attn_kwargs.get("cp_rank", 0), + ) + + # Remove the head dimension we added to k_pe + k_pe = k_pe.squeeze(head_unsqueeze_dim) + + q = torch.cat([q_nope, q_pe], dim=-1) kv = self.kv_b_proj(kv) if qkv_format == "thd": diff --git a/nemo_automodel/components/models/deepseek_v3/model.py b/nemo_automodel/components/models/deepseek_v3/model.py index 862057498..a202f7985 100644 --- a/nemo_automodel/components/models/deepseek_v3/model.py +++ b/nemo_automodel/components/models/deepseek_v3/model.py @@ -167,7 +167,13 @@ def forward( ) with torch.no_grad(): - freqs_cis = freqs_cis_from_position_ids(position_ids, self.freqs_cis) + freqs_cis = freqs_cis_from_position_ids( + position_ids, + self.freqs_cis, + qkv_format=attn_kwargs.get("qkv_format", "bshd"), + for_fused_rope=self.backend.rope_fusion, + cp_size=attn_kwargs.get("cp_size", 1), + ) h = self.embed_tokens(input_ids) if self.embed_tokens is not None else input_ids diff --git a/nemo_automodel/components/models/deepseek_v3/rope_utils.py b/nemo_automodel/components/models/deepseek_v3/rope_utils.py index 9f48e6d8f..c0be5574e 100644 --- a/nemo_automodel/components/models/deepseek_v3/rope_utils.py +++ b/nemo_automodel/components/models/deepseek_v3/rope_utils.py @@ -135,7 +135,7 @@ def apply_rotary_emb( if unsqueeze_dim is not None: x = x.unsqueeze(unsqueeze_dim) - x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2)) + x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2)) # [b, s, h, d] -> [b, s, h, d/2, 2] freqs_cis = freqs_cis.view(x.size(0), x.size(1), 1, x.size(-1)) y = torch.view_as_real(x * freqs_cis).flatten(3) y = y.to(dtype) @@ -148,7 +148,89 @@ def apply_rotary_emb( return y -def freqs_cis_from_position_ids(position_ids: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: - freqs = torch.matmul(position_ids.unsqueeze(-1).float(), freqs.unsqueeze(0)) - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) +def apply_rotary_emb_qk( + q: torch.Tensor, + k: torch.Tensor, + freqs_cis: torch.Tensor, + format: str = "bshd", + rope_fusion: bool = True, + cu_seqlens: torch.Tensor | None = None, + cp_size: int = 1, + cp_rank: int = 0, +) -> tuple[torch.Tensor, torch.Tensor]: + if rope_fusion: + from transformer_engine.pytorch.attention.rope import apply_rotary_pos_emb + + q = apply_rotary_pos_emb( + q, + freqs_cis, + tensor_format=format, + interleaved=True, + fused=True, + cu_seqlens=cu_seqlens, + cp_size=cp_size, + cp_rank=cp_rank, + ) + k = apply_rotary_pos_emb( + k, + freqs_cis, + tensor_format=format, + interleaved=True, + fused=True, + cu_seqlens=cu_seqlens, + cp_size=cp_size, + cp_rank=cp_rank, + ) + return q, k + else: + q = apply_rotary_emb(q, freqs_cis, qkv_format=format) + k = apply_rotary_emb(k, freqs_cis, qkv_format=format) + return q, k + + +@torch.no_grad() +def freqs_cis_from_position_ids( + position_ids: torch.Tensor, + freqs: torch.Tensor, + qkv_format: str = "bshd", + for_fused_rope: bool = False, + cp_size: int = 1, +) -> torch.Tensor: + if qkv_format == "thd": + if for_fused_rope: + # For fused rope with thd, use sequential positions + position_ids = torch.arange( + position_ids.shape[0] * cp_size, device=position_ids.device, dtype=torch.int32 + ).unsqueeze(0) + else: + # For non-fused thd, ensure 2D + if position_ids.dim() == 1: + position_ids = position_ids.unsqueeze(0) + + freqs = freqs.to(device=position_ids.device) + + # Compute angles: (B, T, D/2) - same as original implementation + # angles = torch.matmul(position_ids.unsqueeze(-1).float(), freqs.unsqueeze(0)) + angles = torch.einsum("bt,d->btd", position_ids.to(dtype=torch.float32), freqs) + + if for_fused_rope: + if qkv_format == "thd": + freqs_cis = angles.squeeze(0) + else: + # For bshd, take first batch (assumes uniform positions) + freqs_cis = angles[0] + + # TE fused rope expects [angles, angles] format + # freqs_cis = torch.cat((angles, angles), dim=-1) + freqs_cis = torch.stack((freqs_cis.view(-1, 1), freqs_cis.view(-1, 1)), dim=-1).view(freqs_cis.shape[0], -1) + + # Reshape to (T, 1, 1, D) for TE fused rope + freqs_cis = freqs_cis.reshape(freqs_cis.size(0), 1, 1, freqs_cis.size(1)).contiguous() + else: + # Return complex exponentials for non-fused rope (original behavior) + freqs_cis = torch.polar(torch.ones_like(angles), angles) + + if qkv_format == "thd": + freqs_cis = freqs_cis.squeeze(0) + return freqs_cis diff --git a/nemo_automodel/components/models/glm4_moe/layers.py b/nemo_automodel/components/models/glm4_moe/layers.py index a42b08aa5..50bc5c4f2 100644 --- a/nemo_automodel/components/models/glm4_moe/layers.py +++ b/nemo_automodel/components/models/glm4_moe/layers.py @@ -23,7 +23,7 @@ postprocess_output_for_attn, preprocess_args_and_kwargs_for_attn, ) -from nemo_automodel.components.models.gpt_oss.rope_utils import apply_rotary_emb +from nemo_automodel.components.models.gpt_oss.rope_utils import apply_rotary_emb_qk from nemo_automodel.components.moe.utils import ( BackendConfig, initialize_linear_module, @@ -121,11 +121,16 @@ def forward( k = self.k_norm(k) # Partial RoPE (only apply to first partial_rotary_factor of head_dim) - rotary_dim = int(self.head_dim * self.partial_rotary_factor) - cos, sin = freqs_cis.split(rotary_dim // 2, dim=-1) - - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) + q, k = apply_rotary_emb_qk( + q, + k, + freqs_cis, + format=qkv_format, + rope_fusion=self.backend.rope_fusion, + cu_seqlens=attn_kwargs.get("cu_seqlens", None), + cp_size=attn_kwargs.get("cp_size", 1), + cp_rank=attn_kwargs.get("cp_rank", 0), + ) # Backend-specific attention q, k, v, _attn_kwargs = preprocess_args_and_kwargs_for_attn( diff --git a/nemo_automodel/components/models/glm4_moe/model.py b/nemo_automodel/components/models/glm4_moe/model.py index 57511545b..3a25e3867 100644 --- a/nemo_automodel/components/models/glm4_moe/model.py +++ b/nemo_automodel/components/models/glm4_moe/model.py @@ -154,7 +154,11 @@ def forward( # Compute freqs_cis from RotaryEmbedding inv_freq and current position_ids; then concat [cos, sin] freqs_cis = position_ids_to_freqs_cis( - self.rotary_emb, position_ids, qkv_format=attn_kwargs.get("qkv_format", "bshd") + self.rotary_emb, + position_ids, + qkv_format=attn_kwargs.get("qkv_format", "bshd"), + for_fused_rope=self.backend.rope_fusion, + cp_size=attn_kwargs.get("cp_size", 1), ) h = self.embed_tokens(input_ids) if self.embed_tokens is not None else input_ids diff --git a/nemo_automodel/components/models/gpt_oss/layers.py b/nemo_automodel/components/models/gpt_oss/layers.py index 9ed55ed2d..ffb34713d 100644 --- a/nemo_automodel/components/models/gpt_oss/layers.py +++ b/nemo_automodel/components/models/gpt_oss/layers.py @@ -18,6 +18,7 @@ from torch import nn from torch.distributed.tensor import DTensor +from nemo_automodel.components.models.deepseek_v3.rope_utils import yarn_get_mscale from nemo_automodel.shared.import_utils import is_te_min_version if TYPE_CHECKING: @@ -28,7 +29,7 @@ postprocess_output_for_attn, preprocess_args_and_kwargs_for_attn, ) -from nemo_automodel.components.models.gpt_oss.rope_utils import apply_rotary_emb +from nemo_automodel.components.models.gpt_oss.rope_utils import apply_rotary_emb_qk from nemo_automodel.components.moe.utils import ( BackendConfig, initialize_linear_module, @@ -60,6 +61,12 @@ def __init__(self, config: "GptOssConfig", backend: BackendConfig, use_sliding_a ) self.softmax_scale = self.head_dim**-0.5 + # When using fused rope, YaRN concentration is not baked into freqs_cis, + # so we need to apply concentration to q and k after fused rope + if backend.rope_fusion: + self.yarn_concentration = yarn_get_mscale(config.rope_scaling["factor"]) + else: + self.yarn_concentration = None assert backend.attn in ("flex", "te"), "Only Flex and TE Attention are supported for GPT-OSS" if backend.attn == "te" and not is_te_min_version("2.8.0"): @@ -107,10 +114,18 @@ def forward( k = k.view(bsz, seqlen, self.num_key_value_heads, self.head_dim) v = v.view(bsz, seqlen, self.num_key_value_heads, self.head_dim) - # freqs_cis is concatenated [cos, sin] along last dim with shape (B, T, head_dim) - cos, sin = freqs_cis.split(self.head_dim // 2, dim=-1) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) + # Apply rotary positional embeddings + q, k = apply_rotary_emb_qk( + q, + k, + freqs_cis, + format=qkv_format, + rope_fusion=self.backend.rope_fusion, + cu_seqlens=attn_kwargs.get("cu_seqlens", None), + concentration=self.yarn_concentration, + cp_size=attn_kwargs.get("cp_size", 1), + cp_rank=attn_kwargs.get("cp_rank", 0), + ) if self.backend.attn == "flex": updated_attn_kwargs = { @@ -146,8 +161,8 @@ def init_weights(self, buffer_device: torch.device, init_std: float = 0.02): ] if self.backend.attn == "flex": - nn.init.trunc_normal_(self.sinks, mean=0.0, std=init_std) + nn.init.normal_(self.sinks, mean=0.0, std=init_std) else: - nn.init.trunc_normal_(self.attn_module.softmax_offset, mean=0.0, std=init_std) + nn.init.normal_(self.attn_module.softmax_offset, mean=0.0, std=init_std) for linear in linear_list: nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) diff --git a/nemo_automodel/components/models/gpt_oss/model.py b/nemo_automodel/components/models/gpt_oss/model.py index 23d2590ff..e6dbb7307 100644 --- a/nemo_automodel/components/models/gpt_oss/model.py +++ b/nemo_automodel/components/models/gpt_oss/model.py @@ -153,7 +153,11 @@ def forward( # Compute cos/sin from RotaryEmbedding inv_freq and current position_ids; then concat [cos, sin] freqs_cis = position_ids_to_freqs_cis( - self.rotary_emb, position_ids, qkv_format=attn_kwargs.get("qkv_format", "bshd") + self.rotary_emb, + position_ids, + qkv_format=attn_kwargs.get("qkv_format", "bshd"), + for_fused_rope=self.backend.rope_fusion, + cp_size=attn_kwargs.get("cp_size", 1), ) h = self.embed_tokens(input_ids) if self.embed_tokens is not None else input_ids diff --git a/nemo_automodel/components/models/gpt_oss/rope_utils.py b/nemo_automodel/components/models/gpt_oss/rope_utils.py index 5177f0220..51cb14ff0 100644 --- a/nemo_automodel/components/models/gpt_oss/rope_utils.py +++ b/nemo_automodel/components/models/gpt_oss/rope_utils.py @@ -134,20 +134,90 @@ def forward( return query, key +def apply_rotary_emb_qk( + q: torch.Tensor, + k: torch.Tensor, + freqs_cis: torch.Tensor, + format: str = "bshd", + rope_fusion: bool = True, + cu_seqlens: torch.Tensor | None = None, + concentration: float | None = None, + cp_size: int = 1, + cp_rank: int = 0, +) -> tuple[torch.Tensor, torch.Tensor]: + """Apply rotary embeddings to query and key tensors. + + Args: + q: Query tensor. + k: Key tensor. + freqs_cis: Frequency tensor. Format depends on rope_fusion: + - If rope_fusion=True: [angles, angles] for TE fused rope + - If rope_fusion=False: [cos, sin] with concentration applied + format: QKV format ("bshd" or "thd"). + rope_fusion: If True, use TE fused rope. If False, use non-fused rope. + cu_seqlens: Cumulative sequence lengths for variable-length sequences. + cp_size: Context parallelism size. + cp_rank: Context parallelism rank. + + Returns: + Tuple of (q, k) with rotary embeddings applied. + """ + if rope_fusion: + from transformer_engine.pytorch.attention.rope import apply_rotary_pos_emb + + q = apply_rotary_pos_emb( + q, freqs_cis, tensor_format=format, fused=True, cu_seqlens=cu_seqlens, cp_size=cp_size, cp_rank=cp_rank + ) + k = apply_rotary_pos_emb( + k, freqs_cis, tensor_format=format, fused=True, cu_seqlens=cu_seqlens, cp_size=cp_size, cp_rank=cp_rank + ) + if concentration is not None: + q = q * concentration + k = k * concentration + return q, k + else: + cos, sin = freqs_cis.split(freqs_cis.shape[-1] // 2, dim=-1) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + return q, k + + @torch.no_grad() def position_ids_to_freqs_cis( - rotary_emb: RotaryEmbedding, position_ids: torch.Tensor, qkv_format: str = "bshd" + rotary_emb: RotaryEmbedding, + position_ids: torch.Tensor, + qkv_format: str = "bshd", + for_fused_rope: bool = True, + cp_size: int = 1, ) -> torch.Tensor: if qkv_format == "thd": - position_ids = position_ids.unsqueeze(0) + if for_fused_rope: + position_ids = torch.arange( + position_ids.shape[0] * cp_size, device=position_ids.device, dtype=torch.int32 + ).unsqueeze(0) + else: + position_ids = position_ids.unsqueeze(0) concentration, inv_freq = rotary_emb._compute_concentration_and_inv_freq() inv_freq = inv_freq.to(device=position_ids.device, dtype=torch.float32) # angles: (B, T, D/2) angles = torch.einsum("bt,d->btd", position_ids.to(dtype=torch.float32), inv_freq) - cos = torch.cos(angles) * concentration - sin = torch.sin(angles) * concentration - freqs_cis = torch.cat([cos, sin], dim=-1) + + if for_fused_rope: + # TE fused rope expects [angles, angles] + freqs_cis = torch.cat((angles, angles), dim=-1) + else: + # Non-fused rope expects [cos, sin] with concentration applied + cos = torch.cos(angles) * concentration + sin = torch.sin(angles) * concentration + freqs_cis = torch.cat([cos, sin], dim=-1) + if qkv_format == "thd": freqs_cis = freqs_cis.squeeze(0) + else: + freqs_cis = freqs_cis[0] if for_fused_rope else freqs_cis + + if for_fused_rope: + freqs_cis = freqs_cis.reshape(freqs_cis.size(0), 1, 1, freqs_cis.size(1)).contiguous() + return freqs_cis diff --git a/nemo_automodel/components/models/qwen3_moe/layers.py b/nemo_automodel/components/models/qwen3_moe/layers.py index fb1a33127..52cfc98e5 100644 --- a/nemo_automodel/components/models/qwen3_moe/layers.py +++ b/nemo_automodel/components/models/qwen3_moe/layers.py @@ -23,7 +23,7 @@ postprocess_output_for_attn, preprocess_args_and_kwargs_for_attn, ) -from nemo_automodel.components.models.gpt_oss.rope_utils import apply_rotary_emb +from nemo_automodel.components.models.gpt_oss.rope_utils import apply_rotary_emb_qk from nemo_automodel.components.moe.utils import ( BackendConfig, initialize_linear_module, @@ -113,9 +113,16 @@ def forward( k = self.k_norm(k) # RoPE (complex rotation) - cos, sin = freqs_cis.split(self.head_dim // 2, dim=-1) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) + q, k = apply_rotary_emb_qk( + q, + k, + freqs_cis, + format=qkv_format, + rope_fusion=self.backend.rope_fusion, + cu_seqlens=attn_kwargs.get("cu_seqlens", None), + cp_size=attn_kwargs.get("cp_size", 1), + cp_rank=attn_kwargs.get("cp_rank", 0), + ) # Backend-specific attention q, k, v, _attn_kwargs = preprocess_args_and_kwargs_for_attn( diff --git a/nemo_automodel/components/models/qwen3_moe/model.py b/nemo_automodel/components/models/qwen3_moe/model.py index 16af3f155..a99c2f40b 100644 --- a/nemo_automodel/components/models/qwen3_moe/model.py +++ b/nemo_automodel/components/models/qwen3_moe/model.py @@ -156,7 +156,11 @@ def forward( # Compute freqs_cis from RotaryEmbedding inv_freq and current position_ids; then concat [cos, sin] freqs_cis = position_ids_to_freqs_cis( - self.rotary_emb, position_ids, qkv_format=attn_kwargs.get("qkv_format", "bshd") + self.rotary_emb, + position_ids, + qkv_format=attn_kwargs.get("qkv_format", "bshd"), + for_fused_rope=self.backend.rope_fusion, + cp_size=attn_kwargs.get("cp_size", 1), ) h = self.embed_tokens(input_ids) if self.embed_tokens is not None else input_ids diff --git a/nemo_automodel/components/models/qwen3_next/layers.py b/nemo_automodel/components/models/qwen3_next/layers.py index 60b1d6cbb..eb4e25fdc 100644 --- a/nemo_automodel/components/models/qwen3_next/layers.py +++ b/nemo_automodel/components/models/qwen3_next/layers.py @@ -23,7 +23,7 @@ postprocess_output_for_attn, preprocess_args_and_kwargs_for_attn, ) -from nemo_automodel.components.models.gpt_oss.rope_utils import apply_rotary_emb +from nemo_automodel.components.models.gpt_oss.rope_utils import apply_rotary_emb_qk from nemo_automodel.components.moe.utils import ( BackendConfig, initialize_linear_module, @@ -125,10 +125,17 @@ def forward( q = self.q_norm(q) k = self.k_norm(k) - # Apply RoPE (split freqs_cis by half, accounting for partial rotary factor) - cos, sin = freqs_cis.split(freqs_cis.shape[-1] // 2, dim=-1) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) + # Apply RoPE + q, k = apply_rotary_emb_qk( + q, + k, + freqs_cis, + format=qkv_format, + rope_fusion=self.backend.rope_fusion, + cu_seqlens=attn_kwargs.get("cu_seqlens", None), + cp_size=attn_kwargs.get("cp_size", 1), + cp_rank=attn_kwargs.get("cp_rank", 0), + ) # Backend-specific attention q, k, v, _attn_kwargs = preprocess_args_and_kwargs_for_attn( diff --git a/nemo_automodel/components/models/qwen3_next/model.py b/nemo_automodel/components/models/qwen3_next/model.py index 995f3905c..cc9dc6391 100644 --- a/nemo_automodel/components/models/qwen3_next/model.py +++ b/nemo_automodel/components/models/qwen3_next/model.py @@ -173,7 +173,11 @@ def forward( # Compute freqs_cis from RotaryEmbedding inv_freq and current position_ids; then concat [cos, sin] freqs_cis = position_ids_to_freqs_cis( - self.rotary_emb, position_ids, qkv_format=attn_kwargs.get("qkv_format", "bshd") + self.rotary_emb, + position_ids, + qkv_format=attn_kwargs.get("qkv_format", "bshd"), + for_fused_rope=self.backend.rope_fusion, + cp_size=attn_kwargs.get("cp_size", 1), ) h = self.embed_tokens(input_ids) if self.embed_tokens is not None else input_ids diff --git a/nemo_automodel/components/moe/parallelizer.py b/nemo_automodel/components/moe/parallelizer.py index 670aa6348..e276ee39e 100644 --- a/nemo_automodel/components/moe/parallelizer.py +++ b/nemo_automodel/components/moe/parallelizer.py @@ -138,7 +138,10 @@ def apply_fsdp( if mp_policy is None: mp_policy = MixedPrecisionPolicy( - param_dtype=torch.bfloat16, reduce_dtype=torch.float32, output_dtype=torch.bfloat16 + param_dtype=torch.bfloat16, + reduce_dtype=torch.float32, + output_dtype=torch.bfloat16, + cast_forward_inputs=True, ) fully_shard_default = functools.partial( diff --git a/nemo_automodel/components/moe/utils.py b/nemo_automodel/components/moe/utils.py index 83e83ffd9..ba907a742 100644 --- a/nemo_automodel/components/moe/utils.py +++ b/nemo_automodel/components/moe/utils.py @@ -30,6 +30,7 @@ class BackendConfig: attn: Literal["te", "sdpa", "flex"] = "te" if HAVE_TE and torch.cuda.is_available() else "sdpa" linear: Literal["torch", "te"] = "te" if HAVE_TE and torch.cuda.is_available() else "torch" rms_norm: Literal["torch", "te"] = "te" if HAVE_TE and torch.cuda.is_available() else "torch" + rope_fusion: bool = HAVE_TE and torch.cuda.is_available() enable_deepep: bool = HAVE_DEEP_EP fake_balanced_gate: bool = False enable_hf_state_dict_adapter: bool = True diff --git a/tests/unit_tests/models/deepseek_v3/test_rope_utils.py b/tests/unit_tests/models/deepseek_v3/test_rope_utils.py index e3323591f..85d589302 100644 --- a/tests/unit_tests/models/deepseek_v3/test_rope_utils.py +++ b/tests/unit_tests/models/deepseek_v3/test_rope_utils.py @@ -19,6 +19,7 @@ from nemo_automodel.components.models.deepseek_v3.rope_utils import ( apply_rotary_emb, + apply_rotary_emb_qk, freqs_cis_from_position_ids, precompute_freqs_cis, yarn_get_mscale, @@ -573,3 +574,268 @@ def test_rope_with_scaling_long_context(self): # Verify output assert result.shape == x.shape assert result.dtype == x.dtype + + +class TestApplyRotaryEmbQk: + """Tests for apply_rotary_emb_qk function (DeepseekV3 version)""" + + def test_non_fused_rope_bshd(self): + """Test apply_rotary_emb_qk with non-fused rope in bshd format""" + batch_size = 2 + seq_len = 4 + num_heads = 8 + head_dim = 64 + + q = torch.randn(batch_size, seq_len, num_heads, head_dim) + k = torch.randn(batch_size, seq_len, num_heads, head_dim) + freqs_cis = torch.randn(batch_size, seq_len, head_dim // 2, dtype=torch.complex64) + + q_out, k_out = apply_rotary_emb_qk(q, k, freqs_cis, format="bshd", rope_fusion=False) + + # Check output shapes + assert q_out.shape == q.shape + assert k_out.shape == k.shape + assert q_out.dtype == q.dtype + assert k_out.dtype == k.dtype + + def test_non_fused_rope_thd(self): + """Test apply_rotary_emb_qk with non-fused rope in thd format""" + total_tokens = 16 + num_heads = 8 + head_dim = 64 + + q = torch.randn(total_tokens, num_heads, head_dim) + k = torch.randn(total_tokens, num_heads, head_dim) + freqs_cis = torch.randn(1, total_tokens, head_dim // 2, dtype=torch.complex64) + + q_out, k_out = apply_rotary_emb_qk(q, k, freqs_cis, format="thd", rope_fusion=False) + + # Check output shapes + assert q_out.shape == q.shape + assert k_out.shape == k.shape + + def test_non_fused_rope_consistency(self): + """Test that apply_rotary_emb_qk gives same results as individual apply_rotary_emb calls""" + batch_size = 2 + seq_len = 4 + num_heads = 4 + head_dim = 32 + + q = torch.randn(batch_size, seq_len, num_heads, head_dim) + k = torch.randn(batch_size, seq_len, num_heads, head_dim) + freqs_cis = torch.randn(batch_size, seq_len, head_dim // 2, dtype=torch.complex64) + + # Normalize freqs_cis to have magnitude 1 + freqs_cis = freqs_cis / freqs_cis.abs() + + # Apply using apply_rotary_emb_qk + q_out_qk, k_out_qk = apply_rotary_emb_qk(q, k, freqs_cis, format="bshd", rope_fusion=False) + + # Apply using individual apply_rotary_emb calls + q_out_individual = apply_rotary_emb(q, freqs_cis, qkv_format="bshd") + k_out_individual = apply_rotary_emb(k, freqs_cis, qkv_format="bshd") + + # Results should be identical + torch.testing.assert_close(q_out_qk, q_out_individual) + torch.testing.assert_close(k_out_qk, k_out_individual) + + def test_dtype_preservation(self): + """Test that output dtype matches input dtype""" + batch_size = 2 + seq_len = 4 + num_heads = 4 + head_dim = 32 + + for dtype in [torch.float32, torch.float16, torch.bfloat16]: + q = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=dtype) + k = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=dtype) + freqs_cis = torch.randn(batch_size, seq_len, head_dim // 2, dtype=torch.complex64) + + q_out, k_out = apply_rotary_emb_qk(q, k, freqs_cis, format="bshd", rope_fusion=False) + + assert q_out.dtype == dtype + assert k_out.dtype == dtype + + def test_norm_preservation(self): + """Test that rotary embeddings preserve norm when freqs_cis has unit magnitude""" + batch_size = 2 + seq_len = 4 + num_heads = 2 + head_dim = 8 + + q = torch.randn(batch_size, seq_len, num_heads, head_dim) + k = torch.randn(batch_size, seq_len, num_heads, head_dim) + freqs_cis = torch.randn(batch_size, seq_len, head_dim // 2, dtype=torch.complex64) + + # Normalize freqs_cis to have magnitude 1 + freqs_cis = freqs_cis / freqs_cis.abs() + + q_out, k_out = apply_rotary_emb_qk(q, k, freqs_cis, format="bshd", rope_fusion=False) + + # The norm should be approximately preserved + q_input_norm = torch.norm(q.reshape(-1, head_dim), dim=-1) + q_output_norm = torch.norm(q_out.reshape(-1, head_dim), dim=-1) + torch.testing.assert_close(q_input_norm, q_output_norm, rtol=1e-4, atol=1e-4) + + k_input_norm = torch.norm(k.reshape(-1, head_dim), dim=-1) + k_output_norm = torch.norm(k_out.reshape(-1, head_dim), dim=-1) + torch.testing.assert_close(k_input_norm, k_output_norm, rtol=1e-4, atol=1e-4) + + +class TestFreqsCisFromPositionIdsFusedRope: + """Tests for freqs_cis_from_position_ids with fused rope support""" + + def test_non_fused_bshd_format(self): + """Test non-fused rope output format for bshd""" + batch_size = 2 + seq_len = 4 + head_dim = 64 + + position_ids = torch.arange(seq_len).unsqueeze(0).expand(batch_size, seq_len) + freqs = torch.randn(head_dim // 2) + + freqs_cis = freqs_cis_from_position_ids( + position_ids, freqs, qkv_format="bshd", for_fused_rope=False + ) + + # Non-fused should return complex exponentials with shape (B, T, D/2) + assert freqs_cis.shape == (batch_size, seq_len, head_dim // 2) + assert freqs_cis.dtype == torch.complex64 + + def test_fused_bshd_format(self): + """Test fused rope output format for bshd""" + batch_size = 2 + seq_len = 4 + head_dim = 64 + + position_ids = torch.arange(seq_len).unsqueeze(0).expand(batch_size, seq_len) + freqs = torch.randn(head_dim // 2) + + freqs_cis = freqs_cis_from_position_ids( + position_ids, freqs, qkv_format="bshd", for_fused_rope=True + ) + + # Fused rope expects shape (T, 1, 1, D) where D = head_dim (angles duplicated) + assert freqs_cis.shape == (seq_len, 1, 1, head_dim) + assert freqs_cis.dtype == torch.float32 + + def test_non_fused_thd_format(self): + """Test non-fused rope output format for thd""" + seq_len = 8 + head_dim = 64 + + position_ids = torch.arange(seq_len) + freqs = torch.randn(head_dim // 2) + + freqs_cis = freqs_cis_from_position_ids( + position_ids, freqs, qkv_format="thd", for_fused_rope=False + ) + + # Non-fused thd should return complex exponentials with shape (T, D/2) + assert freqs_cis.shape == (seq_len, head_dim // 2) + assert freqs_cis.dtype == torch.complex64 + + def test_fused_thd_format(self): + """Test fused rope output format for thd""" + seq_len = 8 + head_dim = 64 + + position_ids = torch.arange(seq_len) + freqs = torch.randn(head_dim // 2) + + freqs_cis = freqs_cis_from_position_ids( + position_ids, freqs, qkv_format="thd", for_fused_rope=True + ) + + # Fused rope with thd expects shape (T, 1, 1, D) where D = head_dim + assert freqs_cis.shape == (seq_len, 1, 1, head_dim) + assert freqs_cis.dtype == torch.float32 + + def test_fused_angles_are_interleaved(self): + """Test that fused rope format has angles interleaved [a0, a0, a1, a1, ...]""" + seq_len = 4 + head_dim = 32 + + position_ids = torch.arange(seq_len).unsqueeze(0) + freqs = torch.ones(head_dim // 2) * 0.1 + + freqs_cis = freqs_cis_from_position_ids( + position_ids, freqs, qkv_format="bshd", for_fused_rope=True + ) + + # Shape should be (T, 1, 1, head_dim) + assert freqs_cis.shape == (seq_len, 1, 1, head_dim) + + # Angles should be interleaved: [a0, a0, a1, a1, a2, a2, ...] + # Even indices and odd indices should be equal (each angle is duplicated consecutively) + even_indices = freqs_cis[:, 0, 0, 0::2] # a0, a1, a2, ... + odd_indices = freqs_cis[:, 0, 0, 1::2] # a0, a1, a2, ... + torch.testing.assert_close(even_indices, odd_indices) + + def test_fused_thd_uses_sequential_positions(self): + """Test that fused rope with thd uses sequential positions regardless of input""" + seq_len = 8 + head_dim = 64 + + # Non-sequential position IDs (e.g., from packed sequences) + position_ids = torch.tensor([0, 1, 0, 1, 2, 0, 1, 2]) + freqs = torch.ones(head_dim // 2) * 0.1 + + freqs_cis = freqs_cis_from_position_ids( + position_ids, freqs, qkv_format="thd", for_fused_rope=True + ) + + # Shape should be (T, 1, 1, head_dim) + assert freqs_cis.shape == (seq_len, 1, 1, head_dim) + + # For fused rope, positions should be sequential 0, 1, 2, ... regardless of input + # So angles at position 0 and position 2 (original) should be different + # (because they map to sequential positions 0 and 2) + # With interleaved format, extract unique angles from even indices + angles = freqs_cis[:, 0, 0, 0::2] + + # Each position should have unique angles (since they're sequential) + for i in range(seq_len - 1): + assert not torch.allclose(angles[i], angles[i + 1]) + + def test_non_fused_thd_preserves_positions(self): + """Test that non-fused rope with thd preserves input position IDs""" + seq_len = 8 + head_dim = 64 + + # Non-sequential position IDs (e.g., from packed sequences) + position_ids = torch.tensor([0, 1, 0, 1, 2, 0, 1, 2]) + freqs = torch.ones(head_dim // 2) * 0.1 + + freqs_cis = freqs_cis_from_position_ids( + position_ids, freqs, qkv_format="thd", for_fused_rope=False + ) + + # Shape should be (T, D/2) for non-fused thd + assert freqs_cis.shape == (seq_len, head_dim // 2) + + # Positions with same ID should have same freqs_cis + # Position 0 appears at indices 0, 2, 5 + torch.testing.assert_close(freqs_cis[0], freqs_cis[2]) + torch.testing.assert_close(freqs_cis[0], freqs_cis[5]) + + # Position 1 appears at indices 1, 3, 6 + torch.testing.assert_close(freqs_cis[1], freqs_cis[3]) + torch.testing.assert_close(freqs_cis[1], freqs_cis[6]) + + def test_non_fused_complex_magnitude_is_one(self): + """Test that non-fused output has unit magnitude complex numbers""" + batch_size = 2 + seq_len = 4 + head_dim = 64 + + position_ids = torch.arange(seq_len).unsqueeze(0).expand(batch_size, seq_len) + freqs = torch.randn(head_dim // 2) + + freqs_cis = freqs_cis_from_position_ids( + position_ids, freqs, qkv_format="bshd", for_fused_rope=False + ) + + # All complex numbers should have magnitude 1 + magnitudes = torch.abs(freqs_cis) + torch.testing.assert_close(magnitudes, torch.ones_like(magnitudes), rtol=1e-5, atol=1e-5) diff --git a/tests/unit_tests/models/glm4_moe/test_glm4_moe_layers.py b/tests/unit_tests/models/glm4_moe/test_glm4_moe_layers.py index 5967a1e62..3d75ff476 100644 --- a/tests/unit_tests/models/glm4_moe/test_glm4_moe_layers.py +++ b/tests/unit_tests/models/glm4_moe/test_glm4_moe_layers.py @@ -213,7 +213,7 @@ def test_forward_shape_is_preserved_bshd_format(self, config, sdpa_backend): fake_attn = torch.zeros(batch_size, config.num_attention_heads, seq_len, config.head_dim) attention.attn_func = MagicMock(return_value=fake_attn.to(torch.bfloat16)) - with patch("nemo_automodel.components.models.glm4_moe.layers.apply_rotary_emb", side_effect=lambda x, *_: x): + with patch("nemo_automodel.components.models.glm4_moe.layers.apply_rotary_emb_qk", side_effect=lambda q, k, *_, **__: (q, k)): out = attention(hidden, freqs_cis=freqs_cis) assert out.shape == (batch_size, seq_len, config.hidden_size) @@ -228,7 +228,7 @@ def test_forward_shape_is_preserved_thd_format(self, config, sdpa_backend): fake_attn = torch.zeros(num_tokens, config.num_attention_heads, config.head_dim) attention.attn_func = MagicMock(return_value=fake_attn.to(torch.bfloat16)) - with patch("nemo_automodel.components.models.glm4_moe.layers.apply_rotary_emb", side_effect=lambda x, *_: x): + with patch("nemo_automodel.components.models.glm4_moe.layers.apply_rotary_emb_qk", side_effect=lambda q, k, *_, **__: (q, k)): out = attention(hidden, freqs_cis=freqs_cis) assert out.shape == (num_tokens, config.hidden_size) @@ -245,7 +245,7 @@ def test_forward_applies_qk_norm_when_enabled(self, config, sdpa_backend): fake_attn = torch.zeros(batch_size, config.num_attention_heads, seq_len, config.head_dim).to(torch.bfloat16) attention.attn_func = MagicMock(return_value=fake_attn) - with patch("nemo_automodel.components.models.glm4_moe.layers.apply_rotary_emb", side_effect=lambda x, *_: x): + with patch("nemo_automodel.components.models.glm4_moe.layers.apply_rotary_emb_qk", side_effect=lambda q, k, *_, **__: (q, k)): with patch.object(attention.q_norm, "forward", wraps=attention.q_norm.forward) as mock_q_norm, \ patch.object(attention.k_norm, "forward", wraps=attention.k_norm.forward) as mock_k_norm: attention(hidden, freqs_cis=freqs_cis) @@ -263,7 +263,7 @@ def test_forward_skips_qk_norm_when_disabled(self, config_without_qk_norm, sdpa_ fake_attn = torch.zeros(batch_size, config_without_qk_norm.num_attention_heads, seq_len, config_without_qk_norm.head_dim).to(torch.bfloat16) attention.attn_func = MagicMock(return_value=fake_attn) - with patch("nemo_automodel.components.models.glm4_moe.layers.apply_rotary_emb", side_effect=lambda x, *_: x): + with patch("nemo_automodel.components.models.glm4_moe.layers.apply_rotary_emb_qk", side_effect=lambda q, k, *_, **__: (q, k)): out = attention(hidden, freqs_cis=freqs_cis) # Should complete successfully without QK norm @@ -282,19 +282,12 @@ def test_forward_applies_partial_rotary_embedding(self, config, sdpa_backend): return_value=torch.zeros(batch_size, config.num_attention_heads, seq_len, config.head_dim).to(torch.bfloat16) ) - with patch("nemo_automodel.components.models.glm4_moe.layers.apply_rotary_emb") as mock_rotary: - mock_rotary.side_effect = lambda x, *_: x + with patch("nemo_automodel.components.models.glm4_moe.layers.apply_rotary_emb_qk") as mock_rotary: + mock_rotary.side_effect = lambda q, k, *_, **__: (q, k) attention(hidden, freqs_cis=freqs_cis) - # Should apply rotary to both q and k - assert mock_rotary.call_count == 2 - # Verify that cos and sin are split correctly based on partial_rotary_factor - for call_args in mock_rotary.call_args_list: - cos = call_args[0][1] - sin = call_args[0][2] - # cos and sin should be half of rotary_dim - assert cos.shape[-1] == rotary_dim // 2 - assert sin.shape[-1] == rotary_dim // 2 + # Should apply rotary once (to both q and k together) + assert mock_rotary.call_count == 1 def test_forward_passes_preprocessed_kwargs(self, config, sdpa_backend): attention = Glm4MoeAttention(config, sdpa_backend) @@ -306,7 +299,7 @@ def test_forward_passes_preprocessed_kwargs(self, config, sdpa_backend): fake_attn = torch.zeros(batch, config.num_attention_heads, seq_len, config.head_dim).to(torch.bfloat16) attention.attn_func = MagicMock(return_value=fake_attn.to(torch.bfloat16)) - with patch("nemo_automodel.components.models.glm4_moe.layers.apply_rotary_emb", side_effect=lambda x, *_: x): + with patch("nemo_automodel.components.models.glm4_moe.layers.apply_rotary_emb_qk", side_effect=lambda q, k, *_, **__: (q, k)): attention(hidden, freqs_cis=freqs_cis, attention_mask=attention_mask) _, kwargs = attention.attn_func.call_args @@ -358,7 +351,7 @@ def test_forward_with_te_backend_supports_attention_mask(self, config, te_backen freqs_cis = torch.randn(batch, seq_len, int(config.head_dim * config.partial_rotary_factor)) attention_mask = torch.tensor([[1, 0, 1]], dtype=torch.bool) - with patch("nemo_automodel.components.models.glm4_moe.layers.apply_rotary_emb", side_effect=lambda x, *_: x): + with patch("nemo_automodel.components.models.glm4_moe.layers.apply_rotary_emb_qk", side_effect=lambda q, k, *_, **__: (q, k)): attention(hidden, freqs_cis=freqs_cis, attention_mask=attention_mask) _, kwargs = attention.attn_func.call_args diff --git a/tests/unit_tests/models/gpt_oss/test_gptoss_layers.py b/tests/unit_tests/models/gpt_oss/test_gptoss_layers.py index 947c1a225..2baf104b0 100644 --- a/tests/unit_tests/models/gpt_oss/test_gptoss_layers.py +++ b/tests/unit_tests/models/gpt_oss/test_gptoss_layers.py @@ -52,6 +52,27 @@ def gpt_config(): ) +@pytest.fixture +def gpt_config_with_rope_scaling(): + return GptOssConfig( + vocab_size=1000, + hidden_size=128, + num_attention_heads=4, + num_key_value_heads=4, + head_dim=32, + num_hidden_layers=2, + intermediate_size=256, + max_position_embeddings=512, + rms_norm_eps=1e-6, + sliding_window=None, + layer_types=["full_attention", "full_attention"], + num_local_experts=8, + num_experts_per_tok=2, + router_aux_loss_coef=0.01, + rope_scaling={"factor": 2.0, "type": "yarn"}, + ) + + @pytest.fixture def backend_config(): return BackendConfig( @@ -61,6 +82,20 @@ def backend_config(): enable_deepep=False, fake_balanced_gate=False, enable_hf_state_dict_adapter=False, + rope_fusion=False, + ) + + +@pytest.fixture +def backend_config_with_rope_fusion(): + return BackendConfig( + linear="torch", + attn="flex", + rms_norm="torch", + enable_deepep=False, + fake_balanced_gate=False, + enable_hf_state_dict_adapter=False, + rope_fusion=True, ) @@ -188,6 +223,49 @@ def test_rotary_embedding_application(self, gpt_config, backend_config, device): except Exception as e: pytest.fail(f"Forward pass failed with rotary embedding: {e}") + def test_yarn_concentration_not_set_without_rope_fusion(self, gpt_config, backend_config): + """Test that yarn_concentration is None when rope_fusion is False.""" + attention = GptOssAttention(gpt_config, backend_config) + + assert attention.yarn_concentration is None + + def test_yarn_concentration_set_with_rope_fusion(self, gpt_config_with_rope_scaling, backend_config_with_rope_fusion): + """Test that yarn_concentration is correctly computed when rope_fusion is True.""" + import math + attention = GptOssAttention(gpt_config_with_rope_scaling, backend_config_with_rope_fusion) + + assert hasattr(attention, "yarn_concentration") + # yarn_get_mscale(2.0) = 0.1 * 1.0 * math.log(2.0) + 1.0 + expected_concentration = 0.1 * math.log(2.0) + 1.0 + assert abs(attention.yarn_concentration - expected_concentration) < 1e-6 + + def test_yarn_concentration_different_scaling_factors(self, backend_config_with_rope_fusion): + """Test yarn_concentration with different scaling factors.""" + import math + for factor in [1.5, 2.0, 4.0, 8.0]: + config = GptOssConfig( + vocab_size=1000, + hidden_size=128, + num_attention_heads=4, + num_key_value_heads=4, + head_dim=32, + num_hidden_layers=2, + intermediate_size=256, + max_position_embeddings=512, + rms_norm_eps=1e-6, + sliding_window=None, + layer_types=["full_attention", "full_attention"], + num_local_experts=8, + num_experts_per_tok=2, + router_aux_loss_coef=0.01, + rope_scaling={"factor": factor, "type": "yarn"}, + ) + + attention = GptOssAttention(config, backend_config_with_rope_fusion) + + expected_concentration = 0.1 * math.log(factor) + 1.0 + assert abs(attention.yarn_concentration - expected_concentration) < 1e-6 + @pytest.mark.skipif(not is_te_min_version("2.8.0"), reason="TE version 2.8.0 or higher is required") class TestGptOssAttentionWithTE: @@ -202,6 +280,7 @@ def te_backend_config(self): enable_deepep=False, fake_balanced_gate=False, enable_hf_state_dict_adapter=False, + rope_fusion=False, ) def test_te_backend_requires_min_version(self, gpt_config): diff --git a/tests/unit_tests/models/gpt_oss/test_gptoss_model.py b/tests/unit_tests/models/gpt_oss/test_gptoss_model.py index 7c7872aa9..21ec07dbd 100644 --- a/tests/unit_tests/models/gpt_oss/test_gptoss_model.py +++ b/tests/unit_tests/models/gpt_oss/test_gptoss_model.py @@ -92,6 +92,7 @@ def backend_config(): enable_deepep=False, fake_balanced_gate=False, enable_hf_state_dict_adapter=False, + rope_fusion=False, ) diff --git a/tests/unit_tests/models/gpt_oss/test_gptoss_rope_utils.py b/tests/unit_tests/models/gpt_oss/test_gptoss_rope_utils.py index 79357070f..0169dfe92 100644 --- a/tests/unit_tests/models/gpt_oss/test_gptoss_rope_utils.py +++ b/tests/unit_tests/models/gpt_oss/test_gptoss_rope_utils.py @@ -20,6 +20,7 @@ from nemo_automodel.components.models.gpt_oss.rope_utils import ( RotaryEmbedding, apply_rotary_emb, + apply_rotary_emb_qk, position_ids_to_freqs_cis, ) @@ -713,7 +714,7 @@ def test_basic_computation_bshd(self): seq_len = 8 position_ids = torch.arange(seq_len).unsqueeze(0).expand(batch_size, seq_len) - freqs_cis = position_ids_to_freqs_cis(rope, position_ids, qkv_format="bshd") + freqs_cis = position_ids_to_freqs_cis(rope, position_ids, qkv_format="bshd", for_fused_rope=False) # Check output shape: should be (batch_size, seq_len, head_dim) assert freqs_cis.shape == (batch_size, seq_len, 64) @@ -730,7 +731,7 @@ def test_basic_computation_thd(self): seq_len = 16 position_ids = torch.arange(seq_len) - freqs_cis = position_ids_to_freqs_cis(rope, position_ids, qkv_format="thd") + freqs_cis = position_ids_to_freqs_cis(rope, position_ids, qkv_format="thd", for_fused_rope=False) # Check output shape: should be (seq_len, head_dim) assert freqs_cis.shape == (seq_len, 64) @@ -747,7 +748,7 @@ def test_sequential_positions(self): seq_len = 8 position_ids = torch.arange(seq_len).unsqueeze(0) - freqs_cis = position_ids_to_freqs_cis(rope, position_ids, qkv_format="bshd") + freqs_cis = position_ids_to_freqs_cis(rope, position_ids, qkv_format="bshd", for_fused_rope=False) assert freqs_cis.shape == (1, seq_len, 32) @@ -762,7 +763,7 @@ def test_non_sequential_positions(self): # Packed sequences: position IDs restart position_ids = torch.tensor([[0, 1, 2, 0, 1, 2, 3]]) - freqs_cis = position_ids_to_freqs_cis(rope, position_ids, qkv_format="bshd") + freqs_cis = position_ids_to_freqs_cis(rope, position_ids, qkv_format="bshd", for_fused_rope=False) assert freqs_cis.shape == (1, 7, 64) @@ -780,7 +781,7 @@ def test_large_position_ids(self): position_ids = torch.tensor([[1000, 2000, 3000, 4000]]) - freqs_cis = position_ids_to_freqs_cis(rope, position_ids, qkv_format="bshd") + freqs_cis = position_ids_to_freqs_cis(rope, position_ids, qkv_format="bshd", for_fused_rope=False) assert freqs_cis.shape == (1, 4, 64) @@ -796,7 +797,7 @@ def test_with_scaling_factor(self): position_ids = torch.arange(16).unsqueeze(0) - freqs_cis = position_ids_to_freqs_cis(rope, position_ids, qkv_format="bshd") + freqs_cis = position_ids_to_freqs_cis(rope, position_ids, qkv_format="bshd", for_fused_rope=False) assert freqs_cis.shape == (1, 16, 64) @@ -816,7 +817,7 @@ def test_different_batch_patterns(self): ] ) - freqs_cis = position_ids_to_freqs_cis(rope, position_ids, qkv_format="bshd") + freqs_cis = position_ids_to_freqs_cis(rope, position_ids, qkv_format="bshd", for_fused_rope=False) assert freqs_cis.shape == (3, 4, 32) @@ -843,7 +844,7 @@ def test_freqs_cis_with_different_cp_sizes(self, cp_size): position_ids_rank = torch.arange(seq_len_per_rank) # Compute freqs_cis for this rank's position_ids - freqs_cis_rank = position_ids_to_freqs_cis(rope, position_ids_rank, qkv_format="thd") + freqs_cis_rank = position_ids_to_freqs_cis(rope, position_ids_rank, qkv_format="thd", for_fused_rope=False) # Verify shape and properties assert freqs_cis_rank.ndim == 2 @@ -865,7 +866,7 @@ def test_freqs_cis_consistency_across_ranks(self, cp_size, cp_rank): position_ids_rank = torch.arange(seq_len_per_rank) % 10 # Compute freqs_cis - freqs_cis_rank = position_ids_to_freqs_cis(rope, position_ids_rank, qkv_format="thd") + freqs_cis_rank = position_ids_to_freqs_cis(rope, position_ids_rank, qkv_format="thd", for_fused_rope=False) # Verify that positions with same ID have same freqs_cis unique_positions = torch.unique(position_ids_rank) @@ -890,7 +891,7 @@ def test_freqs_cis_cp_with_variable_sequence_lengths(self): position_ids_rank = torch.tensor([0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3]) # Compute freqs_cis - freqs_cis_rank = position_ids_to_freqs_cis(rope, position_ids_rank, qkv_format="thd") + freqs_cis_rank = position_ids_to_freqs_cis(rope, position_ids_rank, qkv_format="thd", for_fused_rope=False) # Verify output properties assert freqs_cis_rank.dtype == torch.float32 @@ -913,7 +914,7 @@ def test_freqs_cis_cp_reconstructibility(self): for cp_rank in range(cp_size): # Each rank gets half the sequence position_ids_rank = torch.arange(seq_len // cp_size) - freqs_cis_rank = position_ids_to_freqs_cis(rope, position_ids_rank, qkv_format="thd") + freqs_cis_rank = position_ids_to_freqs_cis(rope, position_ids_rank, qkv_format="thd", for_fused_rope=False) freqs_cis_parts.append(freqs_cis_rank) # Verify that each part has the expected length @@ -939,7 +940,7 @@ def test_freqs_cis_cp_different_sizes_with_rope_scaling(self, cp_size): position_ids_rank = torch.arange(seq_len_per_rank) # Compute freqs_cis - freqs_cis_rank = position_ids_to_freqs_cis(rope, position_ids_rank, qkv_format="thd") + freqs_cis_rank = position_ids_to_freqs_cis(rope, position_ids_rank, qkv_format="thd", for_fused_rope=False) # Verify properties assert freqs_cis_rank.dtype == torch.float32 @@ -963,7 +964,7 @@ def test_full_rope_pipeline(self): # Step 1: Create position IDs and compute freqs_cis position_ids = torch.arange(seq_len).unsqueeze(0).expand(batch_size, seq_len) - freqs_cis = position_ids_to_freqs_cis(rope, position_ids, qkv_format="bshd") + freqs_cis = position_ids_to_freqs_cis(rope, position_ids, qkv_format="bshd", for_fused_rope=False) # Step 2: Extract cos and sin from freqs_cis # freqs_cis contains concatenated cos and sin @@ -998,7 +999,7 @@ def test_packed_sequence_scenario(self): position_ids = torch.tensor([0, 1, 2, 0, 1, 2, 3, 0, 1, 2, 3, 4]) # Compute freqs_cis - freqs_cis = position_ids_to_freqs_cis(rope, position_ids, qkv_format="thd") + freqs_cis = position_ids_to_freqs_cis(rope, position_ids, qkv_format="thd", for_fused_rope=False) # Verify output assert freqs_cis.shape == (total_tokens, 64) @@ -1023,7 +1024,7 @@ def test_rope_with_scaling_long_context(self): # Create position IDs position_ids = torch.arange(seq_len).unsqueeze(0) - freqs_cis = position_ids_to_freqs_cis(rope, position_ids, qkv_format="bshd") + freqs_cis = position_ids_to_freqs_cis(rope, position_ids, qkv_format="bshd", for_fused_rope=False) # Verify output assert freqs_cis.shape == (batch_size, seq_len, 32) @@ -1112,7 +1113,7 @@ def test_partial_rotary_with_position_ids_to_freqs_cis(self): position_ids = torch.arange(seq_len).unsqueeze(0).expand(batch_size, seq_len) # Compute freqs_cis - freqs_cis = position_ids_to_freqs_cis(rope, position_ids, qkv_format="bshd") + freqs_cis = position_ids_to_freqs_cis(rope, position_ids, qkv_format="bshd", for_fused_rope=False) # freqs_cis should have shape (batch_size, seq_len, rotary_dim) # because it contains concatenated cos and sin, each of size rotary_dim // 2 @@ -1213,7 +1214,7 @@ def test_partial_rotary_packed_sequences(self): position_ids = torch.tensor([0, 1, 2, 0, 1, 2, 3, 0, 1, 2, 3, 4]) # Compute freqs_cis - freqs_cis = position_ids_to_freqs_cis(rope, position_ids, qkv_format="thd") + freqs_cis = position_ids_to_freqs_cis(rope, position_ids, qkv_format="thd", for_fused_rope=False) # Verify output shape: should be (total_tokens, rotary_dim) assert freqs_cis.shape == (total_tokens, rotary_dim) @@ -1221,3 +1222,395 @@ def test_partial_rotary_packed_sequences(self): # Tokens at position 0 in different sequences should have same freqs_cis torch.testing.assert_close(freqs_cis[0], freqs_cis[3]) torch.testing.assert_close(freqs_cis[0], freqs_cis[7]) + + +class TestApplyRotaryEmbQk: + """Tests for apply_rotary_emb_qk function (GPT-OSS version)""" + + def test_non_fused_rope_bshd(self): + """Test apply_rotary_emb_qk with non-fused rope in bshd format""" + batch_size = 2 + seq_len = 4 + num_heads = 8 + head_dim = 64 + + q = torch.randn(batch_size, seq_len, num_heads, head_dim) + k = torch.randn(batch_size, seq_len, num_heads, head_dim) + # Non-fused rope uses [cos, sin] format + cos = torch.randn(seq_len, head_dim // 2) + sin = torch.randn(seq_len, head_dim // 2) + freqs_cis = torch.cat([cos, sin], dim=-1).unsqueeze(0).expand(batch_size, seq_len, head_dim) + + q_out, k_out = apply_rotary_emb_qk(q, k, freqs_cis, format="bshd", rope_fusion=False) + + # Check output shapes + assert q_out.shape == q.shape + assert k_out.shape == k.shape + assert q_out.dtype == q.dtype + assert k_out.dtype == k.dtype + + def test_non_fused_rope_thd(self): + """Test apply_rotary_emb_qk with non-fused rope in thd format""" + total_tokens = 16 + num_heads = 8 + head_dim = 64 + + q = torch.randn(total_tokens, num_heads, head_dim) + k = torch.randn(total_tokens, num_heads, head_dim) + # Non-fused rope uses [cos, sin] format + cos = torch.randn(total_tokens, head_dim // 2) + sin = torch.randn(total_tokens, head_dim // 2) + freqs_cis = torch.cat([cos, sin], dim=-1) + + q_out, k_out = apply_rotary_emb_qk(q, k, freqs_cis, format="thd", rope_fusion=False) + + # Check output shapes + assert q_out.shape == q.shape + assert k_out.shape == k.shape + + def test_non_fused_rope_consistency(self): + """Test that apply_rotary_emb_qk gives same results as individual apply_rotary_emb calls""" + batch_size = 2 + seq_len = 4 + num_heads = 4 + head_dim = 32 + + q = torch.randn(batch_size, seq_len, num_heads, head_dim) + k = torch.randn(batch_size, seq_len, num_heads, head_dim) + + # Create cos/sin + angles = torch.linspace(0, 2 * math.pi, seq_len * (head_dim // 2)).reshape(seq_len, head_dim // 2) + cos = torch.cos(angles) + sin = torch.sin(angles) + freqs_cis = torch.cat([cos, sin], dim=-1).unsqueeze(0).expand(batch_size, seq_len, head_dim) + + # Apply using apply_rotary_emb_qk + q_out_qk, k_out_qk = apply_rotary_emb_qk(q, k, freqs_cis, format="bshd", rope_fusion=False) + + # Apply using individual apply_rotary_emb calls + q_out_individual = apply_rotary_emb(q, cos, sin) + k_out_individual = apply_rotary_emb(k, cos, sin) + + # Results should be close (might have small numerical differences due to broadcast) + torch.testing.assert_close(q_out_qk, q_out_individual, rtol=1e-5, atol=1e-5) + torch.testing.assert_close(k_out_qk, k_out_individual, rtol=1e-5, atol=1e-5) + + def test_dtype_preservation(self): + """Test that output dtype matches input dtype""" + batch_size = 2 + seq_len = 4 + num_heads = 4 + head_dim = 32 + + for dtype in [torch.float32, torch.float16, torch.bfloat16]: + q = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=dtype) + k = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=dtype) + cos = torch.randn(seq_len, head_dim // 2) + sin = torch.randn(seq_len, head_dim // 2) + freqs_cis = torch.cat([cos, sin], dim=-1).unsqueeze(0).expand(batch_size, seq_len, head_dim) + + q_out, k_out = apply_rotary_emb_qk(q, k, freqs_cis, format="bshd", rope_fusion=False) + + assert q_out.dtype == dtype + assert k_out.dtype == dtype + + def test_with_partial_rotary(self): + """Test apply_rotary_emb_qk with partial rotary embeddings""" + batch_size = 2 + seq_len = 4 + num_heads = 4 + head_dim = 64 + rotary_dim = 32 # Only rotate first 32 dimensions + + q = torch.randn(batch_size, seq_len, num_heads, head_dim) + k = torch.randn(batch_size, seq_len, num_heads, head_dim) + + # Store pass-through parts + q_pass = q[..., rotary_dim:].clone() + k_pass = k[..., rotary_dim:].clone() + + # Partial cos/sin (rotary_dim // 2 each) + cos = torch.randn(seq_len, rotary_dim // 2) + sin = torch.randn(seq_len, rotary_dim // 2) + freqs_cis = torch.cat([cos, sin], dim=-1).unsqueeze(0).expand(batch_size, seq_len, rotary_dim) + + q_out, k_out = apply_rotary_emb_qk(q, k, freqs_cis, format="bshd", rope_fusion=False) + + # Check output shapes + assert q_out.shape == q.shape + assert k_out.shape == k.shape + + # Pass-through parts should be preserved + torch.testing.assert_close(q_out[..., rotary_dim:], q_pass) + torch.testing.assert_close(k_out[..., rotary_dim:], k_pass) + + def test_concentration_parameter_none(self): + """Test apply_rotary_emb_qk with concentration=None (default behavior)""" + batch_size = 2 + seq_len = 4 + num_heads = 4 + head_dim = 32 + + q = torch.randn(batch_size, seq_len, num_heads, head_dim) + k = torch.randn(batch_size, seq_len, num_heads, head_dim) + cos = torch.randn(seq_len, head_dim // 2) + sin = torch.randn(seq_len, head_dim // 2) + freqs_cis = torch.cat([cos, sin], dim=-1).unsqueeze(0).expand(batch_size, seq_len, head_dim) + + # With concentration=None (default), output should match non-concentration behavior + q_out_default, k_out_default = apply_rotary_emb_qk(q, k, freqs_cis, format="bshd", rope_fusion=False) + q_out_none, k_out_none = apply_rotary_emb_qk(q, k, freqs_cis, format="bshd", rope_fusion=False, concentration=None) + + torch.testing.assert_close(q_out_default, q_out_none) + torch.testing.assert_close(k_out_default, k_out_none) + + def test_concentration_parameter_non_fused_has_no_effect(self): + """Test that concentration parameter has no effect when rope_fusion=False""" + batch_size = 2 + seq_len = 4 + num_heads = 4 + head_dim = 32 + concentration = 1.5 + + q = torch.randn(batch_size, seq_len, num_heads, head_dim) + k = torch.randn(batch_size, seq_len, num_heads, head_dim) + cos = torch.randn(seq_len, head_dim // 2) + sin = torch.randn(seq_len, head_dim // 2) + freqs_cis = torch.cat([cos, sin], dim=-1).unsqueeze(0).expand(batch_size, seq_len, head_dim) + + # With rope_fusion=False, concentration should be ignored + q_out_no_conc, k_out_no_conc = apply_rotary_emb_qk(q, k, freqs_cis, format="bshd", rope_fusion=False) + q_out_with_conc, k_out_with_conc = apply_rotary_emb_qk( + q, k, freqs_cis, format="bshd", rope_fusion=False, concentration=concentration + ) + + # Results should be the same since concentration is ignored for non-fused rope + torch.testing.assert_close(q_out_no_conc, q_out_with_conc) + torch.testing.assert_close(k_out_no_conc, k_out_with_conc) + + +class TestPositionIdsToFreqsCisFusedRope: + """Tests for position_ids_to_freqs_cis with fused rope support""" + + def test_non_fused_bshd_format(self): + """Test non-fused rope output format for bshd""" + rope = RotaryEmbedding( + head_dim=64, + base=10000, + dtype=torch.float32, + ) + + batch_size = 2 + seq_len = 4 + position_ids = torch.arange(seq_len).unsqueeze(0).expand(batch_size, seq_len) + + freqs_cis = position_ids_to_freqs_cis( + rope, position_ids, qkv_format="bshd", for_fused_rope=False + ) + + # Non-fused returns [cos, sin] with shape (B, T, head_dim) + assert freqs_cis.shape == (batch_size, seq_len, 64) + assert freqs_cis.dtype == torch.float32 + + def test_fused_bshd_format(self): + """Test fused rope output format for bshd""" + rope = RotaryEmbedding( + head_dim=64, + base=10000, + dtype=torch.float32, + ) + + batch_size = 2 + seq_len = 4 + position_ids = torch.arange(seq_len).unsqueeze(0).expand(batch_size, seq_len) + + freqs_cis = position_ids_to_freqs_cis( + rope, position_ids, qkv_format="bshd", for_fused_rope=True + ) + + # Fused rope expects shape (T, 1, 1, D) where D = head_dim (angles duplicated) + assert freqs_cis.shape == (seq_len, 1, 1, 64) + assert freqs_cis.dtype == torch.float32 + + def test_non_fused_thd_format(self): + """Test non-fused rope output format for thd""" + rope = RotaryEmbedding( + head_dim=64, + base=10000, + dtype=torch.float32, + ) + + seq_len = 8 + position_ids = torch.arange(seq_len) + + freqs_cis = position_ids_to_freqs_cis( + rope, position_ids, qkv_format="thd", for_fused_rope=False + ) + + # Non-fused thd returns [cos, sin] with shape (T, head_dim) + assert freqs_cis.shape == (seq_len, 64) + assert freqs_cis.dtype == torch.float32 + + def test_fused_thd_format(self): + """Test fused rope output format for thd""" + rope = RotaryEmbedding( + head_dim=64, + base=10000, + dtype=torch.float32, + ) + + seq_len = 8 + position_ids = torch.arange(seq_len) + + freqs_cis = position_ids_to_freqs_cis( + rope, position_ids, qkv_format="thd", for_fused_rope=True + ) + + # Fused rope with thd expects shape (T, 1, 1, D) where D = head_dim + assert freqs_cis.shape == (seq_len, 1, 1, 64) + assert freqs_cis.dtype == torch.float32 + + def test_fused_angles_are_duplicated(self): + """Test that fused rope format has angles duplicated [angles, angles]""" + rope = RotaryEmbedding( + head_dim=32, + base=10000, + dtype=torch.float32, + ) + + seq_len = 4 + position_ids = torch.arange(seq_len).unsqueeze(0) + + freqs_cis = position_ids_to_freqs_cis( + rope, position_ids, qkv_format="bshd", for_fused_rope=True + ) + + # Shape should be (T, 1, 1, head_dim) + assert freqs_cis.shape == (seq_len, 1, 1, 32) + + # First half and second half should be identical (angles duplicated) + first_half = freqs_cis[:, 0, 0, :16] + second_half = freqs_cis[:, 0, 0, 16:] + torch.testing.assert_close(first_half, second_half) + + def test_fused_thd_uses_sequential_positions(self): + """Test that fused rope with thd uses sequential positions regardless of input""" + rope = RotaryEmbedding( + head_dim=64, + base=10000, + dtype=torch.float32, + ) + + seq_len = 8 + # Non-sequential position IDs (e.g., from packed sequences) + position_ids = torch.tensor([0, 1, 0, 1, 2, 0, 1, 2]) + + freqs_cis = position_ids_to_freqs_cis( + rope, position_ids, qkv_format="thd", for_fused_rope=True + ) + + # Shape should be (T, 1, 1, head_dim) + assert freqs_cis.shape == (seq_len, 1, 1, 64) + + # For fused rope, positions should be sequential 0, 1, 2, ... regardless of input + angles = freqs_cis[:, 0, 0, :32] + + # Each position should have unique angles (since they're sequential) + for i in range(seq_len - 1): + assert not torch.allclose(angles[i], angles[i + 1]) + + def test_non_fused_thd_preserves_positions(self): + """Test that non-fused rope with thd preserves input position IDs""" + rope = RotaryEmbedding( + head_dim=64, + base=10000, + dtype=torch.float32, + ) + + seq_len = 8 + # Non-sequential position IDs (e.g., from packed sequences) + position_ids = torch.tensor([0, 1, 0, 1, 2, 0, 1, 2]) + + freqs_cis = position_ids_to_freqs_cis( + rope, position_ids, qkv_format="thd", for_fused_rope=False + ) + + # Shape should be (T, head_dim) for non-fused thd + assert freqs_cis.shape == (seq_len, 64) + + # Positions with same ID should have same freqs_cis + # Position 0 appears at indices 0, 2, 5 + torch.testing.assert_close(freqs_cis[0], freqs_cis[2]) + torch.testing.assert_close(freqs_cis[0], freqs_cis[5]) + + # Position 1 appears at indices 1, 3, 6 + torch.testing.assert_close(freqs_cis[1], freqs_cis[3]) + torch.testing.assert_close(freqs_cis[1], freqs_cis[6]) + + def test_fused_with_scaling_factor(self): + """Test fused rope with YaRN scaling factor""" + rope = RotaryEmbedding( + head_dim=64, + base=10000, + dtype=torch.float32, + scaling_factor=2.0, + initial_context_length=4096, + ) + + seq_len = 8 + position_ids = torch.arange(seq_len).unsqueeze(0) + + freqs_cis = position_ids_to_freqs_cis( + rope, position_ids, qkv_format="bshd", for_fused_rope=True + ) + + # Shape should be (T, 1, 1, head_dim) + assert freqs_cis.shape == (seq_len, 1, 1, 64) + assert freqs_cis.dtype == torch.float32 + + def test_fused_with_partial_rotary(self): + """Test fused rope with partial rotary factor""" + head_dim = 128 + partial_rotary_factor = 0.5 + rotary_dim = int(head_dim * partial_rotary_factor) + + rope = RotaryEmbedding( + head_dim=head_dim, + base=10000, + dtype=torch.float32, + partial_rotary_factor=partial_rotary_factor, + ) + + seq_len = 8 + position_ids = torch.arange(seq_len).unsqueeze(0) + + freqs_cis = position_ids_to_freqs_cis( + rope, position_ids, qkv_format="bshd", for_fused_rope=True + ) + + # Shape should be (T, 1, 1, rotary_dim) not head_dim + assert freqs_cis.shape == (seq_len, 1, 1, rotary_dim) + + def test_non_fused_with_partial_rotary(self): + """Test non-fused rope with partial rotary factor""" + head_dim = 128 + partial_rotary_factor = 0.5 + rotary_dim = int(head_dim * partial_rotary_factor) + + rope = RotaryEmbedding( + head_dim=head_dim, + base=10000, + dtype=torch.float32, + partial_rotary_factor=partial_rotary_factor, + ) + + batch_size = 2 + seq_len = 8 + position_ids = torch.arange(seq_len).unsqueeze(0).expand(batch_size, seq_len) + + freqs_cis = position_ids_to_freqs_cis( + rope, position_ids, qkv_format="bshd", for_fused_rope=False + ) + + # Shape should be (B, T, rotary_dim) not head_dim + assert freqs_cis.shape == (batch_size, seq_len, rotary_dim) diff --git a/tests/unit_tests/models/qwen3_moe/test_qwen3_moe_layers.py b/tests/unit_tests/models/qwen3_moe/test_qwen3_moe_layers.py index e52c0aa1e..7f5500ad2 100644 --- a/tests/unit_tests/models/qwen3_moe/test_qwen3_moe_layers.py +++ b/tests/unit_tests/models/qwen3_moe/test_qwen3_moe_layers.py @@ -156,7 +156,7 @@ def test_forward_shape_is_preserved(self, config, sdpa_backend): fake_attn = torch.zeros(batch_size, config.num_attention_heads, seq_len, config.head_dim) attention.attn_func = MagicMock(return_value=fake_attn.to(torch.bfloat16)) - with patch("nemo_automodel.components.models.qwen3_moe.layers.apply_rotary_emb", side_effect=lambda x, *_: x): + with patch("nemo_automodel.components.models.qwen3_moe.layers.apply_rotary_emb_qk", side_effect=lambda q, k, *_, **__: (q, k)): out = attention(hidden, freqs_cis=freqs_cis) assert out.shape == (batch_size, seq_len, config.hidden_size) @@ -172,7 +172,7 @@ def test_forward_passes_preprocessed_kwargs(self, config, sdpa_backend): fake_attn = torch.zeros(batch, config.num_attention_heads, seq_len, config.head_dim).to(torch.bfloat16) attention.attn_func = MagicMock(return_value=fake_attn.to(torch.bfloat16)) - with patch("nemo_automodel.components.models.qwen3_moe.layers.apply_rotary_emb", side_effect=lambda x, *_: x): + with patch("nemo_automodel.components.models.qwen3_moe.layers.apply_rotary_emb_qk", side_effect=lambda q, k, *_, **__: (q, k)): attention(hidden, freqs_cis=freqs_cis, attention_mask=attention_mask) _, kwargs = attention.attn_func.call_args @@ -188,11 +188,11 @@ def test_forward_applies_rotary_embedding(self, config, sdpa_backend): return_value=torch.zeros(batch_size, config.num_attention_heads, seq_len, config.head_dim).to(torch.bfloat16) ) - with patch("nemo_automodel.components.models.qwen3_moe.layers.apply_rotary_emb") as mock_rotary: - mock_rotary.side_effect = lambda x, *_: x + with patch("nemo_automodel.components.models.qwen3_moe.layers.apply_rotary_emb_qk") as mock_rotary: + mock_rotary.side_effect = lambda q, k, *_, **__: (q, k) attention(hidden, freqs_cis=freqs_cis) - assert mock_rotary.call_count == 2 + assert mock_rotary.call_count == 1 def test_init_weights_resets_norms_and_linears(self, config, sdpa_backend): attention = Qwen3MoeAttention(config, sdpa_backend) @@ -221,7 +221,7 @@ def test_forward_with_te_backend_supports_attention_mask(self, config, te_backen freqs_cis = torch.randn(batch, seq_len, config.head_dim) attention_mask = torch.tensor([[1, 0, 1]], dtype=torch.bool) - with patch("nemo_automodel.components.models.qwen3_moe.layers.apply_rotary_emb", side_effect=lambda x, *_: x): + with patch("nemo_automodel.components.models.qwen3_moe.layers.apply_rotary_emb_qk", side_effect=lambda q, k, *_, **__: (q, k)): attention(hidden, freqs_cis=freqs_cis, attention_mask=attention_mask) _, kwargs = attention.attn_func.call_args diff --git a/tests/unit_tests/models/qwen3_next/test_qwen3_next_layers.py b/tests/unit_tests/models/qwen3_next/test_qwen3_next_layers.py index 334c47182..5c9c1a4fd 100644 --- a/tests/unit_tests/models/qwen3_next/test_qwen3_next_layers.py +++ b/tests/unit_tests/models/qwen3_next/test_qwen3_next_layers.py @@ -208,7 +208,7 @@ def test_forward_shape_is_preserved_bshd_format(self, config, sdpa_backend): fake_attn = torch.zeros(batch_size, config.num_attention_heads, seq_len, config.head_dim) attention.attn_func = MagicMock(return_value=fake_attn.to(torch.bfloat16)) - with patch("nemo_automodel.components.models.qwen3_next.layers.apply_rotary_emb", side_effect=lambda x, *_: x): + with patch("nemo_automodel.components.models.qwen3_next.layers.apply_rotary_emb_qk", side_effect=lambda q, k, *_, **__: (q, k)): out = attention(hidden, freqs_cis=freqs_cis) assert out.shape == (batch_size, seq_len, config.hidden_size) @@ -223,7 +223,7 @@ def test_forward_shape_is_preserved_thd_format(self, config, sdpa_backend): fake_attn = torch.zeros(num_tokens, config.num_attention_heads, config.head_dim) attention.attn_func = MagicMock(return_value=fake_attn.to(torch.bfloat16)) - with patch("nemo_automodel.components.models.qwen3_next.layers.apply_rotary_emb", side_effect=lambda x, *_: x): + with patch("nemo_automodel.components.models.qwen3_next.layers.apply_rotary_emb_qk", side_effect=lambda q, k, *_, **__: (q, k)): out = attention(hidden, freqs_cis=freqs_cis) assert out.shape == (num_tokens, config.hidden_size) @@ -239,7 +239,7 @@ def test_forward_applies_query_gating(self, config, sdpa_backend): fake_attn = torch.ones(batch_size, config.num_attention_heads, seq_len, config.head_dim).to(torch.bfloat16) attention.attn_func = MagicMock(return_value=fake_attn) - with patch("nemo_automodel.components.models.qwen3_next.layers.apply_rotary_emb", side_effect=lambda x, *_: x): + with patch("nemo_automodel.components.models.qwen3_next.layers.apply_rotary_emb_qk", side_effect=lambda q, k, *_, **__: (q, k)): with patch("torch.sigmoid", wraps=torch.sigmoid) as mock_sigmoid: out = attention(hidden, freqs_cis=freqs_cis) @@ -257,7 +257,7 @@ def test_forward_passes_preprocessed_kwargs(self, config, sdpa_backend): fake_attn = torch.zeros(batch, config.num_attention_heads, seq_len, config.head_dim).to(torch.bfloat16) attention.attn_func = MagicMock(return_value=fake_attn.to(torch.bfloat16)) - with patch("nemo_automodel.components.models.qwen3_next.layers.apply_rotary_emb", side_effect=lambda x, *_: x): + with patch("nemo_automodel.components.models.qwen3_next.layers.apply_rotary_emb_qk", side_effect=lambda q, k, *_, **__: (q, k)): attention(hidden, freqs_cis=freqs_cis, attention_mask=attention_mask) _, kwargs = attention.attn_func.call_args @@ -273,12 +273,12 @@ def test_forward_applies_rotary_embedding(self, config, sdpa_backend): return_value=torch.zeros(batch_size, config.num_attention_heads, seq_len, config.head_dim).to(torch.bfloat16) ) - with patch("nemo_automodel.components.models.qwen3_next.layers.apply_rotary_emb") as mock_rotary: - mock_rotary.side_effect = lambda x, *_: x + with patch("nemo_automodel.components.models.qwen3_next.layers.apply_rotary_emb_qk") as mock_rotary: + mock_rotary.side_effect = lambda q, k, *_, **__: (q, k) attention(hidden, freqs_cis=freqs_cis) - # Should apply rotary to both q and k - assert mock_rotary.call_count == 2 + # Should apply rotary once (to both q and k together) + assert mock_rotary.call_count == 1 def test_forward_applies_qk_norm(self, config, sdpa_backend): """Test that q_norm and k_norm are applied to query and key""" @@ -290,7 +290,7 @@ def test_forward_applies_qk_norm(self, config, sdpa_backend): fake_attn = torch.zeros(batch_size, config.num_attention_heads, seq_len, config.head_dim).to(torch.bfloat16) attention.attn_func = MagicMock(return_value=fake_attn) - with patch("nemo_automodel.components.models.qwen3_next.layers.apply_rotary_emb", side_effect=lambda x, *_: x): + with patch("nemo_automodel.components.models.qwen3_next.layers.apply_rotary_emb_qk", side_effect=lambda q, k, *_, **__: (q, k)): with patch.object(attention.q_norm, "forward", wraps=attention.q_norm.forward) as mock_q_norm, \ patch.object(attention.k_norm, "forward", wraps=attention.k_norm.forward) as mock_k_norm: attention(hidden, freqs_cis=freqs_cis) @@ -326,7 +326,7 @@ def test_forward_with_te_backend_supports_attention_mask(self, config, te_backen freqs_cis = torch.randn(batch, seq_len, config.head_dim) attention_mask = torch.tensor([[1, 0, 1]], dtype=torch.bool) - with patch("nemo_automodel.components.models.qwen3_next.layers.apply_rotary_emb", side_effect=lambda x, *_: x): + with patch("nemo_automodel.components.models.qwen3_next.layers.apply_rotary_emb_qk", side_effect=lambda q, k, *_, **__: (q, k)): attention(hidden, freqs_cis=freqs_cis, attention_mask=attention_mask) _, kwargs = attention.attn_func.call_args