Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/benchmark/configs/deepseek_v3_te_deepep.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ model:
attn: te
linear: te
rms_norm: te
rope_fusion: true
enable_deepep: true
fake_balanced_gate: true
enable_hf_state_dict_adapter: false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ model:
attn: te
linear: te
rms_norm: te
rope_fusion: true
enable_deepep: true
fake_balanced_gate: true
enable_hf_state_dict_adapter: false
Expand Down
1 change: 1 addition & 0 deletions examples/benchmark/configs/glm_4.5_air_te_deepep.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ model:
attn: te
linear: te
rms_norm: te
rope_fusion: true
enable_deepep: true
fake_balanced_gate: true
enable_hf_state_dict_adapter: true
Expand Down
1 change: 1 addition & 0 deletions examples/benchmark/configs/gptoss_120b_te_deepep.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ model:
attn: flex
linear: te
rms_norm: te
rope_fusion: true
enable_deepep: true
fake_balanced_gate: true
enable_hf_state_dict_adapter: false
Expand Down
1 change: 1 addition & 0 deletions examples/benchmark/configs/gptoss_20b_te_deepep.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ model:
attn: flex
linear: te
rms_norm: te
rope_fusion: true
enable_deepep: true
fake_balanced_gate: true
enable_hf_state_dict_adapter: false
Expand Down
1 change: 1 addition & 0 deletions examples/benchmark/configs/kimi_k2_te_deepep.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ model:
attn: te
linear: te
rms_norm: te
rope_fusion: true
enable_deepep: true
fake_balanced_gate: true
enable_hf_state_dict_adapter: false
Expand Down
1 change: 1 addition & 0 deletions examples/benchmark/configs/moonlight_16b_te_deepep.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ model:
attn: te
linear: te
rms_norm: te
rope_fusion: true
enable_deepep: true
fake_balanced_gate: true
enable_hf_state_dict_adapter: true
Expand Down
1 change: 1 addition & 0 deletions examples/benchmark/configs/qwen3_moe_235b_te_deepep.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ model:
attn: te
linear: te
rms_norm: te
rope_fusion: true
enable_deepep: true
fake_balanced_gate: true
enable_hf_state_dict_adapter: true
Expand Down
1 change: 1 addition & 0 deletions examples/benchmark/configs/qwen3_moe_30b_te_deepep.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ model:
attn: te
linear: te
rms_norm: te
rope_fusion: true
enable_deepep: true
fake_balanced_gate: true
enable_hf_state_dict_adapter: true
Expand Down
1 change: 1 addition & 0 deletions examples/benchmark/configs/qwen3_next_te_deepep.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ model:
attn: te
linear: te
rms_norm: te
rope_fusion: true
enable_deepep: true
fake_balanced_gate: true
enable_hf_state_dict_adapter: true
Expand Down
4 changes: 4 additions & 0 deletions nemo_automodel/components/distributed/cp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,8 @@ def make_cp_batch_for_te(
"max_seqlen": torch.stack([chunk["max_seqlen"] for chunk in chunks]),
"qkv_format": qkv_format,
"padding_mask": torch.stack([chunk["padding_mask"] for chunk in chunks]),
"cp_size": cp_mesh.size() if cp_mesh is not None else 1,
"cp_rank": torch.distributed.get_rank(group=cp_mesh.get_group()) if cp_mesh is not None else 0,
}


