From 7384efb763b4823c7b4918a46a85fc6b9da44c99 Mon Sep 17 00:00:00 2001 From: NouamaneTazi Date: Wed, 25 Oct 2023 19:57:32 +0000 Subject: [PATCH 01/26] implement GPT Neo's rope --- .../gpt_bigcode/modeling_gpt_bigcode.py | 196 ++++++++++++++---- 1 file changed, 160 insertions(+), 36 deletions(-) diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 387e78eb46..b7e447bb51 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -17,25 +17,23 @@ import torch import torch.utils.checkpoint +from configuration_gpt_bigcode import GPTBigCodeConfig from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss - -from ...activations import ACT2FN -from ...modeling_outputs import ( +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, SequenceClassifierOutputWithPast, TokenClassifierOutput, ) -from ...modeling_utils import PreTrainedModel -from ...utils import ( +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging, ) -from .configuration_gpt_bigcode import GPTBigCodeConfig - logger = logging.get_logger(__name__) @@ -93,6 +91,97 @@ def _apply_rotary_embeddings( return torch.view_as_real(complex_tensor * rope_frequencies).view_as(tensor).type_as(tensor) +# rotary pos emb helpers (torch.jit.script does not seem to support staticmethod...) +@torch.jit.script +def rotate_half(x): + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +class StarcoderRotaryEmbedding(nn.Module): + """Implementation of RotaryEmbedding from GPT-NeoX.""" + + def __init__(self, head_dim: int, base=10000): + super().__init__() + self.base = base + self.head_dim = head_dim + self.seq_len_cached = -1 + # TODO @nouamane: Figure out why we can't set `DTypeInvariantTensor` ... + self.inv_freq: torch.Tensor + self.register_buffer( + "inv_freq", + torch.empty(head_dim // 2, dtype=torch.float), + persistent=False, + ) + self.cos_cached: Optional[torch.Tensor] = None + self.sin_cached: Optional[torch.Tensor] = None + self._initialized_buffer = False + + def init_rotary_embeddings(self): + if self._initialized_buffer is True: + # Buffer if already initialized + return + + assert self.inv_freq.device.type == "cuda" + # TODO @nouamane: One we figure out how to do the DTypeInvariantTensor, this can be removed and changed to an assert + if self.inv_freq.dtype != torch.float: + self.inv_freq = self.inv_freq.to(torch.float) + assert self.inv_freq.dtype == torch.float + + self.inv_freq = 1.0 / ( + self.base ** (torch.arange(0, self.head_dim, 2, dtype=torch.float, device="cuda") / self.head_dim) + ) + + self._initialized_buffer = True + + def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype=torch.bfloat16) -> torch.Tensor: + total_length = seq_len + past_key_values_length + if total_length > self.seq_len_cached: + self.seq_len_cached = total_length + t = torch.arange(total_length, device=device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) # [seq_len, head_dim] + + if dtype in [torch.float16, torch.bfloat16]: + emb = emb.float() + + self.cos_cached = emb.cos()[None, :, None, :] # [1, seq_len, 1, head_dim] + self.sin_cached = emb.sin()[None, :, None, :] + + self.cos_cached = self.cos_cached.type(dtype) + self.sin_cached = self.sin_cached.type(dtype) + + return ( + self.cos_cached[:, past_key_values_length : seq_len + past_key_values_length], + self.sin_cached[:, past_key_values_length : seq_len + past_key_values_length], + ) + + def forward(self, query, key, past_key_values_length=0): + """ + Args: + query: [batch_size, seq_len, num_heads, head_dim] + key: [batch_size, seq_len, num_heads_k, head_dim] + past_key_values_length: int + + Returns: + query: [batch_size, seq_len, num_heads, head_dim] + key: [batch_size, seq_len, num_heads_k, head_dim] + """ + if self._initialized_buffer is False: + self.init_rotary_embeddings() + seq_len = query.shape[1] + cos, sin = self.cos_sin( + seq_len, past_key_values_length, query.device, query.dtype + ) # [1, seq_len, 1, head_dim] + query = (query * cos) + (rotate_half(query) * sin) + key = (key * cos) + (rotate_half(key) * sin) + if past_key_values_length > 0: + assert ( + query.shape[1] == 1 + ), f"past_key_values_length={past_key_values_length} but query.shape[1]={query.shape[1]}" + return query, key + + class GPTBigCodeAttention(nn.Module): def __init__(self, config, is_cross_attention=False, layer_idx=None): super().__init__() @@ -130,6 +219,13 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): else: self.c_attn = nn.Linear(self.embed_dim, self.embed_dim + 2 * self.kv_dim) + if self.use_rotary_embeddings: + self.maybe_rotary = ( + StarcoderRotaryEmbedding(head_dim=self.head_dim) + if config.use_rotary_embeddings + else lambda q, k, t: (q, k) + ) + self.c_proj = nn.Linear(self.embed_dim, self.embed_dim) self.attn_dropout = nn.Dropout(config.attn_pdrop) @@ -229,7 +325,7 @@ def forward( use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, rotary_embedding_frequencies_q: Optional[torch.Tensor] = None, - rotary_embedding_frequencies_k: Optional[torch.Tensor] = None + rotary_embedding_frequencies_k: Optional[torch.Tensor] = None, ) -> Union[ Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]], @@ -246,6 +342,8 @@ def forward( attention_mask = encoder_attention_mask elif self.multi_query: query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2) + # query shape: (batch_size, query_length, num_heads * head_dim) + # key_value shape: (batch_size, query_length, 2 * 1 * head_dim) else: # Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim), # i.e., the memory layout is not the same as GPT2. @@ -258,14 +356,33 @@ def forward( ) if layer_past is not None: - key_value = torch.cat((layer_past, key_value), dim=-2) - present = key_value if use_cache else None + past_kv_length = layer_past[0].shape[-2] + else: + past_kv_length = 0 - key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) + key, value = key_value.split( + (self.head_dim, self.head_dim), dim=-1 + ) # (batch_size, query_length, 1 * head_dim) if self.use_rotary_embeddings: - query = _apply_rotary_embeddings(query, rotary_embedding_frequencies_q) - key = _apply_rotary_embeddings(key, rotary_embedding_frequencies_k) + # query = _apply_rotary_embeddings(query, rotary_embedding_frequencies_q) + # key = _apply_rotary_embeddings(key, rotary_embedding_frequencies_k) + query = query.view(*query.shape[:2], self.num_heads, self.head_dim) + key = key.view(*key.shape[:2], 1, self.head_dim) + query, key = self.maybe_rotary(query, key, past_kv_length) + query = query.view(*query.shape[:2], self.num_heads * self.head_dim) + key = key.view(*key.shape[:2], 1 * self.head_dim) + + if layer_past is not None: + # Concatenate past key/values with new key/values. + if self.use_rotary_embeddings: + key_value = torch.cat((key, value), dim=-1) + key_value = torch.cat((layer_past, key_value), dim=-2) + key, value = key_value.split( + (self.head_dim, self.head_dim), dim=-1 + ) # (batch_size, key_length, 1 * head_dim) + + present = key_value if use_cache else None attn_output, attn_weights = self._attn(query, key.transpose(-1, -2), value, attention_mask, head_mask) @@ -331,7 +448,7 @@ def forward( use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, rotary_embedding_frequencies_q: Optional[torch.Tensor] = None, - rotary_embedding_frequencies_k: Optional[torch.Tensor] = None + rotary_embedding_frequencies_k: Optional[torch.Tensor] = None, ) -> Union[ Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor] ]: @@ -345,7 +462,7 @@ def forward( use_cache=use_cache, output_attentions=output_attentions, rotary_embedding_frequencies_q=rotary_embedding_frequencies_q, - rotary_embedding_frequencies_k=rotary_embedding_frequencies_k + rotary_embedding_frequencies_k=rotary_embedding_frequencies_k, ) attn_output = attn_outputs[0] # output_attn: a, present, (attentions) outputs = attn_outputs[1:] @@ -369,7 +486,7 @@ def forward( encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, rotary_embedding_frequencies_q=rotary_embedding_frequencies_q, - rotary_embedding_frequencies_k=rotary_embedding_frequencies_k + rotary_embedding_frequencies_k=rotary_embedding_frequencies_k, ) attn_output = cross_attn_outputs[0] # residual connection @@ -536,20 +653,17 @@ def __init__(self, config): self.wte = nn.Embedding(config.vocab_size, self.embed_dim) if config.use_position_embeddings: self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) - if config.use_rotary_embeddings: - # Calculate the complex frequencies (https://blog.eleuther.ai/rotary-embeddings/) - # `exp(i * n * a) = cos(n * a) + i sin(n * a)`, - # `a = theta ** - (2 * (channel // 2) / kv_channels)`, - # where n is the position in the sequence. - kv_channels = config.n_embd / config.n_head - angles = torch.outer( - torch.arange(config.max_position_embeddings, dtype=torch.float32), - torch.exp( - config.rotary_embedding_scale - * torch.arange(0, 1, 2 / kv_channels, dtype=torch.float32) - ), - ) - self._rotary_embedding_frequencies = torch.polar(torch.ones_like(angles), angles)[None, :, None, :] + # if config.use_rotary_embeddings: + # # Calculate the complex frequencies (https://blog.eleuther.ai/rotary-embeddings/) + # # `exp(i * n * a) = cos(n * a) + i sin(n * a)`, + # # `a = theta ** - (2 * (channel // 2) / kv_channels)`, + # # where n is the position in the sequence. + # kv_channels = config.n_embd / config.n_head + # angles = torch.outer( + # torch.arange(config.max_position_embeddings, dtype=torch.float32), + # torch.exp(config.rotary_embedding_scale * torch.arange(0, 1, 2 / kv_channels, dtype=torch.float32)), + # ) + # self._rotary_embedding_frequencies = torch.polar(torch.ones_like(angles), angles)[None, :, None, :] self.drop = nn.Dropout(config.embd_pdrop) self.h = nn.ModuleList([GPTBigCodeBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)]) @@ -638,13 +752,17 @@ def forward( elif position_ids is None: position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) - + # Rotary frequencies rotary_embedding_frequencies_q = None rotary_embedding_frequencies_k = None - if self.config.use_rotary_embeddings: - rotary_embedding_frequencies_q = self._rotary_embedding_frequencies[:, past_length : past_length + input_shape[-1]].to(device=device) - rotary_embedding_frequencies_k = self._rotary_embedding_frequencies[:, :past_length + input_shape[-1], :, :].to(device=device) + # if self.config.use_rotary_embeddings: + # rotary_embedding_frequencies_q = self._rotary_embedding_frequencies[ + # :, past_length : past_length + input_shape[-1] + # ].to(device=device) + # rotary_embedding_frequencies_k = self._rotary_embedding_frequencies[ + # :, : past_length + input_shape[-1], :, : + # ].to(device=device) # Self-attention mask. query_length = input_shape[-1] @@ -712,7 +830,13 @@ def forward( def create_custom_forward(module): def custom_forward(*inputs): # None for past_key_value - return module(*inputs, use_cache, output_attentions, rotary_embedding_frequencies_q, rotary_embedding_frequencies_k) + return module( + *inputs, + use_cache, + output_attentions, + rotary_embedding_frequencies_q, + rotary_embedding_frequencies_k, + ) return custom_forward @@ -736,7 +860,7 @@ def custom_forward(*inputs): use_cache=use_cache, output_attentions=output_attentions, rotary_embedding_frequencies_q=rotary_embedding_frequencies_q, - rotary_embedding_frequencies_k=rotary_embedding_frequencies_k + rotary_embedding_frequencies_k=rotary_embedding_frequencies_k, ) hidden_states = outputs[0] From 7ebc9ea7fea17fb0db2713a2dc390260a8de3def Mon Sep 17 00:00:00 2001 From: NouamaneTazi Date: Sat, 28 Oct 2023 23:28:57 +0000 Subject: [PATCH 02/26] fix imports --- .../models/gpt_bigcode/modeling_gpt_bigcode.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index b7e447bb51..5fac29ef17 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -17,23 +17,25 @@ import torch import torch.utils.checkpoint -from configuration_gpt_bigcode import GPTBigCodeConfig from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from transformers.activations import ACT2FN -from transformers.modeling_outputs import ( + +from ...activations import ACT2FN +from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, SequenceClassifierOutputWithPast, TokenClassifierOutput, ) -from transformers.modeling_utils import PreTrainedModel -from transformers.utils import ( +from ...modeling_utils import PreTrainedModel +from ...utils import ( add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging, ) +from .configuration_gpt_bigcode import GPTBigCodeConfig + logger = logging.get_logger(__name__) From 2c28b04f8021a3a8809882388201db462ace8b7e Mon Sep 17 00:00:00 2001 From: NouamaneTazi Date: Sat, 28 Oct 2023 23:29:11 +0000 Subject: [PATCH 03/26] output logits --- src/transformers/generation/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 3b1bef6f04..9cfbe9e7dc 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2469,7 +2469,9 @@ def greedy_search( # Store scores, attentions and hidden_states when required if return_dict_in_generate: if output_scores: - scores += (next_tokens_scores,) + scores += (next_tokens_scores,) if outputs.logits.shape[1] == 1 else ( + outputs.logits, + ) if output_attentions: decoder_attentions += ( (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) From 253da5b98139cf7d975579933538797b8b3ecb96 Mon Sep 17 00:00:00 2001 From: NouamaneTazi Date: Mon, 30 Oct 2023 09:41:57 +0000 Subject: [PATCH 04/26] attn mask.all() --- src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 5fac29ef17..e34409c0d1 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -715,6 +715,8 @@ def forward( ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if attention_mask is not None: + assert attention_mask.all(), f"attention_mask: {attention_mask}" if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") From d3c15bad80f58e45ba03bcd9969cf0449aefd4bf Mon Sep 17 00:00:00 2001 From: NouamaneTazi Date: Tue, 31 Oct 2023 13:56:24 +0000 Subject: [PATCH 05/26] fix caching in rope --- .../models/gpt_bigcode/modeling_gpt_bigcode.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index e34409c0d1..ee7a01e210 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -140,6 +140,7 @@ def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype total_length = seq_len + past_key_values_length if total_length > self.seq_len_cached: self.seq_len_cached = total_length + assert self.inv_freq.dtype == torch.float t = torch.arange(total_length, device=device, dtype=self.inv_freq.dtype) freqs = torch.einsum("i,j->ij", t, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1) # [seq_len, head_dim] @@ -358,7 +359,7 @@ def forward( ) if layer_past is not None: - past_kv_length = layer_past[0].shape[-2] + past_kv_length = layer_past.shape[-2] else: past_kv_length = 0 @@ -371,21 +372,25 @@ def forward( # key = _apply_rotary_embeddings(key, rotary_embedding_frequencies_k) query = query.view(*query.shape[:2], self.num_heads, self.head_dim) key = key.view(*key.shape[:2], 1, self.head_dim) + # print(past_kv_length) + # print("before",key[0,:,0,0]) query, key = self.maybe_rotary(query, key, past_kv_length) + # print("after",key[0,:,0,0]) query = query.view(*query.shape[:2], self.num_heads * self.head_dim) key = key.view(*key.shape[:2], 1 * self.head_dim) + if use_cache: + key_value = torch.cat((key, value), dim=-1) + if layer_past is not None: # Concatenate past key/values with new key/values. - if self.use_rotary_embeddings: - key_value = torch.cat((key, value), dim=-1) key_value = torch.cat((layer_past, key_value), dim=-2) key, value = key_value.split( (self.head_dim, self.head_dim), dim=-1 ) # (batch_size, key_length, 1 * head_dim) + # print("after past", key[0,:,0]) present = key_value if use_cache else None - attn_output, attn_weights = self._attn(query, key.transpose(-1, -2), value, attention_mask, head_mask) if not self.multi_query: From 1e746646d710f370c16e427079a406785ba88200 Mon Sep 17 00:00:00 2001 From: NouamaneTazi Date: Wed, 1 Nov 2023 10:33:32 +0000 Subject: [PATCH 06/26] GQA generation without cache --- .../gpt_bigcode/configuration_gpt_bigcode.py | 15 ++++ .../gpt_bigcode/modeling_gpt_bigcode.py | 84 +++++++++++++++---- 2 files changed, 85 insertions(+), 14 deletions(-) diff --git a/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py index 3c5cff3a7e..a71c362ad7 100644 --- a/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py @@ -50,6 +50,14 @@ class GPTBigCodeConfig(PretrainedConfig): Number of hidden layers in the Transformer encoder. n_head (`int`, *optional*, defaults to 12): Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. n_inner (`int`, *optional*, defaults to None): Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd activation_function (`str`, *optional*, defaults to `"gelu_pytorch_tanh"`): @@ -106,6 +114,7 @@ def __init__( n_embd=768, n_layer=12, n_head=12, + num_key_value_heads=None, n_inner=None, activation_function="gelu_pytorch_tanh", resid_pdrop=0.1, @@ -131,6 +140,12 @@ def __init__( self.n_embd = n_embd self.n_layer = n_layer self.n_head = n_head + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = 1 if multi_query else n_head + self.num_key_value_heads = num_key_value_heads + self.n_inner = n_inner self.activation_function = activation_function self.resid_pdrop = resid_pdrop diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index ee7a01e210..e0d32ece5f 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -190,12 +190,13 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): super().__init__() self.mask_value = None + # self.multi_query = config.num_key_value_heads == 1 self.multi_query = config.multi_query + assert not config.multi_query or (config.multi_query and config.num_key_value_heads == 1), f"Got MQA with num_key_value_heads = {config.num_key_value_heads}" self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads - self.kv_heads = 1 if self.multi_query else self.num_heads - self.kv_dim = self.kv_heads * self.head_dim + self.kv_heads = config.num_key_value_heads self.split_size = self.embed_dim if self.head_dim * self.num_heads != self.embed_dim: raise ValueError( @@ -216,11 +217,13 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): if self.is_cross_attention: if self.multi_query: raise NotImplementedError("Multi-Query Attention not supported for cross_attention") + if self.kv_heads != self.num_heads: + raise NotImplementedError("Cross-Attention not supported for num_key_value_heads != num_attention_heads") self.c_attn = nn.Linear(self.embed_dim, 2 * self.embed_dim) self.q_attn = nn.Linear(self.embed_dim, self.embed_dim) else: - self.c_attn = nn.Linear(self.embed_dim, self.embed_dim + 2 * self.kv_dim) + self.c_attn = nn.Linear(self.embed_dim, self.embed_dim + 2 * self.kv_heads * self.head_dim) if self.use_rotary_embeddings: self.maybe_rotary = ( @@ -241,6 +244,13 @@ def _get_mask_value(self, device, dtype): return self.mask_value def _attn(self, query, key, value, attention_mask=None, head_mask=None): + """ + Args: + query (batch_size, query_length, num_heads * head_dim) + key (batch_size, kv_heads*head_dim, key_length) + value (batch_size, key_length, kv_heads, head_dim) + + """ dtype = query.dtype softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else dtype upcast = dtype != softmax_dtype @@ -263,6 +273,31 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): attn_view = (batch_size, query_length * self.num_heads, key_length) # No copy needed for MQA 2, or when layer_past is provided. query = query.reshape(batch_size, query_length * self.num_heads, self.head_dim) + # value shape (batch_size, key_length, 1, head_dim) + # print("value", value.shape) + # print(value[0, :, 0, 0]) + elif self.kv_heads < self.num_heads: # GQA + # (batch_size, query_length, num_heads, head_dim) x (batch_size, kv_heads, head_dim, key_length) + # -> (batch_size, num_heads, query_length, key_length) + query_length = query_shape[1] + attn_shape = (batch_size, self.num_heads, query_length, key_length) + attn_view = (batch_size * self.num_heads, query_length, key_length) + # Always copies + query = query.view(batch_size, query_length, self.num_heads, self.head_dim) + query = query.transpose(1,2).reshape(batch_size * self.num_heads, query_length, self.head_dim) + # query = query.reshape(batch_size, query_length * self.num_heads, self.head_dim) + # Need to repeat key/value for each query -> we need kv_heads dim to precede key_length dim + # key shape: (batch_size, kv_heads, head_dim, key_length) + n_repeats = self.num_heads // self.kv_heads + key = (key.view(batch_size, self.kv_heads, 1, self.head_dim, key_length) + .expand(batch_size, self.kv_heads, n_repeats, self.head_dim, key_length) + .reshape(batch_size * self.num_heads, self.head_dim, key_length)) + value = value.transpose(1,2) + value = (value.view(batch_size, self.kv_heads, 1, key_length, self.head_dim) + .expand(batch_size, self.kv_heads, n_repeats, key_length, self.head_dim) + .reshape(batch_size * self.num_heads, key_length, self.head_dim)) + # print("value", value.shape) + # print(value[0, :, 0]) else: # (batch_size, num_heads, query_length, head_dim) x (batch_size, num_heads, head_dim, key_length) # -> (batch_size, num_heads, query_length, key_length) @@ -274,6 +309,11 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): # No copy when layer_past is provided. key = key.reshape(batch_size * self.num_heads, self.head_dim, key_length) + # print("query", query.shape) + # print(query[0, :query_length, 0]) + # print("key", key.shape) + # print(key[0, 0, :]) + attn_weights = torch.empty(attn_view, device=query.device, dtype=query.dtype) if query.device.type == "cpu": # This is needed because of a bug in pytorch https://github.com/pytorch/pytorch/issues/80588. @@ -284,6 +324,11 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): else: beta = 0 attn_weights = torch.baddbmm(attn_weights, query, key, beta=beta, alpha=scale_factor).view(attn_shape) + # if not self.multi_query: + # attn_weights = attn_weights.transpose(1, 2) # (batch_size, query_length, num_heads, key_length) + # print("attn_weights", attn_weights.shape) + # print(attn_weights[0, :, 0, :]) + # assert False, "done" if upcast: # Use a fused kernel to prevent a large overhead from casting and scaling. @@ -311,6 +356,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): attn_weights = attn_weights * head_mask if self.multi_query: + # value shape: (batch_size, key_length, 1, head_dim) attn_output = torch.bmm(attn_weights.view(attn_view), value).view(query_shape) else: attn_output = torch.matmul(attn_weights, value) @@ -333,6 +379,8 @@ def forward( Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]], ]: + # self.multi_query = False + # self.multi_query = True if encoder_hidden_states is not None: if not hasattr(self, "q_attn") or not self.is_cross_attention: raise ValueError( @@ -344,43 +392,51 @@ def forward( key_value = self.c_attn(encoder_hidden_states) attention_mask = encoder_attention_mask elif self.multi_query: - query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2) + query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_heads * self.head_dim), dim=2) # query shape: (batch_size, query_length, num_heads * head_dim) # key_value shape: (batch_size, query_length, 2 * 1 * head_dim) + key, value = key_value.split( + (self.head_dim, self.head_dim), dim=-1 + ) # (batch_size, query_length, 1 * head_dim) else: # Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim), # i.e., the memory layout is not the same as GPT2. # This makes the concatenation with past_key_value more efficient. - query, key_value = ( + + n_repeats = self.num_heads // self.kv_heads + + query, key, value = ( self.c_attn(hidden_states) - .view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim) - .transpose(1, 2) - .split((self.head_dim, 2 * self.head_dim), dim=3) - ) + .view(*hidden_states.shape[:2], self.kv_heads, n_repeats + 2, self.head_dim) + .split((n_repeats, 1, 1), dim=3) + ) + # print("c_attn query", query.shape) + # print("c_attn key", key.shape) + # query shape: (batch_size, query_length, kv_heads, n_repeats, head_dim) + # key, value shape: (batch_size, query_length, kv_heads, 1, head_dim) if layer_past is not None: past_kv_length = layer_past.shape[-2] else: past_kv_length = 0 - key, value = key_value.split( - (self.head_dim, self.head_dim), dim=-1 - ) # (batch_size, query_length, 1 * head_dim) if self.use_rotary_embeddings: # query = _apply_rotary_embeddings(query, rotary_embedding_frequencies_q) # key = _apply_rotary_embeddings(key, rotary_embedding_frequencies_k) query = query.view(*query.shape[:2], self.num_heads, self.head_dim) - key = key.view(*key.shape[:2], 1, self.head_dim) + key = key.view(*key.shape[:2], self.kv_heads, self.head_dim) # print(past_kv_length) # print("before",key[0,:,0,0]) query, key = self.maybe_rotary(query, key, past_kv_length) # print("after",key[0,:,0,0]) query = query.view(*query.shape[:2], self.num_heads * self.head_dim) - key = key.view(*key.shape[:2], 1 * self.head_dim) + key = key.view(*key.shape[:2], self.kv_heads * self.head_dim) + value = value.view(*value.shape[:2], self.kv_heads, self.head_dim) if use_cache: key_value = torch.cat((key, value), dim=-1) + # TODO @nouamane: do we need to concat? if layer_past is not None: # Concatenate past key/values with new key/values. From 1f424bb5b28e1c6cc8f33734a60a8a31c897b4a5 Mon Sep 17 00:00:00 2001 From: NouamaneTazi Date: Wed, 1 Nov 2023 10:41:04 +0000 Subject: [PATCH 07/26] fix use_cache for GQA --- .../models/gpt_bigcode/modeling_gpt_bigcode.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index e0d32ece5f..ca0b250dea 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -432,7 +432,7 @@ def forward( # print("after",key[0,:,0,0]) query = query.view(*query.shape[:2], self.num_heads * self.head_dim) key = key.view(*key.shape[:2], self.kv_heads * self.head_dim) - value = value.view(*value.shape[:2], self.kv_heads, self.head_dim) + value = value.view(*value.shape[:2], self.kv_heads*self.head_dim) if use_cache: key_value = torch.cat((key, value), dim=-1) @@ -440,13 +440,14 @@ def forward( if layer_past is not None: # Concatenate past key/values with new key/values. - key_value = torch.cat((layer_past, key_value), dim=-2) + key_value = torch.cat((layer_past, key_value), dim=1) key, value = key_value.split( - (self.head_dim, self.head_dim), dim=-1 - ) # (batch_size, key_length, 1 * head_dim) + (self.kv_heads * self.head_dim, self.kv_heads * self.head_dim), dim=-1 + ) # (batch_size, key_length, kv_heads * head_dim) # print("after past", key[0,:,0]) present = key_value if use_cache else None + value = value.view(*value.shape[:2], self.kv_heads, self.head_dim) attn_output, attn_weights = self._attn(query, key.transpose(-1, -2), value, attention_mask, head_mask) if not self.multi_query: From 39a3483488a989328ab5b93e6ef54d8ccad4d29d Mon Sep 17 00:00:00 2001 From: NouamaneTazi Date: Wed, 1 Nov 2023 11:03:12 +0000 Subject: [PATCH 08/26] reshapes fixes for num_heads=2 --- src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index ca0b250dea..c2d47f3fd8 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -424,7 +424,7 @@ def forward( if self.use_rotary_embeddings: # query = _apply_rotary_embeddings(query, rotary_embedding_frequencies_q) # key = _apply_rotary_embeddings(key, rotary_embedding_frequencies_k) - query = query.view(*query.shape[:2], self.num_heads, self.head_dim) + query = query.reshape(*query.shape[:2], self.num_heads, self.head_dim) key = key.view(*key.shape[:2], self.kv_heads, self.head_dim) # print(past_kv_length) # print("before",key[0,:,0,0]) @@ -432,7 +432,7 @@ def forward( # print("after",key[0,:,0,0]) query = query.view(*query.shape[:2], self.num_heads * self.head_dim) key = key.view(*key.shape[:2], self.kv_heads * self.head_dim) - value = value.view(*value.shape[:2], self.kv_heads*self.head_dim) + value = value.reshape(*value.shape[:2], self.kv_heads*self.head_dim) if use_cache: key_value = torch.cat((key, value), dim=-1) From 1c79ecd26ac25a176b66893897f230f1fd0cbd83 Mon Sep 17 00:00:00 2001 From: NouamaneTazi Date: Thu, 2 Nov 2023 02:16:35 +0000 Subject: [PATCH 09/26] . --- src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index c2d47f3fd8..3bb392648c 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -289,13 +289,15 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): # Need to repeat key/value for each query -> we need kv_heads dim to precede key_length dim # key shape: (batch_size, kv_heads, head_dim, key_length) n_repeats = self.num_heads // self.kv_heads + # TODO @nouamane: refactor this key = (key.view(batch_size, self.kv_heads, 1, self.head_dim, key_length) .expand(batch_size, self.kv_heads, n_repeats, self.head_dim, key_length) .reshape(batch_size * self.num_heads, self.head_dim, key_length)) + value = value.view(batch_size, key_length, self.kv_heads, self.head_dim) value = value.transpose(1,2) value = (value.view(batch_size, self.kv_heads, 1, key_length, self.head_dim) .expand(batch_size, self.kv_heads, n_repeats, key_length, self.head_dim) - .reshape(batch_size * self.num_heads, key_length, self.head_dim)) + .reshape(batch_size, self.num_heads, key_length, self.head_dim)) # print("value", value.shape) # print(value[0, :, 0]) else: @@ -447,7 +449,6 @@ def forward( # print("after past", key[0,:,0]) present = key_value if use_cache else None - value = value.view(*value.shape[:2], self.kv_heads, self.head_dim) attn_output, attn_weights = self._attn(query, key.transpose(-1, -2), value, attention_mask, head_mask) if not self.multi_query: From 19cf15318b12eab52a21f1fb2a677d82472832c3 Mon Sep 17 00:00:00 2001 From: NouamaneTazi Date: Thu, 7 Dec 2023 11:56:52 +0000 Subject: [PATCH 10/26] add flash_attn_with_kvcache to GQA --- .../gpt_bigcode/modeling_gpt_bigcode.py | 246 +++++++++++++++--- 1 file changed, 210 insertions(+), 36 deletions(-) diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 3bb392648c..f1db907b58 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -14,7 +14,11 @@ """PyTorch GPTBigCode model.""" import math from typing import List, Optional, Tuple, Union - +from flash_attn import bert_padding +from flash_attn.flash_attn_interface import ( + flash_attn_varlen_func, + flash_attn_with_kvcache, +) import torch import torch.utils.checkpoint from torch import nn @@ -185,6 +189,33 @@ def forward(self, query, key, past_key_values_length=0): return query, key +def pad_to_right(tensor, mask, new_tensor=None): + """Transform a left-padded tensor into a right-padded tensor. (Useful for prefilling key/value states) + Args: + tensor: (batch_size, seqlen, d1, d2) + mask: (batch_size, seqlen) + new_tensor: (batch_size, new_tensor_seqlen, d1, d2) + Returns: + new_tensor: (batch_size, new_tensor_seqlen, d1, d2) + right_padded_mask: (batch_size, seqlen) + """ + # First, we need to find the number of padding for each row + unpad_seqlens = mask.sum(1) + # Then, we need to find the maximum length of the tensor + max_seqlen = mask.shape[1] + # We can then create the indices to select the padded values + # The indices are the same for each row + indices = torch.arange(max_seqlen, device=mask.device) + # We can then create the mask for the padded values + right_padded_mask = indices < unpad_seqlens[:, None] + # We select the useful values + useful_values = tensor[mask] + # We create the new tensor (if not provided) + new_tensor = torch.zeros_like(tensor) if new_tensor is None else new_tensor + # We fill the new tensor with the useful values + new_tensor[:, : right_padded_mask.shape[1], :, :][right_padded_mask] = useful_values + return new_tensor, right_padded_mask + class GPTBigCodeAttention(nn.Module): def __init__(self, config, is_cross_attention=False, layer_idx=None): super().__init__() @@ -237,6 +268,14 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): self.attn_dropout = nn.Dropout(config.attn_pdrop) self.resid_dropout = nn.Dropout(config.resid_pdrop) + self.prefill_kv_len = ( + config.max_position_embeddings + ) # TODO @nouamane: compute based on free memory, because in rope we can surpass max_position_embeddings + self.n_local_kv_heads = self.kv_heads + self.n_local_q_heads = self.num_heads + self.n_repeats = self.num_heads // self.kv_heads + + def _get_mask_value(self, device, dtype): # torch.where expects a tensor. We use a cache to avoid recreating it every time. if self.mask_value is None or self.mask_value.dtype != dtype or self.mask_value.device != device: @@ -365,11 +404,135 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): return attn_output, attn_weights + def _flash_attn(self, hidden_states, layer_past, sequence_mask, position_ids): + # hidden_states (batch_size, query_length, embed_dim) + # layer_past (batch_size, past_key_values_length, embed_dim) + # sequence_mask (batch_size, query_length) + + fused_qkv = self.c_attn( + hidden_states + ) # [batch_size, query_length, n_local_q_heads * head_dim + 2 * n_local_kv_heads * head_dim] + batch_size, q_length, _ = fused_qkv.size() + + qkv = fused_qkv.view(batch_size, q_length, self.n_local_kv_heads, self.n_repeats + 2, self.head_dim) + query, key, value = torch.split(qkv, [self.n_repeats, 1, 1], dim=3) + query_states = query.reshape( + batch_size, q_length, self.n_local_q_heads, self.head_dim + ) + key_states = key.reshape(batch_size, q_length, self.n_local_kv_heads, self.head_dim) + value_states = value.reshape(batch_size, q_length, self.n_local_kv_heads, self.head_dim) + + # Compute rotary embeddings + if layer_past is None: + past_key_values_length = 0 + else: + past_key_values_length = layer_past.shape[1] + query_states, key_states = self.maybe_rotary( + query_states, key_states, past_key_values_length=past_key_values_length + ) + + if layer_past is None: + # First inference iteration (Prefill) + # TODO @nouamane: support custom masking + # assert that [ False, False, False, False, True, True, True, True, True, True] is accepted + # but [ False, False, False, False, True, True, False, False, True, True] is not (can't mask in the middle of sequence) + assert ~( + sequence_mask[:, :-1] & (~sequence_mask[:, 1:]) # True is never followed by False + ).any(), f"Can't mask in the middle of sequence, please use USE_FAST=0 instead.\nGot sequence_mask: {sequence_mask}" + + # preallocate k_cache, v_cache to self.prefill_kv_len + k_cache = torch.zeros( + ( + batch_size, + self.prefill_kv_len, + self.n_local_kv_heads, + self.head_dim, + ), + dtype=query_states.dtype, + device=query_states.device, + ) + v_cache = torch.zeros( + (batch_size, self.prefill_kv_len, self.n_local_kv_heads, self.head_dim), + dtype=query_states.dtype, + device=query_states.device, + ) + self.register_buffer("k_cache", k_cache) + self.register_buffer("v_cache", v_cache) + + # Remove pad tokens from key_states and concatenate samples in key_unpad + # cu_seqlens_k is the cumulative sequence lengths of key_states + (query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input( + query_states, + sequence_mask, + ) + (key_unpad, indices_k, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input( + key_states, sequence_mask + ) + (value_unpad, _, _, _) = bert_padding.unpad_input(value_states, sequence_mask) + + output_unpad = flash_attn_varlen_func( + q=query_unpad, # (total_q, self.n_local_q_heads, d_qk) + k=key_unpad, # (total_kv, self.n_local_kv_heads, d_qk) + v=value_unpad, # (total_kv, self.n_local_kv_heads, d_v) + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + dropout_p=0.0, + softmax_scale=None, + causal=True, # True in prefill phase, False in subsequent phases + return_attn_probs=False, + ) # (total_unpadded, n_local_q_heads, d_v) + + attention_output = bert_padding.pad_input( + output_unpad, indices_q, batch_size, q_length + ) # (batch_size, q_length, n_local_q_heads, d_v) + pad_to_right(key_states, sequence_mask, new_tensor=k_cache) + pad_to_right(value_states, sequence_mask, new_tensor=v_cache) + + else: + # Pull pre-computed key/value states + # Subsequent inference iterations (q_length=1) + k_cache = self.k_cache + v_cache = self.v_cache + # TODO: remove layer_past as it's redundant with k_cache, v_cache + + # [batch_size, seq_length, num_heads, d_qk] + query_states = query_states.view( + batch_size, q_length, self.n_local_q_heads, self.head_dim + ) # [batch_size, q_length, self.n_local_q_heads, self.head_dim] + kv_length = key_states.shape[1] + key_states = key_states.view( + batch_size, kv_length, self.n_local_kv_heads, self.head_dim + ) # [batch_size, kv_length, self.n_local_kv_heads, self.head_dim] + value_states = value_states.view( + batch_size, kv_length, self.n_local_kv_heads, self.head_dim + ) # [batch_size, kv_length, self.n_local_kv_heads, self.head_dim] + + position_offsets = position_ids[:, -1] + # assert False + attention_output = flash_attn_with_kvcache( + query_states, + k_cache, + v_cache, + key_states, + value_states, + rotary_cos=None, + rotary_sin=None, + # TODO @nouamane: seems like this doesnt help to indicate padding in (for first iteration it's just 0) + cache_seqlens=position_offsets.contiguous(), + softmax_scale=None, + causal=True, + rotary_interleaved=False, # GPT-NeoX style + ) + + return key_states, value_states, attention_output + def forward( self, - hidden_states: torch.Tensor, - layer_past: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, + hidden_states: torch.Tensor, # (batch_size, query_length, embed_dim) + layer_past: Optional[torch.Tensor] = None, # (batch_size, past_key_values_length, embed_dim) + attention_mask: Optional[torch.Tensor] = None, # (batch_size, query_length) head_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, @@ -377,10 +540,12 @@ def forward( output_attentions: Optional[bool] = False, rotary_embedding_frequencies_q: Optional[torch.Tensor] = None, rotary_embedding_frequencies_k: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, ) -> Union[ Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]], ]: + USE_FLASH_ATTN = True # self.multi_query = False # self.multi_query = True if encoder_hidden_states is not None: @@ -400,22 +565,25 @@ def forward( key, value = key_value.split( (self.head_dim, self.head_dim), dim=-1 ) # (batch_size, query_length, 1 * head_dim) - else: - # Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim), - # i.e., the memory layout is not the same as GPT2. - # This makes the concatenation with past_key_value more efficient. + else: # GQA + if use_cache and USE_FLASH_ATTN: + key, value, attn_output = self._flash_attn(hidden_states, layer_past, attention_mask, position_ids) + attn_output = attn_output.view(hidden_states.shape) - n_repeats = self.num_heads // self.kv_heads + else: + # Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim), + # i.e., the memory layout is not the same as GPT2. + # This makes the concatenation with past_key_value more efficient. - query, key, value = ( - self.c_attn(hidden_states) - .view(*hidden_states.shape[:2], self.kv_heads, n_repeats + 2, self.head_dim) - .split((n_repeats, 1, 1), dim=3) - ) - # print("c_attn query", query.shape) - # print("c_attn key", key.shape) - # query shape: (batch_size, query_length, kv_heads, n_repeats, head_dim) - # key, value shape: (batch_size, query_length, kv_heads, 1, head_dim) + n_repeats = self.num_heads // self.kv_heads + + query, key, value = ( + self.c_attn(hidden_states) + .view(*hidden_states.shape[:2], self.kv_heads, n_repeats + 2, self.head_dim) + .split((n_repeats, 1, 1), dim=3) + ) + # query shape: (batch_size, query_length, kv_heads, n_repeats, head_dim) + # key, value shape: (batch_size, query_length, kv_heads, 1, head_dim) if layer_past is not None: past_kv_length = layer_past.shape[-2] @@ -423,7 +591,7 @@ def forward( past_kv_length = 0 - if self.use_rotary_embeddings: + if self.use_rotary_embeddings and not USE_FLASH_ATTN: # query = _apply_rotary_embeddings(query, rotary_embedding_frequencies_q) # key = _apply_rotary_embeddings(key, rotary_embedding_frequencies_k) query = query.reshape(*query.shape[:2], self.num_heads, self.head_dim) @@ -438,6 +606,7 @@ def forward( if use_cache: key_value = torch.cat((key, value), dim=-1) + key_value = key_value.view(*key_value.shape[:2], 2 * self.kv_heads * self.head_dim) # TODO @nouamane: do we need to concat? if layer_past is not None: @@ -446,13 +615,13 @@ def forward( key, value = key_value.split( (self.kv_heads * self.head_dim, self.kv_heads * self.head_dim), dim=-1 ) # (batch_size, key_length, kv_heads * head_dim) - # print("after past", key[0,:,0]) present = key_value if use_cache else None - attn_output, attn_weights = self._attn(query, key.transpose(-1, -2), value, attention_mask, head_mask) + if not USE_FLASH_ATTN: + attn_output, attn_weights = self._attn(query, key.transpose(-1, -2), value, attention_mask, head_mask) - if not self.multi_query: - attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape) + if not self.multi_query: + attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape) attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) @@ -461,7 +630,7 @@ def forward( if self.multi_query: # Transpose to return weights in the usual format (batch_size, num_heads, query_length, key_length) attn_weights = attn_weights.transpose(1, 2) - outputs += (attn_weights,) + outputs += (attn_weights,) #TODO: fix for GQA return outputs # a, present, (attentions) @@ -514,6 +683,7 @@ def forward( output_attentions: Optional[bool] = False, rotary_embedding_frequencies_q: Optional[torch.Tensor] = None, rotary_embedding_frequencies_k: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, ) -> Union[ Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor] ]: @@ -528,6 +698,7 @@ def forward( output_attentions=output_attentions, rotary_embedding_frequencies_q=rotary_embedding_frequencies_q, rotary_embedding_frequencies_k=rotary_embedding_frequencies_k, + position_ids=position_ids, ) attn_output = attn_outputs[0] # output_attn: a, present, (attentions) outputs = attn_outputs[1:] @@ -552,6 +723,7 @@ def forward( output_attentions=output_attentions, rotary_embedding_frequencies_q=rotary_embedding_frequencies_q, rotary_embedding_frequencies_k=rotary_embedding_frequencies_k, + position_ids=position_ids, ) attn_output = cross_attn_outputs[0] # residual connection @@ -778,8 +950,6 @@ def forward( ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if attention_mask is not None: - assert attention_mask.all(), f"attention_mask: {attention_mask}" if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") @@ -819,6 +989,7 @@ def forward( elif position_ids is None: position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + position_ids = position_ids.to(torch.int32) # For flash-attn with kv cache # Rotary frequencies rotary_embedding_frequencies_q = None @@ -834,19 +1005,20 @@ def forward( # Self-attention mask. query_length = input_shape[-1] key_length = past_length + query_length - self_attention_mask = self.bias[None, key_length - query_length : key_length, :key_length] + # self_attention_mask = self.bias[None, key_length - query_length : key_length, :key_length] # Sliding window attention - if self.config.attention_window_size is not None: - self_attention_mask.triu_(-self.config.attention_window_size + 1) + # if self.config.attention_window_size is not None: + # self_attention_mask.triu_(-self.config.attention_window_size + 1) - if attention_mask is not None: - self_attention_mask = self_attention_mask * attention_mask.view(batch_size, 1, -1).to( - dtype=torch.bool, device=self_attention_mask.device - ) + # if attention_mask is not None: + # self_attention_mask = self_attention_mask * attention_mask.view(batch_size, 1, -1).to( + # dtype=torch.bool, device=self_attention_mask.device + # ) - # MQA models: (batch_size, query_length, n_heads, key_length) - # MHA models: (batch_size, n_heads, query_length, key_length) - attention_mask = self_attention_mask.unsqueeze(2 if self.multi_query else 1) + # # MQA models: (batch_size, query_length, n_heads, key_length) + # # MHA models: (batch_size, n_heads, query_length, key_length) + # attention_mask = self_attention_mask.unsqueeze(2 if self.multi_query else 1) + attention_mask = attention_mask.to(dtype=torch.bool) # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] @@ -903,6 +1075,7 @@ def custom_forward(*inputs): output_attentions, rotary_embedding_frequencies_q, rotary_embedding_frequencies_k, + position_ids, ) return custom_forward @@ -928,6 +1101,7 @@ def custom_forward(*inputs): output_attentions=output_attentions, rotary_embedding_frequencies_q=rotary_embedding_frequencies_q, rotary_embedding_frequencies_k=rotary_embedding_frequencies_k, + position_ids=position_ids, ) hidden_states = outputs[0] From b493268078fe34d6e87ed993a788466ed5cca494 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Fri, 29 Dec 2023 10:28:28 +0000 Subject: [PATCH 11/26] add merging word embedding checkpoints --- .../models/gpt_bigcode/ckps/download_ckp.py | 5 + .../convert_fast_llm_checkpoint.py | 18 ++- .../models/gpt_bigcode/download_ckp.py | 5 + .../gpt_bigcode/drafts/starcoder_model.py | 7 + .../gpt_bigcode/merge_fast_llm_checkpoint.py | 152 ++++++++++++++++-- src/transformers/models/gpt_bigcode/small.py | 3 + 6 files changed, 171 insertions(+), 19 deletions(-) create mode 100644 src/transformers/models/gpt_bigcode/ckps/download_ckp.py create mode 100644 src/transformers/models/gpt_bigcode/download_ckp.py create mode 100644 src/transformers/models/gpt_bigcode/drafts/starcoder_model.py create mode 100644 src/transformers/models/gpt_bigcode/small.py diff --git a/src/transformers/models/gpt_bigcode/ckps/download_ckp.py b/src/transformers/models/gpt_bigcode/ckps/download_ckp.py new file mode 100644 index 0000000000..a9fdd5ae44 --- /dev/null +++ b/src/transformers/models/gpt_bigcode/ckps/download_ckp.py @@ -0,0 +1,5 @@ +from huggingface_hub import snapshot_download + +if __name__ == "__main__": + snapshot_download("HuggingFaceBR4/starcoder2_7b_4k_smol_data_580000") + print("done") diff --git a/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py b/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py index 8859b7650e..15d1b6c870 100644 --- a/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py +++ b/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py @@ -105,20 +105,24 @@ def convert_fast_llm_checkpoint(state_dict, config): def main(argv=None): parser = argparse.ArgumentParser() - parser.add_argument( - "--checkpoint_dir", - type=Path, - help="Path to the experiment directory", - ) + # parser.add_argument( + # "--checkpoint_dir", + # type=Path, + # help="Path to the experiment directory", + # ) parser.add_argument( "--save_dir", type=Path, help="Path where the converted model is saved" ) args = parser.parse_args(argv) - + # TODO(xrsrke): auto convert checkpoint_dir to Path + checkpoint_dir = "/admin/home/phuc_nguyen/.cache/huggingface/hub/models--HuggingFaceBR4--starcoder2_7b_4k_smol_data_580000/snapshots/92b6c25cab25f07c367bcc6d773635700a8a287d" + checkpoint_dir = Path(checkpoint_dir) + state_dict, config = merge_checkpoint( - args.checkpoint_dir, + # args.checkpoint_dir, + checkpoint_dir, dummy_experiment_dir=None ) diff --git a/src/transformers/models/gpt_bigcode/download_ckp.py b/src/transformers/models/gpt_bigcode/download_ckp.py new file mode 100644 index 0000000000..a9fdd5ae44 --- /dev/null +++ b/src/transformers/models/gpt_bigcode/download_ckp.py @@ -0,0 +1,5 @@ +from huggingface_hub import snapshot_download + +if __name__ == "__main__": + snapshot_download("HuggingFaceBR4/starcoder2_7b_4k_smol_data_580000") + print("done") diff --git a/src/transformers/models/gpt_bigcode/drafts/starcoder_model.py b/src/transformers/models/gpt_bigcode/drafts/starcoder_model.py new file mode 100644 index 0000000000..29d5192f18 --- /dev/null +++ b/src/transformers/models/gpt_bigcode/drafts/starcoder_model.py @@ -0,0 +1,7 @@ +from transformers import AutoModelForCausalLM + + +if __name__ == "__main__": + model = AutoModelForCausalLM.from_pretrained("bigcode/starcoderbase-1b") + states = model.state_dict() + print(states.keys()) diff --git a/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py b/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py index 71731559c7..41221a468a 100644 --- a/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py +++ b/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py @@ -15,8 +15,19 @@ def get_all_checkpoint_paths(experiment_path): def get_checkpoint_paths(checkpoint_dir: Path): + # model/model return [c_name for c_name in checkpoint_dir.glob("*") if re.match(r"\d+", c_name.name)] +def get_safetensor_checkpoint_paths(checkpoint_dir: Path): + model_dir = checkpoint_dir / "model" / "model" # Targeting the specific directory + safetensor_files = [] + + for file_path in model_dir.rglob("*.safetensors"): # Looking for files with .safetensors extension + if file_path.is_file(): # Ensure it's a file + safetensor_files.append(file_path.absolute()) # Adding the absolute path of the file + + return safetensor_files + def extract_stage_shards(state): # Extract the weight shard and split it into the stage shards @@ -58,18 +69,135 @@ def concatenate_tp_shards(stage_tp_shards, stage_content): def merge_checkpoint(checkpoint_dir: Path, dummy_experiment_dir=None): """Load a fast-llm checkpoint and merge the data, tensor, and pipeline-parallel shards""" # checkpoint_dir=experiment_dir/checkpoints/{iteration} - experiment_dir = checkpoint_dir.parent.parent - checkpoint_paths = get_checkpoint_paths(checkpoint_dir) - config = yaml.safe_load((experiment_dir / "config.yaml").read_text()) + # experiment_dir = "~/.cache/huggingface/hub/models--HuggingFaceBR4--starcoder2_7b_4k_smol_data_580000" + # experiment_dir = checkpoint_dir.parent.parent + + # NOTE: use the checkpoint format from https://huggingface.co/HuggingFaceBR4/starcoder2_7b_4k_smol_data_580000/tree/main/model/model/token_embeddings/pp_block/token_embedding + # where experiment_dir = checkpoint_dir + # checkpoint_paths = get_checkpoint_paths(checkpoint_dir) + checkpoint_dir = Path(checkpoint_dir) + checkpoint_paths = get_safetensor_checkpoint_paths(checkpoint_dir) + config = yaml.safe_load((checkpoint_dir / "config.yaml").read_text()) + + # def path2tfm_name(path): + # name = path + # # remove `_pp-rank*` and what comes after + # name = name.split("_pp-rank-")[0] + + # # remove `.safetensors` + # name = name.split(".safetensors")[0] + + # # remove base path + # name = name.split(str(checkpoint_path) + "/model/")[1] + + # # "/" -> "." + # name = name.replace("/", ".") + + # # remove "model." prefix if lm_head + # if ".lm_head." in name: + # name = name[len("model.") :] + + # # remove ".pp_block." + # name = name.replace(".pp_block.", ".") + + # # apply mapping + # name = apply_mappings(name, BRRR_TRFRS_NAME_MAPPING) + # # print(name, path) + + # # skip buffers + # if name.endswith(".model_inv_freq"): + # continue + # return name # Load the states from all the ranks + + import re + + # def create_state_dict(paths): + # state_dict = {} + # for path in paths: + # # Break down the path and extract relevant parts + # parts = path.parts + # # Find the tp-rank part and extract the rank number + # tp_rank_match = re.search(r'tp-rank-(\d+)-of-\d+', str(path)) + # if tp_rank_match: + # tp_rank = tp_rank_match.group(1) + # else: + # continue # Skip if tp-rank is not found + + # # Construct the key from the path segments + # key_segments = [part for part in parts if part not in ['model_weight', 'pp_block', 'model']] + # key = '.'.join(key_segments[-5:]) # Adjust the index as needed to capture the right segments + # key = key.replace('/', '.').replace('\\', '.') + '.' + tp_rank + + # # Add to the dictionary + # state_dict[key] = path + + # return state_dict + + def create_state_dict(paths): + state_dict = {} + keyword_mapping = { + 'model_bias': 'bias', + 'model_weight': 'weight', + } + + for path in paths: + tp_rank_match = re.search(r'tp-rank-(\d+)-of-\d+', str(path)) + if tp_rank_match: + tp_rank = tp_rank_match.group(1) + else: + continue # Skip if tp-rank is not found + + file_name = path.stem + + for key_word, replacement in keyword_mapping.items(): + file_name = replacement + + key = '.'.join(path.parts[-5:-1]) + '.' + file_name + '.' + tp_rank # Modify indices as needed + state_dict[key] = path + + return state_dict + + + state_dict = create_state_dict(checkpoint_paths) + + from collections import defaultdict + grouped_paths = defaultdict(list) + for key, path in state_dict.items(): + module_name, shard_number = key.rsplit('.', 1) + grouped_paths[module_name].append((int(shard_number), path)) + + sorted_grouped_paths = {module: sorted(paths, key=lambda x: x[0]) for module, paths in grouped_paths.items()} + + from safetensors import safe_open + + # path_demo = list(grouped_paths.values())[0] + _states = {} + _embedding_paths = sorted_grouped_paths["model.token_embeddings.pp_block.token_embedding.weight"] + for shard_id, _path in enumerate(_embedding_paths): + with safe_open(_path[1], framework="pt", device="cpu") as f: + for key in f.keys(): + data = f.get_tensor(key) + _states[shard_id] = data + + tensor_list = [tensor for key, tensor in sorted(_states.items())] + _embeddings = torch.cat(tensor_list, dim=-1) + + assert 1 == 1 + + states = { int(c_name.name): torch.load(c_name) for c_name in tqdm(checkpoint_paths) } - num_stages = len(states[0]["stages"]) - tensor_parallel = config["tensor_parallel"] - data_parallel_size = int(config["world_size"] / (tensor_parallel * config["pipeline_parallel"])) + # num_stages = len(states[0]["stages"]) + + # tensor_parallel = config["tensor_parallel"] + # data_parallel_size = int(config["world_size"] / (tensor_parallel * config["pipeline_parallel"])) + tensor_parallel_size = config["parallelism"]["tp"] + pipeline_parallel_size = config["parallelism"]["pp"] + data_parallel_size = config["parallelism"]["dp"] if dummy_experiment_dir is not None: # Use the meta from the dummy checkpoint, and the shard from the actual checkpoint @@ -78,7 +206,7 @@ def merge_checkpoint(checkpoint_dir: Path, dummy_experiment_dir=None): int(c_name.name): torch.load(c_name) for c_name in tqdm(dummy_checkpoint_paths[-1]) } - for rank, state in dummy_states.items(): + for rank, state in dummy_states.aitems(): state['shard'] = states[rank]['shard'] states = dummy_states @@ -86,8 +214,8 @@ def merge_checkpoint(checkpoint_dir: Path, dummy_experiment_dir=None): # {tp_rank: [[stage_0_shard_0, stage_0_shard_1, ...], [stage_1_shard_0, ...], ...]} # {tp_rank: [{fsdp_rank: shard}, ...]} fsdp_shards = { - i: [[None for _ in range(data_parallel_size)] for _ in range(num_stages)] - for i in range(tensor_parallel) + i: [[None for _ in range(data_parallel_size)] for _ in range(pipeline_parallel_size)] + for i in range(tensor_parallel_size) } for rank, state in states.items(): @@ -128,7 +256,7 @@ def merge_checkpoint(checkpoint_dir: Path, dummy_experiment_dir=None): return state_dict, config -if __name__ == "__main__": - merge_checkpoint("/toolkit_infiniband_example_checkpoints/ngc_checkpoints/sc2_ablations/1B_repo_context_Top-level-Depth-first_pp2_64k_64k_2023_10_17_16_35_27/", - dummy_experiment_dir="/toolkit_infiniband_example_checkpoints/ngc_checkpoints/sc2_ablations/dev_1B_repo_context_Random_pp2_64k_64k_2023_10_18_22_20_36/") +# if __name__ == "__main__": +# merge_checkpoint("/toolkit_infiniband_example_checkpoints/ngc_checkpoints/sc2_ablations/1B_repo_context_Top-level-Depth-first_pp2_64k_64k_2023_10_17_16_35_27/", +# dummy_experiment_dir="/toolkit_infiniband_example_checkpoints/ngc_checkpoints/sc2_ablations/dev_1B_repo_context_Random_pp2_64k_64k_2023_10_18_22_20_36/") diff --git a/src/transformers/models/gpt_bigcode/small.py b/src/transformers/models/gpt_bigcode/small.py new file mode 100644 index 0000000000..56fdfc6b24 --- /dev/null +++ b/src/transformers/models/gpt_bigcode/small.py @@ -0,0 +1,3 @@ +if __name__ == "__main__": + print("works") + assert 1 == 1 From 4446fe0e3ed01099f909caadea654cc1ab556178 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Sun, 31 Dec 2023 06:28:11 +0000 Subject: [PATCH 12/26] add merging quite a bit --- .../gpt_bigcode/merge_fast_llm_checkpoint.py | 57 ++++++++++++++++--- 1 file changed, 48 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py b/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py index 41221a468a..f1bbac59aa 100644 --- a/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py +++ b/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py @@ -171,18 +171,57 @@ def create_state_dict(paths): sorted_grouped_paths = {module: sorted(paths, key=lambda x: x[0]) for module, paths in grouped_paths.items()} from safetensors import safe_open + + assert 1 == 1 + + MERGE_DIM_MAPPING = { + "token_embedding": 0, # row linear parallel + # NOTE: weird + "c_fc": 0, # column linear parallel + "c_proj": 0, # row linear parallel + # NOTE: weird + "query_key_value": 0, # row linear parallel + "dense": 1, # row linear parallel + } + + def find_corresponding_dim(name): + """ + Searches the MERGE_DIM_MAPPING for a key that is a substring of the given name. + Returns the corresponding dimension if found, otherwise None. + """ + for key, value in MERGE_DIM_MAPPING.items(): + if key in name: + return value + return None # path_demo = list(grouped_paths.values())[0] - _states = {} - _embedding_paths = sorted_grouped_paths["model.token_embeddings.pp_block.token_embedding.weight"] - for shard_id, _path in enumerate(_embedding_paths): - with safe_open(_path[1], framework="pt", device="cpu") as f: - for key in f.keys(): - data = f.get_tensor(key) - _states[shard_id] = data + _model_states = {} + for state_key, path in sorted_grouped_paths.items(): + _model_states[state_key] = {} + for shard_id, _path in enumerate(path): + with safe_open(_path[1], framework="pt", device="cpu") as f: + for key in f.keys(): + data = f.get_tensor(key) + _model_states[state_key][shard_id] = data + + tensor_list = [tensor for _, tensor in sorted(_model_states[state_key].items())] + merge_dim = find_corresponding_dim(state_key) + print(f"trying to merge: {state_key}") + if state_key == "28.pp_block.attn.query_key_value.weight": + assert 1 == 1 + + _model_states[state_key] = torch.cat(tensor_list, dim=merge_dim) + + # _states = {} + # _embedding_paths = sorted_grouped_paths["model.token_embeddings.pp_block.token_embedding.weight"] + # for shard_id, _path in enumerate(_embedding_paths): + # with safe_open(_path[1], framework="pt", device="cpu") as f: + # for key in f.keys(): + # data = f.get_tensor(key) + # _states[shard_id] = data - tensor_list = [tensor for key, tensor in sorted(_states.items())] - _embeddings = torch.cat(tensor_list, dim=-1) + # tensor_list = [tensor for key, tensor in sorted(_states.items())] + # _embeddings = torch.cat(tensor_list, dim=0) assert 1 == 1 From 1d949b2026d47c04b814958d9012f6ec0e83926c Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Sun, 31 Dec 2023 08:12:12 +0000 Subject: [PATCH 13/26] add reference starcoder model --- .../gpt_bigcode/drafts/starcoder_model.py | 37 +++++++++++++++++-- 1 file changed, 33 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/gpt_bigcode/drafts/starcoder_model.py b/src/transformers/models/gpt_bigcode/drafts/starcoder_model.py index 29d5192f18..f480132793 100644 --- a/src/transformers/models/gpt_bigcode/drafts/starcoder_model.py +++ b/src/transformers/models/gpt_bigcode/drafts/starcoder_model.py @@ -1,7 +1,36 @@ -from transformers import AutoModelForCausalLM +from transformers import GPTBigCodeForCausalLM, GPTBigCodeConfig + +from pathlib import Path +import json if __name__ == "__main__": - model = AutoModelForCausalLM.from_pretrained("bigcode/starcoderbase-1b") - states = model.state_dict() - print(states.keys()) + checkpoint_dir = "/admin/home/phuc_nguyen/.cache/huggingface/hub/models--HuggingFaceBR4--starcoder2_7b_4k_smol_data_580000/snapshots/92b6c25cab25f07c367bcc6d773635700a8a287d" + checkpoint_dir = Path(checkpoint_dir) + config = json.load(open(checkpoint_dir / "model_config.json")) + + model_config = GPTBigCodeConfig( + vocab_size=config["vocab_size"], + n_positions=config["max_position_embeddings"], + n_embd=config["hidden_size"], + n_layer=config["num_hidden_layers"], + n_head=config["num_attention_heads"], + num_key_value_heads=config["num_kv_heads"], + n_inner=config["hidden_size"], + activation_function=config["activation_function"], + resid_pdrop=config["resid_pdrop"], + embd_pdrop=config["embd_pdrop"], + attn_pdrop=config["attn_pdrop"], + layer_norm_epsilon=config["layer_norm_epsilon"], + scale_attn_weights=config["scale_attn_weights"], + bos_token_id=config["bos_token_id"], + eos_token_id=config["eos_token_id"], + attention_softmax_in_fp32=config["attention_softmax_in_fp32"], + scale_attention_softmax_in_fp32=config["scale_attention_softmax_in_fp32"], + multi_query=config["multi_query"], + use_rotary_embeddings=config["use_rotary_embeddings"], + # rotary_embedding_scale=brrr_model_config.rotary_embedding_scale, #TODO + attention_window_size=config["sliding_window_size"], + ) + + model = GPTBigCodeForCausalLM._from_config(model_config) From a58a947d88ee6672512b1a7ef51c2163b445e902 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Sun, 31 Dec 2023 08:55:57 +0000 Subject: [PATCH 14/26] merged most of the checkpoints --- .../convert_fast_llm_checkpoint.py | 20 +- .../gpt_bigcode/drafts/starcoder_model.py | 5 +- .../gpt_bigcode/merge_fast_llm_checkpoint.py | 174 ++++++++++-------- 3 files changed, 107 insertions(+), 92 deletions(-) diff --git a/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py b/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py index 15d1b6c870..7b456e5ce3 100644 --- a/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py +++ b/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py @@ -126,17 +126,17 @@ def main(argv=None): dummy_experiment_dir=None ) - output_state_dict, output_config = convert_fast_llm_checkpoint(state_dict, config) + # output_state_dict, output_config = convert_fast_llm_checkpoint(state_dict, config) - print("Saving config") - save_dir = args.save_dir or args.checkpoint_dir / "converted" - output_config.save_pretrained(save_dir) - - # Store the state_dict to file. - output_checkpoint_file = os.path.join(save_dir, "pytorch_model.bin") - print(f'Saving checkpoint to "{output_checkpoint_file}"') - torch.save(output_state_dict, output_checkpoint_file) - print(f'Done!') + # print("Saving config") + # save_dir = args.save_dir or args.checkpoint_dir / "converted" + # output_config.save_pretrained(save_dir) + + # # Store the state_dict to file. + # output_checkpoint_file = os.path.join(save_dir, "pytorch_model.bin") + # print(f'Saving checkpoint to "{output_checkpoint_file}"') + # torch.save(output_state_dict, output_checkpoint_file) + # print(f'Done!') if __name__ == "__main__": diff --git a/src/transformers/models/gpt_bigcode/drafts/starcoder_model.py b/src/transformers/models/gpt_bigcode/drafts/starcoder_model.py index f480132793..61c2127eeb 100644 --- a/src/transformers/models/gpt_bigcode/drafts/starcoder_model.py +++ b/src/transformers/models/gpt_bigcode/drafts/starcoder_model.py @@ -16,7 +16,8 @@ n_layer=config["num_hidden_layers"], n_head=config["num_attention_heads"], num_key_value_heads=config["num_kv_heads"], - n_inner=config["hidden_size"], + # NOTE: based on https://github.com/huggingface/brrr/blob/f569b93f80d03c626b24370d5ca4b1fe4f13fd76/brrr/models/fast/starcoder2.py#L194C16-L194C88 + n_inner=config.get("n_inner", 4 * config["hidden_size"]), activation_function=config["activation_function"], resid_pdrop=config["resid_pdrop"], embd_pdrop=config["embd_pdrop"], @@ -34,3 +35,5 @@ ) model = GPTBigCodeForCausalLM._from_config(model_config) + + assert 1 == 1 diff --git a/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py b/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py index f1bbac59aa..2b897bd505 100644 --- a/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py +++ b/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py @@ -176,8 +176,7 @@ def create_state_dict(paths): MERGE_DIM_MAPPING = { "token_embedding": 0, # row linear parallel - # NOTE: weird - "c_fc": 0, # column linear parallel + "c_fc": 1, # column linear parallel "c_proj": 0, # row linear parallel # NOTE: weird "query_key_value": 0, # row linear parallel @@ -207,92 +206,105 @@ def find_corresponding_dim(name): tensor_list = [tensor for _, tensor in sorted(_model_states[state_key].items())] merge_dim = find_corresponding_dim(state_key) print(f"trying to merge: {state_key}") - if state_key == "28.pp_block.attn.query_key_value.weight": - assert 1 == 1 + # if state_key == "28.pp_block.attn.query_key_value.weight" or state_key == "2.pp_block.attn.query_key_value.weight": + # assert 1 == 1 + # continue - _model_states[state_key] = torch.cat(tensor_list, dim=merge_dim) - - # _states = {} - # _embedding_paths = sorted_grouped_paths["model.token_embeddings.pp_block.token_embedding.weight"] - # for shard_id, _path in enumerate(_embedding_paths): - # with safe_open(_path[1], framework="pt", device="cpu") as f: - # for key in f.keys(): - # data = f.get_tensor(key) - # _states[shard_id] = data - - # tensor_list = [tensor for key, tensor in sorted(_states.items())] - # _embeddings = torch.cat(tensor_list, dim=0) + # if state_key == "31.pp_block.ff.c_fc.weight" or state_key == "5.pp_block.attn.query_key_value.weight": + # continue + + # if state_key == "5.pp_block.ff.c_fc.weight": + # continue + + # if state_key == "17.pp_block.attn.query_key_value.weight": + # continue + + # if state_key == "0.pp_block.ff.c_fc.weight" or state_key == "20.pp_block.attn.query_key_value.weight": + # continue + + # if state_key == "18.pp_block.ff.c_fc.weight": + # continue + + try: + _model_states[state_key] = torch.cat(tensor_list, dim=merge_dim) + except: + print(f"skipped {state_key}, {[x.shape for x in tensor_list]}") assert 1 == 1 - - states = { - int(c_name.name): torch.load(c_name) - for c_name in tqdm(checkpoint_paths) - } - # num_stages = len(states[0]["stages"]) + # print([f"{key}: {value.shape}" for key, value in _model_states.items() if isinstance(value, torch.Tensor)]) - # tensor_parallel = config["tensor_parallel"] - # data_parallel_size = int(config["world_size"] / (tensor_parallel * config["pipeline_parallel"])) - tensor_parallel_size = config["parallelism"]["tp"] - pipeline_parallel_size = config["parallelism"]["pp"] - data_parallel_size = config["parallelism"]["dp"] - - if dummy_experiment_dir is not None: - # Use the meta from the dummy checkpoint, and the shard from the actual checkpoint - dummy_checkpoint_paths = get_all_checkpoint_paths(dummy_experiment_dir) - dummy_states = { - int(c_name.name): torch.load(c_name) - for c_name in tqdm(dummy_checkpoint_paths[-1]) - } - for rank, state in dummy_states.aitems(): - state['shard'] = states[rank]['shard'] - states = dummy_states - - # Gather the data-parallel shards - # {tp_rank: [[stage_0_shard_0, stage_0_shard_1, ...], [stage_1_shard_0, ...], ...]} - # {tp_rank: [{fsdp_rank: shard}, ...]} - fsdp_shards = { - i: [[None for _ in range(data_parallel_size)] for _ in range(pipeline_parallel_size)] - for i in range(tensor_parallel_size) - } + for key, value in _model_states.items(): + if isinstance(value, torch.Tensor): + print(f"key: {key}, value: {value.shape} \n") - for rank, state in states.items(): - on_device_stage_shards = extract_stage_shards(state) - on_device_stage_indices = [i for (i, stage_meta) in enumerate(state["stages"]) if stage_meta["on_device"]] - for stage_index, stage_shard in zip(on_device_stage_indices, on_device_stage_shards): - stage_meta = state["stages"][stage_index] - # fsdp_shards[stage_meta["tp_rank"]][stage_index].append((stage_meta, stage_shard)) - fsdp_shards[stage_meta["tp_rank"]][stage_index][stage_meta["fsdp_rank"]] = stage_shard + # states = { + # int(c_name.name): torch.load(c_name) + # for c_name in tqdm(checkpoint_paths) + # } + # # num_stages = len(states[0]["stages"]) - # Concatenate the data-parallel shards - # and get individual weights - dp_concatenated_shards = { - tp_rank: [ - extract_individual_weights( - torch.cat(stage_shards, dim=0), - states[0]["stages"][stage_index]['content'] - ) - for stage_index, stage_shards in enumerate(fsdp_shards[tp_rank]) - ] - for tp_rank in range(config["tensor_parallel"]) - } - - # In the tensor-parallel case, concatenate the TP tensors along their TP dimensions. - tp_concatenated_shards = [] - for stage_index, stage_tp_shards in enumerate(zip(*(dp_concatenated_shards[i] for i in range(tensor_parallel)))): - stage_content = states[0]["stages"][stage_index]["content"] - tp_concatenated_shards.append(concatenate_tp_shards(stage_tp_shards, stage_content)) - - # In the pipeline-parallel case, merge the stages - state_dict = { - weight_meta["name"]: weight - for stage_meta, stage_weights in zip(states[0]["stages"], tp_concatenated_shards) - for weight_meta, weight in zip(stage_meta["content"], stage_weights) - } - - print(f"Total number of parameters: {sum([weight.numel() for weight in state_dict.values()])}") - return state_dict, config + # # tensor_parallel = config["tensor_parallel"] + # # data_parallel_size = int(config["world_size"] / (tensor_parallel * config["pipeline_parallel"])) + # tensor_parallel_size = config["parallelism"]["tp"] + # pipeline_parallel_size = config["parallelism"]["pp"] + # data_parallel_size = config["parallelism"]["dp"] + + # if dummy_experiment_dir is not None: + # # Use the meta from the dummy checkpoint, and the shard from the actual checkpoint + # dummy_checkpoint_paths = get_all_checkpoint_paths(dummy_experiment_dir) + # dummy_states = { + # int(c_name.name): torch.load(c_name) + # for c_name in tqdm(dummy_checkpoint_paths[-1]) + # } + # for rank, state in dummy_states.aitems(): + # state['shard'] = states[rank]['shard'] + # states = dummy_states + + # # Gather the data-parallel shards + # # {tp_rank: [[stage_0_shard_0, stage_0_shard_1, ...], [stage_1_shard_0, ...], ...]} + # # {tp_rank: [{fsdp_rank: shard}, ...]} + # fsdp_shards = { + # i: [[None for _ in range(data_parallel_size)] for _ in range(pipeline_parallel_size)] + # for i in range(tensor_parallel_size) + # } + + # for rank, state in states.items(): + # on_device_stage_shards = extract_stage_shards(state) + # on_device_stage_indices = [i for (i, stage_meta) in enumerate(state["stages"]) if stage_meta["on_device"]] + # for stage_index, stage_shard in zip(on_device_stage_indices, on_device_stage_shards): + # stage_meta = state["stages"][stage_index] + # # fsdp_shards[stage_meta["tp_rank"]][stage_index].append((stage_meta, stage_shard)) + # fsdp_shards[stage_meta["tp_rank"]][stage_index][stage_meta["fsdp_rank"]] = stage_shard + + # # Concatenate the data-parallel shards + # # and get individual weights + # dp_concatenated_shards = { + # tp_rank: [ + # extract_individual_weights( + # torch.cat(stage_shards, dim=0), + # states[0]["stages"][stage_index]['content'] + # ) + # for stage_index, stage_shards in enumerate(fsdp_shards[tp_rank]) + # ] + # for tp_rank in range(config["tensor_parallel"]) + # } + + # # In the tensor-parallel case, concatenate the TP tensors along their TP dimensions. + # tp_concatenated_shards = [] + # for stage_index, stage_tp_shards in enumerate(zip(*(dp_concatenated_shards[i] for i in range(tensor_parallel)))): + # stage_content = states[0]["stages"][stage_index]["content"] + # tp_concatenated_shards.append(concatenate_tp_shards(stage_tp_shards, stage_content)) + + # # In the pipeline-parallel case, merge the stages + # state_dict = { + # weight_meta["name"]: weight + # for stage_meta, stage_weights in zip(states[0]["stages"], tp_concatenated_shards) + # for weight_meta, weight in zip(stage_meta["content"], stage_weights) + # } + + # print(f"Total number of parameters: {sum([weight.numel() for weight in state_dict.values()])}") + # return state_dict, config # if __name__ == "__main__": From ac559a1b705a040ae77c1a259b98a3a455eb195f Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Mon, 1 Jan 2024 10:01:56 +0000 Subject: [PATCH 15/26] add merged checkpoints --- .../convert_fast_llm_checkpoint.py | 8 +- .../gpt_bigcode/drafts/starcoder_model.py | 8 +- .../gpt_bigcode/merge_fast_llm_checkpoint.py | 169 ++++++++++++++---- 3 files changed, 151 insertions(+), 34 deletions(-) diff --git a/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py b/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py index 7b456e5ce3..85abdbb78d 100644 --- a/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py +++ b/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py @@ -120,7 +120,13 @@ def main(argv=None): checkpoint_dir = "/admin/home/phuc_nguyen/.cache/huggingface/hub/models--HuggingFaceBR4--starcoder2_7b_4k_smol_data_580000/snapshots/92b6c25cab25f07c367bcc6d773635700a8a287d" checkpoint_dir = Path(checkpoint_dir) - state_dict, config = merge_checkpoint( + # state_dict, config = merge_checkpoint( + # # args.checkpoint_dir, + # checkpoint_dir, + # dummy_experiment_dir=None + # ) + + merge_checkpoint( # args.checkpoint_dir, checkpoint_dir, dummy_experiment_dir=None diff --git a/src/transformers/models/gpt_bigcode/drafts/starcoder_model.py b/src/transformers/models/gpt_bigcode/drafts/starcoder_model.py index 61c2127eeb..9340bd7714 100644 --- a/src/transformers/models/gpt_bigcode/drafts/starcoder_model.py +++ b/src/transformers/models/gpt_bigcode/drafts/starcoder_model.py @@ -36,4 +36,10 @@ model = GPTBigCodeForCausalLM._from_config(model_config) - assert 1 == 1 + print([x for x in model.state_dict().keys()]) + + print("----------------------------------------------------------------\n") + + for key, value in model.state_dict().items(): + print(key, value.shape) + diff --git a/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py b/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py index 2b897bd505..fd570e0510 100644 --- a/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py +++ b/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py @@ -135,46 +135,145 @@ def merge_checkpoint(checkpoint_dir: Path, dummy_experiment_dir=None): # return state_dict - def create_state_dict(paths): - state_dict = {} - keyword_mapping = { - 'model_bias': 'bias', - 'model_weight': 'weight', - } + # def create_state_dict(paths): + # state_dict = {} + # keyword_mapping = { + # 'model_bias': 'bias', + # 'model_weight': 'weight', + # } - for path in paths: - tp_rank_match = re.search(r'tp-rank-(\d+)-of-\d+', str(path)) - if tp_rank_match: - tp_rank = tp_rank_match.group(1) - else: - continue # Skip if tp-rank is not found + # for path in paths: + # tp_rank_match = re.search(r'tp-rank-(\d+)-of-\d+', str(path)) + # if tp_rank_match: + # tp_rank = tp_rank_match.group(1) + # else: + # continue # Skip if tp-rank is not found - file_name = path.stem + # file_name = path.stem - for key_word, replacement in keyword_mapping.items(): - file_name = replacement + # for key_word, replacement in keyword_mapping.items(): + # file_name = replacement + + # key = '.'.join(path.parts[-5:-1]) + '.' + file_name + '.' + tp_rank # Modify indices as needed + # state_dict[key] = path - key = '.'.join(path.parts[-5:-1]) + '.' + file_name + '.' + tp_rank # Modify indices as needed - state_dict[key] = path + # return state_dict + + + # state_dict = create_state_dict(checkpoint_paths) + + from os.path import commonprefix + + def convert_paths_to_dict(paths): + # Convert strings to Path objects + path_objs = [Path(p) for p in paths] + + # Find the common path prefix + common_path_prefix = Path(commonprefix(path_objs)).parent + + # Create a dictionary with the modified paths + path_dict = {str(p.relative_to(common_path_prefix)): str(p) for p in path_objs} + + return path_dict + + paths = convert_paths_to_dict(checkpoint_paths) + + def convert_slashes_to_dots(input_dict): + # Create a new dictionary to store the modified keys and values + converted_dict = {} - return state_dict + # Iterate over the items in the input dictionary + for key, value in input_dict.items(): + # Replace all forward slashes in the key with dots + modified_key = key.replace('/', '.') + # Add the modified key and its corresponding value to the new dictionary + converted_dict[modified_key] = value - state_dict = create_state_dict(checkpoint_paths) + return converted_dict + + paths = convert_slashes_to_dots(paths) + + # def group_by_prefix(input_dict, depth=1): + # grouped_dict = {} + # for key, value in input_dict.items(): + # # Split the key, extract the prefix based on the specified depth + # prefix = '.'.join(key.split('.')[:depth]) + # # Append the item to the corresponding list in the dictionary + # grouped_dict.setdefault(prefix, []).append(value) + # return grouped_dict + + # def group_by_prefix_and_type(input_dict, prefix_depth): + # grouped_dict = {} + # for idx, (key, value) in enumerate(input_dict.items()): + # # Split the key and extract the prefix + # key_parts = key.split('.') + # prefix = '.'.join(key_parts[:prefix_depth]) + + # # Determine if the key is for a weight or bias + # if 'weight' in key_parts: + # prefix += '.weight' + # elif 'bias' in key_parts: + # prefix += '.bias' + + # # Append the index and link to the corresponding list in the dictionary + # grouped_dict.setdefault(prefix, []).append((idx, value)) + # return grouped_dict + + def replace_patterns(paths): + new_paths = {} + for key, value in paths.items(): + # Replace the pattern with 'weight.x' or 'bias.x' + new_key = re.sub(r'model_weight_pp-rank-0-of-1_tp-rank-(\d)-of-4', r'weight.\1', key) + new_key = re.sub(r'model_bias_pp-rank-0-of-1_tp-rank-(\d)-of-4', r'bias.\1', new_key) + new_paths[new_key] = value + return new_paths + + paths = replace_patterns(paths) + + def remove_safetensors_extension(paths): + new_paths = {} + for key, value in paths.items(): + # Remove the '.safetensors' from the key + new_key = key.replace('.safetensors', '') + new_paths[new_key] = value + return new_paths + + paths = remove_safetensors_extension(paths) + + # NOTE: probably the merge checkpoint paths are wrong + assert 1 == 1 from collections import defaultdict grouped_paths = defaultdict(list) - for key, path in state_dict.items(): - module_name, shard_number = key.rsplit('.', 1) - grouped_paths[module_name].append((int(shard_number), path)) - + for key, path in paths.items(): + try: + module_name, shard_number = key.rsplit('.', 1) + # module_name, shard_number, _ = key.rsplit('.', 2) + grouped_paths[module_name].append((int(shard_number), path)) + except: + # NOTE: these are layer norm's weight, bias + # or other module biases, which are small, so brrr doesn't split them + print(f"skipped {key}, {path}") + grouped_paths[key].append(path) + + def remove_keys_with_empty_lists(input_dict): + # Using dictionary comprehension to filter out keys with empty lists + filtered_dict = {key: value for key, value in input_dict.items() if value} + return filtered_dict + + grouped_paths = remove_keys_with_empty_lists(grouped_paths) + + # TODO(xrsrke): it merged paths for bias and weight in the same group => wrong sorted_grouped_paths = {module: sorted(paths, key=lambda x: x[0]) for module, paths in grouped_paths.items()} - + paths = sorted_grouped_paths + from safetensors import safe_open assert 1 == 1 MERGE_DIM_MAPPING = { + "ff.c_fc.bias": 0, "token_embedding": 0, # row linear parallel "c_fc": 1, # column linear parallel "c_proj": 0, # row linear parallel @@ -195,10 +294,11 @@ def find_corresponding_dim(name): # path_demo = list(grouped_paths.values())[0] _model_states = {} - for state_key, path in sorted_grouped_paths.items(): + for state_key, path in paths.items(): _model_states[state_key] = {} for shard_id, _path in enumerate(path): - with safe_open(_path[1], framework="pt", device="cpu") as f: + checkpoint_path = _path[1] if isinstance(_path, tuple) else _path + with safe_open(checkpoint_path, framework="pt", device="cpu") as f: for key in f.keys(): data = f.get_tensor(key) _model_states[state_key][shard_id] = data @@ -225,18 +325,23 @@ def find_corresponding_dim(name): # if state_key == "18.pp_block.ff.c_fc.weight": # continue - try: - _model_states[state_key] = torch.cat(tensor_list, dim=merge_dim) - except: - print(f"skipped {state_key}, {[x.shape for x in tensor_list]}") + if len(tensor_list) > 1: + try: + _model_states[state_key] = torch.cat(tensor_list, dim=merge_dim) + except: + print(f"skipped {state_key}, {[x.shape for x in tensor_list]}") + else: + # NOTE: these are biases + _model_states[state_key] = tensor_list[0] assert 1 == 1 - - # print([f"{key}: {value.shape}" for key, value in _model_states.items() if isinstance(value, torch.Tensor)]) for key, value in _model_states.items(): if isinstance(value, torch.Tensor): print(f"key: {key}, value: {value.shape} \n") + else: + print(f"skipped key: {key}, shape: {[x.shape for x in value.values()]} \n") + # states = { # int(c_name.name): torch.load(c_name) From 78114b794fa93d7bd05e6c41a5642d757e918ff0 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Mon, 1 Jan 2024 11:12:06 +0000 Subject: [PATCH 16/26] add mapping to target state dict --- .../gpt_bigcode/merge_fast_llm_checkpoint.py | 66 +++++++++++++++++-- 1 file changed, 61 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py b/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py index fd570e0510..ca5f4fb717 100644 --- a/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py +++ b/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py @@ -336,12 +336,68 @@ def find_corresponding_dim(name): assert 1 == 1 - for key, value in _model_states.items(): - if isinstance(value, torch.Tensor): - print(f"key: {key}, value: {value.shape} \n") - else: - print(f"skipped key: {key}, shape: {[x.shape for x in value.values()]} \n") + # for key, value in _model_states.items(): + # if isinstance(value, torch.Tensor): + # print(f"key: {key}, value: {value.shape} \n") + # else: + # print(f"skipped key: {key}, shape: {[x.shape for x in value.values()]} \n") + + + assert 1 == 1 + + def remap_keys(target_dict): + new_dict = {} + for key, value in target_dict.items(): + parts = key.split('.') + + # Handling decoder blocks + if 'model.decoder' in key and 'pp_block' in key: + block_number = parts[2] + component_parts = parts[4:] + component = '.'.join(component_parts) + + # Mapping specific components + component_map = { + 'ln_1.model_weight': 'ln_1.weight', + 'ln_1.model_bias': 'ln_1.bias', + 'ln_2.model_weight': 'ln_2.weight', + 'ln_2.model_bias': 'ln_2.bias', + 'attn.query_key_value.weight': 'attn.c_attn.weight', + 'attn.query_key_value.bias': 'attn.c_attn.bias', + 'attn.dense.weight': 'attn.c_proj.weight', + 'attn.dense.model_bias': 'attn.c_proj.bias', + 'ff.c_fc.weight': 'mlp.c_fc.weight', + 'ff.c_fc.bias': 'mlp.c_fc.bias', + 'ff.c_proj.weight': 'mlp.c_proj.weight', + 'ff.c_proj.model_bias': 'mlp.c_proj.bias' + } + + new_component = component_map.get(component, component) + new_key = f"transformer.h.{block_number}.{new_component}" + new_dict[new_key] = value + + # Handling final layer norm + elif key == 'model.final_layer_norm.pp_block.model_weight': + new_dict['transformer.ln_f.weight'] = value + elif key == 'model.final_layer_norm.pp_block.model_bias': + new_dict['transformer.ln_f.bias'] = value + + # Handling token embeddings + elif key == 'model.token_embeddings.pp_block.token_embedding.weight': + new_dict['transformer.wte.weight'] = value + + return new_dict + + _model_states = remap_keys(_model_states) + print("saving merged checkpoint...") + + torch.save(_model_states, './merged_checkpoints.pth') + + print("done") + + assert 1 == 1 + # states = { # int(c_name.name): torch.load(c_name) From 7d50b80597d05f962d6f4b52d3f89ee7196e774b Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Tue, 2 Jan 2024 12:54:18 +0000 Subject: [PATCH 17/26] refactor converting scrip --- .../convert_fast_llm_checkpoint.py | 184 ++++----- .../gpt_bigcode/drafts/starcoder_model.py | 48 ++- .../gpt_bigcode/merge_fast_llm_checkpoint.py | 384 ++++-------------- 3 files changed, 215 insertions(+), 401 deletions(-) diff --git a/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py b/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py index 85abdbb78d..bd6a147a63 100644 --- a/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py +++ b/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py @@ -8,99 +8,99 @@ from transformers.models.gpt_bigcode import GPTBigCodeConfig, GPTBigCodeForCausalLM, GPTBigCodeModel -# The simple map of names for "automated" rules. -NAME_MAP = { - "_mlp._layer_1": "mlp.c_fc", - "_mlp._layer_2": "mlp.c_proj", - "layer_norm_1": "ln_1", - "layer_norm_2": "ln_2", - # "attention.dense": "attn.c_proj", - "self_attn.dense": "attn.c_proj", - # "self_attention.query_key_value": "attn.c_attn", -} - - -def convert_fast_llm_checkpoint(state_dict, config): - # The converted output model. - output_state_dict = {} - if "window_size" in config: - attention_window_size = config["window_size"] - else: - attention_window_size = config.get("attention_window_size", None) - - config = GPTBigCodeConfig( - architectures=["GPTBigCodeLMHeadModel"], - vocab_size=config["vocab_size"], - n_positions=config["max_position_embeddings"], - n_embd=config["hidden_size"], - n_layer=config["num_layers"], - n_head=config["num_attention_heads"], - n_inner=config["ffn_hidden_size"], - activation_function="gelu", # TODO - multi_query=True, # TODO - resid_pdrop=0.1, - embd_pdrop=0.1, - attn_pdrop=0.1, - layer_norm_epsilon=1e-5, - initializer_range=0.02, - summary_type="cls_index", - summary_use_proj=True, - summary_activation=None, - summary_proj_to_labels=True, - summary_first_dropout=0.1, - scale_attn_weights=True, - use_cache=True, - bos_token_id=0, # TODO: can we remove these? - eos_token_id=0, - attention_softmax_in_fp32=True, - scale_attention_softmax_in_fp32=True, - use_rotary_embeddings=config["use_rotary_embeddings"], - rotary_embedding_scale=config["rotary_embedding_scale"], - use_position_embeddings=config["use_position_embeddings"], - attention_window_size=attention_window_size - ) - - # Truncate the word embeddings to the vocab-size - word_embeddings = state_dict.pop("_layers.0._word_embeddings_weight")[:config.vocab_size, :] - output_state_dict["transformer.wte.weight"] = word_embeddings - if config.use_position_embeddings: - output_state_dict["transformer.wpe.weight"] = state_dict.pop("_layers.0._position_embeddings_weight") - - # Layer-0 is the word/position embeddings - # Layers 1 to n_layer need to be re-mapped from 0 to n_layer-1. - # _layers.{layer_index}.{op}.{w/b} - - # Concatenate QKV matrix - for layer_index in range(1, config.n_layer + 1): - for weight_or_bias in ["weight", "bias"]: - query = state_dict.pop(f"_layers.{layer_index}.self_attn.query.{weight_or_bias}") - key_value = state_dict.pop(f"_layers.{layer_index}.self_attn.key_value.{weight_or_bias}") - output_state_dict[f"transformer.h.{layer_index - 1}.attn.c_attn.{weight_or_bias}"] = torch.cat([query, key_value], dim=0) +# # The simple map of names for "automated" rules. +# NAME_MAP = { +# "_mlp._layer_1": "mlp.c_fc", +# "_mlp._layer_2": "mlp.c_proj", +# "layer_norm_1": "ln_1", +# "layer_norm_2": "ln_2", +# # "attention.dense": "attn.c_proj", +# "self_attn.dense": "attn.c_proj", +# # "self_attention.query_key_value": "attn.c_attn", +# } + + +# def convert_fast_llm_checkpoint(state_dict, config): +# # The converted output model. +# output_state_dict = {} +# if "window_size" in config: +# attention_window_size = config["window_size"] +# else: +# attention_window_size = config.get("attention_window_size", None) + +# config = GPTBigCodeConfig( +# architectures=["GPTBigCodeLMHeadModel"], +# vocab_size=config["vocab_size"], +# n_positions=config["max_position_embeddings"], +# n_embd=config["hidden_size"], +# n_layer=config["num_layers"], +# n_head=config["num_attention_heads"], +# n_inner=config["ffn_hidden_size"], +# activation_function="gelu", # TODO +# multi_query=True, # TODO +# resid_pdrop=0.1, +# embd_pdrop=0.1, +# attn_pdrop=0.1, +# layer_norm_epsilon=1e-5, +# initializer_range=0.02, +# summary_type="cls_index", +# summary_use_proj=True, +# summary_activation=None, +# summary_proj_to_labels=True, +# summary_first_dropout=0.1, +# scale_attn_weights=True, +# use_cache=True, +# bos_token_id=0, # TODO: can we remove these? +# eos_token_id=0, +# attention_softmax_in_fp32=True, +# scale_attention_softmax_in_fp32=True, +# use_rotary_embeddings=config["use_rotary_embeddings"], +# rotary_embedding_scale=config["rotary_embedding_scale"], +# use_position_embeddings=config["use_position_embeddings"], +# attention_window_size=attention_window_size +# ) + +# # Truncate the word embeddings to the vocab-size +# word_embeddings = state_dict.pop("_layers.0._word_embeddings_weight")[:config.vocab_size, :] +# output_state_dict["transformer.wte.weight"] = word_embeddings +# if config.use_position_embeddings: +# output_state_dict["transformer.wpe.weight"] = state_dict.pop("_layers.0._position_embeddings_weight") + +# # Layer-0 is the word/position embeddings +# # Layers 1 to n_layer need to be re-mapped from 0 to n_layer-1. +# # _layers.{layer_index}.{op}.{w/b} + +# # Concatenate QKV matrix +# for layer_index in range(1, config.n_layer + 1): +# for weight_or_bias in ["weight", "bias"]: +# query = state_dict.pop(f"_layers.{layer_index}.self_attn.query.{weight_or_bias}") +# key_value = state_dict.pop(f"_layers.{layer_index}.self_attn.key_value.{weight_or_bias}") +# output_state_dict[f"transformer.h.{layer_index - 1}.attn.c_attn.{weight_or_bias}"] = torch.cat([query, key_value], dim=0) - # Extract the other ops - layer_re = re.compile("_layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)") - for name, value in state_dict.items(): - m = layer_re.match(name) - assert m is not None, f"Invalid layer name: {name}" - - # The index of the layer. - layer_index = int(m.group(1)) - # The name of the operation. - op_name = m.group(2) - # Is it a weight or a bias? - weight_or_bias = m.group(3) - - # Final layernorm - if op_name == "final_layernorm": - assert layer_index == config.n_layer + 1 - output_state_dict[f"transformer.ln_f.{weight_or_bias}"] = value - else: - output_state_dict[f"transformer.h.{layer_index-1}.{NAME_MAP[op_name]}.{weight_or_bias}"] = value - - # For LM head, transformers' wants the matrix to weight embeddings. - output_state_dict["lm_head.weight"] = word_embeddings - - return output_state_dict, config +# # Extract the other ops +# layer_re = re.compile("_layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)") +# for name, value in state_dict.items(): +# m = layer_re.match(name) +# assert m is not None, f"Invalid layer name: {name}" + +# # The index of the layer. +# layer_index = int(m.group(1)) +# # The name of the operation. +# op_name = m.group(2) +# # Is it a weight or a bias? +# weight_or_bias = m.group(3) + +# # Final layernorm +# if op_name == "final_layernorm": +# assert layer_index == config.n_layer + 1 +# output_state_dict[f"transformer.ln_f.{weight_or_bias}"] = value +# else: +# output_state_dict[f"transformer.h.{layer_index-1}.{NAME_MAP[op_name]}.{weight_or_bias}"] = value + +# # For LM head, transformers' wants the matrix to weight embeddings. +# output_state_dict["lm_head.weight"] = word_embeddings + +# return output_state_dict, config def main(argv=None): diff --git a/src/transformers/models/gpt_bigcode/drafts/starcoder_model.py b/src/transformers/models/gpt_bigcode/drafts/starcoder_model.py index 9340bd7714..97ceceebde 100644 --- a/src/transformers/models/gpt_bigcode/drafts/starcoder_model.py +++ b/src/transformers/models/gpt_bigcode/drafts/starcoder_model.py @@ -2,6 +2,7 @@ from pathlib import Path import json +import torch if __name__ == "__main__": @@ -9,6 +10,20 @@ checkpoint_dir = Path(checkpoint_dir) config = json.load(open(checkpoint_dir / "model_config.json")) + import random + + import numpy as np + + seed = 42 + + torch.manual_seed(seed) + + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + + np.random.seed(seed) + random.seed(seed) + model_config = GPTBigCodeConfig( vocab_size=config["vocab_size"], n_positions=config["max_position_embeddings"], @@ -34,12 +49,33 @@ attention_window_size=config["sliding_window_size"], ) - model = GPTBigCodeForCausalLM._from_config(model_config) - - print([x for x in model.state_dict().keys()]) + model = GPTBigCodeForCausalLM._from_config(model_config, torch_dtype=torch.bfloat16) + + checkpoint_path = Path("/fsx/phuc/projects/starcoder/transformers-starcoder/src/transformers/models/gpt_bigcode/merged_checkpoint_reversed.pth") + checkpoint = torch.load(checkpoint_path) + model.load_state_dict(checkpoint) + model = model.to("cuda") + + # model = GPTBigCodeForCausalLM.from_pretrained(checkpoint_path, torch_dtype=torch.bfloat16, device_map="cuda", config=config) + + assert 1 == 1 + + from transformers import AutoTokenizer + checkpoint = "bigcode/starcoder" - print("----------------------------------------------------------------\n") + tokenizer = AutoTokenizer.from_pretrained(checkpoint) + tokenizer.eos_token_id = tokenizer.pad_token_id + # inputs = tokenizer("123", return_tensors='pt').to("cuda") - for key, value in model.state_dict().items(): - print(key, value.shape) + inputs = tokenizer.encode("def print_hello_world():", return_tensors="pt").to("cuda") + outputs = model.generate(inputs) + print(f"outputs: {outputs}") + + print(tokenizer.decode(outputs[0], clean_up_tokenization_spaces=False)) + + # print([x for x in model.state_dict().keys()]) + # print("----------------------------------------------------------------\n") + + # for key, value in model.state_dict().items(): + # print(key, value.shape) diff --git a/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py b/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py index ca5f4fb717..4bb95b8643 100644 --- a/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py +++ b/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py @@ -7,16 +7,40 @@ import yaml -def get_all_checkpoint_paths(experiment_path): - checkpoints = (Path(experiment_path) / "checkpoints").glob("*") - # Sort checkpoints by iteration number - checkpoints = sorted(checkpoints, key=lambda x: int(x.name)) - return [get_checkpoint_paths(checkpoint) for checkpoint in checkpoints] - - -def get_checkpoint_paths(checkpoint_dir: Path): - # model/model - return [c_name for c_name in checkpoint_dir.glob("*") if re.match(r"\d+", c_name.name)] +MERGE_DIM_MAPPING = { + "ff.c_fc.bias": 0, + "token_embedding": 0, # row linear parallel + "c_fc": 0, # column linear parallel + "c_proj": 1, # row linear parallel + # NOTE: weird + "query_key_value": 0, # row linear parallel + "dense": 1, # row linear parallel +} + +BRRR_TFMS_NAME_MAPPING = { + 'ln_1.model_weight': 'ln_1.weight', + 'ln_1.model_bias': 'ln_1.bias', + 'ln_2.model_weight': 'ln_2.weight', + 'ln_2.model_bias': 'ln_2.bias', + 'attn.query_key_value.weight': 'attn.c_attn.weight', + 'attn.query_key_value.bias': 'attn.c_attn.bias', + 'attn.dense.weight': 'attn.c_proj.weight', + 'attn.dense.model_bias': 'attn.c_proj.bias', + 'ff.c_fc.weight': 'mlp.c_fc.weight', + 'ff.c_fc.bias': 'mlp.c_fc.bias', + 'ff.c_proj.weight': 'mlp.c_proj.weight', + 'ff.c_proj.model_bias': 'mlp.c_proj.bias' +} + +# def get_all_checkpoint_paths(experiment_path): +# checkpoints = (Path(experiment_path) / "checkpoints").glob("*") +# # Sort checkpoints by iteration number +# checkpoints = sorted(checkpoints, key=lambda x: int(x.name)) +# return [get_checkpoint_paths(checkpoint) for checkpoint in checkpoints] + + +# def get_checkpoint_paths(checkpoint_dir: Path): +# return [c_name for c_name in checkpoint_dir.glob("*") if re.match(r"\d+", c_name.name)] def get_safetensor_checkpoint_paths(checkpoint_dir: Path): model_dir = checkpoint_dir / "model" / "model" # Targeting the specific directory @@ -29,41 +53,41 @@ def get_safetensor_checkpoint_paths(checkpoint_dir: Path): return safetensor_files -def extract_stage_shards(state): - # Extract the weight shard and split it into the stage shards - # Reproduce the split done in MultiStageModelBase.setup - total_shard_size = sum(state['stage_shard_sizes']) - if len(state['shard'].shape) == 1: - # Flat buffer - weight_shard = state['shard'][:total_shard_size] - elif len(state['shard'].shape) == 2: - # 2D buffer - weight_shard = state['shard'][0] - else: - raise ValueError(f"Unrecognized buffer shape {state['shard'].shape}") - return weight_shard.split(state['stage_shard_sizes']) - - -def extract_individual_weights(merged_stage_shard, stage_content): - # Get individual weights from shards that are merged across data-parallel - weights_numel = [np.prod(weight_meta['shape']) for weight_meta in stage_content] - weights = merged_stage_shard[:sum(weights_numel)].split(weights_numel) - return [weight.reshape(weight_meta['shape']) for weight, weight_meta in zip(weights, stage_content)] - - -def concatenate_tp_shards(stage_tp_shards, stage_content): - # Concatenate the tp-shards in a given stage - # Stage_tp_shards: contains the individual weight shards for each rank - # [[weight1, weight2, ...] for rank in range(tp_size)] - concatenated_weights = [] - # Concatenate each individual weight along their TP dimension if they have one. - for weight_tp_shards, weight_meta in zip(zip(*stage_tp_shards), stage_content): - if weight_meta["tensor_parallel_dim"] is not None: - weight = torch.cat(weight_tp_shards, dim=weight_meta["tensor_parallel_dim"]) - else: - weight = weight_tp_shards[0] - concatenated_weights.append(weight) - return concatenated_weights +# def extract_stage_shards(state): +# # Extract the weight shard and split it into the stage shards +# # Reproduce the split done in MultiStageModelBase.setup +# total_shard_size = sum(state['stage_shard_sizes']) +# if len(state['shard'].shape) == 1: +# # Flat buffer +# weight_shard = state['shard'][:total_shard_size] +# elif len(state['shard'].shape) == 2: +# # 2D buffer +# weight_shard = state['shard'][0] +# else: +# raise ValueError(f"Unrecognized buffer shape {state['shard'].shape}") +# return weight_shard.split(state['stage_shard_sizes']) + + +# def extract_individual_weights(merged_stage_shard, stage_content): +# # Get individual weights from shards that are merged across data-parallel +# weights_numel = [np.prod(weight_meta['shape']) for weight_meta in stage_content] +# weights = merged_stage_shard[:sum(weights_numel)].split(weights_numel) +# return [weight.reshape(weight_meta['shape']) for weight, weight_meta in zip(weights, stage_content)] + + +# def concatenate_tp_shards(stage_tp_shards, stage_content): +# # Concatenate the tp-shards in a given stage +# # Stage_tp_shards: contains the individual weight shards for each rank +# # [[weight1, weight2, ...] for rank in range(tp_size)] +# concatenated_weights = [] +# # Concatenate each individual weight along their TP dimension if they have one. +# for weight_tp_shards, weight_meta in zip(zip(*stage_tp_shards), stage_content): +# if weight_meta["tensor_parallel_dim"] is not None: +# weight = torch.cat(weight_tp_shards, dim=weight_meta["tensor_parallel_dim"]) +# else: +# weight = weight_tp_shards[0] +# concatenated_weights.append(weight) +# return concatenated_weights def merge_checkpoint(checkpoint_dir: Path, dummy_experiment_dir=None): @@ -79,151 +103,30 @@ def merge_checkpoint(checkpoint_dir: Path, dummy_experiment_dir=None): checkpoint_paths = get_safetensor_checkpoint_paths(checkpoint_dir) config = yaml.safe_load((checkpoint_dir / "config.yaml").read_text()) - # def path2tfm_name(path): - # name = path - # # remove `_pp-rank*` and what comes after - # name = name.split("_pp-rank-")[0] - - # # remove `.safetensors` - # name = name.split(".safetensors")[0] - - # # remove base path - # name = name.split(str(checkpoint_path) + "/model/")[1] - - # # "/" -> "." - # name = name.replace("/", ".") - - # # remove "model." prefix if lm_head - # if ".lm_head." in name: - # name = name[len("model.") :] - - # # remove ".pp_block." - # name = name.replace(".pp_block.", ".") - - # # apply mapping - # name = apply_mappings(name, BRRR_TRFRS_NAME_MAPPING) - # # print(name, path) - - # # skip buffers - # if name.endswith(".model_inv_freq"): - # continue - # return name - - # Load the states from all the ranks - import re - - # def create_state_dict(paths): - # state_dict = {} - # for path in paths: - # # Break down the path and extract relevant parts - # parts = path.parts - # # Find the tp-rank part and extract the rank number - # tp_rank_match = re.search(r'tp-rank-(\d+)-of-\d+', str(path)) - # if tp_rank_match: - # tp_rank = tp_rank_match.group(1) - # else: - # continue # Skip if tp-rank is not found - - # # Construct the key from the path segments - # key_segments = [part for part in parts if part not in ['model_weight', 'pp_block', 'model']] - # key = '.'.join(key_segments[-5:]) # Adjust the index as needed to capture the right segments - # key = key.replace('/', '.').replace('\\', '.') + '.' + tp_rank - - # # Add to the dictionary - # state_dict[key] = path - - # return state_dict - - # def create_state_dict(paths): - # state_dict = {} - # keyword_mapping = { - # 'model_bias': 'bias', - # 'model_weight': 'weight', - # } - - # for path in paths: - # tp_rank_match = re.search(r'tp-rank-(\d+)-of-\d+', str(path)) - # if tp_rank_match: - # tp_rank = tp_rank_match.group(1) - # else: - # continue # Skip if tp-rank is not found - - # file_name = path.stem - - # for key_word, replacement in keyword_mapping.items(): - # file_name = replacement - - # key = '.'.join(path.parts[-5:-1]) + '.' + file_name + '.' + tp_rank # Modify indices as needed - # state_dict[key] = path - - # return state_dict - - - # state_dict = create_state_dict(checkpoint_paths) - from os.path import commonprefix def convert_paths_to_dict(paths): - # Convert strings to Path objects path_objs = [Path(p) for p in paths] - - # Find the common path prefix common_path_prefix = Path(commonprefix(path_objs)).parent - - # Create a dictionary with the modified paths path_dict = {str(p.relative_to(common_path_prefix)): str(p) for p in path_objs} - return path_dict paths = convert_paths_to_dict(checkpoint_paths) def convert_slashes_to_dots(input_dict): - # Create a new dictionary to store the modified keys and values converted_dict = {} - - # Iterate over the items in the input dictionary for key, value in input_dict.items(): - # Replace all forward slashes in the key with dots modified_key = key.replace('/', '.') - - # Add the modified key and its corresponding value to the new dictionary converted_dict[modified_key] = value return converted_dict paths = convert_slashes_to_dots(paths) - # def group_by_prefix(input_dict, depth=1): - # grouped_dict = {} - # for key, value in input_dict.items(): - # # Split the key, extract the prefix based on the specified depth - # prefix = '.'.join(key.split('.')[:depth]) - # # Append the item to the corresponding list in the dictionary - # grouped_dict.setdefault(prefix, []).append(value) - # return grouped_dict - - # def group_by_prefix_and_type(input_dict, prefix_depth): - # grouped_dict = {} - # for idx, (key, value) in enumerate(input_dict.items()): - # # Split the key and extract the prefix - # key_parts = key.split('.') - # prefix = '.'.join(key_parts[:prefix_depth]) - - # # Determine if the key is for a weight or bias - # if 'weight' in key_parts: - # prefix += '.weight' - # elif 'bias' in key_parts: - # prefix += '.bias' - - # # Append the index and link to the corresponding list in the dictionary - # grouped_dict.setdefault(prefix, []).append((idx, value)) - # return grouped_dict - def replace_patterns(paths): new_paths = {} for key, value in paths.items(): - # Replace the pattern with 'weight.x' or 'bias.x' new_key = re.sub(r'model_weight_pp-rank-0-of-1_tp-rank-(\d)-of-4', r'weight.\1', key) new_key = re.sub(r'model_bias_pp-rank-0-of-1_tp-rank-(\d)-of-4', r'bias.\1', new_key) new_paths[new_key] = value @@ -234,22 +137,17 @@ def replace_patterns(paths): def remove_safetensors_extension(paths): new_paths = {} for key, value in paths.items(): - # Remove the '.safetensors' from the key new_key = key.replace('.safetensors', '') new_paths[new_key] = value return new_paths paths = remove_safetensors_extension(paths) - - # NOTE: probably the merge checkpoint paths are wrong - assert 1 == 1 from collections import defaultdict grouped_paths = defaultdict(list) for key, path in paths.items(): try: module_name, shard_number = key.rsplit('.', 1) - # module_name, shard_number, _ = key.rsplit('.', 2) grouped_paths[module_name].append((int(shard_number), path)) except: # NOTE: these are layer norm's weight, bias @@ -265,28 +163,15 @@ def remove_keys_with_empty_lists(input_dict): grouped_paths = remove_keys_with_empty_lists(grouped_paths) # TODO(xrsrke): it merged paths for bias and weight in the same group => wrong - sorted_grouped_paths = {module: sorted(paths, key=lambda x: x[0]) for module, paths in grouped_paths.items()} + # sorted_grouped_paths = {module: sorted(paths, key=lambda x: x[0]) for module, paths in grouped_paths.items()} + sorted_grouped_paths = {module: sorted(paths, key=lambda x: x[0], reverse=True) for module, paths in grouped_paths.items()} paths = sorted_grouped_paths from safetensors import safe_open assert 1 == 1 - MERGE_DIM_MAPPING = { - "ff.c_fc.bias": 0, - "token_embedding": 0, # row linear parallel - "c_fc": 1, # column linear parallel - "c_proj": 0, # row linear parallel - # NOTE: weird - "query_key_value": 0, # row linear parallel - "dense": 1, # row linear parallel - } - def find_corresponding_dim(name): - """ - Searches the MERGE_DIM_MAPPING for a key that is a substring of the given name. - Returns the corresponding dimension if found, otherwise None. - """ for key, value in MERGE_DIM_MAPPING.items(): if key in name: return value @@ -306,25 +191,7 @@ def find_corresponding_dim(name): tensor_list = [tensor for _, tensor in sorted(_model_states[state_key].items())] merge_dim = find_corresponding_dim(state_key) print(f"trying to merge: {state_key}") - # if state_key == "28.pp_block.attn.query_key_value.weight" or state_key == "2.pp_block.attn.query_key_value.weight": - # assert 1 == 1 - # continue - - # if state_key == "31.pp_block.ff.c_fc.weight" or state_key == "5.pp_block.attn.query_key_value.weight": - # continue - - # if state_key == "5.pp_block.ff.c_fc.weight": - # continue - - # if state_key == "17.pp_block.attn.query_key_value.weight": - # continue - - # if state_key == "0.pp_block.ff.c_fc.weight" or state_key == "20.pp_block.attn.query_key_value.weight": - # continue - - # if state_key == "18.pp_block.ff.c_fc.weight": - # continue - + if len(tensor_list) > 1: try: _model_states[state_key] = torch.cat(tensor_list, dim=merge_dim) @@ -350,33 +217,14 @@ def remap_keys(target_dict): for key, value in target_dict.items(): parts = key.split('.') - # Handling decoder blocks if 'model.decoder' in key and 'pp_block' in key: block_number = parts[2] component_parts = parts[4:] component = '.'.join(component_parts) - # Mapping specific components - component_map = { - 'ln_1.model_weight': 'ln_1.weight', - 'ln_1.model_bias': 'ln_1.bias', - 'ln_2.model_weight': 'ln_2.weight', - 'ln_2.model_bias': 'ln_2.bias', - 'attn.query_key_value.weight': 'attn.c_attn.weight', - 'attn.query_key_value.bias': 'attn.c_attn.bias', - 'attn.dense.weight': 'attn.c_proj.weight', - 'attn.dense.model_bias': 'attn.c_proj.bias', - 'ff.c_fc.weight': 'mlp.c_fc.weight', - 'ff.c_fc.bias': 'mlp.c_fc.bias', - 'ff.c_proj.weight': 'mlp.c_proj.weight', - 'ff.c_proj.model_bias': 'mlp.c_proj.bias' - } - - new_component = component_map.get(component, component) + new_component = BRRR_TFMS_NAME_MAPPING.get(component, component) new_key = f"transformer.h.{block_number}.{new_component}" new_dict[new_key] = value - - # Handling final layer norm elif key == 'model.final_layer_norm.pp_block.model_weight': new_dict['transformer.ln_f.weight'] = value elif key == 'model.final_layer_norm.pp_block.model_bias': @@ -389,86 +237,16 @@ def remap_keys(target_dict): return new_dict _model_states = remap_keys(_model_states) + _model_states["lm_head.weight"] = _model_states["transformer.wte.weight"] + + for key, value in _model_states.items(): + if isinstance(value, torch.Tensor): + print(f"key: {key}, value: {value.shape} \n") + else: + print(f"skipped key: {key}, shape: {[x.shape for x in value.values()]} \n") print("saving merged checkpoint...") - torch.save(_model_states, './merged_checkpoints.pth') + torch.save(_model_states, './merged_checkpoint_reversed.pth') print("done") - - assert 1 == 1 - - - # states = { - # int(c_name.name): torch.load(c_name) - # for c_name in tqdm(checkpoint_paths) - # } - # # num_stages = len(states[0]["stages"]) - - # # tensor_parallel = config["tensor_parallel"] - # # data_parallel_size = int(config["world_size"] / (tensor_parallel * config["pipeline_parallel"])) - # tensor_parallel_size = config["parallelism"]["tp"] - # pipeline_parallel_size = config["parallelism"]["pp"] - # data_parallel_size = config["parallelism"]["dp"] - - # if dummy_experiment_dir is not None: - # # Use the meta from the dummy checkpoint, and the shard from the actual checkpoint - # dummy_checkpoint_paths = get_all_checkpoint_paths(dummy_experiment_dir) - # dummy_states = { - # int(c_name.name): torch.load(c_name) - # for c_name in tqdm(dummy_checkpoint_paths[-1]) - # } - # for rank, state in dummy_states.aitems(): - # state['shard'] = states[rank]['shard'] - # states = dummy_states - - # # Gather the data-parallel shards - # # {tp_rank: [[stage_0_shard_0, stage_0_shard_1, ...], [stage_1_shard_0, ...], ...]} - # # {tp_rank: [{fsdp_rank: shard}, ...]} - # fsdp_shards = { - # i: [[None for _ in range(data_parallel_size)] for _ in range(pipeline_parallel_size)] - # for i in range(tensor_parallel_size) - # } - - # for rank, state in states.items(): - # on_device_stage_shards = extract_stage_shards(state) - # on_device_stage_indices = [i for (i, stage_meta) in enumerate(state["stages"]) if stage_meta["on_device"]] - # for stage_index, stage_shard in zip(on_device_stage_indices, on_device_stage_shards): - # stage_meta = state["stages"][stage_index] - # # fsdp_shards[stage_meta["tp_rank"]][stage_index].append((stage_meta, stage_shard)) - # fsdp_shards[stage_meta["tp_rank"]][stage_index][stage_meta["fsdp_rank"]] = stage_shard - - # # Concatenate the data-parallel shards - # # and get individual weights - # dp_concatenated_shards = { - # tp_rank: [ - # extract_individual_weights( - # torch.cat(stage_shards, dim=0), - # states[0]["stages"][stage_index]['content'] - # ) - # for stage_index, stage_shards in enumerate(fsdp_shards[tp_rank]) - # ] - # for tp_rank in range(config["tensor_parallel"]) - # } - - # # In the tensor-parallel case, concatenate the TP tensors along their TP dimensions. - # tp_concatenated_shards = [] - # for stage_index, stage_tp_shards in enumerate(zip(*(dp_concatenated_shards[i] for i in range(tensor_parallel)))): - # stage_content = states[0]["stages"][stage_index]["content"] - # tp_concatenated_shards.append(concatenate_tp_shards(stage_tp_shards, stage_content)) - - # # In the pipeline-parallel case, merge the stages - # state_dict = { - # weight_meta["name"]: weight - # for stage_meta, stage_weights in zip(states[0]["stages"], tp_concatenated_shards) - # for weight_meta, weight in zip(stage_meta["content"], stage_weights) - # } - - # print(f"Total number of parameters: {sum([weight.numel() for weight in state_dict.values()])}") - # return state_dict, config - - -# if __name__ == "__main__": -# merge_checkpoint("/toolkit_infiniband_example_checkpoints/ngc_checkpoints/sc2_ablations/1B_repo_context_Top-level-Depth-first_pp2_64k_64k_2023_10_17_16_35_27/", -# dummy_experiment_dir="/toolkit_infiniband_example_checkpoints/ngc_checkpoints/sc2_ablations/dev_1B_repo_context_Random_pp2_64k_64k_2023_10_18_22_20_36/") - From 21ee689f3b991631eef319449b73eee5277f5fcc Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Wed, 3 Jan 2024 09:20:33 +0000 Subject: [PATCH 18/26] refactor --- .../convert_fast_llm_checkpoint.py | 121 +----------------- .../gpt_bigcode/merge_fast_llm_checkpoint.py | 80 ++---------- 2 files changed, 9 insertions(+), 192 deletions(-) diff --git a/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py b/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py index bd6a147a63..03fd309a12 100644 --- a/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py +++ b/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py @@ -8,108 +8,8 @@ from transformers.models.gpt_bigcode import GPTBigCodeConfig, GPTBigCodeForCausalLM, GPTBigCodeModel -# # The simple map of names for "automated" rules. -# NAME_MAP = { -# "_mlp._layer_1": "mlp.c_fc", -# "_mlp._layer_2": "mlp.c_proj", -# "layer_norm_1": "ln_1", -# "layer_norm_2": "ln_2", -# # "attention.dense": "attn.c_proj", -# "self_attn.dense": "attn.c_proj", -# # "self_attention.query_key_value": "attn.c_attn", -# } - - -# def convert_fast_llm_checkpoint(state_dict, config): -# # The converted output model. -# output_state_dict = {} -# if "window_size" in config: -# attention_window_size = config["window_size"] -# else: -# attention_window_size = config.get("attention_window_size", None) - -# config = GPTBigCodeConfig( -# architectures=["GPTBigCodeLMHeadModel"], -# vocab_size=config["vocab_size"], -# n_positions=config["max_position_embeddings"], -# n_embd=config["hidden_size"], -# n_layer=config["num_layers"], -# n_head=config["num_attention_heads"], -# n_inner=config["ffn_hidden_size"], -# activation_function="gelu", # TODO -# multi_query=True, # TODO -# resid_pdrop=0.1, -# embd_pdrop=0.1, -# attn_pdrop=0.1, -# layer_norm_epsilon=1e-5, -# initializer_range=0.02, -# summary_type="cls_index", -# summary_use_proj=True, -# summary_activation=None, -# summary_proj_to_labels=True, -# summary_first_dropout=0.1, -# scale_attn_weights=True, -# use_cache=True, -# bos_token_id=0, # TODO: can we remove these? -# eos_token_id=0, -# attention_softmax_in_fp32=True, -# scale_attention_softmax_in_fp32=True, -# use_rotary_embeddings=config["use_rotary_embeddings"], -# rotary_embedding_scale=config["rotary_embedding_scale"], -# use_position_embeddings=config["use_position_embeddings"], -# attention_window_size=attention_window_size -# ) - -# # Truncate the word embeddings to the vocab-size -# word_embeddings = state_dict.pop("_layers.0._word_embeddings_weight")[:config.vocab_size, :] -# output_state_dict["transformer.wte.weight"] = word_embeddings -# if config.use_position_embeddings: -# output_state_dict["transformer.wpe.weight"] = state_dict.pop("_layers.0._position_embeddings_weight") - -# # Layer-0 is the word/position embeddings -# # Layers 1 to n_layer need to be re-mapped from 0 to n_layer-1. -# # _layers.{layer_index}.{op}.{w/b} - -# # Concatenate QKV matrix -# for layer_index in range(1, config.n_layer + 1): -# for weight_or_bias in ["weight", "bias"]: -# query = state_dict.pop(f"_layers.{layer_index}.self_attn.query.{weight_or_bias}") -# key_value = state_dict.pop(f"_layers.{layer_index}.self_attn.key_value.{weight_or_bias}") -# output_state_dict[f"transformer.h.{layer_index - 1}.attn.c_attn.{weight_or_bias}"] = torch.cat([query, key_value], dim=0) - -# # Extract the other ops -# layer_re = re.compile("_layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)") -# for name, value in state_dict.items(): -# m = layer_re.match(name) -# assert m is not None, f"Invalid layer name: {name}" - -# # The index of the layer. -# layer_index = int(m.group(1)) -# # The name of the operation. -# op_name = m.group(2) -# # Is it a weight or a bias? -# weight_or_bias = m.group(3) - -# # Final layernorm -# if op_name == "final_layernorm": -# assert layer_index == config.n_layer + 1 -# output_state_dict[f"transformer.ln_f.{weight_or_bias}"] = value -# else: -# output_state_dict[f"transformer.h.{layer_index-1}.{NAME_MAP[op_name]}.{weight_or_bias}"] = value - -# # For LM head, transformers' wants the matrix to weight embeddings. -# output_state_dict["lm_head.weight"] = word_embeddings - -# return output_state_dict, config - - def main(argv=None): parser = argparse.ArgumentParser() - # parser.add_argument( - # "--checkpoint_dir", - # type=Path, - # help="Path to the experiment directory", - # ) parser.add_argument( "--save_dir", type=Path, @@ -119,30 +19,11 @@ def main(argv=None): # TODO(xrsrke): auto convert checkpoint_dir to Path checkpoint_dir = "/admin/home/phuc_nguyen/.cache/huggingface/hub/models--HuggingFaceBR4--starcoder2_7b_4k_smol_data_580000/snapshots/92b6c25cab25f07c367bcc6d773635700a8a287d" checkpoint_dir = Path(checkpoint_dir) - - # state_dict, config = merge_checkpoint( - # # args.checkpoint_dir, - # checkpoint_dir, - # dummy_experiment_dir=None - # ) - + merge_checkpoint( - # args.checkpoint_dir, checkpoint_dir, dummy_experiment_dir=None ) - - # output_state_dict, output_config = convert_fast_llm_checkpoint(state_dict, config) - - # print("Saving config") - # save_dir = args.save_dir or args.checkpoint_dir / "converted" - # output_config.save_pretrained(save_dir) - - # # Store the state_dict to file. - # output_checkpoint_file = os.path.join(save_dir, "pytorch_model.bin") - # print(f'Saving checkpoint to "{output_checkpoint_file}"') - # torch.save(output_state_dict, output_checkpoint_file) - # print(f'Done!') if __name__ == "__main__": diff --git a/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py b/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py index 4bb95b8643..d87d483f80 100644 --- a/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py +++ b/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py @@ -6,6 +6,9 @@ import torch import yaml +import re +from os.path import commonprefix + MERGE_DIM_MAPPING = { "ff.c_fc.bias": 0, @@ -32,64 +35,16 @@ 'ff.c_proj.model_bias': 'mlp.c_proj.bias' } -# def get_all_checkpoint_paths(experiment_path): -# checkpoints = (Path(experiment_path) / "checkpoints").glob("*") -# # Sort checkpoints by iteration number -# checkpoints = sorted(checkpoints, key=lambda x: int(x.name)) -# return [get_checkpoint_paths(checkpoint) for checkpoint in checkpoints] - - -# def get_checkpoint_paths(checkpoint_dir: Path): -# return [c_name for c_name in checkpoint_dir.glob("*") if re.match(r"\d+", c_name.name)] - def get_safetensor_checkpoint_paths(checkpoint_dir: Path): - model_dir = checkpoint_dir / "model" / "model" # Targeting the specific directory + model_dir = checkpoint_dir / "model" / "model" safetensor_files = [] - for file_path in model_dir.rglob("*.safetensors"): # Looking for files with .safetensors extension - if file_path.is_file(): # Ensure it's a file - safetensor_files.append(file_path.absolute()) # Adding the absolute path of the file + for file_path in model_dir.rglob("*.safetensors"): + if file_path.is_file(): + safetensor_files.append(file_path.absolute()) return safetensor_files - -# def extract_stage_shards(state): -# # Extract the weight shard and split it into the stage shards -# # Reproduce the split done in MultiStageModelBase.setup -# total_shard_size = sum(state['stage_shard_sizes']) -# if len(state['shard'].shape) == 1: -# # Flat buffer -# weight_shard = state['shard'][:total_shard_size] -# elif len(state['shard'].shape) == 2: -# # 2D buffer -# weight_shard = state['shard'][0] -# else: -# raise ValueError(f"Unrecognized buffer shape {state['shard'].shape}") -# return weight_shard.split(state['stage_shard_sizes']) - - -# def extract_individual_weights(merged_stage_shard, stage_content): -# # Get individual weights from shards that are merged across data-parallel -# weights_numel = [np.prod(weight_meta['shape']) for weight_meta in stage_content] -# weights = merged_stage_shard[:sum(weights_numel)].split(weights_numel) -# return [weight.reshape(weight_meta['shape']) for weight, weight_meta in zip(weights, stage_content)] - - -# def concatenate_tp_shards(stage_tp_shards, stage_content): -# # Concatenate the tp-shards in a given stage -# # Stage_tp_shards: contains the individual weight shards for each rank -# # [[weight1, weight2, ...] for rank in range(tp_size)] -# concatenated_weights = [] -# # Concatenate each individual weight along their TP dimension if they have one. -# for weight_tp_shards, weight_meta in zip(zip(*stage_tp_shards), stage_content): -# if weight_meta["tensor_parallel_dim"] is not None: -# weight = torch.cat(weight_tp_shards, dim=weight_meta["tensor_parallel_dim"]) -# else: -# weight = weight_tp_shards[0] -# concatenated_weights.append(weight) -# return concatenated_weights - - def merge_checkpoint(checkpoint_dir: Path, dummy_experiment_dir=None): """Load a fast-llm checkpoint and merge the data, tensor, and pipeline-parallel shards""" # checkpoint_dir=experiment_dir/checkpoints/{iteration} @@ -97,14 +52,9 @@ def merge_checkpoint(checkpoint_dir: Path, dummy_experiment_dir=None): # experiment_dir = checkpoint_dir.parent.parent # NOTE: use the checkpoint format from https://huggingface.co/HuggingFaceBR4/starcoder2_7b_4k_smol_data_580000/tree/main/model/model/token_embeddings/pp_block/token_embedding - # where experiment_dir = checkpoint_dir - # checkpoint_paths = get_checkpoint_paths(checkpoint_dir) checkpoint_dir = Path(checkpoint_dir) checkpoint_paths = get_safetensor_checkpoint_paths(checkpoint_dir) config = yaml.safe_load((checkpoint_dir / "config.yaml").read_text()) - - import re - from os.path import commonprefix def convert_paths_to_dict(paths): path_objs = [Path(p) for p in paths] @@ -156,7 +106,6 @@ def remove_safetensors_extension(paths): grouped_paths[key].append(path) def remove_keys_with_empty_lists(input_dict): - # Using dictionary comprehension to filter out keys with empty lists filtered_dict = {key: value for key, value in input_dict.items() if value} return filtered_dict @@ -177,7 +126,6 @@ def find_corresponding_dim(name): return value return None - # path_demo = list(grouped_paths.values())[0] _model_states = {} for state_key, path in paths.items(): _model_states[state_key] = {} @@ -200,18 +148,7 @@ def find_corresponding_dim(name): else: # NOTE: these are biases _model_states[state_key] = tensor_list[0] - - assert 1 == 1 - - # for key, value in _model_states.items(): - # if isinstance(value, torch.Tensor): - # print(f"key: {key}, value: {value.shape} \n") - # else: - # print(f"skipped key: {key}, shape: {[x.shape for x in value.values()]} \n") - - - assert 1 == 1 - + def remap_keys(target_dict): new_dict = {} for key, value in target_dict.items(): @@ -230,7 +167,6 @@ def remap_keys(target_dict): elif key == 'model.final_layer_norm.pp_block.model_bias': new_dict['transformer.ln_f.bias'] = value - # Handling token embeddings elif key == 'model.token_embeddings.pp_block.token_embedding.weight': new_dict['transformer.wte.weight'] = value From 210311b22d1c1506297aa423097622730b7a3e1a Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Wed, 3 Jan 2024 10:23:41 +0000 Subject: [PATCH 19/26] add inference script --- .../gpt_bigcode/drafts/starcoder_model.py | 27 ++++--------------- 1 file changed, 5 insertions(+), 22 deletions(-) diff --git a/src/transformers/models/gpt_bigcode/drafts/starcoder_model.py b/src/transformers/models/gpt_bigcode/drafts/starcoder_model.py index 97ceceebde..06eb3a6dfc 100644 --- a/src/transformers/models/gpt_bigcode/drafts/starcoder_model.py +++ b/src/transformers/models/gpt_bigcode/drafts/starcoder_model.py @@ -1,26 +1,23 @@ from transformers import GPTBigCodeForCausalLM, GPTBigCodeConfig +from transformers import AutoTokenizer from pathlib import Path import json import torch +import random +import numpy as np + if __name__ == "__main__": + seed = 42 checkpoint_dir = "/admin/home/phuc_nguyen/.cache/huggingface/hub/models--HuggingFaceBR4--starcoder2_7b_4k_smol_data_580000/snapshots/92b6c25cab25f07c367bcc6d773635700a8a287d" checkpoint_dir = Path(checkpoint_dir) config = json.load(open(checkpoint_dir / "model_config.json")) - - import random - - import numpy as np - - seed = 42 torch.manual_seed(seed) - if torch.cuda.is_available(): torch.cuda.manual_seed(seed) - np.random.seed(seed) random.seed(seed) @@ -56,26 +53,12 @@ model.load_state_dict(checkpoint) model = model.to("cuda") - # model = GPTBigCodeForCausalLM.from_pretrained(checkpoint_path, torch_dtype=torch.bfloat16, device_map="cuda", config=config) - - assert 1 == 1 - - from transformers import AutoTokenizer checkpoint = "bigcode/starcoder" - tokenizer = AutoTokenizer.from_pretrained(checkpoint) tokenizer.eos_token_id = tokenizer.pad_token_id - # inputs = tokenizer("123", return_tensors='pt').to("cuda") inputs = tokenizer.encode("def print_hello_world():", return_tensors="pt").to("cuda") outputs = model.generate(inputs) print(f"outputs: {outputs}") print(tokenizer.decode(outputs[0], clean_up_tokenization_spaces=False)) - - # print([x for x in model.state_dict().keys()]) - - # print("----------------------------------------------------------------\n") - - # for key, value in model.state_dict().items(): - # print(key, value.shape) From 09c086af0d8463651ff49091cdbe21a41be916f5 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Wed, 3 Jan 2024 13:23:28 +0000 Subject: [PATCH 20/26] refactor --- .../convert_fast_llm_checkpoint.py | 48 ++++- .../gpt_bigcode/drafts/starcoder_model.py | 4 +- .../gpt_bigcode/merge_fast_llm_checkpoint.py | 199 ++++++++---------- 3 files changed, 130 insertions(+), 121 deletions(-) diff --git a/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py b/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py index 03fd309a12..57834a98e4 100644 --- a/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py +++ b/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py @@ -1,29 +1,61 @@ import argparse import os from pathlib import Path -import re import torch from transformers.models.gpt_bigcode.merge_fast_llm_checkpoint import merge_checkpoint -from transformers.models.gpt_bigcode import GPTBigCodeConfig, GPTBigCodeForCausalLM, GPTBigCodeModel def main(argv=None): parser = argparse.ArgumentParser() + parser.add_argument( + "--checkpoint_dir", + type=Path, + default="/admin/home/phuc_nguyen/.cache/huggingface/hub/models--HuggingFaceBR4--starcoder2_7b_4k_smol_data_580000/snapshots/92b6c25cab25f07c367bcc6d773635700a8a287d", + help="Path where the converted model is saved" + ) parser.add_argument( "--save_dir", type=Path, + default="./", help="Path where the converted model is saved" ) args = parser.parse_args(argv) + + print("start") + # TODO(xrsrke): auto convert checkpoint_dir to Path - checkpoint_dir = "/admin/home/phuc_nguyen/.cache/huggingface/hub/models--HuggingFaceBR4--starcoder2_7b_4k_smol_data_580000/snapshots/92b6c25cab25f07c367bcc6d773635700a8a287d" - checkpoint_dir = Path(checkpoint_dir) + # checkpoint_dir = "/admin/home/phuc_nguyen/.cache/huggingface/hub/models--HuggingFaceBR4--starcoder2_7b_4k_smol_data_580000/snapshots/92b6c25cab25f07c367bcc6d773635700a8a287d" + # checkpoint_dir = Path(checkpoint_dir) + + state_dict = merge_checkpoint(args.checkpoint_dir) + + # save_dir = args.save_dir or args.checkpoint_dir / "converted" + print("save") + # output_checkpoint_file = os.path.join(Path("./"), "pytorch_model_0.bin") + # print(f'Saving checkpoint to "{output_checkpoint_file}"') + # torch.save(state_dict, output_checkpoint_file) + # print(f'Done!') + + # Compare + def compare_state_dicts(dict1, dict2): + # Compare keys + if set(dict1.keys()) != set(dict2.keys()): + return "Different keys" + + # Compare shapes and values + for key in dict1: + if dict1[key].shape != dict2[key].shape: + return f"Different shape for key: {key}" + if not torch.allclose(dict1[key], dict2[key]): + return f"Different values for key: {key}" + + return "State dictionaries are identical" + + ref_state_dict = torch.load("/fsx/phuc/projects/starcoder/transformers-starcoder/src/transformers/models/gpt_bigcode/merged_checkpoint.pth") + result = compare_state_dicts(state_dict, ref_state_dict) + print(result) - merge_checkpoint( - checkpoint_dir, - dummy_experiment_dir=None - ) if __name__ == "__main__": diff --git a/src/transformers/models/gpt_bigcode/drafts/starcoder_model.py b/src/transformers/models/gpt_bigcode/drafts/starcoder_model.py index 06eb3a6dfc..54e542d4de 100644 --- a/src/transformers/models/gpt_bigcode/drafts/starcoder_model.py +++ b/src/transformers/models/gpt_bigcode/drafts/starcoder_model.py @@ -48,12 +48,12 @@ model = GPTBigCodeForCausalLM._from_config(model_config, torch_dtype=torch.bfloat16) - checkpoint_path = Path("/fsx/phuc/projects/starcoder/transformers-starcoder/src/transformers/models/gpt_bigcode/merged_checkpoint_reversed.pth") + checkpoint_path = Path("/fsx/phuc/projects/starcoder/transformers-starcoder/src/transformers/models/gpt_bigcode/merged_checkpoint.pth") checkpoint = torch.load(checkpoint_path) model.load_state_dict(checkpoint) model = model.to("cuda") - checkpoint = "bigcode/starcoder" + checkpoint = "bigcode/starcoder2-tokenizer" tokenizer = AutoTokenizer.from_pretrained(checkpoint) tokenizer.eos_token_id = tokenizer.pad_token_id diff --git a/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py b/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py index d87d483f80..9570eb41a9 100644 --- a/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py +++ b/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py @@ -5,6 +5,8 @@ import numpy as np import torch import yaml +from safetensors import safe_open +from collections import defaultdict import re from os.path import commonprefix @@ -45,109 +47,94 @@ def get_safetensor_checkpoint_paths(checkpoint_dir: Path): return safetensor_files -def merge_checkpoint(checkpoint_dir: Path, dummy_experiment_dir=None): - """Load a fast-llm checkpoint and merge the data, tensor, and pipeline-parallel shards""" - # checkpoint_dir=experiment_dir/checkpoints/{iteration} - # experiment_dir = "~/.cache/huggingface/hub/models--HuggingFaceBR4--starcoder2_7b_4k_smol_data_580000" - # experiment_dir = checkpoint_dir.parent.parent - - # NOTE: use the checkpoint format from https://huggingface.co/HuggingFaceBR4/starcoder2_7b_4k_smol_data_580000/tree/main/model/model/token_embeddings/pp_block/token_embedding - checkpoint_dir = Path(checkpoint_dir) +def merge_checkpoint(checkpoint_dir: Path): + """Load a checkpoint from the BRRR format and merge tensor parallel shards.""" checkpoint_paths = get_safetensor_checkpoint_paths(checkpoint_dir) - config = yaml.safe_load((checkpoint_dir / "config.yaml").read_text()) - - def convert_paths_to_dict(paths): + + def transform_paths(paths): + # Convert to Path objects and find common prefix path_objs = [Path(p) for p in paths] common_path_prefix = Path(commonprefix(path_objs)).parent - path_dict = {str(p.relative_to(common_path_prefix)): str(p) for p in path_objs} - return path_dict - - paths = convert_paths_to_dict(checkpoint_paths) - - def convert_slashes_to_dots(input_dict): - converted_dict = {} - for key, value in input_dict.items(): - modified_key = key.replace('/', '.') - converted_dict[modified_key] = value - return converted_dict - - paths = convert_slashes_to_dots(paths) - - def replace_patterns(paths): - new_paths = {} - for key, value in paths.items(): - new_key = re.sub(r'model_weight_pp-rank-0-of-1_tp-rank-(\d)-of-4', r'weight.\1', key) - new_key = re.sub(r'model_bias_pp-rank-0-of-1_tp-rank-(\d)-of-4', r'bias.\1', new_key) - new_paths[new_key] = value - return new_paths - - paths = replace_patterns(paths) - - def remove_safetensors_extension(paths): - new_paths = {} - for key, value in paths.items(): - new_key = key.replace('.safetensors', '') - new_paths[new_key] = value - return new_paths - - paths = remove_safetensors_extension(paths) - - from collections import defaultdict - grouped_paths = defaultdict(list) - for key, path in paths.items(): - try: - module_name, shard_number = key.rsplit('.', 1) - grouped_paths[module_name].append((int(shard_number), path)) - except: - # NOTE: these are layer norm's weight, bias - # or other module biases, which are small, so brrr doesn't split them - print(f"skipped {key}, {path}") - grouped_paths[key].append(path) - - def remove_keys_with_empty_lists(input_dict): - filtered_dict = {key: value for key, value in input_dict.items() if value} - return filtered_dict - - grouped_paths = remove_keys_with_empty_lists(grouped_paths) - - # TODO(xrsrke): it merged paths for bias and weight in the same group => wrong - # sorted_grouped_paths = {module: sorted(paths, key=lambda x: x[0]) for module, paths in grouped_paths.items()} - sorted_grouped_paths = {module: sorted(paths, key=lambda x: x[0], reverse=True) for module, paths in grouped_paths.items()} - paths = sorted_grouped_paths - - from safetensors import safe_open - - assert 1 == 1 + # Initialize the final paths dictionary + final_paths = {} + + for path in path_objs: + # Relative path + relative_path = str(path.relative_to(common_path_prefix)) + + # Convert slashes to dots + dot_path = relative_path.replace('/', '.') + + # Replace patterns for model weights and biases + weight_replaced = re.sub(r'model_weight_pp-rank-0-of-1_tp-rank-(\d)-of-4', r'weight.\1', dot_path) + bias_replaced = re.sub(r'model_bias_pp-rank-0-of-1_tp-rank-(\d)-of-4', r'bias.\1', weight_replaced) + + # Remove '.safetensors' extension + cleaned_path = bias_replaced.replace('.safetensors', '') + + # Add to final dictionary + final_paths[cleaned_path] = str(path) + + return final_paths + + paths = transform_paths(checkpoint_paths) - def find_corresponding_dim(name): - for key, value in MERGE_DIM_MAPPING.items(): - if key in name: - return value - return None - - _model_states = {} - for state_key, path in paths.items(): - _model_states[state_key] = {} - for shard_id, _path in enumerate(path): - checkpoint_path = _path[1] if isinstance(_path, tuple) else _path - with safe_open(checkpoint_path, framework="pt", device="cpu") as f: - for key in f.keys(): - data = f.get_tensor(key) - _model_states[state_key][shard_id] = data - - tensor_list = [tensor for _, tensor in sorted(_model_states[state_key].items())] - merge_dim = find_corresponding_dim(state_key) - print(f"trying to merge: {state_key}") - - if len(tensor_list) > 1: + def group_and_sort_paths(paths): + grouped_paths = defaultdict(list) + + for key, path in paths.items(): try: - _model_states[state_key] = torch.cat(tensor_list, dim=merge_dim) - except: - print(f"skipped {state_key}, {[x.shape for x in tensor_list]}") - else: - # NOTE: these are biases - _model_states[state_key] = tensor_list[0] + module_name, shard_number = key.rsplit('.', 1) + grouped_paths[module_name].append((int(shard_number), path)) + except ValueError: + # Handle cases where the key does not split into two parts + print(f"skipped {key}, {path}") + grouped_paths[key].append(path) + + # Remove any entries with empty lists + grouped_paths = {k: v for k, v in grouped_paths.items() if v} + + # Sort paths in each group + sorted_grouped_paths = {module: sorted(paths, key=lambda x: x[0]) + for module, paths in grouped_paths.items()} + + return sorted_grouped_paths + + paths = group_and_sort_paths(paths) + + def merge_checkpoints(paths): + def find_corresponding_dim(name): + for key, value in MERGE_DIM_MAPPING.items(): + if key in name: + return value + return None + + model_states = {} + for state_key, path in paths.items(): + model_states[state_key] = {} + for shard_id, _path in enumerate(path): + checkpoint_path = _path[1] if isinstance(_path, tuple) else _path + with safe_open(checkpoint_path, framework="pt", device="cpu") as f: + for key in f.keys(): + data = f.get_tensor(key) + model_states[state_key][shard_id] = data + + tensor_list = [tensor for _, tensor in sorted(model_states[state_key].items())] + merge_dim = find_corresponding_dim(state_key) + print(f"trying to merge: {state_key}") + + if len(tensor_list) > 1: + try: + model_states[state_key] = torch.cat(tensor_list, dim=merge_dim) + except: + print(f"skipped {state_key}, {[x.shape for x in tensor_list]}") + else: + # NOTE: these are biases + model_states[state_key] = tensor_list[0] + return model_states + + model_states = merge_checkpoints(paths) def remap_keys(target_dict): new_dict = {} @@ -170,19 +157,9 @@ def remap_keys(target_dict): elif key == 'model.token_embeddings.pp_block.token_embedding.weight': new_dict['transformer.wte.weight'] = value + new_dict["lm_head.weight"] = new_dict["transformer.wte.weight"] return new_dict - _model_states = remap_keys(_model_states) - _model_states["lm_head.weight"] = _model_states["transformer.wte.weight"] - - for key, value in _model_states.items(): - if isinstance(value, torch.Tensor): - print(f"key: {key}, value: {value.shape} \n") - else: - print(f"skipped key: {key}, shape: {[x.shape for x in value.values()]} \n") - - print("saving merged checkpoint...") - - torch.save(_model_states, './merged_checkpoint_reversed.pth') - - print("done") + model_states = remap_keys(model_states) + + return model_states From ae54653bff3222a11f9df92c8226503157f7c91f Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Wed, 3 Jan 2024 13:32:47 +0000 Subject: [PATCH 21/26] refactor all functions --- .../gpt_bigcode/merge_fast_llm_checkpoint.py | 188 ++++++++---------- 1 file changed, 84 insertions(+), 104 deletions(-) diff --git a/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py b/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py index 9570eb41a9..6176d86873 100644 --- a/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py +++ b/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py @@ -1,10 +1,7 @@ import re -from tqdm import tqdm from pathlib import Path -import numpy as np import torch -import yaml from safetensors import safe_open from collections import defaultdict @@ -47,119 +44,102 @@ def get_safetensor_checkpoint_paths(checkpoint_dir: Path): return safetensor_files -def merge_checkpoint(checkpoint_dir: Path): - """Load a checkpoint from the BRRR format and merge tensor parallel shards.""" - checkpoint_paths = get_safetensor_checkpoint_paths(checkpoint_dir) - - def transform_paths(paths): - # Convert to Path objects and find common prefix - path_objs = [Path(p) for p in paths] - common_path_prefix = Path(commonprefix(path_objs)).parent - # Initialize the final paths dictionary - final_paths = {} +def transform_paths(paths): + path_objs = [Path(p) for p in paths] + common_path_prefix = Path(commonprefix(path_objs)).parent - for path in path_objs: - # Relative path - relative_path = str(path.relative_to(common_path_prefix)) + final_paths = {} + for path in path_objs: + relative_path = str(path.relative_to(common_path_prefix)) + dot_path = relative_path.replace('/', '.') + + weight_replaced = re.sub(r'model_weight_pp-rank-0-of-1_tp-rank-(\d)-of-4', r'weight.\1', dot_path) + bias_replaced = re.sub(r'model_bias_pp-rank-0-of-1_tp-rank-(\d)-of-4', r'bias.\1', weight_replaced) + cleaned_path = bias_replaced.replace('.safetensors', '') - # Convert slashes to dots - dot_path = relative_path.replace('/', '.') + final_paths[cleaned_path] = str(path) - # Replace patterns for model weights and biases - weight_replaced = re.sub(r'model_weight_pp-rank-0-of-1_tp-rank-(\d)-of-4', r'weight.\1', dot_path) - bias_replaced = re.sub(r'model_bias_pp-rank-0-of-1_tp-rank-(\d)-of-4', r'bias.\1', weight_replaced) + return final_paths - # Remove '.safetensors' extension - cleaned_path = bias_replaced.replace('.safetensors', '') +def group_and_sort_paths(paths): + grouped_paths = defaultdict(list) - # Add to final dictionary - final_paths[cleaned_path] = str(path) + for key, path in paths.items(): + try: + module_name, shard_number = key.rsplit('.', 1) + grouped_paths[module_name].append((int(shard_number), path)) + except ValueError: + # NOTE: these are layer norm's weight and biases + # so it don't have shard number + grouped_paths[key].append(path) - return final_paths + # Remove any entries with empty lists + grouped_paths = {k: v for k, v in grouped_paths.items() if v} - paths = transform_paths(checkpoint_paths) - - def group_and_sort_paths(paths): - grouped_paths = defaultdict(list) - - for key, path in paths.items(): - try: - module_name, shard_number = key.rsplit('.', 1) - grouped_paths[module_name].append((int(shard_number), path)) - except ValueError: - # Handle cases where the key does not split into two parts - print(f"skipped {key}, {path}") - grouped_paths[key].append(path) - - # Remove any entries with empty lists - grouped_paths = {k: v for k, v in grouped_paths.items() if v} - - # Sort paths in each group - sorted_grouped_paths = {module: sorted(paths, key=lambda x: x[0]) - for module, paths in grouped_paths.items()} - - return sorted_grouped_paths - - paths = group_and_sort_paths(paths) - - def merge_checkpoints(paths): - def find_corresponding_dim(name): - for key, value in MERGE_DIM_MAPPING.items(): - if key in name: - return value - return None - - model_states = {} - for state_key, path in paths.items(): - model_states[state_key] = {} - for shard_id, _path in enumerate(path): - checkpoint_path = _path[1] if isinstance(_path, tuple) else _path - with safe_open(checkpoint_path, framework="pt", device="cpu") as f: - for key in f.keys(): - data = f.get_tensor(key) - model_states[state_key][shard_id] = data - - tensor_list = [tensor for _, tensor in sorted(model_states[state_key].items())] - merge_dim = find_corresponding_dim(state_key) - print(f"trying to merge: {state_key}") - - if len(tensor_list) > 1: - try: - model_states[state_key] = torch.cat(tensor_list, dim=merge_dim) - except: - print(f"skipped {state_key}, {[x.shape for x in tensor_list]}") - else: - # NOTE: these are biases - model_states[state_key] = tensor_list[0] - return model_states - - model_states = merge_checkpoints(paths) + # NOTE: Sort paths in each group + # module: [(4, path), (0, path), (3, path) ...] -> module: [(0, path), (1, path), (2, path) ...] + sorted_grouped_paths = {module: sorted(paths, key=lambda x: x[0]) + for module, paths in grouped_paths.items()} - def remap_keys(target_dict): - new_dict = {} - for key, value in target_dict.items(): - parts = key.split('.') - - if 'model.decoder' in key and 'pp_block' in key: - block_number = parts[2] - component_parts = parts[4:] - component = '.'.join(component_parts) - - new_component = BRRR_TFMS_NAME_MAPPING.get(component, component) - new_key = f"transformer.h.{block_number}.{new_component}" - new_dict[new_key] = value - elif key == 'model.final_layer_norm.pp_block.model_weight': - new_dict['transformer.ln_f.weight'] = value - elif key == 'model.final_layer_norm.pp_block.model_bias': - new_dict['transformer.ln_f.bias'] = value + return sorted_grouped_paths + +def merge_checkpoints(paths): + def find_corresponding_dim(name): + for key, value in MERGE_DIM_MAPPING.items(): + if key in name: + return value + return None + + model_states = {} + for state_key, path in paths.items(): + model_states[state_key] = {} + for shard_id, _path in enumerate(path): + checkpoint_path = _path[1] if isinstance(_path, tuple) else _path + with safe_open(checkpoint_path, framework="pt", device="cpu") as f: + for key in f.keys(): + data = f.get_tensor(key) + model_states[state_key][shard_id] = data + + tensor_list = [tensor for _, tensor in sorted(model_states[state_key].items())] + merge_dim = find_corresponding_dim(state_key) + + if len(tensor_list) > 1: + model_states[state_key] = torch.cat(tensor_list, dim=merge_dim) + else: + # NOTE: these are biases + model_states[state_key] = tensor_list[0] + return model_states - elif key == 'model.token_embeddings.pp_block.token_embedding.weight': - new_dict['transformer.wte.weight'] = value +def remap_keys(target_dict): + key_mapping = { + 'model.final_layer_norm.pp_block.model_weight': 'transformer.ln_f.weight', + 'model.final_layer_norm.pp_block.model_bias': 'transformer.ln_f.bias', + 'model.token_embeddings.pp_block.token_embedding.weight': 'transformer.wte.weight' + } - new_dict["lm_head.weight"] = new_dict["transformer.wte.weight"] - return new_dict + def get_new_key(key): + if 'model.decoder' in key and 'pp_block' in key: + parts = key.split('.') + block_number = parts[2] + component_parts = parts[4:] + component = '.'.join(component_parts) + new_component = BRRR_TFMS_NAME_MAPPING.get(component, component) + return f"transformer.h.{block_number}.{new_component}" + else: + return key_mapping.get(key, key) + + new_dict = {get_new_key(key): value for key, value in target_dict.items()} + new_dict["lm_head.weight"] = new_dict.get("transformer.wte.weight", new_dict.get("lm_head.weight")) + + return new_dict +def merge_checkpoint(checkpoint_dir: Path): + """Load a checkpoint from the BRRR format and merge tensor parallel shards.""" + checkpoint_paths = get_safetensor_checkpoint_paths(checkpoint_dir) + paths = transform_paths(checkpoint_paths) + paths = group_and_sort_paths(paths) + model_states = merge_checkpoints(paths) model_states = remap_keys(model_states) return model_states From 594099c8d188b8795d0a3c052a6cb00c9801dbdd Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Wed, 3 Jan 2024 13:58:48 +0000 Subject: [PATCH 22/26] save some files before cleaning it all --- .../convert_fast_llm_checkpoint.py | 52 +++++++++---------- .../models/gpt_bigcode/download_ckp.py | 5 -- .../gpt_bigcode/drafts/check_weird_shape.py | 9 ++++ .../gpt_bigcode/drafts/starcoder_model.py | 2 +- .../gpt_bigcode/merge_fast_llm_checkpoint.py | 45 ++++++++-------- 5 files changed, 60 insertions(+), 53 deletions(-) delete mode 100644 src/transformers/models/gpt_bigcode/download_ckp.py create mode 100644 src/transformers/models/gpt_bigcode/drafts/check_weird_shape.py diff --git a/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py b/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py index 57834a98e4..ecae175829 100644 --- a/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py +++ b/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py @@ -11,13 +11,13 @@ def main(argv=None): parser.add_argument( "--checkpoint_dir", type=Path, - default="/admin/home/phuc_nguyen/.cache/huggingface/hub/models--HuggingFaceBR4--starcoder2_7b_4k_smol_data_580000/snapshots/92b6c25cab25f07c367bcc6d773635700a8a287d", + # default="/admin/home/phuc_nguyen/.cache/huggingface/hub/models--HuggingFaceBR4--starcoder2_7b_4k_smol_data_580000/snapshots/92b6c25cab25f07c367bcc6d773635700a8a287d", help="Path where the converted model is saved" ) parser.add_argument( "--save_dir", type=Path, - default="./", + # default="./", help="Path where the converted model is saved" ) args = parser.parse_args(argv) @@ -30,31 +30,31 @@ def main(argv=None): state_dict = merge_checkpoint(args.checkpoint_dir) - # save_dir = args.save_dir or args.checkpoint_dir / "converted" - print("save") - # output_checkpoint_file = os.path.join(Path("./"), "pytorch_model_0.bin") - # print(f'Saving checkpoint to "{output_checkpoint_file}"') - # torch.save(state_dict, output_checkpoint_file) - # print(f'Done!') + save_dir = args.save_dir or args.checkpoint_dir / "converted" + output_checkpoint_file = os.path.join(save_dir, "pytorch_model.bin") - # Compare - def compare_state_dicts(dict1, dict2): - # Compare keys - if set(dict1.keys()) != set(dict2.keys()): - return "Different keys" - - # Compare shapes and values - for key in dict1: - if dict1[key].shape != dict2[key].shape: - return f"Different shape for key: {key}" - if not torch.allclose(dict1[key], dict2[key]): - return f"Different values for key: {key}" - - return "State dictionaries are identical" - - ref_state_dict = torch.load("/fsx/phuc/projects/starcoder/transformers-starcoder/src/transformers/models/gpt_bigcode/merged_checkpoint.pth") - result = compare_state_dicts(state_dict, ref_state_dict) - print(result) + print(f'Saving checkpoint to "{output_checkpoint_file}"') + torch.save(state_dict, output_checkpoint_file) + print(f'Done!') + + # # Compare + # def compare_state_dicts(dict1, dict2): + # # Compare keys + # if set(dict1.keys()) != set(dict2.keys()): + # return "Different keys" + + # # Compare shapes and values + # for key in dict1: + # if dict1[key].shape != dict2[key].shape: + # return f"Different shape for key: {key}" + # if not torch.allclose(dict1[key], dict2[key]): + # return f"Different values for key: {key}" + + # return "State dictionaries are identical" + + # ref_state_dict = torch.load("/fsx/phuc/projects/starcoder/transformers-starcoder/src/transformers/models/gpt_bigcode/merged_checkpoint.pth") + # result = compare_state_dicts(state_dict, ref_state_dict) + # print(result) diff --git a/src/transformers/models/gpt_bigcode/download_ckp.py b/src/transformers/models/gpt_bigcode/download_ckp.py deleted file mode 100644 index a9fdd5ae44..0000000000 --- a/src/transformers/models/gpt_bigcode/download_ckp.py +++ /dev/null @@ -1,5 +0,0 @@ -from huggingface_hub import snapshot_download - -if __name__ == "__main__": - snapshot_download("HuggingFaceBR4/starcoder2_7b_4k_smol_data_580000") - print("done") diff --git a/src/transformers/models/gpt_bigcode/drafts/check_weird_shape.py b/src/transformers/models/gpt_bigcode/drafts/check_weird_shape.py new file mode 100644 index 0000000000..e8c26a801a --- /dev/null +++ b/src/transformers/models/gpt_bigcode/drafts/check_weird_shape.py @@ -0,0 +1,9 @@ +from pathlib import Path +from safetensors import safe_open + + +if __name__ == "__main__": + checkpoint_dir = "/admin/home/phuc_nguyen/.cache/huggingface/hub/models--HuggingFaceBR4--starcoder2_7b_4k_smol_data_580000/snapshots/92b6c25cab25f07c367bcc6d773635700a8a287d" + checkpoint_dir = Path(checkpoint_dir) + + assert 1 == 1 diff --git a/src/transformers/models/gpt_bigcode/drafts/starcoder_model.py b/src/transformers/models/gpt_bigcode/drafts/starcoder_model.py index 54e542d4de..84a5d4185d 100644 --- a/src/transformers/models/gpt_bigcode/drafts/starcoder_model.py +++ b/src/transformers/models/gpt_bigcode/drafts/starcoder_model.py @@ -48,7 +48,7 @@ model = GPTBigCodeForCausalLM._from_config(model_config, torch_dtype=torch.bfloat16) - checkpoint_path = Path("/fsx/phuc/projects/starcoder/transformers-starcoder/src/transformers/models/gpt_bigcode/merged_checkpoint.pth") + checkpoint_path = Path("/fsx/phuc/projects/starcoder/transformers-starcoder/src/transformers/models/gpt_bigcode/pytorch_model.pth") checkpoint = torch.load(checkpoint_path) model.load_state_dict(checkpoint) model = model.to("cuda") diff --git a/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py b/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py index 6176d86873..f426166eca 100644 --- a/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py +++ b/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py @@ -9,16 +9,6 @@ from os.path import commonprefix -MERGE_DIM_MAPPING = { - "ff.c_fc.bias": 0, - "token_embedding": 0, # row linear parallel - "c_fc": 0, # column linear parallel - "c_proj": 1, # row linear parallel - # NOTE: weird - "query_key_value": 0, # row linear parallel - "dense": 1, # row linear parallel -} - BRRR_TFMS_NAME_MAPPING = { 'ln_1.model_weight': 'ln_1.weight', 'ln_1.model_bias': 'ln_1.bias', @@ -34,7 +24,19 @@ 'ff.c_proj.model_bias': 'mlp.c_proj.bias' } -def get_safetensor_checkpoint_paths(checkpoint_dir: Path): +MERGE_DIM_MAPPING = { + "token_embedding": 0, # row linear parallel + # NOTE: MLP layer + "c_fc": 0, # column linear parallel + "ff.c_fc.bias": 0, + "c_proj": 1, # row linear parallel + # NOTE: attention layer + "query_key_value": 0, # row linear parallel + "dense": 1, # row linear parallel +} + + +def _get_safetensor_checkpoint_paths(checkpoint_dir: Path): model_dir = checkpoint_dir / "model" / "model" safetensor_files = [] @@ -45,7 +47,7 @@ def get_safetensor_checkpoint_paths(checkpoint_dir: Path): return safetensor_files -def transform_paths(paths): +def _transform_paths(paths): path_objs = [Path(p) for p in paths] common_path_prefix = Path(commonprefix(path_objs)).parent @@ -62,7 +64,7 @@ def transform_paths(paths): return final_paths -def group_and_sort_paths(paths): +def _group_and_sort_paths(paths): grouped_paths = defaultdict(list) for key, path in paths.items(): @@ -84,7 +86,7 @@ def group_and_sort_paths(paths): return sorted_grouped_paths -def merge_checkpoints(paths): +def _merge_checkpoints(paths): def find_corresponding_dim(name): for key, value in MERGE_DIM_MAPPING.items(): if key in name: @@ -111,7 +113,7 @@ def find_corresponding_dim(name): model_states[state_key] = tensor_list[0] return model_states -def remap_keys(target_dict): +def _remap_keys(target_dict): key_mapping = { 'model.final_layer_norm.pp_block.model_weight': 'transformer.ln_f.weight', 'model.final_layer_norm.pp_block.model_bias': 'transformer.ln_f.bias', @@ -130,16 +132,17 @@ def get_new_key(key): return key_mapping.get(key, key) new_dict = {get_new_key(key): value for key, value in target_dict.items()} + # NOTE: starcoder uses the same embedding matrix for token embedding and lm head new_dict["lm_head.weight"] = new_dict.get("transformer.wte.weight", new_dict.get("lm_head.weight")) return new_dict -def merge_checkpoint(checkpoint_dir: Path): +def merge_checkpoint(checkpoint_dir: Path) -> dict: """Load a checkpoint from the BRRR format and merge tensor parallel shards.""" - checkpoint_paths = get_safetensor_checkpoint_paths(checkpoint_dir) - paths = transform_paths(checkpoint_paths) - paths = group_and_sort_paths(paths) - model_states = merge_checkpoints(paths) - model_states = remap_keys(model_states) + checkpoint_paths = _get_safetensor_checkpoint_paths(checkpoint_dir) + paths = _transform_paths(checkpoint_paths) + paths = _group_and_sort_paths(paths) + model_states = _merge_checkpoints(paths) + model_states = _remap_keys(model_states) return model_states From fb8a86b666eb22a3339908d323b899fc23501a5b Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Wed, 3 Jan 2024 14:01:24 +0000 Subject: [PATCH 23/26] delete uncessary files --- .../models/gpt_bigcode/ckps/download_ckp.py | 5 -- .../gpt_bigcode/drafts/check_weird_shape.py | 9 --- .../gpt_bigcode/drafts/starcoder_model.py | 64 ------------------- src/transformers/models/gpt_bigcode/small.py | 3 - 4 files changed, 81 deletions(-) delete mode 100644 src/transformers/models/gpt_bigcode/ckps/download_ckp.py delete mode 100644 src/transformers/models/gpt_bigcode/drafts/check_weird_shape.py delete mode 100644 src/transformers/models/gpt_bigcode/drafts/starcoder_model.py delete mode 100644 src/transformers/models/gpt_bigcode/small.py diff --git a/src/transformers/models/gpt_bigcode/ckps/download_ckp.py b/src/transformers/models/gpt_bigcode/ckps/download_ckp.py deleted file mode 100644 index a9fdd5ae44..0000000000 --- a/src/transformers/models/gpt_bigcode/ckps/download_ckp.py +++ /dev/null @@ -1,5 +0,0 @@ -from huggingface_hub import snapshot_download - -if __name__ == "__main__": - snapshot_download("HuggingFaceBR4/starcoder2_7b_4k_smol_data_580000") - print("done") diff --git a/src/transformers/models/gpt_bigcode/drafts/check_weird_shape.py b/src/transformers/models/gpt_bigcode/drafts/check_weird_shape.py deleted file mode 100644 index e8c26a801a..0000000000 --- a/src/transformers/models/gpt_bigcode/drafts/check_weird_shape.py +++ /dev/null @@ -1,9 +0,0 @@ -from pathlib import Path -from safetensors import safe_open - - -if __name__ == "__main__": - checkpoint_dir = "/admin/home/phuc_nguyen/.cache/huggingface/hub/models--HuggingFaceBR4--starcoder2_7b_4k_smol_data_580000/snapshots/92b6c25cab25f07c367bcc6d773635700a8a287d" - checkpoint_dir = Path(checkpoint_dir) - - assert 1 == 1 diff --git a/src/transformers/models/gpt_bigcode/drafts/starcoder_model.py b/src/transformers/models/gpt_bigcode/drafts/starcoder_model.py deleted file mode 100644 index 84a5d4185d..0000000000 --- a/src/transformers/models/gpt_bigcode/drafts/starcoder_model.py +++ /dev/null @@ -1,64 +0,0 @@ -from transformers import GPTBigCodeForCausalLM, GPTBigCodeConfig -from transformers import AutoTokenizer - -from pathlib import Path -import json -import torch - -import random -import numpy as np - - -if __name__ == "__main__": - seed = 42 - checkpoint_dir = "/admin/home/phuc_nguyen/.cache/huggingface/hub/models--HuggingFaceBR4--starcoder2_7b_4k_smol_data_580000/snapshots/92b6c25cab25f07c367bcc6d773635700a8a287d" - checkpoint_dir = Path(checkpoint_dir) - config = json.load(open(checkpoint_dir / "model_config.json")) - - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - np.random.seed(seed) - random.seed(seed) - - model_config = GPTBigCodeConfig( - vocab_size=config["vocab_size"], - n_positions=config["max_position_embeddings"], - n_embd=config["hidden_size"], - n_layer=config["num_hidden_layers"], - n_head=config["num_attention_heads"], - num_key_value_heads=config["num_kv_heads"], - # NOTE: based on https://github.com/huggingface/brrr/blob/f569b93f80d03c626b24370d5ca4b1fe4f13fd76/brrr/models/fast/starcoder2.py#L194C16-L194C88 - n_inner=config.get("n_inner", 4 * config["hidden_size"]), - activation_function=config["activation_function"], - resid_pdrop=config["resid_pdrop"], - embd_pdrop=config["embd_pdrop"], - attn_pdrop=config["attn_pdrop"], - layer_norm_epsilon=config["layer_norm_epsilon"], - scale_attn_weights=config["scale_attn_weights"], - bos_token_id=config["bos_token_id"], - eos_token_id=config["eos_token_id"], - attention_softmax_in_fp32=config["attention_softmax_in_fp32"], - scale_attention_softmax_in_fp32=config["scale_attention_softmax_in_fp32"], - multi_query=config["multi_query"], - use_rotary_embeddings=config["use_rotary_embeddings"], - # rotary_embedding_scale=brrr_model_config.rotary_embedding_scale, #TODO - attention_window_size=config["sliding_window_size"], - ) - - model = GPTBigCodeForCausalLM._from_config(model_config, torch_dtype=torch.bfloat16) - - checkpoint_path = Path("/fsx/phuc/projects/starcoder/transformers-starcoder/src/transformers/models/gpt_bigcode/pytorch_model.pth") - checkpoint = torch.load(checkpoint_path) - model.load_state_dict(checkpoint) - model = model.to("cuda") - - checkpoint = "bigcode/starcoder2-tokenizer" - tokenizer = AutoTokenizer.from_pretrained(checkpoint) - tokenizer.eos_token_id = tokenizer.pad_token_id - - inputs = tokenizer.encode("def print_hello_world():", return_tensors="pt").to("cuda") - outputs = model.generate(inputs) - print(f"outputs: {outputs}") - - print(tokenizer.decode(outputs[0], clean_up_tokenization_spaces=False)) diff --git a/src/transformers/models/gpt_bigcode/small.py b/src/transformers/models/gpt_bigcode/small.py deleted file mode 100644 index 56fdfc6b24..0000000000 --- a/src/transformers/models/gpt_bigcode/small.py +++ /dev/null @@ -1,3 +0,0 @@ -if __name__ == "__main__": - print("works") - assert 1 == 1 From c26472c62f013456f9c7b3a574d141e1b7d44aae Mon Sep 17 00:00:00 2001 From: NouamaneTazi Date: Fri, 5 Jan 2024 17:04:15 +0000 Subject: [PATCH 24/26] add rope_theta to config matching logits without using cache --- .../gpt_bigcode/configuration_gpt_bigcode.py | 4 ++ .../gpt_bigcode/modeling_gpt_bigcode.py | 47 +++++++++++++++---- 2 files changed, 43 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py index a71c362ad7..cd1978e4d9 100644 --- a/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py @@ -71,6 +71,8 @@ class GPTBigCodeConfig(PretrainedConfig): The dropout ratio for the attention. layer_norm_epsilon (`float`, *optional*, defaults to 1e-5): The epsilon to use in the layer normalization layers. + rope_theta (`int`, *optional*, defaults to 10000): + The theta value to use in the rotary position embeddings. initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. scale_attn_weights (`bool`, *optional*, defaults to `True`): @@ -121,6 +123,7 @@ def __init__( embd_pdrop=0.1, attn_pdrop=0.1, layer_norm_epsilon=1e-5, + rope_theta=10000, initializer_range=0.02, scale_attn_weights=True, use_cache=True, @@ -152,6 +155,7 @@ def __init__( self.embd_pdrop = embd_pdrop self.attn_pdrop = attn_pdrop self.layer_norm_epsilon = layer_norm_epsilon + self.rope_theta = rope_theta self.initializer_range = initializer_range self.scale_attn_weights = scale_attn_weights self.use_cache = use_cache diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index f1db907b58..c3a9226bf1 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -39,7 +39,7 @@ logging, ) from .configuration_gpt_bigcode import GPTBigCodeConfig - +from apex.normalization import FusedLayerNorm as LayerNorm logger = logging.get_logger(__name__) @@ -258,7 +258,7 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): if self.use_rotary_embeddings: self.maybe_rotary = ( - StarcoderRotaryEmbedding(head_dim=self.head_dim) + StarcoderRotaryEmbedding(head_dim=self.head_dim, base=config.rope_theta) if config.use_rotary_embeddings else lambda q, k, t: (q, k) ) @@ -408,13 +408,21 @@ def _flash_attn(self, hidden_states, layer_past, sequence_mask, position_ids): # hidden_states (batch_size, query_length, embed_dim) # layer_past (batch_size, past_key_values_length, embed_dim) # sequence_mask (batch_size, query_length) + # print("hidden_states", hidden_states.shape) + # print(hidden_states[0]) fused_qkv = self.c_attn( hidden_states ) # [batch_size, query_length, n_local_q_heads * head_dim + 2 * n_local_kv_heads * head_dim] batch_size, q_length, _ = fused_qkv.size() + # print("fused_qkv", fused_qkv.shape) + # print(fused_qkv[0]) + qkv = fused_qkv.view(batch_size, q_length, self.n_local_kv_heads, self.n_repeats + 2, self.head_dim) + # print("qkv", qkv.shape) + # print(qkv[0]) + query, key, value = torch.split(qkv, [self.n_repeats, 1, 1], dim=3) query_states = query.reshape( batch_size, q_length, self.n_local_q_heads, self.head_dim @@ -422,6 +430,10 @@ def _flash_attn(self, hidden_states, layer_past, sequence_mask, position_ids): key_states = key.reshape(batch_size, q_length, self.n_local_kv_heads, self.head_dim) value_states = value.reshape(batch_size, q_length, self.n_local_kv_heads, self.head_dim) + # print("query_states", query_states.shape) + # print(query_states[0]) + # print("key_states", key_states.shape) + # print(key_states[0]) # Compute rotary embeddings if layer_past is None: past_key_values_length = 0 @@ -430,6 +442,12 @@ def _flash_attn(self, hidden_states, layer_past, sequence_mask, position_ids): query_states, key_states = self.maybe_rotary( query_states, key_states, past_key_values_length=past_key_values_length ) + # print("after rotary") + # print("query_states", query_states.shape) + # print(query_states[0]) + # print("key_states", key_states.shape) + # print(key_states[0]) + # assert False if layer_past is None: # First inference iteration (Prefill) @@ -490,6 +508,10 @@ def _flash_attn(self, hidden_states, layer_past, sequence_mask, position_ids): pad_to_right(key_states, sequence_mask, new_tensor=k_cache) pad_to_right(value_states, sequence_mask, new_tensor=v_cache) + # print("attention_output", attention_output.shape) + # print(attention_output[0]) + # # assert False + else: # Pull pre-computed key/value states # Subsequent inference iterations (q_length=1) @@ -509,6 +531,11 @@ def _flash_attn(self, hidden_states, layer_past, sequence_mask, position_ids): batch_size, kv_length, self.n_local_kv_heads, self.head_dim ) # [batch_size, kv_length, self.n_local_kv_heads, self.head_dim] + # print("query_states", query_states.shape) + # print(query_states[0]) + # print("key_states", key_states.shape) + # print(key_states[0]) + # assert False position_offsets = position_ids[:, -1] # assert False attention_output = flash_attn_with_kvcache( @@ -566,7 +593,7 @@ def forward( (self.head_dim, self.head_dim), dim=-1 ) # (batch_size, query_length, 1 * head_dim) else: # GQA - if use_cache and USE_FLASH_ATTN: + if USE_FLASH_ATTN: key, value, attn_output = self._flash_attn(hidden_states, layer_past, attention_mask, position_ids) attn_output = attn_output.view(hidden_states.shape) @@ -659,15 +686,15 @@ def __init__(self, config, layer_idx=None): hidden_size = config.hidden_size self.inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size - self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.ln_1 = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.attn = GPTBigCodeAttention(config, layer_idx=layer_idx) - self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.ln_2 = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) if config.add_cross_attention: if config.multi_query: raise NotImplementedError("Cross-attention not implemented for MQA") self.crossattention = GPTBigCodeAttention(config, is_cross_attention=True, layer_idx=layer_idx) - self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.ln_cross_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.mlp = GPTBigCodeMLP(self.inner_dim, config) @@ -688,7 +715,11 @@ def forward( Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor] ]: residual = hidden_states + # print("hidden_states", hidden_states.shape) + # print(hidden_states[0]) hidden_states = self.ln_1(hidden_states) + # print("after ln 1", hidden_states.shape) + # print(hidden_states[0]) attn_outputs = self.attn( hidden_states, layer_past=layer_past, @@ -782,7 +813,7 @@ def _init_weights(self, module): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): + elif isinstance(module, LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) @@ -904,7 +935,7 @@ def __init__(self, config): self.drop = nn.Dropout(config.embd_pdrop) self.h = nn.ModuleList([GPTBigCodeBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)]) - self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) max_positions = config.max_position_embeddings self.register_buffer( From 9c9cfbb4f823cbc47528c41181684b3d13192d02 Mon Sep 17 00:00:00 2001 From: NouamaneTazi Date: Mon, 8 Jan 2024 10:35:26 +0000 Subject: [PATCH 25/26] fix config.attn_pdrop for flash attn --- src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index c3a9226bf1..1d0969ec0c 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -266,6 +266,7 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): self.c_proj = nn.Linear(self.embed_dim, self.embed_dim) self.attn_dropout = nn.Dropout(config.attn_pdrop) + self.attn_pdrop = config.attn_pdrop self.resid_dropout = nn.Dropout(config.resid_pdrop) self.prefill_kv_len = ( @@ -496,7 +497,7 @@ def _flash_attn(self, hidden_states, layer_past, sequence_mask, position_ids): cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k, - dropout_p=0.0, + dropout_p=self.attn_pdrop, softmax_scale=None, causal=True, # True in prefill phase, False in subsequent phases return_attn_probs=False, From 15077983d17f926bbf61cff744bb4a3221b566f7 Mon Sep 17 00:00:00 2001 From: NouamaneTazi Date: Mon, 8 Jan 2024 12:44:53 +0000 Subject: [PATCH 26/26] Refactor GPTBigCode model conversion code --- src/transformers/generation/utils.py | 4 +- .../convert_fast_llm_checkpoint.py | 141 +++++++--- .../gpt_bigcode/merge_fast_llm_checkpoint.py | 262 +++++++++--------- .../gpt_bigcode/modeling_gpt_bigcode.py | 76 ++--- 4 files changed, 253 insertions(+), 230 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 9cfbe9e7dc..3b1bef6f04 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2469,9 +2469,7 @@ def greedy_search( # Store scores, attentions and hidden_states when required if return_dict_in_generate: if output_scores: - scores += (next_tokens_scores,) if outputs.logits.shape[1] == 1 else ( - outputs.logits, - ) + scores += (next_tokens_scores,) if output_attentions: decoder_attentions += ( (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) diff --git a/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py b/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py index ecae175829..8859b7650e 100644 --- a/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py +++ b/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py @@ -1,9 +1,106 @@ import argparse import os from pathlib import Path +import re import torch from transformers.models.gpt_bigcode.merge_fast_llm_checkpoint import merge_checkpoint +from transformers.models.gpt_bigcode import GPTBigCodeConfig, GPTBigCodeForCausalLM, GPTBigCodeModel + + +# The simple map of names for "automated" rules. +NAME_MAP = { + "_mlp._layer_1": "mlp.c_fc", + "_mlp._layer_2": "mlp.c_proj", + "layer_norm_1": "ln_1", + "layer_norm_2": "ln_2", + # "attention.dense": "attn.c_proj", + "self_attn.dense": "attn.c_proj", + # "self_attention.query_key_value": "attn.c_attn", +} + + +def convert_fast_llm_checkpoint(state_dict, config): + # The converted output model. + output_state_dict = {} + if "window_size" in config: + attention_window_size = config["window_size"] + else: + attention_window_size = config.get("attention_window_size", None) + + config = GPTBigCodeConfig( + architectures=["GPTBigCodeLMHeadModel"], + vocab_size=config["vocab_size"], + n_positions=config["max_position_embeddings"], + n_embd=config["hidden_size"], + n_layer=config["num_layers"], + n_head=config["num_attention_heads"], + n_inner=config["ffn_hidden_size"], + activation_function="gelu", # TODO + multi_query=True, # TODO + resid_pdrop=0.1, + embd_pdrop=0.1, + attn_pdrop=0.1, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + summary_type="cls_index", + summary_use_proj=True, + summary_activation=None, + summary_proj_to_labels=True, + summary_first_dropout=0.1, + scale_attn_weights=True, + use_cache=True, + bos_token_id=0, # TODO: can we remove these? + eos_token_id=0, + attention_softmax_in_fp32=True, + scale_attention_softmax_in_fp32=True, + use_rotary_embeddings=config["use_rotary_embeddings"], + rotary_embedding_scale=config["rotary_embedding_scale"], + use_position_embeddings=config["use_position_embeddings"], + attention_window_size=attention_window_size + ) + + # Truncate the word embeddings to the vocab-size + word_embeddings = state_dict.pop("_layers.0._word_embeddings_weight")[:config.vocab_size, :] + output_state_dict["transformer.wte.weight"] = word_embeddings + if config.use_position_embeddings: + output_state_dict["transformer.wpe.weight"] = state_dict.pop("_layers.0._position_embeddings_weight") + + # Layer-0 is the word/position embeddings + # Layers 1 to n_layer need to be re-mapped from 0 to n_layer-1. + # _layers.{layer_index}.{op}.{w/b} + + # Concatenate QKV matrix + for layer_index in range(1, config.n_layer + 1): + for weight_or_bias in ["weight", "bias"]: + query = state_dict.pop(f"_layers.{layer_index}.self_attn.query.{weight_or_bias}") + key_value = state_dict.pop(f"_layers.{layer_index}.self_attn.key_value.{weight_or_bias}") + output_state_dict[f"transformer.h.{layer_index - 1}.attn.c_attn.{weight_or_bias}"] = torch.cat([query, key_value], dim=0) + + # Extract the other ops + layer_re = re.compile("_layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)") + for name, value in state_dict.items(): + m = layer_re.match(name) + assert m is not None, f"Invalid layer name: {name}" + + # The index of the layer. + layer_index = int(m.group(1)) + # The name of the operation. + op_name = m.group(2) + # Is it a weight or a bias? + weight_or_bias = m.group(3) + + # Final layernorm + if op_name == "final_layernorm": + assert layer_index == config.n_layer + 1 + output_state_dict[f"transformer.ln_f.{weight_or_bias}"] = value + else: + output_state_dict[f"transformer.h.{layer_index-1}.{NAME_MAP[op_name]}.{weight_or_bias}"] = value + + # For LM head, transformers' wants the matrix to weight embeddings. + output_state_dict["lm_head.weight"] = word_embeddings + + return output_state_dict, config def main(argv=None): @@ -11,51 +108,31 @@ def main(argv=None): parser.add_argument( "--checkpoint_dir", type=Path, - # default="/admin/home/phuc_nguyen/.cache/huggingface/hub/models--HuggingFaceBR4--starcoder2_7b_4k_smol_data_580000/snapshots/92b6c25cab25f07c367bcc6d773635700a8a287d", - help="Path where the converted model is saved" + help="Path to the experiment directory", ) parser.add_argument( "--save_dir", type=Path, - # default="./", help="Path where the converted model is saved" ) args = parser.parse_args(argv) - - print("start") - - # TODO(xrsrke): auto convert checkpoint_dir to Path - # checkpoint_dir = "/admin/home/phuc_nguyen/.cache/huggingface/hub/models--HuggingFaceBR4--starcoder2_7b_4k_smol_data_580000/snapshots/92b6c25cab25f07c367bcc6d773635700a8a287d" - # checkpoint_dir = Path(checkpoint_dir) - state_dict = merge_checkpoint(args.checkpoint_dir) + state_dict, config = merge_checkpoint( + args.checkpoint_dir, + dummy_experiment_dir=None + ) + + output_state_dict, output_config = convert_fast_llm_checkpoint(state_dict, config) + print("Saving config") save_dir = args.save_dir or args.checkpoint_dir / "converted" + output_config.save_pretrained(save_dir) + + # Store the state_dict to file. output_checkpoint_file = os.path.join(save_dir, "pytorch_model.bin") - print(f'Saving checkpoint to "{output_checkpoint_file}"') - torch.save(state_dict, output_checkpoint_file) + torch.save(output_state_dict, output_checkpoint_file) print(f'Done!') - - # # Compare - # def compare_state_dicts(dict1, dict2): - # # Compare keys - # if set(dict1.keys()) != set(dict2.keys()): - # return "Different keys" - - # # Compare shapes and values - # for key in dict1: - # if dict1[key].shape != dict2[key].shape: - # return f"Different shape for key: {key}" - # if not torch.allclose(dict1[key], dict2[key]): - # return f"Different values for key: {key}" - - # return "State dictionaries are identical" - - # ref_state_dict = torch.load("/fsx/phuc/projects/starcoder/transformers-starcoder/src/transformers/models/gpt_bigcode/merged_checkpoint.pth") - # result = compare_state_dicts(state_dict, ref_state_dict) - # print(result) - if __name__ == "__main__": diff --git a/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py b/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py index f426166eca..71731559c7 100644 --- a/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py +++ b/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py @@ -1,148 +1,134 @@ import re +from tqdm import tqdm from pathlib import Path +import numpy as np import torch -from safetensors import safe_open -from collections import defaultdict - -import re -from os.path import commonprefix - - -BRRR_TFMS_NAME_MAPPING = { - 'ln_1.model_weight': 'ln_1.weight', - 'ln_1.model_bias': 'ln_1.bias', - 'ln_2.model_weight': 'ln_2.weight', - 'ln_2.model_bias': 'ln_2.bias', - 'attn.query_key_value.weight': 'attn.c_attn.weight', - 'attn.query_key_value.bias': 'attn.c_attn.bias', - 'attn.dense.weight': 'attn.c_proj.weight', - 'attn.dense.model_bias': 'attn.c_proj.bias', - 'ff.c_fc.weight': 'mlp.c_fc.weight', - 'ff.c_fc.bias': 'mlp.c_fc.bias', - 'ff.c_proj.weight': 'mlp.c_proj.weight', - 'ff.c_proj.model_bias': 'mlp.c_proj.bias' -} - -MERGE_DIM_MAPPING = { - "token_embedding": 0, # row linear parallel - # NOTE: MLP layer - "c_fc": 0, # column linear parallel - "ff.c_fc.bias": 0, - "c_proj": 1, # row linear parallel - # NOTE: attention layer - "query_key_value": 0, # row linear parallel - "dense": 1, # row linear parallel -} - - -def _get_safetensor_checkpoint_paths(checkpoint_dir: Path): - model_dir = checkpoint_dir / "model" / "model" - safetensor_files = [] - - for file_path in model_dir.rglob("*.safetensors"): - if file_path.is_file(): - safetensor_files.append(file_path.absolute()) - - return safetensor_files - - -def _transform_paths(paths): - path_objs = [Path(p) for p in paths] - common_path_prefix = Path(commonprefix(path_objs)).parent - - final_paths = {} - for path in path_objs: - relative_path = str(path.relative_to(common_path_prefix)) - dot_path = relative_path.replace('/', '.') - - weight_replaced = re.sub(r'model_weight_pp-rank-0-of-1_tp-rank-(\d)-of-4', r'weight.\1', dot_path) - bias_replaced = re.sub(r'model_bias_pp-rank-0-of-1_tp-rank-(\d)-of-4', r'bias.\1', weight_replaced) - cleaned_path = bias_replaced.replace('.safetensors', '') - - final_paths[cleaned_path] = str(path) - - return final_paths - -def _group_and_sort_paths(paths): - grouped_paths = defaultdict(list) - - for key, path in paths.items(): - try: - module_name, shard_number = key.rsplit('.', 1) - grouped_paths[module_name].append((int(shard_number), path)) - except ValueError: - # NOTE: these are layer norm's weight and biases - # so it don't have shard number - grouped_paths[key].append(path) - - # Remove any entries with empty lists - grouped_paths = {k: v for k, v in grouped_paths.items() if v} - - # NOTE: Sort paths in each group - # module: [(4, path), (0, path), (3, path) ...] -> module: [(0, path), (1, path), (2, path) ...] - sorted_grouped_paths = {module: sorted(paths, key=lambda x: x[0]) - for module, paths in grouped_paths.items()} - - return sorted_grouped_paths - -def _merge_checkpoints(paths): - def find_corresponding_dim(name): - for key, value in MERGE_DIM_MAPPING.items(): - if key in name: - return value - return None - - model_states = {} - for state_key, path in paths.items(): - model_states[state_key] = {} - for shard_id, _path in enumerate(path): - checkpoint_path = _path[1] if isinstance(_path, tuple) else _path - with safe_open(checkpoint_path, framework="pt", device="cpu") as f: - for key in f.keys(): - data = f.get_tensor(key) - model_states[state_key][shard_id] = data - - tensor_list = [tensor for _, tensor in sorted(model_states[state_key].items())] - merge_dim = find_corresponding_dim(state_key) - - if len(tensor_list) > 1: - model_states[state_key] = torch.cat(tensor_list, dim=merge_dim) +import yaml + + +def get_all_checkpoint_paths(experiment_path): + checkpoints = (Path(experiment_path) / "checkpoints").glob("*") + # Sort checkpoints by iteration number + checkpoints = sorted(checkpoints, key=lambda x: int(x.name)) + return [get_checkpoint_paths(checkpoint) for checkpoint in checkpoints] + + +def get_checkpoint_paths(checkpoint_dir: Path): + return [c_name for c_name in checkpoint_dir.glob("*") if re.match(r"\d+", c_name.name)] + + +def extract_stage_shards(state): + # Extract the weight shard and split it into the stage shards + # Reproduce the split done in MultiStageModelBase.setup + total_shard_size = sum(state['stage_shard_sizes']) + if len(state['shard'].shape) == 1: + # Flat buffer + weight_shard = state['shard'][:total_shard_size] + elif len(state['shard'].shape) == 2: + # 2D buffer + weight_shard = state['shard'][0] + else: + raise ValueError(f"Unrecognized buffer shape {state['shard'].shape}") + return weight_shard.split(state['stage_shard_sizes']) + + +def extract_individual_weights(merged_stage_shard, stage_content): + # Get individual weights from shards that are merged across data-parallel + weights_numel = [np.prod(weight_meta['shape']) for weight_meta in stage_content] + weights = merged_stage_shard[:sum(weights_numel)].split(weights_numel) + return [weight.reshape(weight_meta['shape']) for weight, weight_meta in zip(weights, stage_content)] + + +def concatenate_tp_shards(stage_tp_shards, stage_content): + # Concatenate the tp-shards in a given stage + # Stage_tp_shards: contains the individual weight shards for each rank + # [[weight1, weight2, ...] for rank in range(tp_size)] + concatenated_weights = [] + # Concatenate each individual weight along their TP dimension if they have one. + for weight_tp_shards, weight_meta in zip(zip(*stage_tp_shards), stage_content): + if weight_meta["tensor_parallel_dim"] is not None: + weight = torch.cat(weight_tp_shards, dim=weight_meta["tensor_parallel_dim"]) else: - # NOTE: these are biases - model_states[state_key] = tensor_list[0] - return model_states - -def _remap_keys(target_dict): - key_mapping = { - 'model.final_layer_norm.pp_block.model_weight': 'transformer.ln_f.weight', - 'model.final_layer_norm.pp_block.model_bias': 'transformer.ln_f.bias', - 'model.token_embeddings.pp_block.token_embedding.weight': 'transformer.wte.weight' + weight = weight_tp_shards[0] + concatenated_weights.append(weight) + return concatenated_weights + + +def merge_checkpoint(checkpoint_dir: Path, dummy_experiment_dir=None): + """Load a fast-llm checkpoint and merge the data, tensor, and pipeline-parallel shards""" + # checkpoint_dir=experiment_dir/checkpoints/{iteration} + experiment_dir = checkpoint_dir.parent.parent + checkpoint_paths = get_checkpoint_paths(checkpoint_dir) + config = yaml.safe_load((experiment_dir / "config.yaml").read_text()) + + # Load the states from all the ranks + states = { + int(c_name.name): torch.load(c_name) + for c_name in tqdm(checkpoint_paths) + } + num_stages = len(states[0]["stages"]) + tensor_parallel = config["tensor_parallel"] + data_parallel_size = int(config["world_size"] / (tensor_parallel * config["pipeline_parallel"])) + + if dummy_experiment_dir is not None: + # Use the meta from the dummy checkpoint, and the shard from the actual checkpoint + dummy_checkpoint_paths = get_all_checkpoint_paths(dummy_experiment_dir) + dummy_states = { + int(c_name.name): torch.load(c_name) + for c_name in tqdm(dummy_checkpoint_paths[-1]) + } + for rank, state in dummy_states.items(): + state['shard'] = states[rank]['shard'] + states = dummy_states + + # Gather the data-parallel shards + # {tp_rank: [[stage_0_shard_0, stage_0_shard_1, ...], [stage_1_shard_0, ...], ...]} + # {tp_rank: [{fsdp_rank: shard}, ...]} + fsdp_shards = { + i: [[None for _ in range(data_parallel_size)] for _ in range(num_stages)] + for i in range(tensor_parallel) + } + + for rank, state in states.items(): + on_device_stage_shards = extract_stage_shards(state) + on_device_stage_indices = [i for (i, stage_meta) in enumerate(state["stages"]) if stage_meta["on_device"]] + for stage_index, stage_shard in zip(on_device_stage_indices, on_device_stage_shards): + stage_meta = state["stages"][stage_index] + # fsdp_shards[stage_meta["tp_rank"]][stage_index].append((stage_meta, stage_shard)) + fsdp_shards[stage_meta["tp_rank"]][stage_index][stage_meta["fsdp_rank"]] = stage_shard + + # Concatenate the data-parallel shards + # and get individual weights + dp_concatenated_shards = { + tp_rank: [ + extract_individual_weights( + torch.cat(stage_shards, dim=0), + states[0]["stages"][stage_index]['content'] + ) + for stage_index, stage_shards in enumerate(fsdp_shards[tp_rank]) + ] + for tp_rank in range(config["tensor_parallel"]) } - def get_new_key(key): - if 'model.decoder' in key and 'pp_block' in key: - parts = key.split('.') - block_number = parts[2] - component_parts = parts[4:] - component = '.'.join(component_parts) - new_component = BRRR_TFMS_NAME_MAPPING.get(component, component) - return f"transformer.h.{block_number}.{new_component}" - else: - return key_mapping.get(key, key) + # In the tensor-parallel case, concatenate the TP tensors along their TP dimensions. + tp_concatenated_shards = [] + for stage_index, stage_tp_shards in enumerate(zip(*(dp_concatenated_shards[i] for i in range(tensor_parallel)))): + stage_content = states[0]["stages"][stage_index]["content"] + tp_concatenated_shards.append(concatenate_tp_shards(stage_tp_shards, stage_content)) + + # In the pipeline-parallel case, merge the stages + state_dict = { + weight_meta["name"]: weight + for stage_meta, stage_weights in zip(states[0]["stages"], tp_concatenated_shards) + for weight_meta, weight in zip(stage_meta["content"], stage_weights) + } + + print(f"Total number of parameters: {sum([weight.numel() for weight in state_dict.values()])}") + return state_dict, config - new_dict = {get_new_key(key): value for key, value in target_dict.items()} - # NOTE: starcoder uses the same embedding matrix for token embedding and lm head - new_dict["lm_head.weight"] = new_dict.get("transformer.wte.weight", new_dict.get("lm_head.weight")) - - return new_dict -def merge_checkpoint(checkpoint_dir: Path) -> dict: - """Load a checkpoint from the BRRR format and merge tensor parallel shards.""" - checkpoint_paths = _get_safetensor_checkpoint_paths(checkpoint_dir) - paths = _transform_paths(checkpoint_paths) - paths = _group_and_sort_paths(paths) - model_states = _merge_checkpoints(paths) - model_states = _remap_keys(model_states) +if __name__ == "__main__": + merge_checkpoint("/toolkit_infiniband_example_checkpoints/ngc_checkpoints/sc2_ablations/1B_repo_context_Top-level-Depth-first_pp2_64k_64k_2023_10_17_16_35_27/", + dummy_experiment_dir="/toolkit_infiniband_example_checkpoints/ngc_checkpoints/sc2_ablations/dev_1B_repo_context_Random_pp2_64k_64k_2023_10_18_22_20_36/") - return model_states diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 1d0969ec0c..6103d31b3c 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -14,11 +14,16 @@ """PyTorch GPTBigCode model.""" import math from typing import List, Optional, Tuple, Union -from flash_attn import bert_padding -from flash_attn.flash_attn_interface import ( - flash_attn_varlen_func, - flash_attn_with_kvcache, -) + +try: + from flash_attn import bert_padding + from flash_attn.flash_attn_interface import ( + flash_attn_varlen_func, + flash_attn_with_kvcache, + ) + FLASHATTN_IS_AVAILABLE = True +except ImportError: + FLASHATTN_IS_AVAILABLE = False import torch import torch.utils.checkpoint from torch import nn @@ -39,7 +44,12 @@ logging, ) from .configuration_gpt_bigcode import GPTBigCodeConfig -from apex.normalization import FusedLayerNorm as LayerNorm +import warnings +try: + from apex.normalization import FusedLayerNorm as LayerNorm +except ImportError: + from torch.nn import LayerNorm + warnings.warn("Install NVIDIA Apex for apex.normalization.FusedLayerNorm to be available") logger = logging.get_logger(__name__) @@ -314,8 +324,6 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): # No copy needed for MQA 2, or when layer_past is provided. query = query.reshape(batch_size, query_length * self.num_heads, self.head_dim) # value shape (batch_size, key_length, 1, head_dim) - # print("value", value.shape) - # print(value[0, :, 0, 0]) elif self.kv_heads < self.num_heads: # GQA # (batch_size, query_length, num_heads, head_dim) x (batch_size, kv_heads, head_dim, key_length) # -> (batch_size, num_heads, query_length, key_length) @@ -338,8 +346,6 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): value = (value.view(batch_size, self.kv_heads, 1, key_length, self.head_dim) .expand(batch_size, self.kv_heads, n_repeats, key_length, self.head_dim) .reshape(batch_size, self.num_heads, key_length, self.head_dim)) - # print("value", value.shape) - # print(value[0, :, 0]) else: # (batch_size, num_heads, query_length, head_dim) x (batch_size, num_heads, head_dim, key_length) # -> (batch_size, num_heads, query_length, key_length) @@ -351,11 +357,6 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): # No copy when layer_past is provided. key = key.reshape(batch_size * self.num_heads, self.head_dim, key_length) - # print("query", query.shape) - # print(query[0, :query_length, 0]) - # print("key", key.shape) - # print(key[0, 0, :]) - attn_weights = torch.empty(attn_view, device=query.device, dtype=query.dtype) if query.device.type == "cpu": # This is needed because of a bug in pytorch https://github.com/pytorch/pytorch/issues/80588. @@ -368,9 +369,6 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): attn_weights = torch.baddbmm(attn_weights, query, key, beta=beta, alpha=scale_factor).view(attn_shape) # if not self.multi_query: # attn_weights = attn_weights.transpose(1, 2) # (batch_size, query_length, num_heads, key_length) - # print("attn_weights", attn_weights.shape) - # print(attn_weights[0, :, 0, :]) - # assert False, "done" if upcast: # Use a fused kernel to prevent a large overhead from casting and scaling. @@ -409,20 +407,13 @@ def _flash_attn(self, hidden_states, layer_past, sequence_mask, position_ids): # hidden_states (batch_size, query_length, embed_dim) # layer_past (batch_size, past_key_values_length, embed_dim) # sequence_mask (batch_size, query_length) - # print("hidden_states", hidden_states.shape) - # print(hidden_states[0]) fused_qkv = self.c_attn( hidden_states ) # [batch_size, query_length, n_local_q_heads * head_dim + 2 * n_local_kv_heads * head_dim] batch_size, q_length, _ = fused_qkv.size() - # print("fused_qkv", fused_qkv.shape) - # print(fused_qkv[0]) - qkv = fused_qkv.view(batch_size, q_length, self.n_local_kv_heads, self.n_repeats + 2, self.head_dim) - # print("qkv", qkv.shape) - # print(qkv[0]) query, key, value = torch.split(qkv, [self.n_repeats, 1, 1], dim=3) query_states = query.reshape( @@ -431,10 +422,6 @@ def _flash_attn(self, hidden_states, layer_past, sequence_mask, position_ids): key_states = key.reshape(batch_size, q_length, self.n_local_kv_heads, self.head_dim) value_states = value.reshape(batch_size, q_length, self.n_local_kv_heads, self.head_dim) - # print("query_states", query_states.shape) - # print(query_states[0]) - # print("key_states", key_states.shape) - # print(key_states[0]) # Compute rotary embeddings if layer_past is None: past_key_values_length = 0 @@ -443,12 +430,6 @@ def _flash_attn(self, hidden_states, layer_past, sequence_mask, position_ids): query_states, key_states = self.maybe_rotary( query_states, key_states, past_key_values_length=past_key_values_length ) - # print("after rotary") - # print("query_states", query_states.shape) - # print(query_states[0]) - # print("key_states", key_states.shape) - # print(key_states[0]) - # assert False if layer_past is None: # First inference iteration (Prefill) @@ -509,9 +490,6 @@ def _flash_attn(self, hidden_states, layer_past, sequence_mask, position_ids): pad_to_right(key_states, sequence_mask, new_tensor=k_cache) pad_to_right(value_states, sequence_mask, new_tensor=v_cache) - # print("attention_output", attention_output.shape) - # print(attention_output[0]) - # # assert False else: # Pull pre-computed key/value states @@ -532,13 +510,7 @@ def _flash_attn(self, hidden_states, layer_past, sequence_mask, position_ids): batch_size, kv_length, self.n_local_kv_heads, self.head_dim ) # [batch_size, kv_length, self.n_local_kv_heads, self.head_dim] - # print("query_states", query_states.shape) - # print(query_states[0]) - # print("key_states", key_states.shape) - # print(key_states[0]) - # assert False position_offsets = position_ids[:, -1] - # assert False attention_output = flash_attn_with_kvcache( query_states, k_cache, @@ -573,9 +545,6 @@ def forward( Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]], ]: - USE_FLASH_ATTN = True - # self.multi_query = False - # self.multi_query = True if encoder_hidden_states is not None: if not hasattr(self, "q_attn") or not self.is_cross_attention: raise ValueError( @@ -594,7 +563,7 @@ def forward( (self.head_dim, self.head_dim), dim=-1 ) # (batch_size, query_length, 1 * head_dim) else: # GQA - if USE_FLASH_ATTN: + if FLASHATTN_IS_AVAILABLE: key, value, attn_output = self._flash_attn(hidden_states, layer_past, attention_mask, position_ids) attn_output = attn_output.view(hidden_states.shape) @@ -619,15 +588,12 @@ def forward( past_kv_length = 0 - if self.use_rotary_embeddings and not USE_FLASH_ATTN: + if self.use_rotary_embeddings and not FLASHATTN_IS_AVAILABLE: # query = _apply_rotary_embeddings(query, rotary_embedding_frequencies_q) # key = _apply_rotary_embeddings(key, rotary_embedding_frequencies_k) query = query.reshape(*query.shape[:2], self.num_heads, self.head_dim) key = key.view(*key.shape[:2], self.kv_heads, self.head_dim) - # print(past_kv_length) - # print("before",key[0,:,0,0]) query, key = self.maybe_rotary(query, key, past_kv_length) - # print("after",key[0,:,0,0]) query = query.view(*query.shape[:2], self.num_heads * self.head_dim) key = key.view(*key.shape[:2], self.kv_heads * self.head_dim) value = value.reshape(*value.shape[:2], self.kv_heads*self.head_dim) @@ -645,7 +611,7 @@ def forward( ) # (batch_size, key_length, kv_heads * head_dim) present = key_value if use_cache else None - if not USE_FLASH_ATTN: + if not FLASHATTN_IS_AVAILABLE: attn_output, attn_weights = self._attn(query, key.transpose(-1, -2), value, attention_mask, head_mask) if not self.multi_query: @@ -716,11 +682,7 @@ def forward( Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor] ]: residual = hidden_states - # print("hidden_states", hidden_states.shape) - # print(hidden_states[0]) hidden_states = self.ln_1(hidden_states) - # print("after ln 1", hidden_states.shape) - # print(hidden_states[0]) attn_outputs = self.attn( hidden_states, layer_past=layer_past,