From 1016ba6a0c9631c450d95a792bfb6122360eafa5 Mon Sep 17 00:00:00 2001 From: Raymond Li Date: Fri, 20 Oct 2023 03:55:49 +0000 Subject: [PATCH 01/23] add conversion script for fast-llm checkpoints, add rotary to gpt_bigcode --- .../gpt_bigcode/configuration_gpt_bigcode.py | 8 ++ .../convert_fast_llm_checkpoint.py | 127 ++++++++++++++++++ .../gpt_bigcode/merge_fast_llm_checkpoint.py | 114 ++++++++++++++++ .../gpt_bigcode/modeling_gpt_bigcode.py | 52 ++++++- 4 files changed, 297 insertions(+), 4 deletions(-) create mode 100644 src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py create mode 100644 src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py diff --git a/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py index 9cbaf3e184..8ffb15861e 100644 --- a/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py @@ -14,6 +14,7 @@ # limitations under the License. """ GPTBigCode configuration""" +import math from ...configuration_utils import PretrainedConfig from ...utils import logging @@ -119,6 +120,10 @@ def __init__( attention_softmax_in_fp32=True, scale_attention_softmax_in_fp32=True, multi_query=True, + use_rotary_embeddings=False, + rotary_embedding_scale=-math.log(10000), # - 9.210 + use_position_embeddings=None, + # TODO: add window **kwargs, ): self.vocab_size = vocab_size @@ -138,6 +143,9 @@ def __init__( self.attention_softmax_in_fp32 = attention_softmax_in_fp32 self.scale_attention_softmax_in_fp32 = scale_attention_softmax_in_fp32 self.multi_query = multi_query + self.use_rotary_embeddings = use_rotary_embeddings + self.rotary_embedding_scale = rotary_embedding_scale + self.use_position_embeddings = use_position_embeddings if use_position_embeddings is not None else not use_rotary_embeddings self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id diff --git a/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py b/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py new file mode 100644 index 0000000000..242c42a08c --- /dev/null +++ b/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py @@ -0,0 +1,127 @@ +import argparse +import os +import re + +import torch +from merge_fast_llm_checkpoint import merge_checkpoint +from transformers.models.gpt_bigcode import GPTBigCodeConfig, GPTBigCodeForCausalLM, GPTBigCodeModel + + +# The simple map of names for "automated" rules. +NAME_MAP = { + "_mlp._layer_1": "mlp.c_fc", + "_mlp._layer_2": "mlp.c_proj", + "layer_norm_1": "ln_1", + "layer_norm_2": "ln_2", + # "attention.dense": "attn.c_proj", + "self_attn.dense": "attn.c_proj", + # "self_attention.query_key_value": "attn.c_attn", +} + + +def convert_fast_llm_checkpoint(state_dict, config): + # The converted output model. + output_state_dict = {} + + config = GPTBigCodeConfig( + architectures=["GPTBigCodeLMHeadModel"], + vocab_size=config["vocab_size"], + n_positions=config["max_position_embeddings"], + n_embd=config["hidden_size"], + n_layer=config["num_layers"], + n_head=config["num_attention_heads"], + n_inner=config["ffn_hidden_size"], + activation_function="gelu", # TODO + multi_query=True, # TODO + resid_pdrop=0.1, + embd_pdrop=0.1, + attn_pdrop=0.1, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + summary_type="cls_index", + summary_use_proj=True, + summary_activation=None, + summary_proj_to_labels=True, + summary_first_dropout=0.1, + scale_attn_weights=True, + use_cache=True, + bos_token_id=50256, # TODO: can we remove these? + eos_token_id=50256, + attention_softmax_in_fp32=True, + scale_attention_softmax_in_fp32=True, + use_rotary_embeddings=config["use_rotary_embeddings"], + rotary_embedding_scale=config["rotary_embedding_scale"], + use_position_embeddings=config["use_position_embeddings"], + ) + + # Truncate the word embeddings to the vocab-size + word_embeddings = state_dict.pop("_layers.0._word_embeddings_weight")[:config.vocab_size, :] + output_state_dict["transformer.wte.weight"] = word_embeddings + # TODO: positional embeddings + # Layers + # + # _layers.{layer_index}.{op}.{w/b} + + # Concatenate QKV matrix + for layer_index in range(1, config.n_layer + 1): + for weight_or_bias in ["weight", "bias"]: + query = state_dict.pop(f"_layers.{layer_index}.self_attn.query.{weight_or_bias}") + key_value = state_dict.pop(f"_layers.{layer_index}.self_attn.key_value.{weight_or_bias}") + output_state_dict[f"transformer.h.{layer_index - 1}.attn.c_attn.{weight_or_bias}"] = torch.cat([query, key_value], dim=0) + + # Extract the other ops + layer_re = re.compile("_layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)") + for name, value in state_dict.items(): + m = layer_re.match(name) + + # The index of the layer. + layer_index = int(m.group(1)) + # The name of the operation. + op_name = m.group(2) + # Is it a weight or a bias? + weight_or_bias = m.group(3) + + # Final layernorm + if op_name == "final_layernorm": + assert layer_index == config.n_layer + 1 + output_state_dict[f"transformer.ln_f.{weight_or_bias}"] = value + else: + output_state_dict[f"transformer.h.{layer_index-1}.{NAME_MAP[op_name]}.{weight_or_bias}"] = value + + # For LM head, transformers' wants the matrix to weight embeddings. + output_state_dict["lm_head.weight"] = word_embeddings + + return output_state_dict, config + + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + "--experiment_dir", + type=str, + help="Path to the experiment directory", + default="/toolkit_infiniband_example_checkpoints/ngc_checkpoints/sc2_ablations/1B_repo_context_Top-level-Depth-first_pp2_64k_64k_2023_10_17_16_35_27/" # TODO + ) + parser.add_argument( + "--save_dir", help="Path where the converted model is saved. Will use the checkpoint directory if not provided", + default="/toolkit_infiniband_example_checkpoints/ngc_checkpoints/sc2_ablations/1B_repo_context_Top-level-Depth-first_pp2_64k_64k_2023_10_17_16_35_27/converted" # TODO + ) + args = parser.parse_args() + + state_dict, config = merge_checkpoint( + args.experiment_dir, + dummy_experiment_dir="/toolkit_infiniband_example_checkpoints/ngc_checkpoints/sc2_ablations/dev_1B_repo_context_Random_pp2_64k_64k_2023_10_18_22_20_36/" + ) + + output_state_dict, output_config = convert_fast_llm_checkpoint(state_dict, config) + + print("Saving config") + save_dir = args.save_dir or os.path.dirname(args.experiment_dir) + output_config.save_pretrained(save_dir) + + # Store the state_dict to file. + output_checkpoint_file = os.path.join(save_dir, "pytorch_model.bin") + print(f'Saving checkpoint to "{output_checkpoint_file}"') + torch.save(output_state_dict, output_checkpoint_file) + print(f'Done!') diff --git a/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py b/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py new file mode 100644 index 0000000000..dc67b24657 --- /dev/null +++ b/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py @@ -0,0 +1,114 @@ +import re +from tqdm import tqdm +from pathlib import Path + +import numpy as np +import torch +import yaml + + +def get_checkpoint_paths(experiment_path): + checkpoints = (Path(experiment_path) / "checkpoints").glob("*") + # Sort checkpoints by iteration number + checkpoints = sorted(checkpoints, key=lambda x: int(x.name)) + return [ + [c_name for c_name in checkpoint.glob("*") if re.match(r"\d+", c_name.name)] + for checkpoint in checkpoints + ] + + +def extract_stage_shards(state): + # Extract the weight shard and split it into the stage shards + # Reproduce the split done in MultiStageModelBase.setup + total_shard_size = sum(state['stage_shard_sizes']) + if len(state['shard'].shape) == 1: + # Flat buffer + weight_shard = state['shard'][:total_shard_size] + elif len(state['shard'].shape) == 2: + # 2D buffer + weight_shard = state['shard'][0] + else: + raise ValueError(f"Unrecognized buffer shape {state['shard'].shape}") + return weight_shard.split(state['stage_shard_sizes']) + + +def extract_individual_weights(merged_stage_shard, stage_content): + # Get individual weights from shards that are merged across data-parallel + weights_numel = [np.prod(weight_meta['shape']) for weight_meta in stage_content] + weights = merged_stage_shard[:sum(weights_numel)].split(weights_numel) + return [weight.reshape(weight_meta['shape']) for weight, weight_meta in zip(weights, stage_content)] + + +def merge_checkpoint(experiment_dir, dummy_experiment_dir=None): + """Load a fast-llm checkpoint and merge the data, tensor, and pipeline-parallel shards""" + checkpoint_paths = get_checkpoint_paths(experiment_dir) + config = yaml.safe_load((Path(experiment_dir) / "config.yaml").read_text()) + + # Convert the last iteration + # Load the states from all the ranks + states = { + int(c_name.name): torch.load(c_name) + for c_name in tqdm(checkpoint_paths[-1]) + } + num_stages = len(states[0]["stages"]) + data_parallel_size = int(config["world_size"] / (config["tensor_parallel"] * config["pipeline_parallel"])) + + if dummy_experiment_dir is not None: + # Use the meta from the dummy checkpoint, and the shard from the actual checkpoint + dummy_checkpoint_paths = get_checkpoint_paths(dummy_experiment_dir) + dummy_states = { + int(c_name.name): torch.load(c_name) + for c_name in tqdm(dummy_checkpoint_paths[-1]) + } + for rank, state in dummy_states.items(): + state['shard'] = states[rank]['shard'] + states = dummy_states + + # Gather the data-parallel shards + # {tp_rank: [[stage_0_shard_0, stage_0_shard_1, ...], [stage_1_shard_0, ...], ...]} + # {tp_rank: [{fsdp_rank: shard}, ...]} + fsdp_shards = { + i: [[None for _ in range(data_parallel_size)] for _ in range(num_stages)] + for i in range(config["tensor_parallel"]) + } + + for rank, state in states.items(): + on_device_stage_shards = extract_stage_shards(state) + on_device_stage_indices = [i for (i, stage_meta) in enumerate(state["stages"]) if stage_meta["on_device"]] + for stage_index, stage_shard in zip(on_device_stage_indices, on_device_stage_shards): + stage_meta = state["stages"][stage_index] + # fsdp_shards[stage_meta["tp_rank"]][stage_index].append((stage_meta, stage_shard)) + fsdp_shards[stage_meta["tp_rank"]][stage_index][stage_meta["fsdp_rank"]] = stage_shard + + # Concatenate the data-parallel shards + # and get individual weights + data_concatenated_shards = { + tp_rank: [ + extract_individual_weights( + torch.cat(stage_shards, dim=0), + states[0]["stages"][stage_index]['content'] + ) + for stage_index, stage_shards in enumerate(fsdp_shards[tp_rank]) + ] + for tp_rank in range(config["tensor_parallel"]) + } + + # In the tensor-parallel case, concatenate the TP tensors along their TP dimensions. + # TODO + tp_concatenated_shards = data_concatenated_shards[0] + + # In the pipeline-parallel case, merge the stages + state_dict = { + weight_meta["name"]: weight + for stage_meta, stage_weights in zip(states[0]["stages"], tp_concatenated_shards) + for weight_meta, weight in zip(stage_meta["content"], stage_weights) + } + + print(f"Total number of parameters: {sum([weight.numel() for weight in state_dict.values()])}") + return state_dict, config + + +if __name__ == "__main__": + merge_checkpoint("/toolkit_infiniband_example_checkpoints/ngc_checkpoints/sc2_ablations/1B_repo_context_Top-level-Depth-first_pp2_64k_64k_2023_10_17_16_35_27/", + dummy_experiment_dir="/toolkit_infiniband_example_checkpoints/ngc_checkpoints/sc2_ablations/dev_1B_repo_context_Random_pp2_64k_64k_2023_10_18_22_20_36/") + diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 1c34f28a5c..f9aa99a7de 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -78,6 +78,22 @@ def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor return x +@torch.compile +def _apply_rotary_embeddings( + tensor: torch.Tensor, + rope_frequencies: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to a tensor: + * Convert it to a complex, full-precision tensor + * Multiply by the frequencies + * Convert back tho the input format. + # TODO: Full precision only needed for bfloat16? (Doesn't support complex numbers) + """ + complex_tensor = torch.view_as_complex(tensor.float().view(*tensor.shape[:-1], -1, rope_frequencies.size(-1), 2)) + return torch.view_as_real(complex_tensor * rope_frequencies).view_as(tensor).type_as(tensor) + + class GPTBigCodeAttention(nn.Module): def __init__(self, config, is_cross_attention=False, layer_idx=None): super().__init__() @@ -212,6 +228,7 @@ def forward( encoder_attention_mask: Optional[torch.Tensor] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, + rotary_embedding_frequencies: Optional[torch.Tensor] = None ) -> Union[ Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]], @@ -245,6 +262,11 @@ def forward( key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) + if self.config.use_rotary_embeddings: + # TODO: check/fix + query = _apply_rotary_embeddings(query, rotary_embedding_frequencies) + key = _apply_rotary_embeddings(key, rotary_embedding_frequencies) + attn_output, attn_weights = self._attn(query, key.transpose(-1, -2), value, attention_mask, head_mask) if not self.multi_query: @@ -308,6 +330,7 @@ def forward( encoder_attention_mask: Optional[torch.Tensor] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, + rotary_embedding_frequencies: Optional[torch.Tensor] = None ) -> Union[ Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor] ]: @@ -320,6 +343,7 @@ def forward( head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, + rotary_embedding_frequencies=rotary_embedding_frequencies ) attn_output = attn_outputs[0] # output_attn: a, present, (attentions) outputs = attn_outputs[1:] @@ -342,6 +366,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, + rotary_embedding_frequencies=rotary_embedding_frequencies ) attn_output = cross_attn_outputs[0] # residual connection @@ -506,7 +531,22 @@ def __init__(self, config): self.embed_dim = config.hidden_size self.wte = nn.Embedding(config.vocab_size, self.embed_dim) - self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) + if config.use_position_embeddings: + self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) + if config.use_rotary_embeddings: + # Calculate the complex frequencies (https://blog.eleuther.ai/rotary-embeddings/) + # `exp(i * n * a) = cos(n * a) + i sin(n * a)`, + # `a = theta ** - (2 * (channel // 2) / kv_channels)`, + # where n is the position in the sequence. + kv_channels = config.n_embd / config.n_head + angles = torch.outer( + torch.arange(config.max_position_embeddings, dtype=torch.float32), + torch.exp( + config.rotary_embedding_scale + * torch.arange(0, 1, 2 / kv_channels, dtype=torch.float32) + ), + ) + self._rotary_embedding_frequencies = torch.polar(torch.ones_like(angles), angles)[None, :, None, :] self.drop = nn.Dropout(config.embd_pdrop) self.h = nn.ModuleList([GPTBigCodeBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)]) @@ -632,8 +672,11 @@ def forward( if inputs_embeds is None: inputs_embeds = self.wte(input_ids) - position_embeds = self.wpe(position_ids) - hidden_states = inputs_embeds + position_embeds + if self.config.use_position_embeddings: + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + else: + hidden_states = inputs_embeds if token_type_ids is not None: token_type_embeds = self.wte(token_type_ids) @@ -656,7 +699,7 @@ def forward( def create_custom_forward(module): def custom_forward(*inputs): # None for past_key_value - return module(*inputs, use_cache, output_attentions) + return module(*inputs, use_cache, output_attentions, self._rotary_embedding_frequencies) return custom_forward @@ -679,6 +722,7 @@ def custom_forward(*inputs): encoder_attention_mask=encoder_attention_mask, use_cache=use_cache, output_attentions=output_attentions, + rotary_embedding_frequencies=self._rotary_embedding_frequencies ) hidden_states = outputs[0] From a1a1433248ba436a315705d57100ca37eb850bd6 Mon Sep 17 00:00:00 2001 From: Raymond Li Date: Fri, 20 Oct 2023 18:52:59 +0000 Subject: [PATCH 02/23] fix model --- src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index f9aa99a7de..b02e22ee08 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -111,6 +111,7 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" f" {self.num_heads})." ) + self.use_rotary_embeddings = config.use_rotary_embeddings self.scale_attn_weights = config.scale_attn_weights self.is_cross_attention = is_cross_attention @@ -262,7 +263,7 @@ def forward( key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) - if self.config.use_rotary_embeddings: + if self.use_rotary_embeddings: # TODO: check/fix query = _apply_rotary_embeddings(query, rotary_embedding_frequencies) key = _apply_rotary_embeddings(key, rotary_embedding_frequencies) From ed0f4b7ca1d8f67d7722f4f52a787516a15bf86a Mon Sep 17 00:00:00 2001 From: Raymond Li Date: Fri, 20 Oct 2023 19:57:01 +0000 Subject: [PATCH 03/23] fix model --- .../gpt_bigcode/modeling_gpt_bigcode.py | 26 ++++++++++++------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index b02e22ee08..985d98f43c 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -229,7 +229,8 @@ def forward( encoder_attention_mask: Optional[torch.Tensor] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, - rotary_embedding_frequencies: Optional[torch.Tensor] = None + rotary_embedding_frequencies_q: Optional[torch.Tensor] = None, + rotary_embedding_frequencies_k: Optional[torch.Tensor] = None ) -> Union[ Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]], @@ -264,9 +265,8 @@ def forward( key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) if self.use_rotary_embeddings: - # TODO: check/fix - query = _apply_rotary_embeddings(query, rotary_embedding_frequencies) - key = _apply_rotary_embeddings(key, rotary_embedding_frequencies) + query = _apply_rotary_embeddings(query, rotary_embedding_frequencies_q) + key = _apply_rotary_embeddings(key, rotary_embedding_frequencies_k) attn_output, attn_weights = self._attn(query, key.transpose(-1, -2), value, attention_mask, head_mask) @@ -331,7 +331,8 @@ def forward( encoder_attention_mask: Optional[torch.Tensor] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, - rotary_embedding_frequencies: Optional[torch.Tensor] = None + rotary_embedding_frequencies_q: Optional[torch.Tensor] = None, + rotary_embedding_frequencies_k: Optional[torch.Tensor] = None ) -> Union[ Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor] ]: @@ -344,7 +345,8 @@ def forward( head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, - rotary_embedding_frequencies=rotary_embedding_frequencies + rotary_embedding_frequencies_q=rotary_embedding_frequencies_q, + rotary_embedding_frequencies_k=rotary_embedding_frequencies_k ) attn_output = attn_outputs[0] # output_attn: a, present, (attentions) outputs = attn_outputs[1:] @@ -367,7 +369,8 @@ def forward( encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, - rotary_embedding_frequencies=rotary_embedding_frequencies + rotary_embedding_frequencies_q=rotary_embedding_frequencies_q, + rotary_embedding_frequencies_k=rotary_embedding_frequencies_k ) attn_output = cross_attn_outputs[0] # residual connection @@ -636,6 +639,10 @@ def forward( elif position_ids is None: position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + + # Rotary frequencies + rotary_embedding_frequencies_q = self._rotary_embedding_frequencies[:, past_length : past_length + input_shape[-1]] + rotary_embedding_frequencies_k = self._rotary_embedding_frequencies[:, :past_length + input_shape[-1], :, :] # Self-attention mask. query_length = input_shape[-1] @@ -700,7 +707,7 @@ def forward( def create_custom_forward(module): def custom_forward(*inputs): # None for past_key_value - return module(*inputs, use_cache, output_attentions, self._rotary_embedding_frequencies) + return module(*inputs, use_cache, output_attentions, rotary_embedding_frequencies_q, rotary_embedding_frequencies_k) return custom_forward @@ -723,7 +730,8 @@ def custom_forward(*inputs): encoder_attention_mask=encoder_attention_mask, use_cache=use_cache, output_attentions=output_attentions, - rotary_embedding_frequencies=self._rotary_embedding_frequencies + rotary_embedding_frequencies_q=rotary_embedding_frequencies_q, + rotary_embedding_frequencies_k=rotary_embedding_frequencies_k ) hidden_states = outputs[0] From e3f417c11b2f311c40f58b8601862c2169d2cb2a Mon Sep 17 00:00:00 2001 From: Raymond Li Date: Fri, 20 Oct 2023 20:20:20 +0000 Subject: [PATCH 04/23] fix model --- src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 985d98f43c..ae0e52a68a 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -641,8 +641,8 @@ def forward( position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) # Rotary frequencies - rotary_embedding_frequencies_q = self._rotary_embedding_frequencies[:, past_length : past_length + input_shape[-1]] - rotary_embedding_frequencies_k = self._rotary_embedding_frequencies[:, :past_length + input_shape[-1], :, :] + rotary_embedding_frequencies_q = self._rotary_embedding_frequencies[:, past_length : past_length + input_shape[-1]].to(device=device) + rotary_embedding_frequencies_k = self._rotary_embedding_frequencies[:, :past_length + input_shape[-1], :, :].to(device=device) # Self-attention mask. query_length = input_shape[-1] From 5aa788933bf567535a047989e3bbe033b01fc726 Mon Sep 17 00:00:00 2001 From: Raymond Li Date: Fri, 20 Oct 2023 20:21:03 +0000 Subject: [PATCH 05/23] support merge for tp checkpoints --- .../gpt_bigcode/merge_fast_llm_checkpoint.py | 28 +++++++++++++++---- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py b/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py index dc67b24657..3532bcd073 100644 --- a/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py +++ b/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py @@ -39,6 +39,21 @@ def extract_individual_weights(merged_stage_shard, stage_content): return [weight.reshape(weight_meta['shape']) for weight, weight_meta in zip(weights, stage_content)] +def concatenate_tp_shards(stage_tp_shards, stage_content): + # Concatenate the tp-shards in a given stage + # Stage_tp_shards: contains the individual weight shards for each rank + # [[weight1, weight2, ...] for rank in range(tp_size)] + concatenated_weights = [] + # Concatenate each individual weight along their TP dimension if they have one. + for weight_tp_shards, weight_meta in zip(zip(*stage_tp_shards), stage_content): + if weight_meta["tensor_parallel_dim"] is not None: + weight = torch.cat(weight_tp_shards, dim=weight_meta["tensor_parallel_dim"]) + else: + weight = weight_tp_shards[0] + concatenated_weights.append(weight) + return concatenated_weights + + def merge_checkpoint(experiment_dir, dummy_experiment_dir=None): """Load a fast-llm checkpoint and merge the data, tensor, and pipeline-parallel shards""" checkpoint_paths = get_checkpoint_paths(experiment_dir) @@ -51,7 +66,8 @@ def merge_checkpoint(experiment_dir, dummy_experiment_dir=None): for c_name in tqdm(checkpoint_paths[-1]) } num_stages = len(states[0]["stages"]) - data_parallel_size = int(config["world_size"] / (config["tensor_parallel"] * config["pipeline_parallel"])) + tensor_parallel = config["tensor_parallel"] + data_parallel_size = int(config["world_size"] / (tensor_parallel * config["pipeline_parallel"])) if dummy_experiment_dir is not None: # Use the meta from the dummy checkpoint, and the shard from the actual checkpoint @@ -69,7 +85,7 @@ def merge_checkpoint(experiment_dir, dummy_experiment_dir=None): # {tp_rank: [{fsdp_rank: shard}, ...]} fsdp_shards = { i: [[None for _ in range(data_parallel_size)] for _ in range(num_stages)] - for i in range(config["tensor_parallel"]) + for i in range(tensor_parallel) } for rank, state in states.items(): @@ -82,7 +98,7 @@ def merge_checkpoint(experiment_dir, dummy_experiment_dir=None): # Concatenate the data-parallel shards # and get individual weights - data_concatenated_shards = { + dp_concatenated_shards = { tp_rank: [ extract_individual_weights( torch.cat(stage_shards, dim=0), @@ -94,8 +110,10 @@ def merge_checkpoint(experiment_dir, dummy_experiment_dir=None): } # In the tensor-parallel case, concatenate the TP tensors along their TP dimensions. - # TODO - tp_concatenated_shards = data_concatenated_shards[0] + tp_concatenated_shards = [] + for stage_index, stage_tp_shards in enumerate(zip(*(dp_concatenated_shards[i] for i in range(tensor_parallel)))): + stage_content = states[0]["stages"][stage_index]["content"] + tp_concatenated_shards.append(concatenate_tp_shards(stage_tp_shards, stage_content)) # In the pipeline-parallel case, merge the stages state_dict = { From ac8ec9c099ca64c58a861c3a5a5d08e63340b017 Mon Sep 17 00:00:00 2001 From: Raymond Li Date: Fri, 20 Oct 2023 20:48:38 +0000 Subject: [PATCH 06/23] refactor scripts --- .../convert_fast_llm_checkpoint.py | 23 +++++++++++-------- .../gpt_bigcode/merge_fast_llm_checkpoint.py | 23 +++++++++++-------- 2 files changed, 26 insertions(+), 20 deletions(-) diff --git a/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py b/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py index 242c42a08c..5347777ac8 100644 --- a/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py +++ b/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py @@ -94,30 +94,29 @@ def convert_fast_llm_checkpoint(state_dict, config): return output_state_dict, config - -if __name__ == '__main__': +def main(argv=None): parser = argparse.ArgumentParser() parser.add_argument( - "--experiment_dir", + "--checkpoint_dir", type=str, help="Path to the experiment directory", - default="/toolkit_infiniband_example_checkpoints/ngc_checkpoints/sc2_ablations/1B_repo_context_Top-level-Depth-first_pp2_64k_64k_2023_10_17_16_35_27/" # TODO ) parser.add_argument( - "--save_dir", help="Path where the converted model is saved. Will use the checkpoint directory if not provided", - default="/toolkit_infiniband_example_checkpoints/ngc_checkpoints/sc2_ablations/1B_repo_context_Top-level-Depth-first_pp2_64k_64k_2023_10_17_16_35_27/converted" # TODO + "--save_dir", + type=str, + help="Path where the converted model is saved" ) - args = parser.parse_args() + args = parser.parse_args(argv) state_dict, config = merge_checkpoint( - args.experiment_dir, - dummy_experiment_dir="/toolkit_infiniband_example_checkpoints/ngc_checkpoints/sc2_ablations/dev_1B_repo_context_Random_pp2_64k_64k_2023_10_18_22_20_36/" + args.checkpoint_dir, + dummy_experiment_dir="/toolkit_infiniband_example_checkpoints/ngc_checkpoints/sc2_ablations/dev_1B_repo_context_Random_tp4_pp2_8k_8k_2023_10_19_18_40_11/" ) output_state_dict, output_config = convert_fast_llm_checkpoint(state_dict, config) print("Saving config") - save_dir = args.save_dir or os.path.dirname(args.experiment_dir) + save_dir = args.save_dir or args.checkpoint_dir + "-converted" output_config.save_pretrained(save_dir) # Store the state_dict to file. @@ -125,3 +124,7 @@ def convert_fast_llm_checkpoint(state_dict, config): print(f'Saving checkpoint to "{output_checkpoint_file}"') torch.save(output_state_dict, output_checkpoint_file) print(f'Done!') + + +if __name__ == "__main__": + main() diff --git a/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py b/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py index 3532bcd073..5f125abd9c 100644 --- a/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py +++ b/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py @@ -7,14 +7,15 @@ import yaml -def get_checkpoint_paths(experiment_path): +def get_all_checkpoint_paths(experiment_path): checkpoints = (Path(experiment_path) / "checkpoints").glob("*") # Sort checkpoints by iteration number checkpoints = sorted(checkpoints, key=lambda x: int(x.name)) - return [ - [c_name for c_name in checkpoint.glob("*") if re.match(r"\d+", c_name.name)] - for checkpoint in checkpoints - ] + return [get_checkpoint_paths(checkpoint) for checkpoint in checkpoints] + + +def get_checkpoint_paths(checkpoint_dir: Path): + return [c_name for c_name in checkpoint_dir.glob("*") if re.match(r"\d+", c_name.name)] def extract_stage_shards(state): @@ -54,16 +55,18 @@ def concatenate_tp_shards(stage_tp_shards, stage_content): return concatenated_weights -def merge_checkpoint(experiment_dir, dummy_experiment_dir=None): +def merge_checkpoint(checkpoint_dir, dummy_experiment_dir=None): """Load a fast-llm checkpoint and merge the data, tensor, and pipeline-parallel shards""" - checkpoint_paths = get_checkpoint_paths(experiment_dir) + # checkpoint_dir=experiment_dir/checkpoints/{iteration} + checkpoint_dir = Path(checkpoint_dir) + experiment_dir = checkpoint_dir.parent.parent + checkpoint_paths = get_checkpoint_paths(checkpoint_dir) config = yaml.safe_load((Path(experiment_dir) / "config.yaml").read_text()) - # Convert the last iteration # Load the states from all the ranks states = { int(c_name.name): torch.load(c_name) - for c_name in tqdm(checkpoint_paths[-1]) + for c_name in tqdm(checkpoint_paths) } num_stages = len(states[0]["stages"]) tensor_parallel = config["tensor_parallel"] @@ -71,7 +74,7 @@ def merge_checkpoint(experiment_dir, dummy_experiment_dir=None): if dummy_experiment_dir is not None: # Use the meta from the dummy checkpoint, and the shard from the actual checkpoint - dummy_checkpoint_paths = get_checkpoint_paths(dummy_experiment_dir) + dummy_checkpoint_paths = get_all_checkpoint_paths(dummy_experiment_dir) dummy_states = { int(c_name.name): torch.load(c_name) for c_name in tqdm(dummy_checkpoint_paths[-1]) From 07cb73927b4b956d11a3460e86543f65a8ea6a7a Mon Sep 17 00:00:00 2001 From: Raymond Li Date: Fri, 20 Oct 2023 20:55:54 +0000 Subject: [PATCH 07/23] add push_checkpoints --- .../convert_fast_llm_checkpoint.py | 7 +- .../gpt_bigcode/merge_fast_llm_checkpoint.py | 5 +- .../models/gpt_bigcode/push_checkpoints.py | 71 +++++++++++++++++++ 3 files changed, 77 insertions(+), 6 deletions(-) create mode 100644 src/transformers/models/gpt_bigcode/push_checkpoints.py diff --git a/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py b/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py index 5347777ac8..26fd87025b 100644 --- a/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py +++ b/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py @@ -1,5 +1,6 @@ import argparse import os +from pathlib import Path import re import torch @@ -98,12 +99,12 @@ def main(argv=None): parser = argparse.ArgumentParser() parser.add_argument( "--checkpoint_dir", - type=str, + type=Path, help="Path to the experiment directory", ) parser.add_argument( "--save_dir", - type=str, + type=Path, help="Path where the converted model is saved" ) args = parser.parse_args(argv) @@ -116,7 +117,7 @@ def main(argv=None): output_state_dict, output_config = convert_fast_llm_checkpoint(state_dict, config) print("Saving config") - save_dir = args.save_dir or args.checkpoint_dir + "-converted" + save_dir = args.save_dir or args.checkpoint_dir / "converted" output_config.save_pretrained(save_dir) # Store the state_dict to file. diff --git a/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py b/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py index 5f125abd9c..71731559c7 100644 --- a/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py +++ b/src/transformers/models/gpt_bigcode/merge_fast_llm_checkpoint.py @@ -55,13 +55,12 @@ def concatenate_tp_shards(stage_tp_shards, stage_content): return concatenated_weights -def merge_checkpoint(checkpoint_dir, dummy_experiment_dir=None): +def merge_checkpoint(checkpoint_dir: Path, dummy_experiment_dir=None): """Load a fast-llm checkpoint and merge the data, tensor, and pipeline-parallel shards""" # checkpoint_dir=experiment_dir/checkpoints/{iteration} - checkpoint_dir = Path(checkpoint_dir) experiment_dir = checkpoint_dir.parent.parent checkpoint_paths = get_checkpoint_paths(checkpoint_dir) - config = yaml.safe_load((Path(experiment_dir) / "config.yaml").read_text()) + config = yaml.safe_load((experiment_dir / "config.yaml").read_text()) # Load the states from all the ranks states = { diff --git a/src/transformers/models/gpt_bigcode/push_checkpoints.py b/src/transformers/models/gpt_bigcode/push_checkpoints.py new file mode 100644 index 0000000000..10c4acdc13 --- /dev/null +++ b/src/transformers/models/gpt_bigcode/push_checkpoints.py @@ -0,0 +1,71 @@ +import argparse +import re +import subprocess +from pathlib import Path + +from huggingface_hub import Repository + +from transformers.models.gpt_bigcode.convert_fast_llm_checkpoint import main as convert + + +""" +Script to upload Fast-llm checkpoints to a HF repo on the Hub. The script clones/creates a repo on the Hub, checks out +a branch `--branch_name`, and converts each `iter_` checkpoint and saves it as a commit on that branch. +""" + + +def get_iter_number(iter_dir: str): + m = re.match(r"(\d+)", iter_dir) + if m is not None: + return int(m.group(1)) + else: + raise ValueError(f"Invalid directory name: {iter_dir}") + + +def main(argv=None): + parser = argparse.ArgumentParser() + parser.add_argument("--exp_dir", type=Path, required=True, help="Path to experiment folder.") + parser.add_argument("--repo_name", required=True, help="Name of repository on the Hub in 'ORG/NAME' format.") + parser.add_argument("--branch_name", required=True, help="Name of branch in repository to save experiments.") + parser.add_argument( + "--save_dir", + type=Path, + help="Path where repository is cloned to locally. Will use {exp_dir}/hf_checkpoints if not provided", + ) + parser.add_argument( + "--iter_interval", + type=int, + default=1, + help="Iteration number must be divisble by iter_interval in order to be pushed", + ) + args, argv = parser.parse_known_args(argv) + + save_dir = args.save_dir or args.exp_dir / "hf_checkpoints" + + hf_repo = Repository(save_dir, clone_from=args.repo_name) + hf_repo.git_checkout(args.branch_name, create_branch_ok=True) + # Find last checkpoint that was uploaded + head_hash = hf_repo.git_head_hash() + commit_msg = subprocess.check_output(["git", "show", "-s", "--format=%B", head_hash], cwd=save_dir).decode() + try: + last_commit_iter = get_iter_number(commit_msg.strip()) + print(f"Last commit iteration: {last_commit_iter}") + except ValueError: + last_commit_iter = -1 + + # The checkpoint dirs should be in ascending iteration order, so that the last commit corresponds to the latest checkpoint + ckpt_dirs = sorted([x for x in (args.exp_dir / "checkpoints").iterdir() if re.match(r"(\d+)", x.name) and x.is_dir()]) + + for ckpt_dir in ckpt_dirs: + iter_number = get_iter_number(ckpt_dir.name) + if iter_number <= last_commit_iter: + continue + if iter_number % args.iter_interval == 0: + print(f"Converting iteration {iter_number}") + convert(argv + [f"--save_dir={str(save_dir)}", f"--checkpoint_dir={ckpt_dir}"]) + print(f"Pushing iteration {iter_number}") + hf_repo.push_to_hub(commit_message=f"{ckpt_dir.name}") + + +if __name__ == "__main__": + main() \ No newline at end of file From 0ba0a9b15cb215edf5fe389c22f51fbd242932e3 Mon Sep 17 00:00:00 2001 From: Raymond Li Date: Fri, 20 Oct 2023 21:46:03 +0000 Subject: [PATCH 08/23] use torchscript instead of torch compile? --- src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index ae0e52a68a..80d5f19daf 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -78,7 +78,7 @@ def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor return x -@torch.compile +@torch.jit.script def _apply_rotary_embeddings( tensor: torch.Tensor, rope_frequencies: torch.Tensor, From a452b099489de055f603e02c2aff410dd63130cc Mon Sep 17 00:00:00 2001 From: Raymond Li Date: Fri, 20 Oct 2023 21:53:49 +0000 Subject: [PATCH 09/23] fix --- src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 80d5f19daf..b757ecd96b 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -82,7 +82,7 @@ def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor def _apply_rotary_embeddings( tensor: torch.Tensor, rope_frequencies: torch.Tensor, -) -> tuple[torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor]: """ Apply rotary embeddings to a tensor: * Convert it to a complex, full-precision tensor From ba76887787e6452913fe02834afe8159fbda5122 Mon Sep 17 00:00:00 2001 From: Raymond Li Date: Fri, 20 Oct 2023 21:57:02 +0000 Subject: [PATCH 10/23] support sliding window --- .../models/gpt_bigcode/convert_fast_llm_checkpoint.py | 7 ++++--- .../models/gpt_bigcode/modeling_gpt_bigcode.py | 3 +++ 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py b/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py index 26fd87025b..1dc783fef3 100644 --- a/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py +++ b/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py @@ -4,7 +4,7 @@ import re import torch -from merge_fast_llm_checkpoint import merge_checkpoint +from transformers.models.gpt_bigcode.merge_fast_llm_checkpoint import merge_checkpoint from transformers.models.gpt_bigcode import GPTBigCodeConfig, GPTBigCodeForCausalLM, GPTBigCodeModel @@ -53,14 +53,15 @@ def convert_fast_llm_checkpoint(state_dict, config): use_rotary_embeddings=config["use_rotary_embeddings"], rotary_embedding_scale=config["rotary_embedding_scale"], use_position_embeddings=config["use_position_embeddings"], + attention_window_size=config["attention_window_size"] ) # Truncate the word embeddings to the vocab-size word_embeddings = state_dict.pop("_layers.0._word_embeddings_weight")[:config.vocab_size, :] output_state_dict["transformer.wte.weight"] = word_embeddings # TODO: positional embeddings - # Layers - # + # Layer-0 is the word embeddings + # Layers 1 to n_layer need to be re-mapped from 0 to n_layer-1. # _layers.{layer_index}.{op}.{w/b} # Concatenate QKV matrix diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index b757ecd96b..dd9f2bfc54 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -648,6 +648,9 @@ def forward( query_length = input_shape[-1] key_length = past_length + query_length self_attention_mask = self.bias[None, key_length - query_length : key_length, :key_length] + # Sliding window attention + if self.config.attention_window_size is not None: + self_attention_mask.triu_(-self.config.attention_window_size + 1) if attention_mask is not None: self_attention_mask = self_attention_mask * attention_mask.view(batch_size, 1, -1).to( From 9c46a310a73dfaec8a1bd8641effd1082af6d56e Mon Sep 17 00:00:00 2001 From: Raymond Li Date: Fri, 20 Oct 2023 21:59:57 +0000 Subject: [PATCH 11/23] remove torchscript --- src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index dd9f2bfc54..6f116078bf 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -78,7 +78,6 @@ def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor return x -@torch.jit.script def _apply_rotary_embeddings( tensor: torch.Tensor, rope_frequencies: torch.Tensor, From 921a6452165179f79491dd0f5bb71cb8fd1fb0f8 Mon Sep 17 00:00:00 2001 From: Raymond Li Date: Fri, 20 Oct 2023 22:14:47 +0000 Subject: [PATCH 12/23] fix config --- .../models/gpt_bigcode/configuration_gpt_bigcode.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py index 8ffb15861e..3c5cff3a7e 100644 --- a/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py @@ -123,7 +123,7 @@ def __init__( use_rotary_embeddings=False, rotary_embedding_scale=-math.log(10000), # - 9.210 use_position_embeddings=None, - # TODO: add window + attention_window_size=None, **kwargs, ): self.vocab_size = vocab_size @@ -146,6 +146,7 @@ def __init__( self.use_rotary_embeddings = use_rotary_embeddings self.rotary_embedding_scale = rotary_embedding_scale self.use_position_embeddings = use_position_embeddings if use_position_embeddings is not None else not use_rotary_embeddings + self.attention_window_size = attention_window_size self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id From 904d6872782e6090d88b694d62bf89f227d24391 Mon Sep 17 00:00:00 2001 From: Raymond Li Date: Mon, 23 Oct 2023 18:34:30 +0000 Subject: [PATCH 13/23] fix model when not using rotary --- .../models/gpt_bigcode/modeling_gpt_bigcode.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 6f116078bf..387e78eb46 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -640,8 +640,11 @@ def forward( position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) # Rotary frequencies - rotary_embedding_frequencies_q = self._rotary_embedding_frequencies[:, past_length : past_length + input_shape[-1]].to(device=device) - rotary_embedding_frequencies_k = self._rotary_embedding_frequencies[:, :past_length + input_shape[-1], :, :].to(device=device) + rotary_embedding_frequencies_q = None + rotary_embedding_frequencies_k = None + if self.config.use_rotary_embeddings: + rotary_embedding_frequencies_q = self._rotary_embedding_frequencies[:, past_length : past_length + input_shape[-1]].to(device=device) + rotary_embedding_frequencies_k = self._rotary_embedding_frequencies[:, :past_length + input_shape[-1], :, :].to(device=device) # Self-attention mask. query_length = input_shape[-1] From 7be390b8ed6bc25513258b3093f9ea8aa311d5bc Mon Sep 17 00:00:00 2001 From: Raymond Li Date: Tue, 24 Oct 2023 19:49:17 +0000 Subject: [PATCH 14/23] support absolute positional embeddings, new arg --- .../convert_fast_llm_checkpoint.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py b/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py index 1dc783fef3..8859b7650e 100644 --- a/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py +++ b/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py @@ -23,6 +23,10 @@ def convert_fast_llm_checkpoint(state_dict, config): # The converted output model. output_state_dict = {} + if "window_size" in config: + attention_window_size = config["window_size"] + else: + attention_window_size = config.get("attention_window_size", None) config = GPTBigCodeConfig( architectures=["GPTBigCodeLMHeadModel"], @@ -46,21 +50,23 @@ def convert_fast_llm_checkpoint(state_dict, config): summary_first_dropout=0.1, scale_attn_weights=True, use_cache=True, - bos_token_id=50256, # TODO: can we remove these? - eos_token_id=50256, + bos_token_id=0, # TODO: can we remove these? + eos_token_id=0, attention_softmax_in_fp32=True, scale_attention_softmax_in_fp32=True, use_rotary_embeddings=config["use_rotary_embeddings"], rotary_embedding_scale=config["rotary_embedding_scale"], use_position_embeddings=config["use_position_embeddings"], - attention_window_size=config["attention_window_size"] + attention_window_size=attention_window_size ) # Truncate the word embeddings to the vocab-size word_embeddings = state_dict.pop("_layers.0._word_embeddings_weight")[:config.vocab_size, :] output_state_dict["transformer.wte.weight"] = word_embeddings - # TODO: positional embeddings - # Layer-0 is the word embeddings + if config.use_position_embeddings: + output_state_dict["transformer.wpe.weight"] = state_dict.pop("_layers.0._position_embeddings_weight") + + # Layer-0 is the word/position embeddings # Layers 1 to n_layer need to be re-mapped from 0 to n_layer-1. # _layers.{layer_index}.{op}.{w/b} @@ -75,6 +81,7 @@ def convert_fast_llm_checkpoint(state_dict, config): layer_re = re.compile("_layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)") for name, value in state_dict.items(): m = layer_re.match(name) + assert m is not None, f"Invalid layer name: {name}" # The index of the layer. layer_index = int(m.group(1)) @@ -112,7 +119,7 @@ def main(argv=None): state_dict, config = merge_checkpoint( args.checkpoint_dir, - dummy_experiment_dir="/toolkit_infiniband_example_checkpoints/ngc_checkpoints/sc2_ablations/dev_1B_repo_context_Random_tp4_pp2_8k_8k_2023_10_19_18_40_11/" + dummy_experiment_dir=None ) output_state_dict, output_config = convert_fast_llm_checkpoint(state_dict, config) From 300190de507868b3af7405e77b8f63c88a8ad9ff Mon Sep 17 00:00:00 2001 From: RaymondLi0 Date: Fri, 27 Oct 2023 15:23:52 +0000 Subject: [PATCH 15/23] correctly sort checkpoint dirs --- src/transformers/models/gpt_bigcode/push_checkpoints.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/gpt_bigcode/push_checkpoints.py b/src/transformers/models/gpt_bigcode/push_checkpoints.py index 10c4acdc13..65f88d7e7f 100644 --- a/src/transformers/models/gpt_bigcode/push_checkpoints.py +++ b/src/transformers/models/gpt_bigcode/push_checkpoints.py @@ -54,7 +54,11 @@ def main(argv=None): last_commit_iter = -1 # The checkpoint dirs should be in ascending iteration order, so that the last commit corresponds to the latest checkpoint - ckpt_dirs = sorted([x for x in (args.exp_dir / "checkpoints").iterdir() if re.match(r"(\d+)", x.name) and x.is_dir()]) + ckpt_dirs = sorted( + [x for x in (args.exp_dir / "checkpoints").iterdir() if re.match(r"(\d+)", x.name) and x.is_dir()], + key=lambda p: get_iter_number(p.name) + ) + print(f"Found the following checkpoints: {ckpt_dirs}") for ckpt_dir in ckpt_dirs: iter_number = get_iter_number(ckpt_dir.name) From e18d4af26fd025551f20dd94b37f876d3c005933 Mon Sep 17 00:00:00 2001 From: RaymondLi0 Date: Fri, 27 Oct 2023 15:26:42 +0000 Subject: [PATCH 16/23] set bias to zero for input-parallel layers --- .../models/gpt_bigcode/convert_fast_llm_checkpoint.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py b/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py index 8859b7650e..b7c0a64698 100644 --- a/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py +++ b/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py @@ -94,6 +94,12 @@ def convert_fast_llm_checkpoint(state_dict, config): if op_name == "final_layernorm": assert layer_index == config.n_layer + 1 output_state_dict[f"transformer.ln_f.{weight_or_bias}"] = value + # Bias was not used in training for InputParallel layers + elif op_name == "self_attn.dense" and weight_or_bias == "bias": + output_state_dict[f"transformer.h.{layer_index-1}.{NAME_MAP[op_name]}.{weight_or_bias}"] = torch.zeros_like(value) + # MLP layer-2 is also InputParallel + elif op_name == "_mlp._layer_2" and weight_or_bias == "bias": + output_state_dict[f"transformer.h.{layer_index-1}.{NAME_MAP[op_name]}.{weight_or_bias}"] = torch.zeros_like(value) else: output_state_dict[f"transformer.h.{layer_index-1}.{NAME_MAP[op_name]}.{weight_or_bias}"] = value From a84b7d2fac89b70be9e2e7529fd401a644a5701f Mon Sep 17 00:00:00 2001 From: RaymondLi0 Date: Fri, 10 Nov 2023 22:05:36 +0000 Subject: [PATCH 17/23] add args for each bias to set to zero --- .../convert_fast_llm_checkpoint.py | 25 ++++++++++++++++--- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py b/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py index b7c0a64698..97a1394dbd 100644 --- a/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py +++ b/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py @@ -20,7 +20,11 @@ } -def convert_fast_llm_checkpoint(state_dict, config): +def convert_fast_llm_checkpoint(state_dict, config, set_attn_dense_bias_zero, set_mlp_2_bias_zero): + if set_attn_dense_bias_zero: + print("Will set attention output layer biases to zero") + if set_mlp_2_bias_zero: + print("Will set MLP layer-2 biases to zero") # The converted output model. output_state_dict = {} if "window_size" in config: @@ -95,10 +99,10 @@ def convert_fast_llm_checkpoint(state_dict, config): assert layer_index == config.n_layer + 1 output_state_dict[f"transformer.ln_f.{weight_or_bias}"] = value # Bias was not used in training for InputParallel layers - elif op_name == "self_attn.dense" and weight_or_bias == "bias": + elif op_name == "self_attn.dense" and weight_or_bias == "bias" and set_attn_dense_bias_zero: output_state_dict[f"transformer.h.{layer_index-1}.{NAME_MAP[op_name]}.{weight_or_bias}"] = torch.zeros_like(value) # MLP layer-2 is also InputParallel - elif op_name == "_mlp._layer_2" and weight_or_bias == "bias": + elif op_name == "_mlp._layer_2" and weight_or_bias == "bias" and set_mlp_2_bias_zero: output_state_dict[f"transformer.h.{layer_index-1}.{NAME_MAP[op_name]}.{weight_or_bias}"] = torch.zeros_like(value) else: output_state_dict[f"transformer.h.{layer_index-1}.{NAME_MAP[op_name]}.{weight_or_bias}"] = value @@ -121,6 +125,19 @@ def main(argv=None): type=Path, help="Path where the converted model is saved" ) + parser.add_argument( + "--set_attn_dense_bias_zero", + action='store_true', + default=False, + help="Set the attention output layer bias to zero and ignore the value from the checkpoint. Shouldn't be used except to fix a bug from training." + ) + parser.add_argument( + "--set_mlp_2_bias_zero", + action='store_true', + default=False, + help="Set the MLP second layer bias to zero and ignore the value from the checkpoint. Shouldn't be used except to fix a bug from training." + ) + args = parser.parse_args(argv) state_dict, config = merge_checkpoint( @@ -128,7 +145,7 @@ def main(argv=None): dummy_experiment_dir=None ) - output_state_dict, output_config = convert_fast_llm_checkpoint(state_dict, config) + output_state_dict, output_config = convert_fast_llm_checkpoint(state_dict, config, args.set_attn_dense_bias_zero, args.set_mlp_2_bias_zero) print("Saving config") save_dir = args.save_dir or args.checkpoint_dir / "converted" From 13c6eb924cee1c7cf6de66cbfd93fa0146368b1c Mon Sep 17 00:00:00 2001 From: Raymond Li Date: Sat, 11 Nov 2023 03:24:58 +0000 Subject: [PATCH 18/23] support gqa --- .../gpt_bigcode/configuration_gpt_bigcode.py | 6 +- .../gpt_bigcode/modeling_gpt_bigcode.py | 101 +++++++++++------- 2 files changed, 65 insertions(+), 42 deletions(-) diff --git a/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py index 3c5cff3a7e..fbcb58a634 100644 --- a/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py @@ -106,6 +106,7 @@ def __init__( n_embd=768, n_layer=12, n_head=12, + head_groups=None, n_inner=None, activation_function="gelu_pytorch_tanh", resid_pdrop=0.1, @@ -142,11 +143,14 @@ def __init__( self.use_cache = use_cache self.attention_softmax_in_fp32 = attention_softmax_in_fp32 self.scale_attention_softmax_in_fp32 = scale_attention_softmax_in_fp32 - self.multi_query = multi_query self.use_rotary_embeddings = use_rotary_embeddings self.rotary_embedding_scale = rotary_embedding_scale self.use_position_embeddings = use_position_embeddings if use_position_embeddings is not None else not use_rotary_embeddings self.attention_window_size = attention_window_size + if head_groups is None: + self.head_groups = 1 if multi_query else n_head + else: + self.head_groups = head_groups self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 387e78eb46..6003e1bb9d 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -98,11 +98,10 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): super().__init__() self.mask_value = None - self.multi_query = config.multi_query self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads - self.kv_heads = 1 if self.multi_query else self.num_heads + self.kv_heads = config.head_groups self.kv_dim = self.kv_heads * self.head_dim self.split_size = self.embed_dim if self.head_dim * self.num_heads != self.embed_dim: @@ -122,8 +121,8 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): ) if self.is_cross_attention: - if self.multi_query: - raise NotImplementedError("Multi-Query Attention not supported for cross_attention") + if self.kv_heads != self.num_heads: + raise NotImplementedError("MQA / GQA not supported for cross_attention") self.c_attn = nn.Linear(self.embed_dim, 2 * self.embed_dim) self.q_attn = nn.Linear(self.embed_dim, self.embed_dim) @@ -151,29 +150,29 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): if self.scale_attn_weights: scale_factor /= self.head_dim**0.5 - # MQA models: (batch_size, query_length, num_heads * head_dim) - # MHA models: (batch_size, num_heads, query_length, head_dim) + # query: (batch_size, query_length, num_heads * head_dim) query_shape = query.shape batch_size = query_shape[0] + query_length = query_shape[1] key_length = key.size(-1) - if self.multi_query: + # MQA + if self.kv_heads == 1: # (batch_size, query_length, num_heads, head_dim) x (batch_size, head_dim, key_length) # -> (batch_size, query_length, num_heads, key_length) - query_length = query_shape[1] attn_shape = (batch_size, query_length, self.num_heads, key_length) attn_view = (batch_size, query_length * self.num_heads, key_length) # No copy needed for MQA 2, or when layer_past is provided. query = query.reshape(batch_size, query_length * self.num_heads, self.head_dim) else: - # (batch_size, num_heads, query_length, head_dim) x (batch_size, num_heads, head_dim, key_length) - # -> (batch_size, num_heads, query_length, key_length) - query_length = query_shape[2] - attn_shape = (batch_size, self.num_heads, query_length, key_length) - attn_view = (batch_size * self.num_heads, query_length, key_length) - # Always copies - query = query.reshape(batch_size * self.num_heads, query_length, self.head_dim) - # No copy when layer_past is provided. - key = key.reshape(batch_size * self.num_heads, self.head_dim, key_length) + heads_per_group = self.num_heads // self.kv_heads + attn_shape = (batch_size, self.kv_heads, query_length, heads_per_group, key_length) + attn_view = (batch_size * self.kv_heads, query_length * heads_per_group, key_length) + query = query.reshape(batch_size, query_length, self.kv_heads, heads_per_group, self.head_dim).transpose(1, 2) + query = query.reshape(batch_size * self.kv_heads, query_length * heads_per_group, self.head_dim) + key = key.reshape(batch_size * self.kv_heads, self.head_dim, key_length) + value = value.transpose(1, 2) # (batch, kv_heads * head_dim, key_length) + value = value.reshape(batch_size * self.kv_heads, self.head_dim, key_length).transpose(1, 2) + # Attention Mask: (batch_size, query_length, 1, key_length) attn_weights = torch.empty(attn_view, device=query.device, dtype=query.dtype) if query.device.type == "cpu": @@ -207,14 +206,18 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): # Mask heads if we want to if head_mask is not None: - if self.multi_query: + if self.kv_heads == 1: head_mask = head_mask.transpose(1, 2) attn_weights = attn_weights * head_mask - if self.multi_query: + # MQA + if self.kv_heads == 1: attn_output = torch.bmm(attn_weights.view(attn_view), value).view(query_shape) else: - attn_output = torch.matmul(attn_weights, value) + # -> (batch_size * self.kv_heads, query_length * heads_per_group, head_dim) + attn_output = torch.bmm(attn_weights.view(attn_view), value) + attn_output = attn_output.reshape(batch_size, self.kv_heads, query_length, heads_per_group, self.head_dim).transpose(1, 2) + attn_output = attn_output.reshape(batch_size, query_length, self.kv_heads * heads_per_group * self.head_dim) return attn_output, attn_weights @@ -234,6 +237,7 @@ def forward( Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]], ]: + # hidden: (batch, sequence, hidden_size) if encoder_hidden_states is not None: if not hasattr(self, "q_attn") or not self.is_cross_attention: raise ValueError( @@ -244,24 +248,39 @@ def forward( query = self.q_attn(hidden_states) key_value = self.c_attn(encoder_hidden_states) attention_mask = encoder_attention_mask - elif self.multi_query: + + if layer_past is not None: + key_value = torch.cat((layer_past, key_value), dim=-2) + present = key_value if use_cache else None + + key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) + elif self.kv_heads == 1: query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2) + + if layer_past is not None: + key_value = torch.cat((layer_past, key_value), dim=-2) + present = key_value if use_cache else None + + key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) else: # Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim), # i.e., the memory layout is not the same as GPT2. # This makes the concatenation with past_key_value more efficient. - query, key_value = ( - self.c_attn(hidden_states) - .view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim) - .transpose(1, 2) - .split((self.head_dim, 2 * self.head_dim), dim=3) - ) + # query, key_value = ( + # self.c_attn(hidden_states) + # .view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim) # (batch, sequence, num_heads, 3*head_dim) + # .transpose(1, 2) # (batch, num_heads, sequence, 3*head_dim) + # .split((self.head_dim, 2 * self.head_dim), dim=3) + # ) + + query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2) + # key_value: (batch, sequence, 2 * kv_heads * head_dim) - if layer_past is not None: - key_value = torch.cat((layer_past, key_value), dim=-2) - present = key_value if use_cache else None + if layer_past is not None: + key_value = torch.cat((layer_past, key_value), dim=-2) + present = key_value if use_cache else None - key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) + key, value = key_value.split((self.kv_heads * self.head_dim), dim=-1) if self.use_rotary_embeddings: query = _apply_rotary_embeddings(query, rotary_embedding_frequencies_q) @@ -269,16 +288,17 @@ def forward( attn_output, attn_weights = self._attn(query, key.transpose(-1, -2), value, attention_mask, head_mask) - if not self.multi_query: - attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape) attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) outputs = (attn_output, present) if output_attentions: - if self.multi_query: + if self.kv_heads == 1: # Transpose to return weights in the usual format (batch_size, num_heads, query_length, key_length) attn_weights = attn_weights.transpose(1, 2) + else: + # (batch_size, self.kv_heads, query_length, heads_per_group, key_length) + attn_weights = attn_weights.transpose(2, 3) outputs += (attn_weights,) return outputs # a, present, (attentions) @@ -313,8 +333,8 @@ def __init__(self, config, layer_idx=None): self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) if config.add_cross_attention: - if config.multi_query: - raise NotImplementedError("Cross-attention not implemented for MQA") + if config.head_groups < config.num_heads: + raise NotImplementedError("Cross-attention not implemented for MQA / GQA") self.crossattention = GPTBigCodeAttention(config, is_cross_attention=True, layer_idx=layer_idx) self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) @@ -530,7 +550,7 @@ def _set_gradient_checkpointing(self, module, value=False): class GPTBigCodeModel(GPTBigCodePreTrainedModel): def __init__(self, config): super().__init__(config) - self.multi_query = config.multi_query + self.kv_heads = config.head_groups self.embed_dim = config.hidden_size self.wte = nn.Embedding(config.vocab_size, self.embed_dim) @@ -659,9 +679,8 @@ def forward( dtype=torch.bool, device=self_attention_mask.device ) - # MQA models: (batch_size, query_length, n_heads, key_length) - # MHA models: (batch_size, n_heads, query_length, key_length) - attention_mask = self_attention_mask.unsqueeze(2 if self.multi_query else 1) + # Attention-shape: (batch_size, query_length, n_heads, key_length) + attention_mask = self_attention_mask.unsqueeze(2) # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] @@ -673,7 +692,7 @@ def forward( if encoder_attention_mask.dim() == 2: encoder_attention_mask.unsqueeze(1) assert encoder_attention_mask.dim() == 3 - encoder_attention_mask = encoder_attention_mask.bool().unsqueeze(2 if self.multi_query else 1) + encoder_attention_mask = encoder_attention_mask.bool().unsqueeze(2) else: encoder_attention_mask = None From 40db6c2f04cf800bf029532132622f94d3267361 Mon Sep 17 00:00:00 2001 From: Raymond Li Date: Sat, 11 Nov 2023 04:10:59 +0000 Subject: [PATCH 19/23] fix attention_mask --- src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 6003e1bb9d..56ce2358c4 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -681,6 +681,9 @@ def forward( # Attention-shape: (batch_size, query_length, n_heads, key_length) attention_mask = self_attention_mask.unsqueeze(2) + if self.kv_heads > 1: + # (batch_size, self.kv_heads, query_length, heads_per_group, key_length) + attention_mask = attention_mask.unsqueeze(1) # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] From aa0671f94f42ed633c2eacdb71c0ccfdbd81094e Mon Sep 17 00:00:00 2001 From: Raymond Li Date: Sun, 12 Nov 2023 22:12:46 +0000 Subject: [PATCH 20/23] update conversion --- .../models/gpt_bigcode/convert_fast_llm_checkpoint.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py b/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py index 8859b7650e..d8a085a29f 100644 --- a/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py +++ b/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py @@ -35,6 +35,7 @@ def convert_fast_llm_checkpoint(state_dict, config): n_embd=config["hidden_size"], n_layer=config["num_layers"], n_head=config["num_attention_heads"], + head_groups=config.get("head_groups", None), n_inner=config["ffn_hidden_size"], activation_function="gelu", # TODO multi_query=True, # TODO From 5cfb845db57ea91732494114fd191b5506557af3 Mon Sep 17 00:00:00 2001 From: RaymondLi0 Date: Tue, 14 Nov 2023 05:40:35 +0000 Subject: [PATCH 21/23] add option to re-push past iterations --- .../models/gpt_bigcode/push_checkpoints.py | 41 ++++++++++++++++--- 1 file changed, 36 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/gpt_bigcode/push_checkpoints.py b/src/transformers/models/gpt_bigcode/push_checkpoints.py index 65f88d7e7f..3abd551b71 100644 --- a/src/transformers/models/gpt_bigcode/push_checkpoints.py +++ b/src/transformers/models/gpt_bigcode/push_checkpoints.py @@ -1,3 +1,4 @@ +import os import argparse import re import subprocess @@ -38,12 +39,38 @@ def main(argv=None): default=1, help="Iteration number must be divisble by iter_interval in order to be pushed", ) + parser.add_argument( + "--iters", + type=int, + nargs='+', + default=None, + help="Specify a list of iterations to push. If None (default), will potentially push all the checkpoints (subject to iter_interval)", + ) + parser.add_argument( + "--tokenizer", + type=str, + default=None, + help="Path to tokenizer file to commit before the checkoints.", + ) + parser.add_argument( + "--push_past_iters", + action="store_true", + default=False, + help="If True, also push iterations that are lower than the last commit.", + ) args, argv = parser.parse_known_args(argv) save_dir = args.save_dir or args.exp_dir / "hf_checkpoints" hf_repo = Repository(save_dir, clone_from=args.repo_name) hf_repo.git_checkout(args.branch_name, create_branch_ok=True) + + # Pull latest changes + env = os.environ.copy() + env["GIT_LFS_SKIP_SMUDGE"] = "1" + git_pull_output = subprocess.run(["git", "pull"], cwd=save_dir, capture_output=True, env=env) + print(git_pull_output) + # Find last checkpoint that was uploaded head_hash = hf_repo.git_head_hash() commit_msg = subprocess.check_output(["git", "show", "-s", "--format=%B", head_hash], cwd=save_dir).decode() @@ -54,15 +81,19 @@ def main(argv=None): last_commit_iter = -1 # The checkpoint dirs should be in ascending iteration order, so that the last commit corresponds to the latest checkpoint - ckpt_dirs = sorted( - [x for x in (args.exp_dir / "checkpoints").iterdir() if re.match(r"(\d+)", x.name) and x.is_dir()], - key=lambda p: get_iter_number(p.name) - ) + ckpt_dirs = [x for x in (args.exp_dir / "checkpoints").iterdir() if re.match(r"(\d+)", x.name) and x.is_dir()] + if args.iters is not None: + args.iters = [int(n) for n in args.iters] + ckpt_dirs = [p for p in ckpt_dirs if get_iter_number(p.name) in args.iters] + ckpt_dirs = sorted(ckpt_dirs, key=lambda p: get_iter_number(p.name)) print(f"Found the following checkpoints: {ckpt_dirs}") + + if args.tokenizer is not None: + raise NotImplementedError("Push tokenizer not implemented yet") for ckpt_dir in ckpt_dirs: iter_number = get_iter_number(ckpt_dir.name) - if iter_number <= last_commit_iter: + if not args.push_past_iters and iter_number <= last_commit_iter: continue if iter_number % args.iter_interval == 0: print(f"Converting iteration {iter_number}") From 4648c300ccfa65b902330b4e001a8f06448ddec9 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 17 Nov 2023 13:17:25 -0500 Subject: [PATCH 22/23] Checkpoint v1 --- .../convert_fast_llm_checkpoint.py | 45 +++++++++---------- 1 file changed, 22 insertions(+), 23 deletions(-) diff --git a/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py b/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py index 1f99d94cee..4f3f16b9ea 100644 --- a/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py +++ b/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py @@ -5,22 +5,10 @@ import torch from transformers.models.gpt_bigcode.merge_fast_llm_checkpoint import merge_checkpoint -from transformers.models.gpt_bigcode import GPTBigCodeConfig, GPTBigCodeForCausalLM, GPTBigCodeModel +from transformers.models.gpt_bigcode import GPTBigCodeConfig -# The simple map of names for "automated" rules. -NAME_MAP = { - "_mlp._layer_1": "mlp.c_fc", - "_mlp._layer_2": "mlp.c_proj", - "layer_norm_1": "ln_1", - "layer_norm_2": "ln_2", - # "attention.dense": "attn.c_proj", - "self_attn.dense": "attn.c_proj", - # "self_attention.query_key_value": "attn.c_attn", -} - - -def convert_fast_llm_checkpoint(state_dict, config, set_attn_dense_bias_zero, set_mlp_2_bias_zero): +def convert_fast_llm_checkpoint(state_dict, config, set_attn_dense_bias_zero, set_mlp_2_bias_zero, version=1): if set_attn_dense_bias_zero: print("Will set attention output layer biases to zero") if set_mlp_2_bias_zero: @@ -66,10 +54,11 @@ def convert_fast_llm_checkpoint(state_dict, config, set_attn_dense_bias_zero, se ) # Truncate the word embeddings to the vocab-size - word_embeddings = state_dict.pop("_layers.0._word_embeddings_weight")[:config.vocab_size, :] + u="_" if version>=1 else "" + word_embeddings = state_dict.pop(f"{u}layers.0.{u}word_embeddings_weight")[:config.vocab_size, :] output_state_dict["transformer.wte.weight"] = word_embeddings if config.use_position_embeddings: - output_state_dict["transformer.wpe.weight"] = state_dict.pop("_layers.0._position_embeddings_weight") + output_state_dict["transformer.wpe.weight"] = state_dict.pop(f"{u}layers.0.{u}position_embeddings_weight") # Layer-0 is the word/position embeddings # Layers 1 to n_layer need to be re-mapped from 0 to n_layer-1. @@ -78,12 +67,22 @@ def convert_fast_llm_checkpoint(state_dict, config, set_attn_dense_bias_zero, se # Concatenate QKV matrix for layer_index in range(1, config.n_layer + 1): for weight_or_bias in ["weight", "bias"]: - query = state_dict.pop(f"_layers.{layer_index}.self_attn.query.{weight_or_bias}") - key_value = state_dict.pop(f"_layers.{layer_index}.self_attn.key_value.{weight_or_bias}") + query = state_dict.pop(f"{u}layers.{layer_index}.self_attn.query.{weight_or_bias}") + key_value = state_dict.pop(f"{u}layers.{layer_index}.self_attn.key_value.{weight_or_bias}") output_state_dict[f"transformer.h.{layer_index - 1}.attn.c_attn.{weight_or_bias}"] = torch.cat([query, key_value], dim=0) - + + # The simple map of names for "automated" rules. + name_map = { + f"{u}mlp.{u}layer_1": "mlp.c_fc", + f"{u}mlp.{u}layer_2": "mlp.c_proj", + "layer_norm_1": "ln_1", + "layer_norm_2": "ln_2", + # "attention.dense": "attn.c_proj", + "self_attn.dense": "attn.c_proj", + # "self_attention.query_key_value": "attn.c_attn", + } # Extract the other ops - layer_re = re.compile("_layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)") + layer_re = re.compile(f"{u}layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)") for name, value in state_dict.items(): m = layer_re.match(name) assert m is not None, f"Invalid layer name: {name}" @@ -101,12 +100,12 @@ def convert_fast_llm_checkpoint(state_dict, config, set_attn_dense_bias_zero, se output_state_dict[f"transformer.ln_f.{weight_or_bias}"] = value # Bias was not used in training for InputParallel layers elif op_name == "self_attn.dense" and weight_or_bias == "bias" and set_attn_dense_bias_zero: - output_state_dict[f"transformer.h.{layer_index-1}.{NAME_MAP[op_name]}.{weight_or_bias}"] = torch.zeros_like(value) + output_state_dict[f"transformer.h.{layer_index-1}.{name_map[op_name]}.{weight_or_bias}"] = torch.zeros_like(value) # MLP layer-2 is also InputParallel elif op_name == "_mlp._layer_2" and weight_or_bias == "bias" and set_mlp_2_bias_zero: - output_state_dict[f"transformer.h.{layer_index-1}.{NAME_MAP[op_name]}.{weight_or_bias}"] = torch.zeros_like(value) + output_state_dict[f"transformer.h.{layer_index-1}.{name_map[op_name]}.{weight_or_bias}"] = torch.zeros_like(value) else: - output_state_dict[f"transformer.h.{layer_index-1}.{NAME_MAP[op_name]}.{weight_or_bias}"] = value + output_state_dict[f"transformer.h.{layer_index-1}.{name_map[op_name]}.{weight_or_bias}"] = value # For LM head, transformers' wants the matrix to weight embeddings. output_state_dict["lm_head.weight"] = word_embeddings From d924a92a2af1c48cf24a6416dbd19862a1290fff Mon Sep 17 00:00:00 2001 From: RaymondLi0 Date: Sat, 25 Nov 2023 16:15:06 +0000 Subject: [PATCH 23/23] fix tensor names --- .../models/gpt_bigcode/convert_fast_llm_checkpoint.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py b/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py index 4f3f16b9ea..d9e78e97aa 100644 --- a/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py +++ b/src/transformers/models/gpt_bigcode/convert_fast_llm_checkpoint.py @@ -54,7 +54,7 @@ def convert_fast_llm_checkpoint(state_dict, config, set_attn_dense_bias_zero, se ) # Truncate the word embeddings to the vocab-size - u="_" if version>=1 else "" + u="_" if version==0 else "" word_embeddings = state_dict.pop(f"{u}layers.0.{u}word_embeddings_weight")[:config.vocab_size, :] output_state_dict["transformer.wte.weight"] = word_embeddings if config.use_position_embeddings: @@ -102,7 +102,7 @@ def convert_fast_llm_checkpoint(state_dict, config, set_attn_dense_bias_zero, se elif op_name == "self_attn.dense" and weight_or_bias == "bias" and set_attn_dense_bias_zero: output_state_dict[f"transformer.h.{layer_index-1}.{name_map[op_name]}.{weight_or_bias}"] = torch.zeros_like(value) # MLP layer-2 is also InputParallel - elif op_name == "_mlp._layer_2" and weight_or_bias == "bias" and set_mlp_2_bias_zero: + elif op_name == f"{u}mlp.{u}layer_2" and weight_or_bias == "bias" and set_mlp_2_bias_zero: output_state_dict[f"transformer.h.{layer_index-1}.{name_map[op_name]}.{weight_or_bias}"] = torch.zeros_like(value) else: output_state_dict[f"transformer.h.{layer_index-1}.{name_map[op_name]}.{weight_or_bias}"] = value