Expand Down Expand Up @@ -329,5 +331,7 @@ def _shard_thd_chunk_for_te(
"max_seqlen": torch.tensor(max_seqlen).to(torch.int32).to(device=cu_seqlens_padded.device),
"qkv_format": qkv_format,
"padding_mask": (batch["input_ids"] == padding_token_id).bool().contiguous(),
"cp_size": cp_size,
"cp_rank": cp_rank,
}
return output_batch
34 changes: 29 additions & 5 deletions nemo_automodel/components/models/deepseek_v3/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
postprocess_output_for_attn,
preprocess_args_and_kwargs_for_attn,
)
from nemo_automodel.components.models.deepseek_v3.rope_utils import apply_rotary_emb, yarn_get_mscale
from nemo_automodel.components.models.deepseek_v3.rope_utils import (
apply_rotary_emb_qk,
yarn_get_mscale,
)
from nemo_automodel.components.moe.utils import (
BackendConfig,
initialize_linear_module,
Expand All @@ -46,6 +49,7 @@ def __init__(self, config: DeepseekV3Config, backend: BackendConfig):
self.v_head_dim = config.v_head_dim

self.backend = backend
self.rope_fusion = backend.rope_fusion
attn_impl = backend.attn
linear_impl = backend.linear
rms_norm_impl = backend.rms_norm
Expand Down Expand Up @@ -135,14 +139,34 @@ def forward(
q = q.view(bsz, local_seq_len, self.n_heads, self.qk_head_dim)

q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
q_pe = apply_rotary_emb(q_pe, freqs_cis, qkv_format, unsqueeze_dim=None)

q = torch.cat([q_nope, q_pe], dim=-1)

kv = self.kv_a_proj_with_mqa(x)
kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv = self.kv_a_layernorm(kv)
k_pe = apply_rotary_emb(k_pe, freqs_cis, qkv_format, unsqueeze_dim=2)

# For MLA, k_pe needs an extra head dimension for apply_rotary_emb_qk
# k_pe shape: (B, S, rope_dim) -> (B, S, 1, rope_dim) for bshd
# k_pe shape: (T, rope_dim) -> (T, 1, rope_dim) for thd
head_unsqueeze_dim = 2 if qkv_format == "bshd" else 1
k_pe = k_pe.unsqueeze(head_unsqueeze_dim)

# Apply rotary embeddings to q_pe and k_pe
cu_seqlens = attn_kwargs.get("cu_seqlens", None)
q_pe, k_pe = apply_rotary_emb_qk(
q_pe,
k_pe,
freqs_cis,
format=qkv_format,
rope_fusion=self.rope_fusion,
cu_seqlens=cu_seqlens,
cp_size=attn_kwargs.get("cp_size", 1),
cp_rank=attn_kwargs.get("cp_rank", 0),
)

# Remove the head dimension we added to k_pe
k_pe = k_pe.squeeze(head_unsqueeze_dim)

q = torch.cat([q_nope, q_pe], dim=-1)

kv = self.kv_b_proj(kv)
if qkv_format == "thd":
Expand Down
8 changes: 7 additions & 1 deletion nemo_automodel/components/models/deepseek_v3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,13 @@ def forward(
)

with torch.no_grad():
freqs_cis = freqs_cis_from_position_ids(position_ids, self.freqs_cis)
freqs_cis = freqs_cis_from_position_ids(
position_ids,
self.freqs_cis,
qkv_format=attn_kwargs.get("qkv_format", "bshd"),
for_fused_rope=self.backend.rope_fusion,
cp_size=attn_kwargs.get("cp_size", 1),
)

h = self.embed_tokens(input_ids) if self.embed_tokens is not None else input_ids

Expand Down
90 changes: 86 additions & 4 deletions nemo_automodel/components/models/deepseek_v3/rope_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def apply_rotary_emb(
if unsqueeze_dim is not None:
x = x.unsqueeze(unsqueeze_dim)

x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2))
x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2)) # [b, s, h, d] -> [b, s, h, d/2, 2]
freqs_cis = freqs_cis.view(x.size(0), x.size(1), 1, x.size(-1))
y = torch.view_as_real(x * freqs_cis).flatten(3)
y = y.to(dtype)
Expand All @@ -148,7 +148,89 @@ def apply_rotary_emb(
return y


def freqs_cis_from_position_ids(position_ids: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
freqs = torch.matmul(position_ids.unsqueeze(-1).float(), freqs.unsqueeze(0))
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
def apply_rotary_emb_qk(
q: torch.Tensor,
k: torch.Tensor,
freqs_cis: torch.Tensor,
format: str = "bshd",
rope_fusion: bool = True,
cu_seqlens: torch.Tensor | None = None,
cp_size: int = 1,
cp_rank: int = 0,
) -> tuple[torch.Tensor, torch.Tensor]:
if rope_fusion:
from transformer_engine.pytorch.attention.rope import apply_rotary_pos_emb

q = apply_rotary_pos_emb(
q,
freqs_cis,
tensor_format=format,
interleaved=True,
fused=True,
cu_seqlens=cu_seqlens,
cp_size=cp_size,
cp_rank=cp_rank,
)
k = apply_rotary_pos_emb(
k,
freqs_cis,
tensor_format=format,
interleaved=True,
fused=True,
cu_seqlens=cu_seqlens,
cp_size=cp_size,
cp_rank=cp_rank,
)
return q, k
else:
q = apply_rotary_emb(q, freqs_cis, qkv_format=format)
k = apply_rotary_emb(k, freqs_cis, qkv_format=format)
return q, k


@torch.no_grad()
def freqs_cis_from_position_ids(
position_ids: torch.Tensor,
freqs: torch.Tensor,
qkv_format: str = "bshd",
for_fused_rope: bool = False,
cp_size: int = 1,
) -> torch.Tensor:
if qkv_format == "thd":
if for_fused_rope:
# For fused rope with thd, use sequential positions
position_ids = torch.arange(
position_ids.shape[0] * cp_size, device=position_ids.device, dtype=torch.int32
).unsqueeze(0)
else:
# For non-fused thd, ensure 2D
if position_ids.dim() == 1:
position_ids = position_ids.unsqueeze(0)

freqs = freqs.to(device=position_ids.device)

# Compute angles: (B, T, D/2) - same as original implementation
# angles = torch.matmul(position_ids.unsqueeze(-1).float(), freqs.unsqueeze(0))
angles = torch.einsum("bt,d->btd", position_ids.to(dtype=torch.float32), freqs)

if for_fused_rope:
if qkv_format == "thd":
freqs_cis = angles.squeeze(0)
else:
# For bshd, take first batch (assumes uniform positions)
freqs_cis = angles[0]

# TE fused rope expects [angles, angles] format
# freqs_cis = torch.cat((angles, angles), dim=-1)
freqs_cis = torch.stack((freqs_cis.view(-1, 1), freqs_cis.view(-1, 1)), dim=-1).view(freqs_cis.shape[0], -1)

# Reshape to (T, 1, 1, D) for TE fused rope
freqs_cis = freqs_cis.reshape(freqs_cis.size(0), 1, 1, freqs_cis.size(1)).contiguous()
else:
# Return complex exponentials for non-fused rope (original behavior)
freqs_cis = torch.polar(torch.ones_like(angles), angles)

if qkv_format == "thd":
freqs_cis = freqs_cis.squeeze(0)

return freqs_cis
17 changes: 11 additions & 6 deletions nemo_automodel/components/models/glm4_moe/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
postprocess_output_for_attn,
preprocess_args_and_kwargs_for_attn,
)
from nemo_automodel.components.models.gpt_oss.rope_utils import apply_rotary_emb
from nemo_automodel.components.models.gpt_oss.rope_utils import apply_rotary_emb_qk
from nemo_automodel.components.moe.utils import (
BackendConfig,
initialize_linear_module,
Expand Down Expand Up @@ -121,11 +121,16 @@ def forward(
k = self.k_norm(k)

# Partial RoPE (only apply to first partial_rotary_factor of head_dim)
rotary_dim = int(self.head_dim * self.partial_rotary_factor)
cos, sin = freqs_cis.split(rotary_dim // 2, dim=-1)

q = apply_rotary_emb(q, cos, sin)
k = apply_rotary_emb(k, cos, sin)
q, k = apply_rotary_emb_qk(
q,
k,
freqs_cis,
format=qkv_format,
rope_fusion=self.backend.rope_fusion,
cu_seqlens=attn_kwargs.get("cu_seqlens", None),
cp_size=attn_kwargs.get("cp_size", 1),
cp_rank=attn_kwargs.get("cp_rank", 0),
)

# Backend-specific attention
q, k, v, _attn_kwargs = preprocess_args_and_kwargs_for_attn(
Expand Down
6 changes: 5 additions & 1 deletion nemo_automodel/components/models/glm4_moe/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,11 @@ def forward(

# Compute freqs_cis from RotaryEmbedding inv_freq and current position_ids; then concat [cos, sin]
freqs_cis = position_ids_to_freqs_cis(
self.rotary_emb, position_ids, qkv_format=attn_kwargs.get("qkv_format", "bshd")
self.rotary_emb,
position_ids,
qkv_format=attn_kwargs.get("qkv_format", "bshd"),
for_fused_rope=self.backend.rope_fusion,
cp_size=attn_kwargs.get("cp_size", 1),
)

h = self.embed_tokens(input_ids) if self.embed_tokens is not None else input_ids
Expand Down
29 changes: 22 additions & 7 deletions nemo_automodel/components/models/gpt_oss/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from torch import nn
from torch.distributed.tensor import DTensor

from nemo_automodel.components.models.deepseek_v3.rope_utils import yarn_get_mscale
from nemo_automodel.shared.import_utils import is_te_min_version

if TYPE_CHECKING:
Expand All @@ -28,7 +29,7 @@
postprocess_output_for_attn,
preprocess_args_and_kwargs_for_attn,
)
from nemo_automodel.components.models.gpt_oss.rope_utils import apply_rotary_emb
from nemo_automodel.components.models.gpt_oss.rope_utils import apply_rotary_emb_qk
from nemo_automodel.components.moe.utils import (
BackendConfig,
initialize_linear_module,
Expand Down Expand Up @@ -60,6 +61,12 @@ def __init__(self, config: "GptOssConfig", backend: BackendConfig, use_sliding_a
)

self.softmax_scale = self.head_dim**-0.5
# When using fused rope, YaRN concentration is not baked into freqs_cis,
# so we need to apply concentration to q and k after fused rope
if backend.rope_fusion:
self.yarn_concentration = yarn_get_mscale(config.rope_scaling["factor"])
else:
self.yarn_concentration = None

assert backend.attn in ("flex", "te"), "Only Flex and TE Attention are supported for GPT-OSS"
if backend.attn == "te" and not is_te_min_version("2.8.0"):
Expand Down Expand Up @@ -107,10 +114,18 @@ def forward(
k = k.view(bsz, seqlen, self.num_key_value_heads, self.head_dim)
v = v.view(bsz, seqlen, self.num_key_value_heads, self.head_dim)

# freqs_cis is concatenated [cos, sin] along last dim with shape (B, T, head_dim)
cos, sin = freqs_cis.split(self.head_dim // 2, dim=-1)
q = apply_rotary_emb(q, cos, sin)
k = apply_rotary_emb(k, cos, sin)
# Apply rotary positional embeddings
q, k = apply_rotary_emb_qk(
q,
k,
freqs_cis,
format=qkv_format,
rope_fusion=self.backend.rope_fusion,
cu_seqlens=attn_kwargs.get("cu_seqlens", None),
concentration=self.yarn_concentration,
cp_size=attn_kwargs.get("cp_size", 1),
cp_rank=attn_kwargs.get("cp_rank", 0),
)

if self.backend.attn == "flex":
updated_attn_kwargs = {
Expand Down Expand Up @@ -146,8 +161,8 @@ def init_weights(self, buffer_device: torch.device, init_std: float = 0.02):
]

if self.backend.attn == "flex":
nn.init.trunc_normal_(self.sinks, mean=0.0, std=init_std)
nn.init.normal_(self.sinks, mean=0.0, std=init_std)
else:
nn.init.trunc_normal_(self.attn_module.softmax_offset, mean=0.0, std=init_std)
nn.init.normal_(self.attn_module.softmax_offset, mean=0.0, std=init_std)
for linear in linear_list:
nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std)
6 changes: 5 additions & 1 deletion nemo_automodel/components/models/gpt_oss/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,11 @@ def forward(

# Compute cos/sin from RotaryEmbedding inv_freq and current position_ids; then concat [cos, sin]
freqs_cis = position_ids_to_freqs_cis(
self.rotary_emb, position_ids, qkv_format=attn_kwargs.get("qkv_format", "bshd")
self.rotary_emb,
position_ids,
qkv_format=attn_kwargs.get("qkv_format", "bshd"),
for_fused_rope=self.backend.rope_fusion,
cp_size=attn_kwargs.get("cp_size", 1),
)

h = self.embed_tokens(input_ids) if self.embed_tokens is not None else input_ids
Expand Down
Loading
Loading