From df6fcce097ad30d5e744e74e04f51ed8c42b73e2 Mon Sep 17 00:00:00 2001 From: HyungjunOh Date: Tue, 23 Sep 2025 19:11:58 +0900 Subject: [PATCH 1/3] add pipe for gpt --- .../models/gpt2/configuration_gpt2_moreh.py | 3 +++ .../models/gpt2/modeling_gpt2_moreh.py | 15 +++++++++++++++ 2 files changed, 18 insertions(+) diff --git a/src/transformers/models/gpt2/configuration_gpt2_moreh.py b/src/transformers/models/gpt2/configuration_gpt2_moreh.py index 0b46ac45b1d7..590c421a4e81 100644 --- a/src/transformers/models/gpt2/configuration_gpt2_moreh.py +++ b/src/transformers/models/gpt2/configuration_gpt2_moreh.py @@ -159,6 +159,7 @@ def __init__( eos_token_id=50256, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False, + moreh_config=None, **kwargs, ): self.vocab_size = vocab_size @@ -186,6 +187,8 @@ def __init__( self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id + self.moreh_config = moreh_config + super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) diff --git a/src/transformers/models/gpt2/modeling_gpt2_moreh.py b/src/transformers/models/gpt2/modeling_gpt2_moreh.py index 6ff2c98b990a..e67845e48fff 100644 --- a/src/transformers/models/gpt2/modeling_gpt2_moreh.py +++ b/src/transformers/models/gpt2/modeling_gpt2_moreh.py @@ -1016,6 +1016,13 @@ def __init__(self, config): # Initialize weights and apply final processing self.post_init() + # Moreh Config + self.moreh_pipeline_layers = [] + moreh_config = getattr(config, "moreh_config", None) + if moreh_config is not None and "pipeline_layers" in moreh_config: + self.moreh_pipeline_layers = moreh_config["pipeline_layers"] + + @add_start_docstrings(PARALLELIZE_DOCSTRING) def parallelize(self, device_map=None): # Check validity of device_map @@ -1257,6 +1264,8 @@ def forward( for k, v in self.device_map.items(): if i == v[-1] and "cuda:" + str(k) != self.last_device: hidden_states = hidden_states.to("cuda:" + str(k + 1)) + if i in self.moreh_pipeline_layers: + hidden_states = torch.moreh.pipeline_assign(hidden_states) hidden_states = self.ln_f(hidden_states) @@ -1294,6 +1303,12 @@ class GPT2LMHeadModelMoreh(GPT2PreTrainedModel): def __init__(self, config): super().__init__(config) print("GPT2LMHeadModelMoreh ##################################") + if config.moreh_config is not None: + print("config.moreh_config") + for key, value in config.moreh_config.items(): + print(f"\t {key}, {value}") + else: + print("config.moreh_config is None") self.transformer = GPT2Model(config) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) From 7f3e5f518cdf34d50a78ac56dcf61c96c8ed12f2 Mon Sep 17 00:00:00 2001 From: HyungjunOh Date: Tue, 23 Sep 2025 20:52:26 +0900 Subject: [PATCH 2/3] add pipe and rope cache for mistral --- .../mistral/configuration_mistral_moreh.py | 3 ++ .../models/mistral/modeling_mistral_moreh.py | 53 ++++++++++++++++++- 2 files changed, 54 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/mistral/configuration_mistral_moreh.py b/src/transformers/models/mistral/configuration_mistral_moreh.py index 600b7e4650dd..8abb6fab88aa 100644 --- a/src/transformers/models/mistral/configuration_mistral_moreh.py +++ b/src/transformers/models/mistral/configuration_mistral_moreh.py @@ -116,6 +116,7 @@ def __init__( rope_theta=10000.0, sliding_window=4096, attention_dropout=0.0, + moreh_config=None, **kwargs, ): self.vocab_size = vocab_size @@ -138,6 +139,8 @@ def __init__( self.rope_theta = rope_theta self.attention_dropout = attention_dropout + self.moreh_config = moreh_config + super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, diff --git a/src/transformers/models/mistral/modeling_mistral_moreh.py b/src/transformers/models/mistral/modeling_mistral_moreh.py index 4278a52eafec..9c617d2e8b41 100644 --- a/src/transformers/models/mistral/modeling_mistral_moreh.py +++ b/src/transformers/models/mistral/modeling_mistral_moreh.py @@ -94,7 +94,7 @@ def forward(self, hidden_states): class MistralRotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, use_rope_cache=False): super().__init__() self.dim = dim @@ -103,9 +103,36 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) + self.use_rope_cache = use_rope_cache + if self.use_rope_cache: + self._set_cos_sin_cache(max_position_embeddings, dtype=torch.float32) + + def _set_cos_sin_cache(self, seq_len, dtype): + self.max_seq_len_cached = seq_len + + t = torch.arange(seq_len, dtype=torch.float32, device="cpu") + freqs = torch.outer(t, self.inv_freq.cpu()) # [seq_len, dim/2] + emb = torch.cat((freqs, freqs), dim=-1) # [seq_len, dim] + + cos = emb.cos() + sin = emb.sin() + + cos = cos.to(device='cuda', dtype=dtype) + sin = sin.to(device='cuda', dtype=dtype) + + self.register_buffer("cos_cached", cos, persistent=False) + self.register_buffer("sin_cached", sin, persistent=False) + @torch.no_grad() # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward def forward(self, x, position_ids): + if self.use_rope_cache: + seq_len = position_ids.shape[-1] + assert seq_len <= self.max_position_embeddings, "Sequence length exceeds maximum position embeddings" + cos = self.cos_cached[:seq_len].to(dtype=x.dtype, device=x.device).unsqueeze(0) + sin = self.sin_cached[:seq_len].to(dtype=x.dtype, device=x.device).unsqueeze(0) + return cos, sin + # x: [bs, num_attention_heads, seq_len, head_size] inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() @@ -221,10 +248,16 @@ def __init__(self, config: MistralMorehConfig, layer_idx: Optional[int] = None): self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + use_rope_cache = False + moreh_config = getattr(config, "moreh_config", None) + if moreh_config is not None and "rope_cache" in moreh_config: + use_rope_cache = moreh_config["rope_cache"] + self.rotary_emb = MistralRotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, base=self.rope_theta, + use_rope_cache=use_rope_cache, ) def forward( @@ -885,6 +918,12 @@ def __init__(self, config: MistralMorehConfig): # Initialize weights and apply final processing self.post_init() + # Moreh Config + self.moreh_pipeline_layers = [] + moreh_config = getattr(config, "moreh_config", None) + if moreh_config is not None and "pipeline_layers" in moreh_config: + self.moreh_pipeline_layers = moreh_config["pipeline_layers"] + def get_input_embeddings(self): return self.embed_tokens @@ -957,7 +996,7 @@ def forward( all_self_attns = () if output_attentions else None next_decoder_cache = None - for decoder_layer in self.layers: + for layer_idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) @@ -984,6 +1023,9 @@ def forward( ) hidden_states = layer_outputs[0] + if layer_idx in self.moreh_pipeline_layers: + print(f"Set pipe in mistral L : {layer_idx}") + hidden_states = torch.moreh.pipeline_assign(hidden_states) if use_cache: next_decoder_cache = layer_outputs[2 if output_attentions else 1] @@ -1124,6 +1166,13 @@ class MistralForCausalLMMoreh(MistralPreTrainedModel): def __init__(self, config): super().__init__(config) print("MistralForCausalLMMoreh #########################################") + if config.moreh_config is not None: + print("config.moreh_config") + for key, value in config.moreh_config.items(): + print(f"\t {key}, {value}") + else: + print("config.moreh_config is None") + self.model = MistralModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) From ff1c72b4d0702c790a6df7b2138cfa559aded0ae Mon Sep 17 00:00:00 2001 From: HyungjunOh Date: Tue, 23 Sep 2025 21:04:47 +0900 Subject: [PATCH 3/3] remove debug logs --- src/transformers/models/gpt2/modeling_gpt2_moreh.py | 7 ------- .../models/mistral/modeling_mistral_moreh.py | 9 --------- 2 files changed, 16 deletions(-) diff --git a/src/transformers/models/gpt2/modeling_gpt2_moreh.py b/src/transformers/models/gpt2/modeling_gpt2_moreh.py index e67845e48fff..533e292b81ea 100644 --- a/src/transformers/models/gpt2/modeling_gpt2_moreh.py +++ b/src/transformers/models/gpt2/modeling_gpt2_moreh.py @@ -1302,13 +1302,6 @@ class GPT2LMHeadModelMoreh(GPT2PreTrainedModel): def __init__(self, config): super().__init__(config) - print("GPT2LMHeadModelMoreh ##################################") - if config.moreh_config is not None: - print("config.moreh_config") - for key, value in config.moreh_config.items(): - print(f"\t {key}, {value}") - else: - print("config.moreh_config is None") self.transformer = GPT2Model(config) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) diff --git a/src/transformers/models/mistral/modeling_mistral_moreh.py b/src/transformers/models/mistral/modeling_mistral_moreh.py index 9c617d2e8b41..0ff2da78034a 100644 --- a/src/transformers/models/mistral/modeling_mistral_moreh.py +++ b/src/transformers/models/mistral/modeling_mistral_moreh.py @@ -1024,7 +1024,6 @@ def forward( hidden_states = layer_outputs[0] if layer_idx in self.moreh_pipeline_layers: - print(f"Set pipe in mistral L : {layer_idx}") hidden_states = torch.moreh.pipeline_assign(hidden_states) if use_cache: @@ -1165,14 +1164,6 @@ class MistralForCausalLMMoreh(MistralPreTrainedModel): def __init__(self, config): super().__init__(config) - print("MistralForCausalLMMoreh #########################################") - if config.moreh_config is not None: - print("config.moreh_config") - for key, value in config.moreh_config.items(): - print(f"\t {key}, {value}") - else: - print("config.moreh_config is None") - self.model = MistralModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)