From f71559b4650fc60f6fb693c12f4e4f9229828afa Mon Sep 17 00:00:00 2001 From: Yechan Jun Date: Thu, 30 Oct 2025 07:53:19 +0000 Subject: [PATCH 1/2] Add fused_rms_norm usage --- .../models/mistral/modeling_mistral_moreh.py | 29 +++++++++++++++++-- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/mistral/modeling_mistral_moreh.py b/src/transformers/models/mistral/modeling_mistral_moreh.py index 019ccdac4419..f2df57feb1a2 100644 --- a/src/transformers/models/mistral/modeling_mistral_moreh.py +++ b/src/transformers/models/mistral/modeling_mistral_moreh.py @@ -56,11 +56,21 @@ _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) +try: + # Moreh extension + moreh_ops = torch.ops.moreh + MorehRMSNorm = moreh_ops.T5LayerNorm +except AttributeError: + MorehRMSNorm = None logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "MistralConfig" +if MorehRMSNorm is not None: + logger.warning( + "You can't use Masked Structured Growth Training..! You should avoid using rmsnorm in any way. " + ) # Copied from transformers.models.llama.modeling_llama._get_unpad_data def _get_unpad_data(attention_mask): @@ -74,6 +84,12 @@ def _get_unpad_data(attention_mask): max_seqlen_in_batch, ) +def get_moreh_fused_rms_norm(config): + moreh_config = getattr(config, "moreh_config", None) + if moreh_config is not None and "fused_rms_norm" in moreh_config: + return moreh_config["fused_rms_norm"] + return False + # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mistral class MistralRMSNorm(nn.Module): @@ -710,8 +726,12 @@ def __init__(self, config: MistralMorehConfig, layer_idx: int): self.self_attn = MISTRAL_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) self.mlp = MistralMLP(config) - self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + if get_moreh_fused_rms_norm(config): + self.input_layernorm = MorehRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = MorehRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, @@ -912,7 +932,10 @@ def __init__(self, config: MistralMorehConfig): [MistralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self._attn_implementation = config._attn_implementation - self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + if get_moreh_fused_rms_norm(config): + self.norm = MorehRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False # Initialize weights and apply final processing From 6273096bbfa87b2483a65a8f55ee7e62d32e04df Mon Sep 17 00:00:00 2001 From: Yechan Jun Date: Fri, 31 Oct 2025 01:34:21 +0000 Subject: [PATCH 2/2] Fix term --- src/transformers/models/mistral/modeling_mistral_moreh.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/mistral/modeling_mistral_moreh.py b/src/transformers/models/mistral/modeling_mistral_moreh.py index f2df57feb1a2..8a1132d70a2b 100644 --- a/src/transformers/models/mistral/modeling_mistral_moreh.py +++ b/src/transformers/models/mistral/modeling_mistral_moreh.py @@ -69,7 +69,7 @@ if MorehRMSNorm is not None: logger.warning( - "You can't use Masked Structured Growth Training..! You should avoid using rmsnorm in any way. " + "You can't use Masked Structured Growth Training..! You should avoid using RMSNorm in any way. " ) # Copied from transformers.models.llama.modeling_llama._get_unpad_data