Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/transformers/models/gpt2/configuration_gpt2_moreh.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def __init__(
eos_token_id=50256,
scale_attn_by_inverse_layer_idx=False,
reorder_and_upcast_attn=False,
moreh_config=None,
**kwargs,
):
self.vocab_size = vocab_size
Expand Down Expand Up @@ -186,6 +187,8 @@ def __init__(
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id

self.moreh_config = moreh_config

super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)


Expand Down
10 changes: 9 additions & 1 deletion src/transformers/models/gpt2/modeling_gpt2_moreh.py
Original file line number Diff line number Diff line change
Expand Up @@ -1016,6 +1016,13 @@ def __init__(self, config):
# Initialize weights and apply final processing
self.post_init()

# Moreh Config
self.moreh_pipeline_layers = []
moreh_config = getattr(config, "moreh_config", None)
if moreh_config is not None and "pipeline_layers" in moreh_config:
self.moreh_pipeline_layers = moreh_config["pipeline_layers"]


@add_start_docstrings(PARALLELIZE_DOCSTRING)
def parallelize(self, device_map=None):
# Check validity of device_map
Expand Down Expand Up @@ -1257,6 +1264,8 @@ def forward(
for k, v in self.device_map.items():
if i == v[-1] and "cuda:" + str(k) != self.last_device:
hidden_states = hidden_states.to("cuda:" + str(k + 1))
if i in self.moreh_pipeline_layers:
hidden_states = torch.moreh.pipeline_assign(hidden_states)

hidden_states = self.ln_f(hidden_states)

Expand Down Expand Up @@ -1293,7 +1302,6 @@ class GPT2LMHeadModelMoreh(GPT2PreTrainedModel):

def __init__(self, config):
super().__init__(config)
print("GPT2LMHeadModelMoreh ##################################")
self.transformer = GPT2Model(config)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def __init__(
rope_theta=10000.0,
sliding_window=4096,
attention_dropout=0.0,
moreh_config=None,
**kwargs,
):
self.vocab_size = vocab_size
Expand All @@ -138,6 +139,8 @@ def __init__(
self.rope_theta = rope_theta
self.attention_dropout = attention_dropout

self.moreh_config = moreh_config

super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
Expand Down
46 changes: 43 additions & 3 deletions src/transformers/models/mistral/modeling_mistral_moreh.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def forward(self, hidden_states):


class MistralRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, use_rope_cache=False):
super().__init__()

self.dim = dim
Expand All @@ -103,9 +103,36 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)

self.use_rope_cache = use_rope_cache
if self.use_rope_cache:
self._set_cos_sin_cache(max_position_embeddings, dtype=torch.float32)

def _set_cos_sin_cache(self, seq_len, dtype):
self.max_seq_len_cached = seq_len

t = torch.arange(seq_len, dtype=torch.float32, device="cpu")
freqs = torch.outer(t, self.inv_freq.cpu()) # [seq_len, dim/2]
emb = torch.cat((freqs, freqs), dim=-1) # [seq_len, dim]

cos = emb.cos()
sin = emb.sin()

cos = cos.to(device='cuda', dtype=dtype)
sin = sin.to(device='cuda', dtype=dtype)

self.register_buffer("cos_cached", cos, persistent=False)
self.register_buffer("sin_cached", sin, persistent=False)

@torch.no_grad()
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward
def forward(self, x, position_ids):
if self.use_rope_cache:
seq_len = position_ids.shape[-1]
assert seq_len <= self.max_position_embeddings, "Sequence length exceeds maximum position embeddings"
cos = self.cos_cached[:seq_len].to(dtype=x.dtype, device=x.device).unsqueeze(0)
sin = self.sin_cached[:seq_len].to(dtype=x.dtype, device=x.device).unsqueeze(0)
return cos, sin

# x: [bs, num_attention_heads, seq_len, head_size]
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
Expand Down Expand Up @@ -221,10 +248,16 @@ def __init__(self, config: MistralMorehConfig, layer_idx: Optional[int] = None):
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)

use_rope_cache = False
moreh_config = getattr(config, "moreh_config", None)
if moreh_config is not None and "rope_cache" in moreh_config:
use_rope_cache = moreh_config["rope_cache"]

self.rotary_emb = MistralRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
use_rope_cache=use_rope_cache,
)

def forward(
Expand Down Expand Up @@ -885,6 +918,12 @@ def __init__(self, config: MistralMorehConfig):
# Initialize weights and apply final processing
self.post_init()

# Moreh Config
self.moreh_pipeline_layers = []
moreh_config = getattr(config, "moreh_config", None)
if moreh_config is not None and "pipeline_layers" in moreh_config:
self.moreh_pipeline_layers = moreh_config["pipeline_layers"]

def get_input_embeddings(self):
return self.embed_tokens

Expand Down Expand Up @@ -957,7 +996,7 @@ def forward(
all_self_attns = () if output_attentions else None
next_decoder_cache = None

for decoder_layer in self.layers:
for layer_idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)

Expand All @@ -984,6 +1023,8 @@ def forward(
)

hidden_states = layer_outputs[0]
if layer_idx in self.moreh_pipeline_layers:
hidden_states = torch.moreh.pipeline_assign(hidden_states)

if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
Expand Down Expand Up @@ -1123,7 +1164,6 @@ class MistralForCausalLMMoreh(MistralPreTrainedModel):

def __init__(self, config):
super().__init__(config)
print("MistralForCausalLMMoreh #########################################")
self.model = MistralModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
Expand Down