diff --git a/src/transformers/models/mistral/modeling_mistral_moreh.py b/src/transformers/models/mistral/modeling_mistral_moreh.py index 019ccdac4419..8a1132d70a2b 100644 --- a/src/transformers/models/mistral/modeling_mistral_moreh.py +++ b/src/transformers/models/mistral/modeling_mistral_moreh.py @@ -56,11 +56,21 @@ _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) +try: + # Moreh extension + moreh_ops = torch.ops.moreh + MorehRMSNorm = moreh_ops.T5LayerNorm +except AttributeError: + MorehRMSNorm = None logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "MistralConfig" +if MorehRMSNorm is not None: + logger.warning( + "You can't use Masked Structured Growth Training..! You should avoid using RMSNorm in any way. " + ) # Copied from transformers.models.llama.modeling_llama._get_unpad_data def _get_unpad_data(attention_mask): @@ -74,6 +84,12 @@ def _get_unpad_data(attention_mask): max_seqlen_in_batch, ) +def get_moreh_fused_rms_norm(config): + moreh_config = getattr(config, "moreh_config", None) + if moreh_config is not None and "fused_rms_norm" in moreh_config: + return moreh_config["fused_rms_norm"] + return False + # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mistral class MistralRMSNorm(nn.Module): @@ -710,8 +726,12 @@ def __init__(self, config: MistralMorehConfig, layer_idx: int): self.self_attn = MISTRAL_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) self.mlp = MistralMLP(config) - self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + if get_moreh_fused_rms_norm(config): + self.input_layernorm = MorehRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = MorehRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, @@ -912,7 +932,10 @@ def __init__(self, config: MistralMorehConfig): [MistralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self._attn_implementation = config._attn_implementation - self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + if get_moreh_fused_rms_norm(config): + self.norm = MorehRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False # Initialize weights and apply final processing