From 48efb8ddaf0606b122b498b19581626c607349d3 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 7 Feb 2023 09:56:47 -0500 Subject: [PATCH 01/27] Add files --- .../models/gpt_bigcode/__init__.py | 0 .../gpt_bigcode/configuration_gpt_bigcode.py | 274 +++ .../gpt_bigcode/modeling_gpt_bigcode.py | 1537 +++++++++++++++++ 3 files changed, 1811 insertions(+) create mode 100644 src/transformers/models/gpt_bigcode/__init__.py create mode 100644 src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py create mode 100644 src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py diff --git a/src/transformers/models/gpt_bigcode/__init__.py b/src/transformers/models/gpt_bigcode/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py new file mode 100644 index 0000000000..fe9c711d73 --- /dev/null +++ b/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py @@ -0,0 +1,274 @@ +# coding=utf-8 +# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" OpenAI GPT-2 configuration""" +from collections import OrderedDict +from typing import Any, List, Mapping, Optional + +from transformers import PreTrainedTokenizer, TensorType, is_torch_available + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfigWithPast, PatchingSpec +from ...utils import logging + + +logger = logging.get_logger(__name__) + +GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "gpt2": "https://huggingface.co/gpt2/resolve/main/config.json", + "gpt2-medium": "https://huggingface.co/gpt2-medium/resolve/main/config.json", + "gpt2-large": "https://huggingface.co/gpt2-large/resolve/main/config.json", + "gpt2-xl": "https://huggingface.co/gpt2-xl/resolve/main/config.json", + "distilgpt2": "https://huggingface.co/distilgpt2/resolve/main/config.json", +} + + +class GPT2Config(PretrainedConfig): + """ + This is the configuration class to store the configuration of a [`GPT2Model`] or a [`TFGPT2Model`]. It is used to + instantiate a GPT-2 model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the GPT-2 + [gpt2](https://huggingface.co/gpt2) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50257): + Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`GPT2Model`] or [`TFGPT2Model`]. + n_positions (`int`, *optional*, defaults to 1024): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + n_embd (`int`, *optional*, defaults to 768): + Dimensionality of the embeddings and hidden states. + n_layer (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + n_head (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + n_inner (`int`, *optional*, defaults to None): + Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd + activation_function (`str`, *optional*, defaults to `"gelu"`): + Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`. + resid_pdrop (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + embd_pdrop (`int`, *optional*, defaults to 0.1): + The dropout ratio for the embeddings. + attn_pdrop (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention. + layer_norm_epsilon (`float`, *optional*, defaults to 1e-5): + The epsilon to use in the layer normalization layers. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + summary_type (`string`, *optional*, defaults to `"cls_index"`): + Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and + [`TFGPT2DoubleHeadsModel`]. + + Has to be one of the following options: + + - `"last"`: Take the last token hidden state (like XLNet). + - `"first"`: Take the first token hidden state (like BERT). + - `"mean"`: Take the mean of all tokens hidden states. + - `"cls_index"`: Supply a Tensor of classification token position (like GPT/GPT-2). + - `"attn"`: Not implemented now, use multi-head attention. + summary_use_proj (`bool`, *optional*, defaults to `True`): + Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and + [`TFGPT2DoubleHeadsModel`]. + + Whether or not to add a projection after the vector extraction. + summary_activation (`str`, *optional*): + Argument used when doing sequence summary. Used in for the multiple choice head in + [`GPT2DoubleHeadsModel`]. + + Pass `"tanh"` for a tanh activation to the output, any other value will result in no activation. + summary_proj_to_labels (`bool`, *optional*, defaults to `True`): + Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and + [`TFGPT2DoubleHeadsModel`]. + + Whether the projection outputs should have `config.num_labels` or `config.hidden_size` classes. + summary_first_dropout (`float`, *optional*, defaults to 0.1): + Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and + [`TFGPT2DoubleHeadsModel`]. + + The dropout ratio to be used after the projection and activation. + scale_attn_weights (`bool`, *optional*, defaults to `True`): + Scale attention weights by dividing by sqrt(hidden_size).. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + scale_attn_by_inverse_layer_idx (`bool`, *optional*, defaults to `False`): + Whether to additionally scale attention weights by `1 / layer_idx + 1`. + reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`): + Whether to scale keys (K) prior to computing attention (dot-product) and upcast attention + dot-product/softmax to float() when training with mixed precision. + + Example: + + ```python + >>> from transformers import GPT2Config, GPT2Model + + >>> # Initializing a GPT2 configuration + >>> configuration = GPT2Config() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = GPT2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "gpt2" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = { + "hidden_size": "n_embd", + "max_position_embeddings": "n_positions", + "num_attention_heads": "n_head", + "num_hidden_layers": "n_layer", + } + + def __init__( + self, + vocab_size=50257, + n_positions=1024, + n_embd=768, + n_layer=12, + n_head=12, + n_inner=None, + activation_function="gelu_new", + resid_pdrop=0.1, + embd_pdrop=0.1, + attn_pdrop=0.1, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + summary_type="cls_index", + summary_use_proj=True, + summary_activation=None, + summary_proj_to_labels=True, + summary_first_dropout=0.1, + scale_attn_weights=True, + use_cache=True, + bos_token_id=50256, + eos_token_id=50256, + scale_attn_by_inverse_layer_idx=False, + reorder_and_upcast_attn=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.n_positions = n_positions + self.n_embd = n_embd + self.n_layer = n_layer + self.n_head = n_head + self.n_inner = n_inner + self.activation_function = activation_function + self.resid_pdrop = resid_pdrop + self.embd_pdrop = embd_pdrop + self.attn_pdrop = attn_pdrop + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.summary_type = summary_type + self.summary_use_proj = summary_use_proj + self.summary_activation = summary_activation + self.summary_first_dropout = summary_first_dropout + self.summary_proj_to_labels = summary_proj_to_labels + self.scale_attn_weights = scale_attn_weights + self.use_cache = use_cache + self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx + self.reorder_and_upcast_attn = reorder_and_upcast_attn + + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + + super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + +class GPT2OnnxConfig(OnnxConfigWithPast): + def __init__( + self, + config: PretrainedConfig, + task: str = "default", + patching_specs: List[PatchingSpec] = None, + use_past: bool = False, + ): + super().__init__(config, task=task, patching_specs=patching_specs, use_past=use_past) + if not getattr(self._config, "pad_token_id", None): + # TODO: how to do that better? + self._config.pad_token_id = 0 + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}}) + if self.use_past: + self.fill_with_past_key_values_(common_inputs, direction="inputs") + common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"} + else: + common_inputs["attention_mask"] = {0: "batch", 1: "sequence"} + + return common_inputs + + @property + def num_layers(self) -> int: + return self._config.n_layer + + @property + def num_attention_heads(self) -> int: + return self._config.n_head + + def generate_dummy_inputs( + self, + tokenizer: PreTrainedTokenizer, + batch_size: int = -1, + seq_length: int = -1, + is_pair: bool = False, + framework: Optional[TensorType] = None, + ) -> Mapping[str, Any]: + common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( + tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + ) + + # We need to order the input in the way they appears in the forward() + ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]}) + + # Need to add the past_keys + if self.use_past: + if not is_torch_available(): + raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.") + else: + import torch + + batch, seqlen = common_inputs["input_ids"].shape + # Not using the same length for past_key_values + past_key_values_length = seqlen + 2 + past_shape = ( + batch, + self.num_attention_heads, + past_key_values_length, + self._config.hidden_size // self.num_attention_heads, + ) + ordered_inputs["past_key_values"] = [ + (torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self.num_layers) + ] + + ordered_inputs["attention_mask"] = common_inputs["attention_mask"] + if self.use_past: + mask_dtype = ordered_inputs["attention_mask"].dtype + ordered_inputs["attention_mask"] = torch.cat( + [ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1 + ) + + return ordered_inputs + + @property + def default_onnx_opset(self) -> int: + return 13 diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py new file mode 100644 index 0000000000..5fe33bbca5 --- /dev/null +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -0,0 +1,1537 @@ +# coding=utf-8 +# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch OpenAI GPT-2 model.""" + +import math +import os +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.cuda.amp import autocast +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from ...modeling_utils import PreTrainedModel, SequenceSummary +from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer +from ...utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from ...utils.model_parallel_utils import assert_device_map, get_device_map +from .configuration_gpt2 import GPT2Config + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "gpt2" +_CONFIG_FOR_DOC = "GPT2Config" + +GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "gpt2", + "gpt2-medium", + "gpt2-large", + "gpt2-xl", + "distilgpt2", + # See all GPT-2 models at https://huggingface.co/models?filter=gpt2 +] + + +def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): + """Load tf checkpoints in a pytorch model""" + try: + import re + + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(gpt2_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array.squeeze()) + + for name, array in zip(names, arrays): + name = name[6:] # skip "model/" + name = name.split("/") + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+\d+", m_name): + scope_names = re.split(r"(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "w" or scope_names[0] == "g": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "b": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "wpe" or scope_names[0] == "wte": + pointer = getattr(pointer, scope_names[0]) + pointer = getattr(pointer, "weight") + else: + pointer = getattr(pointer, scope_names[0]) + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + try: + assert ( + pointer.shape == array.shape + ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array) + return model + + +class GPT2Attention(nn.Module): + def __init__(self, config, is_cross_attention=False, layer_idx=None): + super().__init__() + + max_positions = config.max_position_embeddings + self.register_buffer( + "bias", + torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view( + 1, 1, max_positions, max_positions + ), + ) + self.register_buffer("masked_bias", torch.tensor(-1e4)) + + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + self.split_size = self.embed_dim + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + + self.scale_attn_weights = config.scale_attn_weights + self.is_cross_attention = is_cross_attention + + # Layer-wise attention scaling, reordering, and upcasting + self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx + self.layer_idx = layer_idx + self.reorder_and_upcast_attn = config.reorder_and_upcast_attn + + if self.is_cross_attention: + self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim) + self.q_attn = Conv1D(self.embed_dim, self.embed_dim) + else: + self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim) + self.c_proj = Conv1D(self.embed_dim, self.embed_dim) + + self.attn_dropout = nn.Dropout(config.attn_pdrop) + self.resid_dropout = nn.Dropout(config.resid_pdrop) + + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads) + index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)]) + + # Prune conv1d layers + self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1) + self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0) + + # Update hyper params + self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads)) + self.num_heads = self.num_heads - len(heads) + self.pruned_heads = self.pruned_heads.union(heads) + + def _attn(self, query, key, value, attention_mask=None, head_mask=None): + attn_weights = torch.matmul(query, key.transpose(-1, -2)) + + if self.scale_attn_weights: + attn_weights = attn_weights / torch.full( + [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device + ) + + # Layer-wise attention scaling + if self.scale_attn_by_inverse_layer_idx: + attn_weights = attn_weights / float(self.layer_idx + 1) + + if not self.is_cross_attention: + # if only "normal" attention layer implements causal mask + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool) + mask_value = torch.finfo(attn_weights.dtype).min + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(attn_weights.device) + attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value) + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise + attn_weights = attn_weights.type(value.dtype) + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + + return attn_output, attn_weights + + def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None): + # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM) + bsz, num_heads, q_seq_len, dk = query.size() + _, _, k_seq_len, _ = key.size() + + # Preallocate attn_weights for `baddbmm` + attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device) + + # Compute Scale Factor + scale_factor = 1.0 + if self.scale_attn_weights: + scale_factor /= float(value.size(-1)) ** 0.5 + + if self.scale_attn_by_inverse_layer_idx: + scale_factor /= float(self.layer_idx + 1) + + # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk)) + with autocast(enabled=False): + q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) + attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) + attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) + + if not self.is_cross_attention: + # if only "normal" attention layer implements causal mask + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool() + mask_value = torch.finfo(attn_weights.dtype).min + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device) + attn_weights = torch.where(causal_mask, attn_weights, mask_value) + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise + if attn_weights.dtype != torch.float32: + raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32") + attn_weights = attn_weights.type(value.dtype) + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + + return attn_output, attn_weights + + def _split_heads(self, tensor, num_heads, attn_head_size): + """ + Splits hidden_size dim into attn_head_size and num_heads + """ + new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) + tensor = tensor.view(new_shape) + return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) + + def _merge_heads(self, tensor, num_heads, attn_head_size): + """ + Merges attn_head_size dim and num_attn_heads dim into hidden_size + """ + tensor = tensor.permute(0, 2, 1, 3).contiguous() + new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) + return tensor.view(new_shape) + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: + if encoder_hidden_states is not None: + if not hasattr(self, "q_attn"): + raise ValueError( + "If class is used as cross attention, the weights `q_attn` have to be defined. " + "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." + ) + + query = self.q_attn(hidden_states) + key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) + attention_mask = encoder_attention_mask + else: + query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) + + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) + + if layer_past is not None: + past_key, past_value = layer_past + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + + if use_cache is True: + present = (key, value) + else: + present = None + + if self.reorder_and_upcast_attn: + attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask) + else: + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + + attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) + attn_output = self.c_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights,) + + return outputs # a, present, (attentions) + + +class GPT2MLP(nn.Module): + def __init__(self, intermediate_size, config): + super().__init__() + embed_dim = config.hidden_size + self.c_fc = Conv1D(intermediate_size, embed_dim) + self.c_proj = Conv1D(embed_dim, intermediate_size) + self.act = ACT2FN[config.activation_function] + self.dropout = nn.Dropout(config.resid_pdrop) + + def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.c_proj(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class GPT2Block(nn.Module): + def __init__(self, config, layer_idx=None): + super().__init__() + hidden_size = config.hidden_size + inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size + + self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.attn = GPT2Attention(config, layer_idx=layer_idx) + self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + + if config.add_cross_attention: + self.crossattention = GPT2Attention(config, is_cross_attention=True, layer_idx=layer_idx) + self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + + self.mlp = GPT2MLP(inner_dim, config) + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_outputs = self.attn( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + attn_output = attn_outputs[0] # output_attn: a, present, (attentions) + outputs = attn_outputs[1:] + # residual connection + hidden_states = attn_output + residual + + if encoder_hidden_states is not None: + # add one self-attention block for cross-attention + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with " + "cross-attention layers by setting `config.add_cross_attention=True`" + ) + residual = hidden_states + hidden_states = self.ln_cross_attn(hidden_states) + cross_attn_outputs = self.crossattention( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) + attn_output = cross_attn_outputs[0] + # residual connection + hidden_states = residual + attn_output + outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights + + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + feed_forward_hidden_states = self.mlp(hidden_states) + # residual connection + hidden_states = residual + feed_forward_hidden_states + + if use_cache: + outputs = (hidden_states,) + outputs + else: + outputs = (hidden_states,) + outputs[1:] + + return outputs # hidden_states, present, (attentions, cross_attentions) + + +class GPT2PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = GPT2Config + load_tf_weights = load_tf_weights_in_gpt2 + base_model_prefix = "transformer" + is_parallelizable = True + supports_gradient_checkpointing = True + _no_split_modules = ["GPT2Block"] + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, (nn.Linear, Conv1D)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name == "c_proj.weight": + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, GPT2Model): + module.gradient_checkpointing = value + + +@dataclass +class GPT2DoubleHeadsModelOutput(ModelOutput): + """ + Base class for outputs of models predicting if two sentences are consecutive or not. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss. + mc_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mc_labels` is provided): + Multiple choice classification loss. + logits (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + mc_logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`): + Prediction scores of the multiple choice classification head (scores for each choice before SoftMax). + past_key_values (`Tuple[Tuple[torch.Tensor]]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of length `config.n_layers`, containing tuples of tensors of shape `(batch_size, num_heads, + sequence_length, embed_size_per_head)`). + + Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + GPT2Attentions weights after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + mc_loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + mc_logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +GPT2_START_DOCSTRING = r""" + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`GPT2Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +GPT2_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else + `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + sequence tokens in the vocabulary. + + If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as + `input_ids`. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`): + Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see + `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have + their past given to this model should not be passed as `input_ids` as they have already been computed. + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for + `past_key_values`. In other words, the `attention_mask` always has to have the length: + `len(past_key_values) + len(input_ids)` + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + + If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see + `past_key_values`). + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" +PARALLELIZE_DOCSTRING = r""" + This is an experimental feature and is a subject to change at a moment's notice. + + Uses a device map to distribute attention modules of the model across several devices. If no device map is given, + it will evenly distribute blocks across all devices. + + Args: + device_map (`Dict[int, list]`, optional, defaults to None): + A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always + automatically mapped to the first device (for esoteric reasons). That means that the first device should + have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the + following number of attention modules: + + - gpt2: 12 + - gpt2-medium: 24 + - gpt2-large: 36 + - gpt2-xl: 48 + + Example: + + ```python + # Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules: + model = GPT2LMHeadModel.from_pretrained("gpt2-xl") + device_map = { + 0: [0, 1, 2, 3, 4, 5, 6, 7, 8], + 1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21], + 2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34], + 3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + } + model.parallelize(device_map) + ``` +""" +DEPARALLELIZE_DOCSTRING = r""" + Moves the model to cpu from a model parallel state. + + Example: + + ```python + # On a 4 GPU machine with gpt2-large: + model = GPT2LMHeadModel.from_pretrained("gpt2-large") + device_map = { + 0: [0, 1, 2, 3, 4, 5, 6, 7], + 1: [8, 9, 10, 11, 12, 13, 14, 15], + 2: [16, 17, 18, 19, 20, 21, 22, 23], + 3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35], + } + model.parallelize(device_map) # Splits the model across several devices + model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache() + ``` +""" + + +@add_start_docstrings( + "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.", + GPT2_START_DOCSTRING, +) +class GPT2Model(GPT2PreTrainedModel): + _keys_to_ignore_on_load_missing = ["attn.masked_bias"] + + def __init__(self, config): + super().__init__(config) + + self.embed_dim = config.hidden_size + + self.wte = nn.Embedding(config.vocab_size, self.embed_dim) + self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) + + self.drop = nn.Dropout(config.embd_pdrop) + self.h = nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)]) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + + # Model parallel + self.model_parallel = False + self.device_map = None + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + # Check validity of device_map + self.device_map = ( + get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map + ) + assert_device_map(self.device_map, len(self.h)) + self.model_parallel = True + self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) + self.last_device = "cuda:" + str(max(self.device_map.keys())) + self.wte = self.wte.to(self.first_device) + self.wpe = self.wpe.to(self.first_device) + # Load onto devices + for k, v in self.device_map.items(): + for block in v: + cuda_device = "cuda:" + str(k) + self.h[block] = self.h[block].to(cuda_device) + # ln_f to last + self.ln_f = self.ln_f.to(self.last_device) + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + self.model_parallel = False + self.device_map = None + self.first_device = "cpu" + self.last_device = "cpu" + self.wte = self.wte.to("cpu") + self.wpe = self.wpe.to("cpu") + for index in range(len(self.h)): + self.h[index] = self.h[index].to("cpu") + self.ln_f = self.ln_f.to("cpu") + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.wte + + def set_input_embeddings(self, new_embeddings): + self.wte = new_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} + """ + for layer, heads in heads_to_prune.items(): + self.h[layer].attn.prune_heads(heads) + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPastAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + if position_ids is not None: + position_ids = position_ids.view(-1, input_shape[-1]) + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + past_length = past_key_values[0][0].size(-2) + if position_ids is None: + position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + + # GPT2Attention mask. + if attention_mask is not None: + if batch_size <= 0: + raise ValueError("batch_size has to be defined and > 0") + attention_mask = attention_mask.view(batch_size, -1) + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.add_cross_attention and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # head_mask has shape n_layer x batch x n_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + + output_shape = input_shape + (hidden_states.size(-1),) + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + all_hidden_states = () if output_hidden_states else None + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure layer_past is on same device as hidden_states (might not be correct) + if layer_past is not None: + layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if isinstance(head_mask, torch.Tensor): + head_mask = head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + + if use_cache: + logger.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache, output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + None, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.ln_f(hidden_states) + + hidden_states = hidden_states.view(output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] + if v is not None + ) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +@add_start_docstrings( + """ + The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input + embeddings). + """, + GPT2_START_DOCSTRING, +) +class GPT2LMHeadModel(GPT2PreTrainedModel): + _keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.transformer = GPT2Model(config) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + self.device_map = ( + get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.transformer.h)) + self.transformer.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.transformer.first_device) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + self.transformer.deparallelize() + self.transformer = self.transformer.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.model_parallel = False + torch.cuda.empty_cache() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids", None) + # only last token for inputs_ids if past is defined in kwargs + if past_key_values: + input_ids = input_ids[:, -1].unsqueeze(-1) + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -1].unsqueeze(-1) + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -1].unsqueeze(-1) + else: + position_ids = None + return { + "input_ids": input_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=CausalLMOutputWithCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.transformer.first_device) + hidden_states = hidden_states.to(self.lm_head.weight.device) + + lm_logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + cross_attentions=transformer_outputs.cross_attentions, + ) + + @staticmethod + def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + """ + return tuple( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past + ) + + +@add_start_docstrings( + """ +The GPT2 Model transformer with a language modeling and a multiple-choice classification head on top e.g. for +RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the +input embeddings, the classification head takes as input the input of a specified classification token index in the +input sequence). +""", + GPT2_START_DOCSTRING, +) +class GPT2DoubleHeadsModel(GPT2PreTrainedModel): + _keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + config.num_labels = 1 + self.transformer = GPT2Model(config) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + self.multiple_choice_head = SequenceSummary(config) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + self.device_map = ( + get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.transformer.h)) + self.transformer.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.transformer.first_device) + self.multiple_choice_head = self.multiple_choice_head.to(self.transformer.first_device) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + self.transformer.deparallelize() + self.transformer = self.transformer.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.multiple_choice_head = self.multiple_choice_head.to("cpu") + self.model_parallel = False + torch.cuda.empty_cache() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids", None) + # only last token for inputs_ids if past is defined in kwargs + if past_key_values: + input_ids = input_ids[:, -1].unsqueeze(-1) + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -1].unsqueeze(-1) + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -1].unsqueeze(-1) + else: + position_ids = None + + return { + "input_ids": input_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=GPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + mc_token_ids: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + mc_labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, GPT2DoubleHeadsModelOutput]: + r""" + mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input): + Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) - + 1]`. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids`. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. All labels set to + `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size - 1]` + mc_labels (`torch.LongTensor` of shape `(batch_size)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` + where *num_choices* is the size of the second dimension of the input tensors. (see *input_ids* above) + + Return: + + Example: + + ```python + >>> import torch + >>> from transformers import AutoTokenizer, GPT2DoubleHeadsModel + + >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") + >>> model = GPT2DoubleHeadsModel.from_pretrained("gpt2") + + >>> # Add a [CLS] to the vocabulary (we should train it also!) + >>> num_added_tokens = tokenizer.add_special_tokens({"cls_token": "[CLS]"}) + >>> # Update the model embeddings with the new vocabulary size + >>> embedding_layer = model.resize_token_embeddings(len(tokenizer)) + + >>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"] + >>> encoded_choices = [tokenizer.encode(s) for s in choices] + >>> cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices] + + >>> input_ids = torch.tensor(encoded_choices).unsqueeze(0) # Batch size: 1, number of choices: 2 + >>> mc_token_ids = torch.tensor([cls_token_location]) # Batch size: 1 + + >>> outputs = model(input_ids, mc_token_ids=mc_token_ids) + >>> lm_logits = outputs.logits + >>> mc_logits = outputs.mc_logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.transformer.first_device) + hidden_states = hidden_states.to(self.lm_head.weight.device) + + lm_logits = self.lm_head(hidden_states) + mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1) + + mc_loss = None + if mc_labels is not None: + loss_fct = CrossEntropyLoss() + mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)) + lm_loss = None + if labels is not None: + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + loss_fct = CrossEntropyLoss() + lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + if not return_dict: + output = (lm_logits, mc_logits) + transformer_outputs[1:] + if mc_loss is not None: + output = (mc_loss,) + output + return ((lm_loss,) + output) if lm_loss is not None else output + + return GPT2DoubleHeadsModelOutput( + loss=lm_loss, + mc_loss=mc_loss, + logits=lm_logits, + mc_logits=mc_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + @staticmethod + def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + """ + return tuple( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past + ) + + +@add_start_docstrings( + """ + The GPT2 Model transformer with a sequence classification head on top (linear layer). + + [`GPT2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-1) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + GPT2_START_DOCSTRING, +) +class GPT2ForSequenceClassification(GPT2PreTrainedModel): + _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.transformer = GPT2Model(config) + self.score = nn.Linear(config.n_embd, self.num_labels, bias=False) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint="microsoft/DialogRPT-updown", + output_type=SequenceClassifierOutputWithPast, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size, sequence_length = input_ids.shape[:2] + else: + batch_size, sequence_length = inputs_embeds.shape[:2] + + assert ( + self.config.pad_token_id is not None or batch_size == 1 + ), "Cannot handle batch sizes > 1 if no padding token is defined." + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) + else: + sequence_lengths = -1 + logger.warning( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ + GPT2 Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + GPT2_START_DOCSTRING, +) +class GPT2ForTokenClassification(GPT2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.transformer = GPT2Model(config) + if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None: + classifier_dropout = config.classifier_dropout + elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + # fmt: off + @add_code_sample_docstrings( + checkpoint="brad1141/gpt2-finetuned-comp2", + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_loss=0.25, + expected_output=["Lead", "Lead", "Lead", "Position", "Lead", "Lead", "Lead", "Lead", "Lead", "Lead", "Lead", "Lead"], + ) + # fmt: on + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + hidden_states = self.dropout(hidden_states) + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + transformer_outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) From 5145458c28d27411afb44b60dd34614dc316ef20 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 7 Feb 2023 10:05:07 -0500 Subject: [PATCH 02/27] Changes --- src/transformers/__init__.py | 24 ++++ src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 3 + src/transformers/models/auto/modeling_auto.py | 6 + .../models/auto/tokenization_auto.py | 1 + .../models/gpt_bigcode/__init__.py | 72 +++++++++++ .../gpt_bigcode/configuration_gpt_bigcode.py | 46 ++++--- .../gpt_bigcode/modeling_gpt_bigcode.py | 122 +++++++++--------- 8 files changed, 188 insertions(+), 87 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index b65781ab75..6d8be0ae26 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -274,6 +274,7 @@ "models.git": ["GIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "GitConfig", "GitProcessor", "GitVisionConfig"], "models.glpn": ["GLPN_PRETRAINED_CONFIG_ARCHIVE_MAP", "GLPNConfig"], "models.gpt2": ["GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPT2Config", "GPT2Tokenizer"], + "models.gpt_bigcode": ["GPT_BIGCODE_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTBigCodeConfig", "GPTBigCodeTokenizer"], "models.gpt_neo": ["GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTNeoConfig"], "models.gpt_neox": ["GPT_NEOX_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTNeoXConfig"], "models.gpt_neox_japanese": ["GPT_NEOX_JAPANESE_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTNeoXJapaneseConfig"], @@ -1531,6 +1532,18 @@ "load_tf_weights_in_gpt2", ] ) + _import_structure["models.gpt_bigcode"].extend( + [ + "GPT_BIGCODE_PRETRAINED_MODEL_ARCHIVE_LIST", + "GPTBigCodeDoubleHeadsModel", + "GPTBigCodeForSequenceClassification", + "GPTBigCodeForTokenClassification", + "GPTBigCodeLMHeadModel", + "GPTBigCodeModel", + "GPTBigCodePreTrainedModel", + "load_tf_weights_in_gpt_bigcode", + ] + ) _import_structure["models.gpt_neo"].extend( [ "GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -3714,6 +3727,7 @@ from .models.git import GIT_PRETRAINED_CONFIG_ARCHIVE_MAP, GitConfig, GitProcessor, GitVisionConfig from .models.glpn import GLPN_PRETRAINED_CONFIG_ARCHIVE_MAP, GLPNConfig from .models.gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config, GPT2Tokenizer + from .models.gpt_bigcode import GPT_BIGCODE_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTBigCodeConfig from .models.gpt_neo import GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoConfig from .models.gpt_neox import GPT_NEOX_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoXConfig from .models.gpt_neox_japanese import GPT_NEOX_JAPANESE_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoXJapaneseConfig @@ -4781,6 +4795,16 @@ GPT2PreTrainedModel, load_tf_weights_in_gpt2, ) + from .models.gpt_bigcode import ( + GPT_BIGCODE_PRETRAINED_MODEL_ARCHIVE_LIST, + GPTBigCodeDoubleHeadsModel, + GPTBigCodeForSequenceClassification, + GPTBigCodeForTokenClassification, + GPTBigCodeLMHeadModel, + GPTBigCodeModel, + GPTBigCodePreTrainedModel, + load_tf_weights_in_gpt_bigcode, + ) from .models.gpt_neo import ( GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST, GPTNeoForCausalLM, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 9eade475fa..820f90bc36 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -79,6 +79,7 @@ git, glpn, gpt2, + gpt_bigcode, gpt_neo, gpt_neox, gpt_neox_japanese, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 1a77eb0153..70b777f414 100755 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -84,6 +84,7 @@ ("glpn", "GLPNConfig"), ("gpt-sw3", "GPT2Config"), ("gpt2", "GPT2Config"), + ("gpt_bigcode", "GPTBigCodeConfig"), ("gpt_neo", "GPTNeoConfig"), ("gpt_neox", "GPTNeoXConfig"), ("gpt_neox_japanese", "GPTNeoXJapaneseConfig"), @@ -248,6 +249,7 @@ ("git", "GIT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("glpn", "GLPN_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("gpt2", "GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("gpt_bigcode", "GPT_BIGCODE_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("gpt_neo", "GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("gpt_neox", "GPT_NEOX_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("gpt_neox_japanese", "GPT_NEOX_JAPANESE_PRETRAINED_CONFIG_ARCHIVE_MAP"), @@ -412,6 +414,7 @@ ("glpn", "GLPN"), ("gpt-sw3", "GPT-Sw3"), ("gpt2", "OpenAI GPT-2"), + ("gpt_bigcode", "GPT BigCode"), ("gpt_neo", "GPT Neo"), ("gpt_neox", "GPT NeoX"), ("gpt_neox_japanese", "GPT NeoX Japanese"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 2fe18122aa..6c5c18fe06 100755 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -83,6 +83,7 @@ ("glpn", "GLPNModel"), ("gpt-sw3", "GPT2Model"), ("gpt2", "GPT2Model"), + ("gpt_bigcode", "GPTBigCodeModel"), ("gpt_neo", "GPTNeoModel"), ("gpt_neox", "GPTNeoXModel"), ("gpt_neox_japanese", "GPTNeoXJapaneseModel"), @@ -210,6 +211,7 @@ ("funnel", "FunnelForPreTraining"), ("gpt-sw3", "GPT2LMHeadModel"), ("gpt2", "GPT2LMHeadModel"), + ("gpt_bigcode", "GPTBigCodeLMHeadModel"), ("ibert", "IBertForMaskedLM"), ("layoutlm", "LayoutLMForMaskedLM"), ("longformer", "LongformerForMaskedLM"), @@ -274,6 +276,7 @@ ("git", "GitForCausalLM"), ("gpt-sw3", "GPT2LMHeadModel"), ("gpt2", "GPT2LMHeadModel"), + ("gpt_bigcode", "GPTBigCodeLMHeadModel"), ("gpt_neo", "GPTNeoForCausalLM"), ("gpt_neox", "GPTNeoXForCausalLM"), ("gpt_neox_japanese", "GPTNeoXJapaneseForCausalLM"), @@ -340,6 +343,7 @@ ("git", "GitForCausalLM"), ("gpt-sw3", "GPT2LMHeadModel"), ("gpt2", "GPT2LMHeadModel"), + ("gpt_bigcode", "GPTBigCodeLMHeadModel"), ("gpt_neo", "GPTNeoForCausalLM"), ("gpt_neox", "GPTNeoXForCausalLM"), ("gpt_neox_japanese", "GPTNeoXJapaneseForCausalLM"), @@ -618,6 +622,7 @@ ("funnel", "FunnelForSequenceClassification"), ("gpt-sw3", "GPT2ForSequenceClassification"), ("gpt2", "GPT2ForSequenceClassification"), + ("gpt_bigcode", "GPTBigCodeForSequenceClassification"), ("gpt_neo", "GPTNeoForSequenceClassification"), ("gptj", "GPTJForSequenceClassification"), ("ibert", "IBertForSequenceClassification"), @@ -757,6 +762,7 @@ ("funnel", "FunnelForTokenClassification"), ("gpt-sw3", "GPT2ForTokenClassification"), ("gpt2", "GPT2ForTokenClassification"), + ("gpt_bigcode", "GPTBigCodeForTokenClassification"), ("ibert", "IBertForTokenClassification"), ("layoutlm", "LayoutLMForTokenClassification"), ("layoutlmv2", "LayoutLMv2ForTokenClassification"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 7073221d74..3a26d54d16 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -141,6 +141,7 @@ ("git", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), ("gpt-sw3", ("GPTSw3Tokenizer" if is_sentencepiece_available() else None, None)), ("gpt2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), + ("gpt_bigcode", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), ("gpt_neo", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), ("gpt_neox", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)), ("gpt_neox_japanese", ("GPTNeoXJapaneseTokenizer", None)), diff --git a/src/transformers/models/gpt_bigcode/__init__.py b/src/transformers/models/gpt_bigcode/__init__.py index e69de29bb2..5b585a8908 100644 --- a/src/transformers/models/gpt_bigcode/__init__.py +++ b/src/transformers/models/gpt_bigcode/__init__.py @@ -0,0 +1,72 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ...utils import ( + OptionalDependencyNotAvailable, + _LazyModule, + is_torch_available, +) + + +_import_structure = { + "configuration_gpt_bigcode": ["GPT_BIGCODE_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTBigCodeConfig", "GPTBigCodeOnnxConfig"], +} + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_gpt_bigcode"] = [ + "GPT_BIGCODE_PRETRAINED_MODEL_ARCHIVE_LIST", + "GPTBigCodeDoubleHeadsModel", + "GPTBigCodeForSequenceClassification", + "GPTBigCodeForTokenClassification", + "GPTBigCodeLMHeadModel", + "GPTBigCodeModel", + "GPTBigCodePreTrainedModel", + "load_tf_weights_in_gpt_bigcode", + ] + +if TYPE_CHECKING: + from .configuration_gpt_bigcode import GPT_BIGCODE_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTBigCodeConfig, GPTBigCodeOnnxConfig + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_gpt_bigcode import ( + GPT_BIGCODE_PRETRAINED_MODEL_ARCHIVE_LIST, + GPTBigCodeDoubleHeadsModel, + GPTBigCodeForSequenceClassification, + GPTBigCodeForTokenClassification, + GPTBigCodeLMHeadModel, + GPTBigCodeModel, + GPTBigCodePreTrainedModel, + load_tf_weights_in_gpt_bigcode, + ) + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py index fe9c711d73..8fcf554ded 100644 --- a/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py @@ -26,18 +26,15 @@ logger = logging.get_logger(__name__) -GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = { - "gpt2": "https://huggingface.co/gpt2/resolve/main/config.json", - "gpt2-medium": "https://huggingface.co/gpt2-medium/resolve/main/config.json", - "gpt2-large": "https://huggingface.co/gpt2-large/resolve/main/config.json", - "gpt2-xl": "https://huggingface.co/gpt2-xl/resolve/main/config.json", - "distilgpt2": "https://huggingface.co/distilgpt2/resolve/main/config.json", +GPT_BIGCODE_PRETRAINED_CONFIG_ARCHIVE_MAP = { + # TODO: Add support for santa models. } -class GPT2Config(PretrainedConfig): +class GPTBigCodeConfig(PretrainedConfig): """ - This is the configuration class to store the configuration of a [`GPT2Model`] or a [`TFGPT2Model`]. It is used to + # TODO: Update doc + This is the configuration class to store the configuration of a [`GPTBigCodeModel`] or a [`TFGPTBigCodeModel`]. It is used to instantiate a GPT-2 model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the GPT-2 [gpt2](https://huggingface.co/gpt2) architecture. @@ -49,7 +46,7 @@ class GPT2Config(PretrainedConfig): Args: vocab_size (`int`, *optional*, defaults to 50257): Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`GPT2Model`] or [`TFGPT2Model`]. + `inputs_ids` passed when calling [`GPTBigCodeModel`] or [`TFGPTBigCodeModel`]. n_positions (`int`, *optional*, defaults to 1024): The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 512 or 1024 or 2048). @@ -74,8 +71,8 @@ class GPT2Config(PretrainedConfig): initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. summary_type (`string`, *optional*, defaults to `"cls_index"`): - Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and - [`TFGPT2DoubleHeadsModel`]. + Argument used when doing sequence summary, used in the models [`GPTBigCodeDoubleHeadsModel`] and + [`TFGPTBigCodeDoubleHeadsModel`]. Has to be one of the following options: @@ -85,23 +82,23 @@ class GPT2Config(PretrainedConfig): - `"cls_index"`: Supply a Tensor of classification token position (like GPT/GPT-2). - `"attn"`: Not implemented now, use multi-head attention. summary_use_proj (`bool`, *optional*, defaults to `True`): - Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and - [`TFGPT2DoubleHeadsModel`]. + Argument used when doing sequence summary, used in the models [`GPTBigCodeDoubleHeadsModel`] and + [`TFGPTBigCodeDoubleHeadsModel`]. Whether or not to add a projection after the vector extraction. summary_activation (`str`, *optional*): Argument used when doing sequence summary. Used in for the multiple choice head in - [`GPT2DoubleHeadsModel`]. + [`GPTBigCodeDoubleHeadsModel`]. Pass `"tanh"` for a tanh activation to the output, any other value will result in no activation. summary_proj_to_labels (`bool`, *optional*, defaults to `True`): - Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and - [`TFGPT2DoubleHeadsModel`]. + Argument used when doing sequence summary, used in the models [`GPTBigCodeDoubleHeadsModel`] and + [`TFGPTBigCodeDoubleHeadsModel`]. Whether the projection outputs should have `config.num_labels` or `config.hidden_size` classes. summary_first_dropout (`float`, *optional*, defaults to 0.1): - Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and - [`TFGPT2DoubleHeadsModel`]. + Argument used when doing sequence summary, used in the models [`GPTBigCodeDoubleHeadsModel`] and + [`TFGPTBigCodeDoubleHeadsModel`]. The dropout ratio to be used after the projection and activation. scale_attn_weights (`bool`, *optional*, defaults to `True`): @@ -117,19 +114,19 @@ class GPT2Config(PretrainedConfig): Example: ```python - >>> from transformers import GPT2Config, GPT2Model + >>> from transformers import GPTBigCodeConfig, GPTBigCodeModel - >>> # Initializing a GPT2 configuration - >>> configuration = GPT2Config() + >>> # Initializing a GPTBigCode configuration + >>> configuration = GPTBigCodeConfig() >>> # Initializing a model (with random weights) from the configuration - >>> model = GPT2Model(configuration) + >>> model = GPTBigCodeModel(configuration) >>> # Accessing the model configuration >>> configuration = model.config ```""" - model_type = "gpt2" + model_type = "gpt_bigcode" keys_to_ignore_at_inference = ["past_key_values"] attribute_map = { "hidden_size": "n_embd", @@ -193,7 +190,8 @@ def __init__( super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) -class GPT2OnnxConfig(OnnxConfigWithPast): +class GPTBigCodeOnnxConfig(OnnxConfigWithPast): + # TODO: Onnx support? def __init__( self, config: PretrainedConfig, diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 5fe33bbca5..8625e77759 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -44,26 +44,22 @@ replace_return_docstrings, ) from ...utils.model_parallel_utils import assert_device_map, get_device_map -from .configuration_gpt2 import GPT2Config +from .configuration_gpt_bigcode import GPTBigCodeConfig logger = logging.get_logger(__name__) -_CHECKPOINT_FOR_DOC = "gpt2" -_CONFIG_FOR_DOC = "GPT2Config" +_CHECKPOINT_FOR_DOC = "gpt_bigcode" +_CONFIG_FOR_DOC = "GPTBigCodeConfig" -GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [ - "gpt2", - "gpt2-medium", - "gpt2-large", - "gpt2-xl", - "distilgpt2", - # See all GPT-2 models at https://huggingface.co/models?filter=gpt2 +GPT_BIGCODE_PRETRAINED_MODEL_ARCHIVE_LIST = [ + # TODO: Add support for santa models. ] -def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): +def load_tf_weights_in_gpt_bigcode(model, config, gpt_bigcode_checkpoint_path): """Load tf checkpoints in a pytorch model""" + # TODO: Update this. try: import re @@ -74,7 +70,7 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): "https://www.tensorflow.org/install/ for installation instructions." ) raise - tf_path = os.path.abspath(gpt2_checkpoint_path) + tf_path = os.path.abspath(gpt_bigcode_checkpoint_path) logger.info(f"Converting TensorFlow checkpoint from {tf_path}") # Load weights from TF model init_vars = tf.train.list_variables(tf_path) @@ -119,7 +115,7 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): return model -class GPT2Attention(nn.Module): +class GPTBigCodeAttention(nn.Module): def __init__(self, config, is_cross_attention=False, layer_idx=None): super().__init__() @@ -300,7 +296,7 @@ def forward( if not hasattr(self, "q_attn"): raise ValueError( "If class is used as cross attention, the weights `q_attn` have to be defined. " - "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." + "Please make sure to instantiate class with `GPTBigCodeAttention(..., is_cross_attention=True)`." ) query = self.q_attn(hidden_states) @@ -339,7 +335,7 @@ def forward( return outputs # a, present, (attentions) -class GPT2MLP(nn.Module): +class GPTBigCodeMLP(nn.Module): def __init__(self, intermediate_size, config): super().__init__() embed_dim = config.hidden_size @@ -356,21 +352,21 @@ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.Fl return hidden_states -class GPT2Block(nn.Module): +class GPTBigCodeBlock(nn.Module): def __init__(self, config, layer_idx=None): super().__init__() hidden_size = config.hidden_size inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.attn = GPT2Attention(config, layer_idx=layer_idx) + self.attn = GPTBigCodeAttention(config, layer_idx=layer_idx) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) if config.add_cross_attention: - self.crossattention = GPT2Attention(config, is_cross_attention=True, layer_idx=layer_idx) + self.crossattention = GPTBigCodeAttention(config, is_cross_attention=True, layer_idx=layer_idx) self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.mlp = GPT2MLP(inner_dim, config) + self.mlp = GPTBigCodeMLP(inner_dim, config) def forward( self, @@ -434,18 +430,18 @@ def forward( return outputs # hidden_states, present, (attentions, cross_attentions) -class GPT2PreTrainedModel(PreTrainedModel): +class GPTBigCodePreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ - config_class = GPT2Config - load_tf_weights = load_tf_weights_in_gpt2 + config_class = GPTBigCodeConfig + load_tf_weights = load_tf_weights_in_gpt_bigcode base_model_prefix = "transformer" is_parallelizable = True supports_gradient_checkpointing = True - _no_split_modules = ["GPT2Block"] + _no_split_modules = ["GPTBigCodeBlock"] def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) @@ -478,12 +474,12 @@ def _init_weights(self, module): p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))) def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, GPT2Model): + if isinstance(module, GPTBigCodeModel): module.gradient_checkpointing = value @dataclass -class GPT2DoubleHeadsModelOutput(ModelOutput): +class GPTBigCodeDoubleHeadsModelOutput(ModelOutput): """ Base class for outputs of models predicting if two sentences are consecutive or not. @@ -511,7 +507,7 @@ class GPT2DoubleHeadsModelOutput(ModelOutput): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. - GPT2Attentions weights after the attention softmax, used to compute the weighted average in the + GPTBigCodeAttentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. """ @@ -524,7 +520,7 @@ class GPT2DoubleHeadsModelOutput(ModelOutput): attentions: Optional[Tuple[torch.FloatTensor]] = None -GPT2_START_DOCSTRING = r""" +GPTBigCode_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads @@ -535,12 +531,12 @@ class GPT2DoubleHeadsModelOutput(ModelOutput): and behavior. Parameters: - config ([`GPT2Config`]): Model configuration class with all the parameters of the model. + config ([`GPTBigCodeConfig`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ -GPT2_INPUTS_DOCSTRING = r""" +GPTBigCode_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): `input_ids_length` = `sequence_length` if `past_key_values` is `None` else @@ -629,7 +625,7 @@ class GPT2DoubleHeadsModelOutput(ModelOutput): ```python # Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules: - model = GPT2LMHeadModel.from_pretrained("gpt2-xl") + model = GPTBigCodeLMHeadModel.from_pretrained("gpt2-xl") device_map = { 0: [0, 1, 2, 3, 4, 5, 6, 7, 8], 1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21], @@ -646,7 +642,7 @@ class GPT2DoubleHeadsModelOutput(ModelOutput): ```python # On a 4 GPU machine with gpt2-large: - model = GPT2LMHeadModel.from_pretrained("gpt2-large") + model = GPTBigCodeLMHeadModel.from_pretrained("gpt2-large") device_map = { 0: [0, 1, 2, 3, 4, 5, 6, 7], 1: [8, 9, 10, 11, 12, 13, 14, 15], @@ -660,10 +656,10 @@ class GPT2DoubleHeadsModelOutput(ModelOutput): @add_start_docstrings( - "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.", - GPT2_START_DOCSTRING, + "The bare GPTBigCode Model transformer outputting raw hidden-states without any specific head on top.", + GPTBigCode_START_DOCSTRING, ) -class GPT2Model(GPT2PreTrainedModel): +class GPTBigCodeModel(GPTBigCodePreTrainedModel): _keys_to_ignore_on_load_missing = ["attn.masked_bias"] def __init__(self, config): @@ -675,7 +671,7 @@ def __init__(self, config): self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.drop = nn.Dropout(config.embd_pdrop) - self.h = nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)]) + self.h = nn.ModuleList([GPTBigCodeBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)]) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) # Model parallel @@ -732,7 +728,7 @@ def _prune_heads(self, heads_to_prune): for layer, heads in heads_to_prune.items(): self.h[layer].attn.prune_heads(heads) - @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(GPTBigCode_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, output_type=BaseModelOutputWithPastAndCrossAttentions, @@ -789,7 +785,7 @@ def forward( position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) - # GPT2Attention mask. + # GPTBigCodeAttention mask. if attention_mask is not None: if batch_size <= 0: raise ValueError("batch_size has to be defined and > 0") @@ -935,17 +931,17 @@ def custom_forward(*inputs): @add_start_docstrings( """ - The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input + The GPTBigCode Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings). """, - GPT2_START_DOCSTRING, + GPTBigCode_START_DOCSTRING, ) -class GPT2LMHeadModel(GPT2PreTrainedModel): +class GPTBigCodeLMHeadModel(GPTBigCodePreTrainedModel): _keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"] def __init__(self, config): super().__init__(config) - self.transformer = GPT2Model(config) + self.transformer = GPTBigCodeModel(config) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) # Model parallel @@ -1009,7 +1005,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg "token_type_ids": token_type_ids, } - @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(GPTBigCode_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, output_type=CausalLMOutputWithCrossAttentions, @@ -1101,20 +1097,20 @@ def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> @add_start_docstrings( """ -The GPT2 Model transformer with a language modeling and a multiple-choice classification head on top e.g. for +The GPTBigCode Model transformer with a language modeling and a multiple-choice classification head on top e.g. for RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the input embeddings, the classification head takes as input the input of a specified classification token index in the input sequence). """, - GPT2_START_DOCSTRING, + GPTBigCode_START_DOCSTRING, ) -class GPT2DoubleHeadsModel(GPT2PreTrainedModel): +class GPTBigCodeDoubleHeadsModel(GPTBigCodePreTrainedModel): _keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"] def __init__(self, config): super().__init__(config) config.num_labels = 1 - self.transformer = GPT2Model(config) + self.transformer = GPTBigCodeModel(config) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) self.multiple_choice_head = SequenceSummary(config) @@ -1182,8 +1178,8 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg "token_type_ids": token_type_ids, } - @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=GPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC) + @add_start_docstrings_to_model_forward(GPTBigCode_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=GPTBigCodeDoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -1201,7 +1197,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs, - ) -> Union[Tuple, GPT2DoubleHeadsModelOutput]: + ) -> Union[Tuple, GPTBigCodeDoubleHeadsModelOutput]: r""" mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input): Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) - @@ -1220,10 +1216,10 @@ def forward( ```python >>> import torch - >>> from transformers import AutoTokenizer, GPT2DoubleHeadsModel + >>> from transformers import AutoTokenizer, GPTBigCodeDoubleHeadsModel >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") - >>> model = GPT2DoubleHeadsModel.from_pretrained("gpt2") + >>> model = GPTBigCodeDoubleHeadsModel.from_pretrained("gpt2") >>> # Add a [CLS] to the vocabulary (we should train it also!) >>> num_added_tokens = tokenizer.add_special_tokens({"cls_token": "[CLS]"}) @@ -1284,7 +1280,7 @@ def forward( output = (mc_loss,) + output return ((lm_loss,) + output) if lm_loss is not None else output - return GPT2DoubleHeadsModelOutput( + return GPTBigCodeDoubleHeadsModelOutput( loss=lm_loss, mc_loss=mc_loss, logits=lm_logits, @@ -1309,9 +1305,9 @@ def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> @add_start_docstrings( """ - The GPT2 Model transformer with a sequence classification head on top (linear layer). + The GPTBigCode Model transformer with a sequence classification head on top (linear layer). - [`GPT2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models + [`GPTBigCodeForSequenceClassification`] uses the last token in order to do the classification, as other causal models (e.g. GPT-1) do. Since it does classification on the last token, it requires to know the position of the last token. If a @@ -1320,15 +1316,15 @@ def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in each row of the batch). """, - GPT2_START_DOCSTRING, + GPTBigCode_START_DOCSTRING, ) -class GPT2ForSequenceClassification(GPT2PreTrainedModel): +class GPTBigCodeForSequenceClassification(GPTBigCodePreTrainedModel): _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head.weight"] def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels - self.transformer = GPT2Model(config) + self.transformer = GPTBigCodeModel(config) self.score = nn.Linear(config.n_embd, self.num_labels, bias=False) # Model parallel @@ -1338,7 +1334,7 @@ def __init__(self, config): # Initialize weights and apply final processing self.post_init() - @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(GPTBigCode_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint="microsoft/DialogRPT-updown", output_type=SequenceClassifierOutputWithPast, @@ -1442,17 +1438,17 @@ def forward( @add_start_docstrings( """ - GPT2 Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + GPTBigCode Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """, - GPT2_START_DOCSTRING, + GPTBigCode_START_DOCSTRING, ) -class GPT2ForTokenClassification(GPT2PreTrainedModel): +class GPTBigCodeForTokenClassification(GPTBigCodePreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels - self.transformer = GPT2Model(config) + self.transformer = GPTBigCodeModel(config) if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None: classifier_dropout = config.classifier_dropout elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None: @@ -1469,7 +1465,7 @@ def __init__(self, config): # Initialize weights and apply final processing self.post_init() - @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(GPTBigCode_INPUTS_DOCSTRING) # fmt: off @add_code_sample_docstrings( checkpoint="brad1141/gpt2-finetuned-comp2", From d5dd3072c7114c3835c66f555df94ab3e1e41e68 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 7 Feb 2023 11:31:52 -0500 Subject: [PATCH 03/27] tests --- __init__.py | 0 tests/models/gpt_bigcode/__init__.py | 0 .../gpt_bigcode/test_modeling_gpt_bigcode.py | 801 ++++++++++++++++++ 3 files changed, 801 insertions(+) create mode 100644 __init__.py create mode 100644 tests/models/gpt_bigcode/__init__.py create mode 100644 tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/models/gpt_bigcode/__init__.py b/tests/models/gpt_bigcode/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py b/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py new file mode 100644 index 0000000000..52b4669896 --- /dev/null +++ b/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py @@ -0,0 +1,801 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import datetime +import math +import unittest + +from transformers import GPTBigCodeConfig, is_torch_available +from transformers.testing_utils import require_torch, slow, torch_device + +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask + + +if is_torch_available(): + import torch + + from transformers import ( + GPT_BIGCODE_PRETRAINED_MODEL_ARCHIVE_LIST, + GPTBigCodeDoubleHeadsModel, + GPTBigCodeForSequenceClassification, + GPTBigCodeForTokenClassification, + GPTBigCodeLMHeadModel, + GPTBigCodeModel, + GPTBigCodeTokenizer, + ) + + +class GPTBigCodeModelTester: + # TODO: Update the tests to use valid pretrained models. + def __init__( + self, + parent, + batch_size=14, + seq_length=7, + is_training=True, + use_token_type_ids=True, + use_input_mask=True, + use_labels=True, + use_mc_token_ids=True, + vocab_size=99, + hidden_size=32, + num_hidden_layers=5, + num_attention_heads=4, + intermediate_size=37, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=16, + type_sequence_label_size=2, + initializer_range=0.02, + num_labels=3, + num_choices=4, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_token_type_ids = use_token_type_ids + self.use_input_mask = use_input_mask + self.use_labels = use_labels + self.use_mc_token_ids = use_mc_token_ids + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.num_labels = num_labels + self.num_choices = num_choices + self.scope = None + self.bos_token_id = vocab_size - 1 + self.eos_token_id = vocab_size - 1 + self.pad_token_id = vocab_size - 1 + + def get_large_model_config(self): + return GPTBigCodeConfig.from_pretrained("gpt2") + + def prepare_config_and_inputs( + self, gradient_checkpointing=False, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False + ): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = random_attention_mask([self.batch_size, self.seq_length]) + + token_type_ids = None + if self.use_token_type_ids: + token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) + + mc_token_ids = None + if self.use_mc_token_ids: + mc_token_ids = ids_tensor([self.batch_size, self.num_choices], self.seq_length) + + sequence_labels = None + token_labels = None + choice_labels = None + if self.use_labels: + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) + choice_labels = ids_tensor([self.batch_size], self.num_choices) + + config = self.get_config( + gradient_checkpointing=gradient_checkpointing, + scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx, + reorder_and_upcast_attn=reorder_and_upcast_attn, + ) + + head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2) + + return ( + config, + input_ids, + input_mask, + head_mask, + token_type_ids, + mc_token_ids, + sequence_labels, + token_labels, + choice_labels, + ) + + def get_config( + self, gradient_checkpointing=False, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False + ): + return GPTBigCodeConfig( + vocab_size=self.vocab_size, + n_embd=self.hidden_size, + n_layer=self.num_hidden_layers, + n_head=self.num_attention_heads, + n_inner=self.intermediate_size, + activation_function=self.hidden_act, + resid_pdrop=self.hidden_dropout_prob, + attn_pdrop=self.attention_probs_dropout_prob, + n_positions=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + initializer_range=self.initializer_range, + use_cache=True, + bos_token_id=self.bos_token_id, + eos_token_id=self.eos_token_id, + pad_token_id=self.pad_token_id, + gradient_checkpointing=gradient_checkpointing, + scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx, + reorder_and_upcast_attn=reorder_and_upcast_attn, + ) + + def get_pipeline_config(self): + config = self.get_config() + config.vocab_size = 300 + return config + + def prepare_config_and_inputs_for_decoder(self): + ( + config, + input_ids, + input_mask, + head_mask, + token_type_ids, + mc_token_ids, + sequence_labels, + token_labels, + choice_labels, + ) = self.prepare_config_and_inputs() + + encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size]) + encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) + + return ( + config, + input_ids, + input_mask, + head_mask, + token_type_ids, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ) + + def create_and_check_gpt_bigcode_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): + model = GPTBigCodeModel(config=config) + model.to(torch_device) + model.eval() + + result = model(input_ids, token_type_ids=token_type_ids, head_mask=head_mask) + result = model(input_ids, token_type_ids=token_type_ids) + result = model(input_ids) + + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + self.parent.assertEqual(len(result.past_key_values), config.n_layer) + + def create_and_check_gpt_bigcode_model_past(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): + model = GPTBigCodeModel(config=config) + model.to(torch_device) + model.eval() + + # first forward pass + outputs = model(input_ids, token_type_ids=token_type_ids, use_cache=True) + outputs_use_cache_conf = model(input_ids, token_type_ids=token_type_ids) + outputs_no_past = model(input_ids, token_type_ids=token_type_ids, use_cache=False) + + self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf)) + self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1) + + output, past = outputs.to_tuple() + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + next_token_types = ids_tensor([self.batch_size, 1], self.type_vocab_size) + + # append to next input_ids and token_type_ids + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + next_token_type_ids = torch.cat([token_type_ids, next_token_types], dim=-1) + + output_from_no_past = model(next_input_ids, token_type_ids=next_token_type_ids)["last_hidden_state"] + output_from_past = model(next_tokens, token_type_ids=next_token_types, past_key_values=past)[ + "last_hidden_state" + ] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach() + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + + def create_and_check_gpt_bigcode_model_attention_mask_past( + self, config, input_ids, input_mask, head_mask, token_type_ids, *args + ): + model = GPTBigCodeModel(config=config) + model.to(torch_device) + model.eval() + + # create attention mask + attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device) + half_seq_length = self.seq_length // 2 + attn_mask[:, half_seq_length:] = 0 + + # first forward pass + output, past = model(input_ids, attention_mask=attn_mask).to_tuple() + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + + # change a random masked slice from input_ids + random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1 + random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1) + input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens + + # append to next input_ids and attn_mask + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + attn_mask = torch.cat( + [attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)], + dim=1, + ) + + # get two different outputs + output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"] + output_from_past = model(next_tokens, past_key_values=past, attention_mask=attn_mask)["last_hidden_state"] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach() + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + + def create_and_check_gpt_bigcode_model_past_large_inputs( + self, config, input_ids, input_mask, head_mask, token_type_ids, *args + ): + model = GPTBigCodeModel(config=config) + model.to(torch_device) + model.eval() + + # first forward pass + outputs = model(input_ids, token_type_ids=token_type_ids, attention_mask=input_mask, use_cache=True) + + output, past = outputs.to_tuple() + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) + next_token_types = ids_tensor([self.batch_size, 3], self.type_vocab_size) + next_mask = ids_tensor((self.batch_size, 3), vocab_size=2) + + # append to next input_ids and token_type_ids + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + next_token_type_ids = torch.cat([token_type_ids, next_token_types], dim=-1) + next_attention_mask = torch.cat([input_mask, next_mask], dim=-1) + + output_from_no_past = model( + next_input_ids, token_type_ids=next_token_type_ids, attention_mask=next_attention_mask + )["last_hidden_state"] + output_from_past = model( + next_tokens, token_type_ids=next_token_types, attention_mask=next_attention_mask, past_key_values=past + )["last_hidden_state"] + self.parent.assertTrue(output_from_past.shape[1] == next_tokens.shape[1]) + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + + def create_and_check_lm_head_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): + model = GPTBigCodeLMHeadModel(config) + model.to(torch_device) + model.eval() + + result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids) + self.parent.assertEqual(result.loss.shape, ()) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + + def create_and_check_forward_and_backwards( + self, config, input_ids, input_mask, head_mask, token_type_ids, *args, gradient_checkpointing=False + ): + model = GPTBigCodeLMHeadModel(config) + model.to(torch_device) + if gradient_checkpointing: + model.gradient_checkpointing_enable() + + result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids) + self.parent.assertEqual(result.loss.shape, ()) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + result.loss.backward() + + def create_and_check_double_lm_head_model( + self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, *args + ): + model = GPTBigCodeDoubleHeadsModel(config) + model.to(torch_device) + model.eval() + + multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() + multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() + multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() + + inputs = { + "input_ids": multiple_choice_inputs_ids, + "mc_token_ids": mc_token_ids, + "attention_mask": multiple_choice_input_mask, + "token_type_ids": multiple_choice_token_type_ids, + "labels": multiple_choice_inputs_ids, + } + + result = model(**inputs) + self.parent.assertEqual(result.loss.shape, ()) + self.parent.assertEqual( + result.logits.shape, (self.batch_size, self.num_choices, self.seq_length, self.vocab_size) + ) + self.parent.assertEqual(result.mc_logits.shape, (self.batch_size, self.num_choices)) + + def create_and_check_gpt_bigcode_for_sequence_classification( + self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, sequence_labels, *args + ): + config.num_labels = self.num_labels + model = GPTBigCodeForSequenceClassification(config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels)) + + def create_and_check_gpt_bigcode_for_token_classification( + self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, sequence_labels, *args + ): + config.num_labels = self.num_labels + model = GPTBigCodeForTokenClassification(config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels)) + + def create_and_check_gpt_bigcode_weight_initialization(self, config, *args): + model = GPTBigCodeModel(config) + model_std = model.config.initializer_range / math.sqrt(2 * model.config.n_layer) + for key in model.state_dict().keys(): + if "c_proj" in key and "weight" in key: + self.parent.assertLessEqual(abs(torch.std(model.state_dict()[key]) - model_std), 0.001) + self.parent.assertLessEqual(abs(torch.mean(model.state_dict()[key]) - 0.0), 0.01) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + + ( + config, + input_ids, + input_mask, + head_mask, + token_type_ids, + mc_token_ids, + sequence_labels, + token_labels, + choice_labels, + ) = config_and_inputs + + inputs_dict = { + "input_ids": input_ids, + "token_type_ids": token_type_ids, + "head_mask": head_mask, + } + + return config, inputs_dict + + +@require_torch +class GPTBigCodeModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): + # TODO: Update the tests to use valid pretrained models. + + all_model_classes = ( + (GPTBigCodeModel, GPTBigCodeLMHeadModel, GPTBigCodeDoubleHeadsModel, GPTBigCodeForSequenceClassification, GPTBigCodeForTokenClassification) + if is_torch_available() + else () + ) + all_generative_model_classes = (GPTBigCodeLMHeadModel, GPTBigCodeDoubleHeadsModel) if is_torch_available() else () + all_parallelizable_model_classes = (GPTBigCodeLMHeadModel, GPTBigCodeDoubleHeadsModel) if is_torch_available() else () + fx_compatible = True + test_missing_keys = False + test_model_parallel = True + + # special case for DoubleHeads model + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) + + if return_labels: + if model_class.__name__ == "GPTBigCodeDoubleHeadsModel": + inputs_dict["labels"] = torch.zeros( + (self.model_tester.batch_size, self.model_tester.num_choices, self.model_tester.seq_length), + dtype=torch.long, + device=torch_device, + ) + inputs_dict["input_ids"] = inputs_dict["labels"] + inputs_dict["token_type_ids"] = inputs_dict["labels"] + inputs_dict["mc_token_ids"] = torch.zeros( + (self.model_tester.batch_size, self.model_tester.num_choices), + dtype=torch.long, + device=torch_device, + ) + inputs_dict["mc_labels"] = torch.zeros( + self.model_tester.batch_size, dtype=torch.long, device=torch_device + ) + return inputs_dict + + def setUp(self): + self.model_tester = GPTBigCodeModelTester(self) + self.config_tester = ConfigTester(self, config_class=GPTBigCodeConfig, n_embd=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_gpt_bigcode_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_gpt_bigcode_model(*config_and_inputs) + + def test_gpt_bigcode_model_past(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_gpt_bigcode_model_past(*config_and_inputs) + + def test_gpt_bigcode_model_att_mask_past(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_gpt_bigcode_model_attention_mask_past(*config_and_inputs) + + def test_gpt_bigcode_model_past_large_inputs(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_gpt_bigcode_model_past_large_inputs(*config_and_inputs) + + def test_gpt_bigcode_lm_head_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_lm_head_model(*config_and_inputs) + + def test_gpt_bigcode_double_lm_head_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_double_lm_head_model(*config_and_inputs) + + def test_gpt_bigcode_sequence_classification_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_gpt_bigcode_for_sequence_classification(*config_and_inputs) + + def test_gpt_bigcode_token_classification_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_gpt_bigcode_for_token_classification(*config_and_inputs) + + def test_gpt_bigcode_gradient_checkpointing(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs, gradient_checkpointing=True) + + def test_gpt_bigcode_scale_attn_by_inverse_layer_idx(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs(scale_attn_by_inverse_layer_idx=True) + self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs) + + def test_gpt_bigcode_reorder_and_upcast_attn(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs(reorder_and_upcast_attn=True) + self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs) + + def test_gpt_bigcode_weight_initialization(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_gpt_bigcode_weight_initialization(*config_and_inputs) + + @slow + def test_batch_generation(self): + model = GPTBigCodeLMHeadModel.from_pretrained("gpt2") + model.to(torch_device) + tokenizer = GPTBigCodeTokenizer.from_pretrained("gpt2") + + tokenizer.padding_side = "left" + + # Define PAD Token = EOS Token = 50256 + tokenizer.pad_token = tokenizer.eos_token + model.config.pad_token_id = model.config.eos_token_id + + # use different length sentences to test batching + sentences = [ + "Hello, my dog is a little", + "Today, I", + ] + + inputs = tokenizer(sentences, return_tensors="pt", padding=True) + input_ids = inputs["input_ids"].to(torch_device) + token_type_ids = torch.cat( + [ + input_ids.new_full((input_ids.shape[0], input_ids.shape[1] - 1), 0), + input_ids.new_full((input_ids.shape[0], 1), 500), + ], + dim=-1, + ) + + outputs = model.generate( + input_ids=input_ids, + attention_mask=inputs["attention_mask"].to(torch_device), + ) + + outputs_tt = model.generate( + input_ids=input_ids, + attention_mask=inputs["attention_mask"].to(torch_device), + token_type_ids=token_type_ids, + ) + + inputs_non_padded = tokenizer(sentences[0], return_tensors="pt").input_ids.to(torch_device) + output_non_padded = model.generate(input_ids=inputs_non_padded) + + num_paddings = inputs_non_padded.shape[-1] - inputs["attention_mask"][-1].long().sum().cpu().item() + inputs_padded = tokenizer(sentences[1], return_tensors="pt").input_ids.to(torch_device) + output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings) + + batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True) + batch_out_sentence_tt = tokenizer.batch_decode(outputs_tt, skip_special_tokens=True) + non_padded_sentence = tokenizer.decode(output_non_padded[0], skip_special_tokens=True) + padded_sentence = tokenizer.decode(output_padded[0], skip_special_tokens=True) + + expected_output_sentence = [ + "Hello, my dog is a little bit of a mess. I'm not sure if he's going", + "Today, I'm going to be doing a lot of research on this. I", + ] + self.assertListEqual(expected_output_sentence, batch_out_sentence) + self.assertTrue(batch_out_sentence_tt != batch_out_sentence) # token_type_ids should change output + self.assertListEqual(expected_output_sentence, [non_padded_sentence, padded_sentence]) + + @slow + def test_batch_generation_2heads(self): + model = GPTBigCodeDoubleHeadsModel.from_pretrained("gpt2") + model.to(torch_device) + tokenizer = GPTBigCodeTokenizer.from_pretrained("gpt2") + + tokenizer.padding_side = "left" + + # This tokenizer has no pad token, so we have to set it in some way + # Define PAD Token = EOS Token = 50256 + tokenizer.pad_token = tokenizer.eos_token + model.config.pad_token_id = model.config.eos_token_id + + # use different length sentences to test batching + sentences = [ + "Hello, my dog is a little", + "Today, I", + ] + + inputs = tokenizer(sentences, return_tensors="pt", padding=True) + input_ids = inputs["input_ids"].to(torch_device) + token_type_ids = torch.cat( + [ + input_ids.new_full((input_ids.shape[0], input_ids.shape[1] - 1), 0), + input_ids.new_full((input_ids.shape[0], 1), 500), + ], + dim=-1, + ) + + outputs = model.generate( + input_ids=input_ids, + attention_mask=inputs["attention_mask"].to(torch_device), + ) + + outputs_tt = model.generate( + input_ids=input_ids, + attention_mask=inputs["attention_mask"].to(torch_device), + token_type_ids=token_type_ids, + ) + + inputs_non_padded = tokenizer(sentences[0], return_tensors="pt").input_ids.to(torch_device) + output_non_padded = model.generate(input_ids=inputs_non_padded) + + num_paddings = inputs_non_padded.shape[-1] - inputs["attention_mask"][-1].long().sum().cpu().item() + inputs_padded = tokenizer(sentences[1], return_tensors="pt").input_ids.to(torch_device) + output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings) + + batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True) + batch_out_sentence_tt = tokenizer.batch_decode(outputs_tt, skip_special_tokens=True) + non_padded_sentence = tokenizer.decode(output_non_padded[0], skip_special_tokens=True) + padded_sentence = tokenizer.decode(output_padded[0], skip_special_tokens=True) + + expected_output_sentence = [ + "Hello, my dog is a little bit of a mess. I'm not sure if he's going", + "Today, I'm going to be doing a lot of research on this. I", + ] + self.assertListEqual(expected_output_sentence, batch_out_sentence) + self.assertTrue(batch_out_sentence_tt != batch_out_sentence) # token_type_ids should change output + self.assertListEqual(expected_output_sentence, [non_padded_sentence, padded_sentence]) + + @slow + def test_model_from_pretrained(self): + for model_name in GPT_BIGCODE_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: + model = GPTBigCodeModel.from_pretrained(model_name) + self.assertIsNotNone(model) + + +@require_torch +class GPTBigCodeModelLanguageGenerationTest(unittest.TestCase): + def _test_lm_generate_gpt_bigcode_helper( + self, + gradient_checkpointing=False, + reorder_and_upcast_attn=False, + scale_attn_by_inverse_layer_idx=False, + verify_outputs=True, + ): + model = GPTBigCodeLMHeadModel.from_pretrained( + "gpt2", + reorder_and_upcast_attn=reorder_and_upcast_attn, + scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx, + ) + if gradient_checkpointing: + model.gradient_checkpointing_enable() + else: + model.gradient_checkpointing_disable() + model.to(torch_device) + + # The dog + input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) + + # The dog was found in a field near the intersection of West and West Streets.\n\nThe dog + # fmt: off + expected_output_ids = [ + 464, 3290, 373, 1043, 287, 257, 2214, 1474, 262, 16246, 286, 2688, 290, 2688, 27262, 13, 198, 198, 464, 3290, + ] + # fmt: on + output_ids = model.generate(input_ids, do_sample=False) + if verify_outputs: + self.assertListEqual(output_ids[0].tolist(), expected_output_ids) + + @slow + def test_lm_generate_gpt_bigcode(self): + self._test_lm_generate_gpt_bigcode_helper() + + @slow + def test_lm_generate_gpt_bigcode_with_gradient_checkpointing(self): + self._test_lm_generate_gpt_bigcode_helper(gradient_checkpointing=True) + + @slow + def test_lm_generate_gpt_bigcode_with_reorder_and_upcast_attn(self): + self._test_lm_generate_gpt_bigcode_helper(reorder_and_upcast_attn=True) + + @slow + def test_lm_generate_gpt_bigcode_with_scale_attn_by_inverse_layer_idx(self): + self._test_lm_generate_gpt_bigcode_helper(scale_attn_by_inverse_layer_idx=True, verify_outputs=False) + + @slow + def test_gpt_bigcode_sample(self): + tokenizer = GPTBigCodeTokenizer.from_pretrained("gpt2") + model = GPTBigCodeLMHeadModel.from_pretrained("gpt2") + model.to(torch_device) + + torch.manual_seed(0) + tokenized = tokenizer("Today is a nice day and", return_tensors="pt", return_token_type_ids=True) + input_ids = tokenized.input_ids.to(torch_device) + output_ids = model.generate(input_ids, do_sample=True) + output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True) + + token_type_ids = tokenized.token_type_ids.to(torch_device) + output_seq = model.generate(input_ids=input_ids, do_sample=True, num_return_sequences=5) + output_seq_tt = model.generate( + input_ids=input_ids, token_type_ids=token_type_ids, do_sample=True, num_return_sequences=5 + ) + output_seq_strs = tokenizer.batch_decode(output_seq, skip_special_tokens=True) + output_seq_tt_strs = tokenizer.batch_decode(output_seq_tt, skip_special_tokens=True) + + EXPECTED_OUTPUT_STR = ( + "Today is a nice day and if you don't know anything about the state of play during your holiday" + ) + self.assertEqual(output_str, EXPECTED_OUTPUT_STR) + self.assertTrue( + all([output_seq_strs[idx] != output_seq_tt_strs[idx] for idx in range(len(output_seq_tt_strs))]) + ) # token_type_ids should change output + + @slow + def test_gpt_bigcode_sample_max_time(self): + tokenizer = GPTBigCodeTokenizer.from_pretrained("gpt2") + model = GPTBigCodeLMHeadModel.from_pretrained("gpt2") + model.to(torch_device) + + torch.manual_seed(0) + tokenized = tokenizer("Today is a nice day and", return_tensors="pt", return_token_type_ids=True) + input_ids = tokenized.input_ids.to(torch_device) + + MAX_TIME = 0.5 + + start = datetime.datetime.now() + model.generate(input_ids, do_sample=True, max_time=MAX_TIME, max_length=256) + duration = datetime.datetime.now() - start + self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME)) + self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME)) + + start = datetime.datetime.now() + model.generate(input_ids, do_sample=False, max_time=MAX_TIME, max_length=256) + duration = datetime.datetime.now() - start + self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME)) + self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME)) + + start = datetime.datetime.now() + model.generate(input_ids, do_sample=False, num_beams=2, max_time=MAX_TIME, max_length=256) + duration = datetime.datetime.now() - start + self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME)) + self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME)) + + start = datetime.datetime.now() + model.generate(input_ids, do_sample=True, num_beams=2, max_time=MAX_TIME, max_length=256) + duration = datetime.datetime.now() - start + self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME)) + self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME)) + + start = datetime.datetime.now() + model.generate(input_ids, do_sample=False, max_time=None, max_length=256) + duration = datetime.datetime.now() - start + self.assertGreater(duration, datetime.timedelta(seconds=1.5 * MAX_TIME)) + + @slow + def test_contrastive_search_gpt_bigcode(self): + article = ( + "DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research " + "laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based" + ) + + gpt_bigcode_tokenizer = GPTBigCodeTokenizer.from_pretrained("gpt2-large") + gpt_bigcode_model = GPTBigCodeLMHeadModel.from_pretrained("gpt2-large").to(torch_device) + input_ids = gpt_bigcode_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) + + outputs = gpt_bigcode_model.generate(input_ids, penalty_alpha=0.6, top_k=4, max_length=256) + + generated_text = gpt_bigcode_tokenizer.batch_decode(outputs, skip_special_tokens=True) + + self.assertListEqual( + generated_text, + [ + "DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research " + "laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based in London, " + "United Kingdom\n\nGoogle has a lot of data on its users and uses it to improve its products, such as " + "Google Now, which helps users find the information they're looking for on the web. But the company " + "is not the only one to collect data on its users. Facebook, for example, has its own facial " + "recognition technology, as well as a database of millions of photos that it uses to personalize its " + "News Feed.\n\nFacebook's use of data is a hot topic in the tech industry, with privacy advocates " + "concerned about the company's ability to keep users' information private. In a blog post last " + 'year, Facebook CEO Mark Zuckerberg said his company would "do our best to be transparent about our ' + 'data use and how we use it."\n\n"We have made it clear that we do not sell or share your data with ' + 'third parties," Zuckerberg wrote. "If you have questions or concerns, please reach out to us at ' + 'privacy@facebook.com."\n\nGoogle declined to comment on the privacy implications of its use of data, ' + "but said in a statement to The Associated Press that" + ], + ) From 4fdf8c1e25022efd713e9c02d12e28174293f29e Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 7 Feb 2023 11:32:12 -0500 Subject: [PATCH 04/27] fix --- __init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 __init__.py diff --git a/__init__.py b/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 From c25f221fb3bcaf78afa3f1f7d2884297afd3484c Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 7 Feb 2023 12:08:41 -0500 Subject: [PATCH 05/27] fix --- tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py b/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py index 52b4669896..936d4c9f76 100644 --- a/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py +++ b/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py @@ -36,7 +36,7 @@ GPTBigCodeForTokenClassification, GPTBigCodeLMHeadModel, GPTBigCodeModel, - GPTBigCodeTokenizer, + GPT2Tokenizer, ) @@ -526,7 +526,7 @@ def test_gpt_bigcode_weight_initialization(self): def test_batch_generation(self): model = GPTBigCodeLMHeadModel.from_pretrained("gpt2") model.to(torch_device) - tokenizer = GPTBigCodeTokenizer.from_pretrained("gpt2") + tokenizer = GPT2Tokenizer.from_pretrained("gpt2") tokenizer.padding_side = "left" @@ -585,7 +585,7 @@ def test_batch_generation(self): def test_batch_generation_2heads(self): model = GPTBigCodeDoubleHeadsModel.from_pretrained("gpt2") model.to(torch_device) - tokenizer = GPTBigCodeTokenizer.from_pretrained("gpt2") + tokenizer = GPT2Tokenizer.from_pretrained("gpt2") tokenizer.padding_side = "left" @@ -699,7 +699,7 @@ def test_lm_generate_gpt_bigcode_with_scale_attn_by_inverse_layer_idx(self): @slow def test_gpt_bigcode_sample(self): - tokenizer = GPTBigCodeTokenizer.from_pretrained("gpt2") + tokenizer = GPT2Tokenizer.from_pretrained("gpt2") model = GPTBigCodeLMHeadModel.from_pretrained("gpt2") model.to(torch_device) @@ -773,7 +773,7 @@ def test_contrastive_search_gpt_bigcode(self): "laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based" ) - gpt_bigcode_tokenizer = GPTBigCodeTokenizer.from_pretrained("gpt2-large") + gpt_bigcode_tokenizer = GPT2Tokenizer.from_pretrained("gpt2-large") gpt_bigcode_model = GPTBigCodeLMHeadModel.from_pretrained("gpt2-large").to(torch_device) input_ids = gpt_bigcode_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) From 2134a34d7e9b46db9d5379dcd4e4535f6a2c9170 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 7 Feb 2023 14:55:04 -0500 Subject: [PATCH 06/27] Megatron conversion --- .../models/megatron_gpt_bigcode/__init__.py | 17 + ...eckpoint_reshaping_and_interoperability.py | 900 ++++++++++++++++++ .../convert_megatron_gpt2_checkpoint.py | 364 +++++++ 3 files changed, 1281 insertions(+) create mode 100644 src/transformers/models/megatron_gpt_bigcode/__init__.py create mode 100644 src/transformers/models/megatron_gpt_bigcode/checkpoint_reshaping_and_interoperability.py create mode 100644 src/transformers/models/megatron_gpt_bigcode/convert_megatron_gpt2_checkpoint.py diff --git a/src/transformers/models/megatron_gpt_bigcode/__init__.py b/src/transformers/models/megatron_gpt_bigcode/__init__.py new file mode 100644 index 0000000000..8228eea536 --- /dev/null +++ b/src/transformers/models/megatron_gpt_bigcode/__init__.py @@ -0,0 +1,17 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2021 NVIDIA Corporation and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/transformers/models/megatron_gpt_bigcode/checkpoint_reshaping_and_interoperability.py b/src/transformers/models/megatron_gpt_bigcode/checkpoint_reshaping_and_interoperability.py new file mode 100644 index 0000000000..f8fb1e7c8f --- /dev/null +++ b/src/transformers/models/megatron_gpt_bigcode/checkpoint_reshaping_and_interoperability.py @@ -0,0 +1,900 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import os +import re +import sys +import types + +import torch + +from transformers import AutoTokenizer, GPT2Config +from transformers.modeling_utils import WEIGHTS_INDEX_NAME, WEIGHTS_NAME, shard_checkpoint + + +def add_checkpointing_args(parser): + parser.add_argument("--megatron-path", type=str, default=None, help="Base directory of Megatron repository") + parser.add_argument( + "--convert_checkpoint_from_megatron_to_transformers", + action="store_true", + help=( + "If True, convert a Megatron checkpoint to a Transformers checkpoint. " + "If False, convert a Transformers checkpoint to a Megatron checkpoint." + ), + ) + parser.add_argument( + "--load_path", + type=str, + required=True, + help="Path to the checkpoint to convert.", + ) + parser.add_argument( + "--save_path", + type=str, + required=True, + help="Path to the converted checkpoint.", + ) + parser.add_argument("--print-checkpoint-structure", action="store_true") + return parser + + +def add_megatron_checkpoint_args(parser): + parser.add_argument( + "--target_tensor_model_parallel_size", + type=int, + default=1, + help=( + "The tensor model parallel size of the converted checkpoint. " + "Only used when converting a Transformers checkpoint to a Megatron checkpoint." + ), + ) + parser.add_argument( + "--target_pipeline_model_parallel_size", + type=int, + default=1, + help=( + "The pipeline model parallel size of the converted checkpoint. " + "Only used when converting a Transformers checkpoint to a Megatron checkpoint." + ), + ) + parser.add_argument( + "--target_data_parallel_size", + type=int, + default=1, + help=( + "The data parallel size of the converted checkpoint. " + "Only used when converting a Transformers checkpoint to a Megatron checkpoint." + ), + ) + parser.add_argument( + "--target_params_dtype", + type=str, + default="fp32", + help=( + "The dtype of the converted checkpoint. " + "Only used when converting a Transformers checkpoint to a Megatron checkpoint." + ), + ) + parser.add_argument( + "--make_vocab_size_divisible_by", + type=int, + default=128, + help=( + "Pad the vocab size to be divisible by this value. " + "This is added for computational efficieny reasons. " + "Only used when converting a Transformers checkpoint to a Megatron checkpoint." + ), + ) + parser.add_argument( + "--use_distributed_optimizer", + action="store_true", + help=( + "If True, use the distributed optimizer. " + "Only used when converting a Transformers checkpoint to a Megatron checkpoint." + ), + ) + return parser + + +def add_transformers_checkpoint_args(parser): + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help=( + "The name of the pre-trained tokenizer to save. " + "If not None, the tokenizer will be saved. " + "Only used when converting a Megatron checkpoint to a Transformers checkpoint." + ), + ) + parser.add_argument( + "--max_shard_size", + type=str, + default="10GB", + help=( + "The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size " + "lower than this size. If expressed as a string, needs to be digits followed by a unit (like `5MB`). " + "Only used when converting a Megatron checkpoint to a Transformers checkpoint." + ), + ) + + return parser + + +# The simple map of names for "automated" rules. +megatron_to_transformers = { + "attention.dense": ".attn.c_proj.", + "self_attention.dense": ".attn.c_proj.", + "mlp.dense_h_to_4h": ".mlp.c_fc.", + "mlp.dense_4h_to_h": ".mlp.c_proj.", +} +transformers_to_megatron = {v[1:-1]: k for k, v in megatron_to_transformers.items()} + +tensor_parallel_params = [ + # megatron-lm layers to merge across tp ranks + "self_attention.query_key_value.weight", + "self_attention.query_key_value.bias", + "self_attention.dense.weight", + "mlp.dense_h_to_4h.weight", + "mlp.dense_h_to_4h.bias", + "mlp.dense_4h_to_h.weight", + # deprecated + "attention.query_key_value.weight", + "attention.query_key_value.bias", + "attention.dense.weight", + # transformers layers to split across tp ranks + "attn.c_attn.weight", + "attn.c_attn.bias", + "attn.c_proj.weight", + "mlp.c_fc.weight", + "mlp.c_fc.bias", + "mlp.c_proj.weight", +] + + +def recursive_print(name, val, spaces=0): + """ + Recursively print the structure of a checkpoint. This function is taken from `convert_megatron_gpt2_checkpoint.py` + + Args: + name (str): the name of the current tensor parameter + val (Tuple(int)): the shape of the current tensor parameter + spaces (int): the number of spaces to print before the output for a nested structure + """ + # Format the message. + if name is None: + msg = None + else: + fmt = "." * max(0, spaces - 2) + "# {:" + str(50 - spaces) + "s}" + msg = fmt.format(name) + + # Print and recurse (if needed). + if isinstance(val, dict): + if msg is not None: + print(msg) + for k in val.keys(): + recursive_print(k, val[k], spaces + 2) + elif isinstance(val, torch.Tensor): + print(msg, ":", val.size()) + else: + print(msg, ":", val) + + +def megatron_to_transformers_fix_query_key_value_ordering( + param, checkpoint_version, num_splits, num_heads, hidden_size +): + """ + Permutes layout of param tensor to [num_splits * num_heads * hidden_size, :] for compatibility with later versions + of NVIDIA Megatron-LM. The inverse operation is performed inside Megatron-LM to read checkpoints: + https://github.com/NVIDIA/Megatron-LM/blob/v2.4/megatron/checkpointing.py#L209 If param is the weight tensor of the + self-attention block, the returned tensor will have to be transposed one more time to be read by HuggingFace GPT2. + This function is taken from `convert_megatron_gpt2_checkpoint.py` + + Args: + param (torch.Tensor): the tensor to permute + checkpoint_version (int): the version of the checkpoint. + num_splits (int): the number of projections, usually 3 for (Query, Key, Value) + num_heads (int): the number of attention heads + hidden_size (int): the hidden size per head + """ + + input_shape = param.size() + if checkpoint_version == 1.0: + # version 1.0 stores [num_heads * hidden_size * num_splits, :] + saved_shape = (num_heads, hidden_size, num_splits) + input_shape[1:] + param = param.view(*saved_shape) + param = param.transpose(0, 2) + param = param.transpose(1, 2).contiguous() + elif checkpoint_version >= 2.0: + # other versions store [num_heads * num_splits * hidden_size, :] + saved_shape = (num_heads, num_splits, hidden_size) + input_shape[1:] + param = param.view(*saved_shape) + param = param.transpose(0, 1).contiguous() + param = param.view(*input_shape) + return param + + +def transformers_to_megatron_fix_query_key_value_ordering( + param, checkpoint_version, num_splits, num_heads, hidden_size +): + """ + Permutes layout of param tensor to the one compatible with respective NVIDIA Megatron-LM chekpoint versions. Input + is [num_splits * num_heads * hidden_size, :] and output is [num_heads * hidden_size * num_splits, :] for version + 1.0 and [num_heads * num_splits * hidden_size, :] for version 2.0 and later. If param is the weight tensor of the + self-attention block, the param needs to be already transposed before calling this function. + + Args: + param (torch.Tensor): the tensor to permute + checkpoint_version (int): the version of the checkpoint. + num_splits (int): the number of projections, usually 3 for (Query, Key, Value) + num_heads (int): the number of attention heads + hidden_size (int): the hidden size per head + """ + + # Input is [num_splits * num_heads * hidden_size, :] + input_shape = param.size() + if checkpoint_version == 1.0: + # version 1.0 stores [num_heads * hidden_size * num_splits, :] + current_shape = (num_splits, num_heads, hidden_size) + input_shape[1:] + param = param.view(*current_shape) + param = param.transpose(0, 2) + param = param.transpose(1, 2).contiguous() + elif checkpoint_version >= 2.0: + # other versions store [num_heads * num_splits * hidden_size, :] + current_shape = (num_splits, num_heads, hidden_size) + input_shape[1:] + param = param.view(*current_shape) + param = param.transpose(0, 1).contiguous() + param = param.view(*input_shape) + return param + + +def merge_transformers_sharded_states(path, num_checkpoints): + """ + Merge sharded checkpoints from transformers into a single checkpoint. + + Args: + path (str): the path to the sharded checkpoints + num_checkpoints (int): the number of checkpoints to merge + """ + state_dict = {} + for i in range(1, num_checkpoints + 1): + checkpoint_path = os.path.join(path, f"pytorch_model-{i:05d}-of-{num_checkpoints:05d}.bin") + current_chunk = torch.load(checkpoint_path, map_location="cpu") + state_dict.update(current_chunk) + return state_dict + + +def get_megatron_sharded_states(args, tp_size, pp_size, pp_rank): + """ + Get sharded checkpoints from NVIDIA Megatron-LM checkpoint based on the provided tensor parallel size, pipeline + parallel size and pipeline parallel rank. + + Args: + args (argparse.Namespace): the arguments to the script + tp_size (int): the tensor parallel size + pp_size (int): the pipeline parallel size + pp_rank (int): the pipeline parallel rank + """ + tp_state_dicts = [] + for i in range(tp_size): + sub_dir_name = f"mp_rank_{i:02d}" if pp_size == 1 else f"mp_rank_{i:02d}_{pp_rank:03d}" + checkpoint_name = os.listdir(os.path.join(args.load_path, sub_dir_name))[0] + checkpoint_path = os.path.join(args.load_path, sub_dir_name, checkpoint_name) + state_dict = torch.load(checkpoint_path, map_location="cpu") + tp_state_dicts.append(state_dict) + return tp_state_dicts + + +def get_element_from_dict_by_path(d, path): + """ + Get element from dictionary by path. If element is not present, recursively add empty dictionaries. + + Args: + d (dict): the dictionary to get the element from + path (list): the path to the element which is delimited by "." + """ + path = path.split(".") + for k in path: + if k not in d: + d[k] = {} + d = d[k] + return d + + +def convert_checkpoint_from_megatron_to_transformers(args): + """ + Convert NVIDIA Megatron-LM checkpoint to HuggingFace Transformers checkpoint. This handles Megatron checkpoints + with different tensor parallelism and pipeline parallelism sizes. It saves the converted checkpoint into shards + using HuggingFace Transformers checkpoint sharding functionality. This greatly extends the functionality of + `convert_megatron_gpt2_checkpoint.py` + + Args: + args (argparse.Namespace): the arguments to the script + """ + # Load Megatron-LM checkpoint arguments from the state dict + sub_dirs = os.listdir(args.load_path) + possible_sub_dirs = ["mp_rank_00", "mp_rank_00_000"] + for sub_dir in possible_sub_dirs: + if sub_dir in sub_dirs: + rank0_checkpoint_name = os.listdir(os.path.join(args.load_path, sub_dir))[0] + rank0_checkpoint_path = os.path.join(args.load_path, sub_dir, rank0_checkpoint_name) + break + print(f"Loading Megatron-LM checkpoint arguments from: {rank0_checkpoint_path}") + state_dict = torch.load(rank0_checkpoint_path, map_location="cpu") + megatron_args = state_dict.get("args", None) + if megatron_args is None: + raise ValueError( + "Megatron-LM checkpoint does not contain arguments. This utility only supports Megatron-LM checkpoints" + " containing all the megatron arguments. This is because it loads all config related to model" + " architecture, the tensor and pipeline model parallel size from the checkpoint insead of user having to" + " manually specify all the details. Please save Megatron-LM checkpoint along with all the megatron" + " arguments to use this utility." + ) + + # Create Transformers GPT2 config from Megatron-LM arguments + if megatron_args is not None: + if megatron_args.bias_gelu_fusion: + activation_function = "gelu_fast" + elif megatron_args.openai_gelu: + activation_function = "gelu_new" + else: + activation_function = "gelu" + else: + # in the very early days this used to be "gelu_new" + activation_function = "gelu_new" + vocab_size = ( + megatron_args.padded_vocab_size + if getattr(megatron_args, "orig_vocab_size", None) is None + else megatron_args.orig_vocab_size + ) + print(vocab_size) + + config = GPT2Config( + vocab_size=vocab_size, + n_positions=megatron_args.max_position_embeddings, + n_embd=megatron_args.hidden_size, + n_layer=megatron_args.num_layers, + n_head=megatron_args.num_attention_heads, + n_inner=megatron_args.ffn_hidden_size, + activation_function=activation_function, + resid_pdrop=0.1, + embd_pdrop=0.1, + attn_pdrop=0.1, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + summary_type="cls_index", + summary_use_proj=True, + summary_activation=None, + summary_proj_to_labels=True, + summary_first_dropout=0.1, + scale_attn_weights=True, + use_cache=True, + bos_token_id=vocab_size - 1, + eos_token_id=vocab_size - 1, + architectures=["GPT2LMHeadModel"], + ) + + output_state_dict = {} + + checkpoint_version = state_dict.get("checkpoint_version", 0.0) + tp_size = megatron_args.tensor_model_parallel_size + pp_size = megatron_args.pipeline_model_parallel_size + dtype = torch.float32 + # The regex to extract layer names. + layer_re = re.compile("layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)") + + # Convert. + print("Converting") + + # Embeddings + print("Converting embeddings") + tp_state_dicts = get_megatron_sharded_states(args, tp_size, pp_size, 0) + + # Convert and store the position embeddings. + position_embeddings = get_element_from_dict_by_path( + tp_state_dicts[0], "model.language_model.embedding.position_embeddings.weight" + ) + output_state_dict["transformer.wpe.weight"] = position_embeddings.to(dtype) + + # Convert and store the word embeddings. + word_embeddings = torch.cat( + [ + get_element_from_dict_by_path( + tp_state_dicts[tp_rank], "model.language_model.embedding.word_embeddings.weight" + ) + for tp_rank in range(tp_size) + ], + dim=0, + ) + word_embeddings = word_embeddings[:vocab_size].to(dtype) + output_state_dict["transformer.wte.weight"] = word_embeddings + + # Transformer Layers + print("Converting transformer layers") + # The number of heads. + heads = config.n_head + # The hidden_size per head. + hidden_size_per_head = config.n_embd // config.n_head + n_positions = config.n_positions + num_layers = config.num_hidden_layers // pp_size + + for pp_rank in range(pp_size): + if pp_size > 0: + print(f"Converting pipeline parallel rank {pp_rank}") + tp_state_dicts = get_megatron_sharded_states(args, tp_size, pp_size, pp_rank) + + # The transformer. + path = ( + "model.language_model.transformer" + if "transformer" in get_element_from_dict_by_path(tp_state_dicts[0], "model.language_model").keys() + else "model.language_model.encoder" + ) + # Extract the layers. + for key, val in get_element_from_dict_by_path(tp_state_dicts[0], path).items(): + # Match the name. + m = layer_re.match(key) + # Stop if that's not a layer + if m is None: + break + + # The index of the layer. + layer_idx = int(m.group(1)) + pp_rank * num_layers + # The name of the operation. + op_name = m.group(2) + # Is it a weight or a bias? + weight_or_bias = m.group(3) + + # The name of the layer. + layer_name = f"transformer.h.{layer_idx}" + + if op_name + "." + weight_or_bias not in tensor_parallel_params: + params = val.to(dtype) + else: + dim = 1 if op_name in ["self_attention.dense", "mlp.dense_4h_to_h", "attention.dense"] else 0 + params = torch.cat( + [val] + + [ + get_element_from_dict_by_path(tp_state_dicts[tp_rank], f"{path}")[key] + for tp_rank in range(1, tp_size) + ], + dim=dim, + ).to(dtype) + + # For layernorm(s), simply store the layer norm. + if op_name.endswith("layernorm"): + + ln_name = "ln_1" if op_name.startswith("input") else "ln_2" + output_state_dict[layer_name + "." + ln_name + "." + weight_or_bias] = params + + # Transpose the QKV matrix. + elif ( + op_name == "attention.query_key_value" or op_name == "self_attention.query_key_value" + ) and weight_or_bias == "weight": + + # Insert a tensor of 1x1xDxD bias. + causal_mask = torch.tril(torch.ones((n_positions, n_positions), dtype=dtype)).view( + 1, 1, n_positions, n_positions + ) + output_state_dict[layer_name + ".attn.bias"] = causal_mask + + # Insert a "dummy" tensor for masked_bias. + masked_bias = torch.tensor(-1e4, dtype=dtype) + output_state_dict[layer_name + ".attn.masked_bias"] = masked_bias + + out_val = megatron_to_transformers_fix_query_key_value_ordering( + params, + checkpoint_version, + 3, + heads, + hidden_size_per_head, + ) + # Megatron stores (3*D) x D but transformers-GPT2 expects D x 3*D. + out_val = out_val.transpose(0, 1).contiguous() + # Store. + output_state_dict[layer_name + ".attn.c_attn.weight"] = out_val + + # Transpose the bias. + elif ( + op_name == "attention.query_key_value" or op_name == "self_attention.query_key_value" + ) and weight_or_bias == "bias": + + out_val = megatron_to_transformers_fix_query_key_value_ordering( + params, checkpoint_version, 3, heads, hidden_size_per_head + ) + # Store. No change of shape. + output_state_dict[layer_name + ".attn.c_attn.bias"] = out_val + + # Transpose the weights. + elif weight_or_bias == "weight": + + out_name = megatron_to_transformers[op_name] + output_state_dict[layer_name + out_name + "weight"] = params.transpose(0, 1) + + # Copy the bias. + elif weight_or_bias == "bias": + + out_name = megatron_to_transformers[op_name] + output_state_dict[layer_name + out_name + "bias"] = params + + if config.n_layer != (layer_idx + 1): + raise ValueError(f"Expected {config.n_layer} layers but found {layer_idx + 1}") + + # The final layernorm. + print("Converting final layernorm") + params = get_element_from_dict_by_path(tp_state_dicts[0], str(path)) + output_state_dict["transformer.ln_f.weight"] = params["final_layernorm.weight"].to(dtype) + output_state_dict["transformer.ln_f.bias"] = params["final_layernorm.bias"].to(dtype) + + # For LM head, transformers' wants the matrix to weight embeddings. + print("Converting LM head") + output_state_dict["lm_head.weight"] = word_embeddings.to(dtype) + + # It should be done! + print("Conversion from Megatron-LM to Transformers is done!") + + # Print the structure of converted state dict. + if args.print_checkpoint_structure: + recursive_print(None, output_state_dict) + + # Add tokenizer class info to config + # see https://github.com/huggingface/transformers/issues/13906) + + if args.tokenizer_name is None: + tokenizer_name = "gpt2" + else: + tokenizer_name = args.tokenizer_name + + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + tokenizer_class = type(tokenizer).__name__ + config.tokenizer_class = tokenizer_class + + # Store the config to file. + print("Saving config") + config.save_pretrained(args.save_path) + + # Save tokenizer based on args + if args.tokenizer_name is not None: + print(f"Adding {tokenizer_class} tokenizer files") + tokenizer.save_pretrained(args.save_path) + + # Store the state_dict to file. + max_shard_size = int(args.max_shard_size) if args.max_shard_size.isdigit() else args.max_shard_size + shards, index = shard_checkpoint(output_state_dict, max_shard_size=max_shard_size) + + # Save the model + for shard_file, shard in shards.items(): + torch.save(shard, os.path.join(args.save_path, shard_file)) + + if index is None: + print(f"Model weights saved in {os.path.join(args.save_path, WEIGHTS_NAME)}") + else: + save_index_file = os.path.join(args.save_path, WEIGHTS_INDEX_NAME) + # Save the index as well + with open(save_index_file, "w", encoding="utf-8") as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) + print( + f"The model is bigger than the maximum size per checkpoint ({args.max_shard_size}) and is going to be " + f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) + + +def convert_checkpoint_from_transformers_to_megatron(args): + """ + Convert a checkpoint from HuggingFace Transformers to Megatron-LM. This allows converted checkpoints with variable + tensor parallelism and pipeline parallelism sizes. It takes as input a checkpoint from HuggingFace Transformers + which can have multiple shards. + + Args: + args (argparse.Namespace): the arguments to the script + + """ + os.makedirs(args.save_path, exist_ok=True) + # Search in directory above this + sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) + if args.megatron_path is not None: + sys.path.insert(0, args.megatron_path) + + try: + from megatron.tokenizer.tokenizer import _vocab_size_with_padding + except ModuleNotFoundError: + print("Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting.") + exit(1) + + # load the transformers model state dict and config + sub_dirs = [x for x in os.listdir(args.load_path) if x.startswith("pytorch_model")] + if len(sub_dirs) == 1: + checkpoint_name = "pytorch_model.bin" + state_dict = torch.load(os.path.join(args.load_path, checkpoint_name), map_location="cpu") + else: + num_checkpoints = len(sub_dirs) - 1 + state_dict = merge_transformers_sharded_states(args.load_path, num_checkpoints) + + config = GPT2Config.from_pretrained(args.load_path) + + # Saving the tracker file + tracker_filepath = os.path.join(args.save_path, "latest_checkpointed_iteration.txt") + with open(tracker_filepath, "w") as f: + f.write("release") + + # create `release` dir in args.load_path + release_dir = os.path.join(args.save_path, "release") + os.makedirs(release_dir, exist_ok=True) + + # megatron args + megatron_args = { + "orig_vocab_size": config.vocab_size, + "max_position_embeddings": config.n_positions, + "hidden_size": config.n_embd, + "num_layers": config.n_layer, + "num_attention_heads": config.n_head, + "ffn_hidden_size": config.n_inner, + "tensor_model_parallel_size": args.target_tensor_model_parallel_size, + "pipeline_model_parallel_size": args.target_pipeline_model_parallel_size, + "data_parallel_size": args.target_data_parallel_size, + "make_vocab_size_divisible_by": args.make_vocab_size_divisible_by, + "rank": 0, + "tokenizer_type": "GPT2BPETokenizer", + } + + if config.activation_function == "gelu": + megatron_args["bias_gelu_fusion"] = False + megatron_args["openai_gelu"] = False + elif config.activation_function == "gelu_fast": + megatron_args["bias_gelu_fusion"] = True + megatron_args["openai_gelu"] = False + elif config.activation_function == "gelu_new": + megatron_args["bias_gelu_fusion"] = False + megatron_args["openai_gelu"] = True + + margs = types.SimpleNamespace() + for k, v in megatron_args.items(): + setattr(margs, k, v) + + # params dtype + if args.target_params_dtype == "fp16": + dtype = torch.float16 + elif args.target_params_dtype == "bf16": + dtype = torch.bfloat16 + else: + dtype = torch.float32 + setattr(margs, "params_dtype", dtype) + + # save dummy optim state dict + dummy_optim_state_dict = {} + dummy_optim_state_dict["optimizer"] = { + "step": 0, + "param_groups": [ + { + "lr": 0.0, + "beta1": 0.0, + "beta2": 0.0, + "eps": 0.0, + "weight_decay": 0.0, + "correct_bias": False, + "params": [], + } + ], + } + if args.use_distributed_optimizer: + for i in range(args.target_pipeline_model_parallel_size): + for j in range(args.target_tensor_model_parallel_size): + for k in range(args.target_data_parallel_size): + if args.target_pipeline_model_parallel_size == 1: + checkpoint_dir = f"mp_rank_{i:02d}_{k:03d}" + else: + checkpoint_dir = f"mp_rank_{i:02d}_{j:03d}_{k:03d}" + checkpoint_dir = os.path.join(release_dir, checkpoint_dir) + os.makedirs(checkpoint_dir, exist_ok=True) + torch.save( + dummy_optim_state_dict, + os.path.join(checkpoint_dir, "optim.pt"), + ) + + # Convert. + print("Converting") + output_state_dict = [] + for i in range(args.target_tensor_model_parallel_size): + output_state_dict.append({}) + + # Embedding layer + print("converting embedding layer") + pos_embedding = state_dict["transformer.wpe.weight"].to(dtype) + word_embedding = state_dict["transformer.wte.weight"].to(dtype) + orig_vocab_size = config.vocab_size + padded_vocab_size = _vocab_size_with_padding(orig_vocab_size, margs) + setattr(margs, "padded_vocab_size", padded_vocab_size) + # Cut out extra padding we don't need + if orig_vocab_size > padded_vocab_size: + full_word_embed = word_embedding[0:padded_vocab_size, :] + # Expanding embedding to larger size by replicating final entry + elif orig_vocab_size < padded_vocab_size: + padding_size = padded_vocab_size - orig_vocab_size + full_word_embed = torch.cat((word_embedding, word_embedding[-1].unsqueeze(0).expand(padding_size, -1))) + # Same size! + else: + full_word_embed = word_embedding + + # Split into new tensor model parallel sizes + out_word_embed = torch.chunk(full_word_embed, args.target_tensor_model_parallel_size, dim=0) + for i in range(args.target_tensor_model_parallel_size): + pos_emb_dict = get_element_from_dict_by_path( + output_state_dict[i], "model.language_model.embedding.position_embeddings" + ) + pos_emb_dict["weight"] = pos_embedding + + word_emb_dict = get_element_from_dict_by_path( + output_state_dict[i], "model.language_model.embedding.word_embeddings" + ) + word_emb_dict["weight"] = out_word_embed[i] + + # Transformer layers + print("converting transformer layers") + if config.num_hidden_layers % args.target_tensor_model_parallel_size != 0: + raise ValueError( + f"Number of layers ({config.num_hidden_layers}) must be divisible by number of tensor parallelism" + f" ({args.target_tensor_model_parallel_size})" + ) + num_layers = config.num_hidden_layers // args.target_pipeline_model_parallel_size + + layer_re = re.compile("transformer.h\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)") + # The number of heads. + heads = config.n_head + # The hidden_size per head. + hidden_size_per_head = config.n_embd // config.n_head + for pp_rank in range(args.target_pipeline_model_parallel_size): + layer_offset = pp_rank * num_layers + if pp_rank > 0: + output_state_dict = [] + for i in range(args.target_tensor_model_parallel_size): + output_state_dict.append({}) + + for layer in range(num_layers): + pp_layer_id = layer + layer_offset + layers_to_copy = [ + layer_name + for layer_name in state_dict.keys() + if layer_name.startswith(f"transformer.h.{pp_layer_id}.") + ] + + for layer_name in layers_to_copy: + m = layer_re.match(layer_name) + # Stop if that's not a layer + if m is None: + break + + # The index of the layer. + _ = int(m.group(1)) + # The name of the operation. + op_name = m.group(2) + # Is it a weight or a bias? + weight_or_bias = m.group(3) + + params = state_dict[layer_name].to(dtype) + # handle layernorm + if op_name.startswith("ln"): + out_name = "input_layernorm" if op_name.endswith("1") else "post_attention_layernorm" + layer_name = f"layers.{layer}.{out_name}.{weight_or_bias}" + + # handle attention K, V, Q weights + elif op_name.startswith("attn.c_attn") and weight_or_bias == "weight": + # transformers stores D X (3*D) but Megatron-LM expects (3*D) X D. + params = params.transpose(0, 1).contiguous() + + params = transformers_to_megatron_fix_query_key_value_ordering( + params, + 3.0, + 3, + heads, + hidden_size_per_head, + ) + layer_name = f"layers.{layer}.self_attention.query_key_value.{weight_or_bias}" + + # handle attention K, V, Q bias + elif op_name.startswith("attn.c_attn") and weight_or_bias == "bias": + params = transformers_to_megatron_fix_query_key_value_ordering( + params, + 3.0, + 3, + heads, + hidden_size_per_head, + ) + layer_name = f"layers.{layer}.self_attention.query_key_value.{weight_or_bias}" + + # handle attention and mlp weights + elif weight_or_bias == "weight": + out_name = transformers_to_megatron.get(op_name, None) + if out_name is None: + continue + params = params.transpose(0, 1) + layer_name = f"layers.{layer}.{out_name}.{weight_or_bias}" + + # handle attention and mlp bias + elif weight_or_bias == "bias": + out_name = transformers_to_megatron.get(op_name, None) + if out_name is None: + continue + layer_name = f"layers.{layer}.{out_name}.{weight_or_bias}" + + # skip + else: + continue + + if op_name + "." + weight_or_bias in tensor_parallel_params: + dim = 1 if op_name in ["attn.c_proj", "mlp.c_proj"] else 0 + params = torch.chunk(params, args.target_tensor_model_parallel_size, dim=dim) + + for i in range(args.target_tensor_model_parallel_size): + params_dict = get_element_from_dict_by_path(output_state_dict[i], "model.language_model.encoder") + params_dict[layer_name] = ( + params[i] if (op_name + "." + weight_or_bias in tensor_parallel_params) else params + ) + + if pp_rank == args.target_pipeline_model_parallel_size - 1: + # handle final layernorm + for weight_or_bias in ["weight", "bias"]: + params = state_dict[f"transformer.ln_f.{weight_or_bias}"].to(dtype) + layer_name = f"final_layernorm.{weight_or_bias}" + for i in range(args.target_tensor_model_parallel_size): + params_dict = get_element_from_dict_by_path(output_state_dict[i], "model.language_model.encoder") + params_dict[layer_name] = params + + # add the LM head + for i in range(args.target_tensor_model_parallel_size): + params_dict = get_element_from_dict_by_path(output_state_dict[i], "model.word_embeddings_for_head") + params_dict["weight"] = out_word_embed[i] + + # saving the state dict as per the tp_rank and pp_rank + for tp_rank in range(args.target_tensor_model_parallel_size): + output_state_dict[tp_rank]["checkpoint_version"] = 3.0 + output_state_dict[tp_rank]["args"] = margs + checkpoint_dir = ( + f"mp_rank_{tp_rank:02d}" + if args.target_pipeline_model_parallel_size == 1 + else f"mp_rank_{tp_rank:02d}_{pp_rank:03d}" + ) + if args.use_distributed_optimizer: + checkpoint_name = "model_rng.pt" + else: + checkpoint_name = "model_optim_rng.pt" + output_state_dict[tp_rank]["optimizer"] = dummy_optim_state_dict["optimizer"] + checkpoint_dir = os.path.join(release_dir, checkpoint_dir) + os.makedirs(checkpoint_dir, exist_ok=True) + checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name) + if args.print_checkpoint_structure: + print( + f"Checkpoint structure of model state dict shard belonging to TP rank {tp_rank} and PP rank" + f" {pp_rank}:" + ) + recursive_print(None, output_state_dict[tp_rank]) + torch.save(output_state_dict[tp_rank], checkpoint_path) + + +def main(): + parser = argparse.ArgumentParser() + parser = add_checkpointing_args(parser) + parser = add_megatron_checkpoint_args(parser) + parser = add_transformers_checkpoint_args(parser) + args = parser.parse_args() + if args.convert_checkpoint_from_megatron_to_transformers: + convert_checkpoint_from_megatron_to_transformers(args) + else: + convert_checkpoint_from_transformers_to_megatron(args) + + +if __name__ == "__main__": + main() diff --git a/src/transformers/models/megatron_gpt_bigcode/convert_megatron_gpt2_checkpoint.py b/src/transformers/models/megatron_gpt_bigcode/convert_megatron_gpt2_checkpoint.py new file mode 100644 index 0000000000..778b1384a2 --- /dev/null +++ b/src/transformers/models/megatron_gpt_bigcode/convert_megatron_gpt2_checkpoint.py @@ -0,0 +1,364 @@ +#################################################################################################### + +# Copyright (c) 2021-, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#################################################################################################### + +# +# Note: If when running this conversion script you're getting an exception: +# ModuleNotFoundError: No module named 'megatron.model.enums' +# you need to tell python where to find the clone of Megatron-LM, e.g.: +# +# cd /tmp +# git clone https://github.com/NVIDIA/Megatron-LM +# PYTHONPATH=/tmp/Megatron-LM python src/transformers/models/megatron_gpt2/convert_megatron_gpt2_checkpoint.py ... +# +# if you already have it cloned elsewhere, simply adjust the path to the existing path +# +# If the training was done using a Megatron-LM fork, e.g., +# https://github.com/microsoft/Megatron-DeepSpeed/ then chances are that you need to have that one +# in your path, i.e., /path/to/Megatron-DeepSpeed/ +# + +import argparse +import os +import re +import zipfile + +import torch + +from transformers import AutoTokenizer, GPT2Config + + +#################################################################################################### + + +def recursive_print(name, val, spaces=0): + # Format the message. + if name is None: + msg = None + else: + fmt = "." * max(0, spaces - 2) + "# {:" + str(50 - spaces) + "s}" + msg = fmt.format(name) + + # Print and recurse (if needed). + if isinstance(val, dict): + if msg is not None: + print(msg) + for k in val.keys(): + recursive_print(k, val[k], spaces + 2) + elif isinstance(val, torch.Tensor): + print(msg, ":", val.size()) + else: + print(msg, ":", val) + + +def fix_query_key_value_ordering(param, checkpoint_version, num_splits, num_heads, hidden_size): + # Permutes layout of param tensor to [num_splits * num_heads * hidden_size, :] + # for compatibility with later versions of NVIDIA Megatron-LM. + # The inverse operation is performed inside Megatron-LM to read checkpoints: + # https://github.com/NVIDIA/Megatron-LM/blob/v2.4/megatron/checkpointing.py#L209 + # If param is the weight tensor of the self-attention block, the returned tensor + # will have to be transposed one more time to be read by HuggingFace GPT2. + input_shape = param.size() + if checkpoint_version == 1.0: + # version 1.0 stores [num_heads * hidden_size * num_splits, :] + saved_shape = (num_heads, hidden_size, num_splits) + input_shape[1:] + param = param.view(*saved_shape) + param = param.transpose(0, 2) + param = param.transpose(1, 2).contiguous() + elif checkpoint_version >= 2.0: + # other versions store [num_heads * num_splits * hidden_size, :] + saved_shape = (num_heads, num_splits, hidden_size) + input_shape[1:] + param = param.view(*saved_shape) + param = param.transpose(0, 1).contiguous() + param = param.view(*input_shape) + return param + + +#################################################################################################### + + +def convert_megatron_checkpoint(args, input_state_dict, config): + # The converted output model. + output_state_dict = {} + + # old versions did not store training args + ds_args = input_state_dict.get("args", None) + if ds_args is not None: + # do not make the user write a config file when the exact dimensions/sizes are already in the checkpoint + # from pprint import pprint + # pprint(vars(ds_args)) + + config.vocab_size = ds_args.padded_vocab_size + config.n_positions = ds_args.max_position_embeddings + config.n_embd = ds_args.hidden_size + config.n_layer = ds_args.num_layers + config.n_head = ds_args.num_attention_heads + config.n_inner = ds_args.ffn_hidden_size + # pprint(config) + + # The number of heads. + heads = config.n_head + # The hidden_size per head. + hidden_size_per_head = config.n_embd // config.n_head + # Megatron-LM checkpoint version + if "checkpoint_version" in input_state_dict.keys(): + checkpoint_version = input_state_dict["checkpoint_version"] + else: + checkpoint_version = 0.0 + + # The model. + model = input_state_dict["model"] + # The language model. + lm = model["language_model"] + # The embeddings. + embeddings = lm["embedding"] + + # The word embeddings. + word_embeddings = embeddings["word_embeddings"]["weight"] + # Truncate the embedding table to vocab_size rows. + word_embeddings = word_embeddings[: config.vocab_size, :] + output_state_dict["transformer.wte.weight"] = word_embeddings + + # The position embeddings. + pos_embeddings = embeddings["position_embeddings"]["weight"] + # Read the causal mask dimension (seqlen). [max_sequence_length, hidden_size] + n_positions = pos_embeddings.size(0) + if n_positions != config.n_positions: + raise ValueError( + f"pos_embeddings.max_sequence_length={n_positions} and config.n_positions={config.n_positions} don't match" + ) + # Store the position embeddings. + output_state_dict["transformer.wpe.weight"] = pos_embeddings + + # The transformer. + transformer = lm["transformer"] if "transformer" in lm.keys() else lm["encoder"] + + # The regex to extract layer names. + layer_re = re.compile("layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)") + + # The simple map of names for "automated" rules. + megatron_to_transformers = { + "attention.dense": ".attn.c_proj.", + "self_attention.dense": ".attn.c_proj.", + "mlp.dense_h_to_4h": ".mlp.c_fc.", + "mlp.dense_4h_to_h": ".mlp.c_proj.", + } + + # Extract the layers. + for key, val in transformer.items(): + # Match the name. + m = layer_re.match(key) + + # Stop if that's not a layer + if m is None: + break + + # The index of the layer. + layer_idx = int(m.group(1)) + # The name of the operation. + op_name = m.group(2) + # Is it a weight or a bias? + weight_or_bias = m.group(3) + + # The name of the layer. + layer_name = f"transformer.h.{layer_idx}" + + # For layernorm(s), simply store the layer norm. + if op_name.endswith("layernorm"): + + ln_name = "ln_1" if op_name.startswith("input") else "ln_2" + output_state_dict[layer_name + "." + ln_name + "." + weight_or_bias] = val + + # Transpose the QKV matrix. + elif ( + op_name == "attention.query_key_value" or op_name == "self_attention.query_key_value" + ) and weight_or_bias == "weight": + + # Insert a tensor of 1x1xDxD bias. + causal_mask = torch.tril(torch.ones((n_positions, n_positions), dtype=torch.float16)).view( + 1, 1, n_positions, n_positions + ) + output_state_dict[layer_name + ".attn.bias"] = causal_mask + + # Insert a "dummy" tensor for masked_bias. + masked_bias = torch.tensor(-1e4, dtype=torch.float16) + output_state_dict[layer_name + ".attn.masked_bias"] = masked_bias + + out_val = fix_query_key_value_ordering(val, checkpoint_version, 3, heads, hidden_size_per_head) + # Megatron stores (3*D) x D but transformers-GPT2 expects D x 3*D. + out_val = out_val.transpose(0, 1).contiguous() + # Store. + output_state_dict[layer_name + ".attn.c_attn.weight"] = out_val + + # Transpose the bias. + elif ( + op_name == "attention.query_key_value" or op_name == "self_attention.query_key_value" + ) and weight_or_bias == "bias": + + out_val = fix_query_key_value_ordering(val, checkpoint_version, 3, heads, hidden_size_per_head) + # Store. No change of shape. + output_state_dict[layer_name + ".attn.c_attn.bias"] = out_val + + # Transpose the weights. + elif weight_or_bias == "weight": + + out_name = megatron_to_transformers[op_name] + output_state_dict[layer_name + out_name + "weight"] = val.transpose(0, 1) + + # Copy the bias. + elif weight_or_bias == "bias": + + out_name = megatron_to_transformers[op_name] + output_state_dict[layer_name + out_name + "bias"] = val + + # DEBUG. + assert config.n_layer == layer_idx + 1 + + # The final layernorm. + output_state_dict["transformer.ln_f.weight"] = transformer["final_layernorm.weight"] + output_state_dict["transformer.ln_f.bias"] = transformer["final_layernorm.bias"] + + # For LM head, transformers' wants the matrix to weight embeddings. + output_state_dict["lm_head.weight"] = word_embeddings + + # It should be done! + return output_state_dict + + +#################################################################################################### + + +def main(): + # Create the argument parser. + parser = argparse.ArgumentParser() + parser.add_argument("--print-checkpoint-structure", action="store_true") + parser.add_argument( + "path_to_checkpoint", + type=str, + help="Path to the checkpoint file (.zip archive or direct .pt file)", + ) + parser.add_argument( + "--config_file", + default="", + type=str, + help="An optional config json file describing the pre-trained model.", + ) + args = parser.parse_args() + + # Extract the basename. + basename = os.path.dirname(args.path_to_checkpoint) + + # Load the model. + # the .zip is very optional, let's keep it for backward compatibility + print(f"Extracting PyTorch state dictionary from {args.path_to_checkpoint}") + if args.path_to_checkpoint.endswith(".zip"): + with zipfile.ZipFile(args.path_to_checkpoint, "r") as checkpoint: + with checkpoint.open("release/mp_rank_00/model_optim_rng.pt") as pytorch_dict: + input_state_dict = torch.load(pytorch_dict, map_location="cpu") + else: + input_state_dict = torch.load(args.path_to_checkpoint, map_location="cpu") + + ds_args = input_state_dict.get("args", None) + + # Read the config, or default to the model released by NVIDIA. + if args.config_file == "": + + if ds_args is not None: + if ds_args.bias_gelu_fusion: + activation_function = "gelu_fast" + elif ds_args.openai_gelu: + activation_function = "gelu_new" + else: + activation_function = "gelu" + else: + # in the very early days this used to be "gelu_new" + activation_function = "gelu_new" + + # Spell out all parameters in case the defaults change. + config = GPT2Config( + vocab_size=50257, + n_positions=1024, + n_embd=1024, + n_layer=24, + n_head=16, + n_inner=4096, + activation_function=activation_function, + resid_pdrop=0.1, + embd_pdrop=0.1, + attn_pdrop=0.1, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + summary_type="cls_index", + summary_use_proj=True, + summary_activation=None, + summary_proj_to_labels=True, + summary_first_dropout=0.1, + scale_attn_weights=True, + use_cache=True, + bos_token_id=50256, + eos_token_id=50256, + ) + else: + config = GPT2Config.from_json_file(args.config_file) + + config.architectures = ["GPT2LMHeadModel"] + + # Convert. + print("Converting") + output_state_dict = convert_megatron_checkpoint(args, input_state_dict, config) + + # Print the structure of converted state dict. + if args.print_checkpoint_structure: + recursive_print(None, output_state_dict) + + # Add tokenizer class info to config + # see https://github.com/huggingface/transformers/issues/13906) + if ds_args is not None: + tokenizer_type = ds_args.tokenizer_type + if tokenizer_type == "GPT2BPETokenizer": + tokenizer_model_name = "gpt2" + elif tokenizer_type == "PretrainedFromHF": + tokenizer_model_name = ds_args.tokenizer_name_or_path + else: + raise ValueError(f"Unrecognized tokenizer_type {tokenizer_type}") + else: + tokenizer_model_name = "gpt2" + + tokenizer = AutoTokenizer.from_pretrained(tokenizer_model_name) + tokenizer_class = type(tokenizer).__name__ + config.tokenizer_class = tokenizer_class + + # Store the config to file. + print("Saving config") + config.save_pretrained(basename) + + # Save tokenizer based on args + print(f"Adding {tokenizer_class} tokenizer files") + tokenizer.save_pretrained(basename) + + # Store the state_dict to file. + output_checkpoint_file = os.path.join(basename, "pytorch_model.bin") + print(f'Saving checkpoint to "{output_checkpoint_file}"') + torch.save(output_state_dict, output_checkpoint_file) + + +#################################################################################################### + +if __name__ == "__main__": + main() + +#################################################################################################### From 33bb045f86736eac86c8c36824616c98ab55ee13 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 7 Feb 2023 14:55:13 -0500 Subject: [PATCH 07/27] Add todos --- .../checkpoint_reshaping_and_interoperability.py | 2 ++ .../megatron_gpt_bigcode/convert_megatron_gpt2_checkpoint.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/src/transformers/models/megatron_gpt_bigcode/checkpoint_reshaping_and_interoperability.py b/src/transformers/models/megatron_gpt_bigcode/checkpoint_reshaping_and_interoperability.py index f8fb1e7c8f..9292cb505c 100644 --- a/src/transformers/models/megatron_gpt_bigcode/checkpoint_reshaping_and_interoperability.py +++ b/src/transformers/models/megatron_gpt_bigcode/checkpoint_reshaping_and_interoperability.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +# TODO: Update this file + import argparse import json import os diff --git a/src/transformers/models/megatron_gpt_bigcode/convert_megatron_gpt2_checkpoint.py b/src/transformers/models/megatron_gpt_bigcode/convert_megatron_gpt2_checkpoint.py index 778b1384a2..0484f8994f 100644 --- a/src/transformers/models/megatron_gpt_bigcode/convert_megatron_gpt2_checkpoint.py +++ b/src/transformers/models/megatron_gpt_bigcode/convert_megatron_gpt2_checkpoint.py @@ -32,6 +32,8 @@ # in your path, i.e., /path/to/Megatron-DeepSpeed/ # +# TODO: Update this file + import argparse import os import re From e92b63f3e73b2741ccca1db04cee1fb19a6d759e Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 7 Feb 2023 15:13:47 -0500 Subject: [PATCH 08/27] fix --- .../gpt_bigcode/modeling_gpt_bigcode.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 8625e77759..942cea744b 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -520,7 +520,7 @@ class GPTBigCodeDoubleHeadsModelOutput(ModelOutput): attentions: Optional[Tuple[torch.FloatTensor]] = None -GPTBigCode_START_DOCSTRING = r""" +GPT_BIGCODE_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads @@ -536,7 +536,7 @@ class GPTBigCodeDoubleHeadsModelOutput(ModelOutput): configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ -GPTBigCode_INPUTS_DOCSTRING = r""" +GPT_BIGCODE_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): `input_ids_length` = `sequence_length` if `past_key_values` is `None` else @@ -657,7 +657,7 @@ class GPTBigCodeDoubleHeadsModelOutput(ModelOutput): @add_start_docstrings( "The bare GPTBigCode Model transformer outputting raw hidden-states without any specific head on top.", - GPTBigCode_START_DOCSTRING, + GPT_BIGCODE_START_DOCSTRING, ) class GPTBigCodeModel(GPTBigCodePreTrainedModel): _keys_to_ignore_on_load_missing = ["attn.masked_bias"] @@ -728,7 +728,7 @@ def _prune_heads(self, heads_to_prune): for layer, heads in heads_to_prune.items(): self.h[layer].attn.prune_heads(heads) - @add_start_docstrings_to_model_forward(GPTBigCode_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(GPT_BIGCODE_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, output_type=BaseModelOutputWithPastAndCrossAttentions, @@ -934,7 +934,7 @@ def custom_forward(*inputs): The GPTBigCode Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings). """, - GPTBigCode_START_DOCSTRING, + GPT_BIGCODE_START_DOCSTRING, ) class GPTBigCodeLMHeadModel(GPTBigCodePreTrainedModel): _keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"] @@ -1005,7 +1005,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg "token_type_ids": token_type_ids, } - @add_start_docstrings_to_model_forward(GPTBigCode_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(GPT_BIGCODE_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, output_type=CausalLMOutputWithCrossAttentions, @@ -1102,7 +1102,7 @@ def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> input embeddings, the classification head takes as input the input of a specified classification token index in the input sequence). """, - GPTBigCode_START_DOCSTRING, + GPT_BIGCODE_START_DOCSTRING, ) class GPTBigCodeDoubleHeadsModel(GPTBigCodePreTrainedModel): _keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"] @@ -1178,7 +1178,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg "token_type_ids": token_type_ids, } - @add_start_docstrings_to_model_forward(GPTBigCode_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(GPT_BIGCODE_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=GPTBigCodeDoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC) def forward( self, @@ -1316,7 +1316,7 @@ def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in each row of the batch). """, - GPTBigCode_START_DOCSTRING, + GPT_BIGCODE_START_DOCSTRING, ) class GPTBigCodeForSequenceClassification(GPTBigCodePreTrainedModel): _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head.weight"] @@ -1334,7 +1334,7 @@ def __init__(self, config): # Initialize weights and apply final processing self.post_init() - @add_start_docstrings_to_model_forward(GPTBigCode_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(GPT_BIGCODE_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint="microsoft/DialogRPT-updown", output_type=SequenceClassifierOutputWithPast, @@ -1441,7 +1441,7 @@ def forward( GPTBigCode Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """, - GPTBigCode_START_DOCSTRING, + GPT_BIGCODE_START_DOCSTRING, ) class GPTBigCodeForTokenClassification(GPTBigCodePreTrainedModel): def __init__(self, config): @@ -1465,7 +1465,7 @@ def __init__(self, config): # Initialize weights and apply final processing self.post_init() - @add_start_docstrings_to_model_forward(GPTBigCode_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(GPT_BIGCODE_INPUTS_DOCSTRING) # fmt: off @add_code_sample_docstrings( checkpoint="brad1141/gpt2-finetuned-comp2", From 2f32703ef4047617e88eada0c338c0b996f6d304 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 7 Feb 2023 15:38:20 -0500 Subject: [PATCH 09/27] format --- .../models/gpt_bigcode/__init__.py | 18 +++++++++++------- .../gpt_bigcode/test_modeling_gpt_bigcode.py | 14 +++++++++++--- 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/gpt_bigcode/__init__.py b/src/transformers/models/gpt_bigcode/__init__.py index 5b585a8908..2af1863d70 100644 --- a/src/transformers/models/gpt_bigcode/__init__.py +++ b/src/transformers/models/gpt_bigcode/__init__.py @@ -18,15 +18,15 @@ from typing import TYPE_CHECKING -from ...utils import ( - OptionalDependencyNotAvailable, - _LazyModule, - is_torch_available, -) +from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available _import_structure = { - "configuration_gpt_bigcode": ["GPT_BIGCODE_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTBigCodeConfig", "GPTBigCodeOnnxConfig"], + "configuration_gpt_bigcode": [ + "GPT_BIGCODE_PRETRAINED_CONFIG_ARCHIVE_MAP", + "GPTBigCodeConfig", + "GPTBigCodeOnnxConfig", + ], } try: @@ -47,7 +47,11 @@ ] if TYPE_CHECKING: - from .configuration_gpt_bigcode import GPT_BIGCODE_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTBigCodeConfig, GPTBigCodeOnnxConfig + from .configuration_gpt_bigcode import ( + GPT_BIGCODE_PRETRAINED_CONFIG_ARCHIVE_MAP, + GPTBigCodeConfig, + GPTBigCodeOnnxConfig, + ) try: if not is_torch_available(): diff --git a/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py b/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py index 936d4c9f76..f5f31ce05f 100644 --- a/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py +++ b/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py @@ -31,12 +31,12 @@ from transformers import ( GPT_BIGCODE_PRETRAINED_MODEL_ARCHIVE_LIST, + GPT2Tokenizer, GPTBigCodeDoubleHeadsModel, GPTBigCodeForSequenceClassification, GPTBigCodeForTokenClassification, GPTBigCodeLMHeadModel, GPTBigCodeModel, - GPT2Tokenizer, ) @@ -434,12 +434,20 @@ class GPTBigCodeModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test # TODO: Update the tests to use valid pretrained models. all_model_classes = ( - (GPTBigCodeModel, GPTBigCodeLMHeadModel, GPTBigCodeDoubleHeadsModel, GPTBigCodeForSequenceClassification, GPTBigCodeForTokenClassification) + ( + GPTBigCodeModel, + GPTBigCodeLMHeadModel, + GPTBigCodeDoubleHeadsModel, + GPTBigCodeForSequenceClassification, + GPTBigCodeForTokenClassification, + ) if is_torch_available() else () ) all_generative_model_classes = (GPTBigCodeLMHeadModel, GPTBigCodeDoubleHeadsModel) if is_torch_available() else () - all_parallelizable_model_classes = (GPTBigCodeLMHeadModel, GPTBigCodeDoubleHeadsModel) if is_torch_available() else () + all_parallelizable_model_classes = ( + (GPTBigCodeLMHeadModel, GPTBigCodeDoubleHeadsModel) if is_torch_available() else () + ) fx_compatible = True test_missing_keys = False test_model_parallel = True From 035c73c89c72744baf0587d0f916811040331349 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 7 Feb 2023 15:55:47 -0500 Subject: [PATCH 10/27] fix --- ...2_checkpoint.py => convert_megatron_gpt_bigcode_checkpoint.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename src/transformers/models/megatron_gpt_bigcode/{convert_megatron_gpt2_checkpoint.py => convert_megatron_gpt_bigcode_checkpoint.py} (100%) diff --git a/src/transformers/models/megatron_gpt_bigcode/convert_megatron_gpt2_checkpoint.py b/src/transformers/models/megatron_gpt_bigcode/convert_megatron_gpt_bigcode_checkpoint.py similarity index 100% rename from src/transformers/models/megatron_gpt_bigcode/convert_megatron_gpt2_checkpoint.py rename to src/transformers/models/megatron_gpt_bigcode/convert_megatron_gpt_bigcode_checkpoint.py From 9ab47d99f867f26537aefcb8230ad4999e1a1134 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 8 Feb 2023 14:16:33 -0500 Subject: [PATCH 11/27] Multi-query attention (#4) --- src/transformers/activations.py | 23 ++- .../gpt_bigcode/configuration_gpt_bigcode.py | 13 +- .../gpt_bigcode/modeling_gpt_bigcode.py | 189 +++++++++--------- 3 files changed, 130 insertions(+), 95 deletions(-) diff --git a/src/transformers/activations.py b/src/transformers/activations.py index d9caf8763e..1c59568835 100644 --- a/src/transformers/activations.py +++ b/src/transformers/activations.py @@ -25,6 +25,26 @@ logger = logging.get_logger(__name__) +class PytorchGELUTanh(nn.Module): + """ + A fast C implementation of the tanh approximation of the GeLU activation function. See + https://arxiv.org/abs/1606.08415. + This implementation is equivalent to NewGELU and FastGELU but much faster. However, it is not an exact numerical + match due to rounding errors. + """ + + def __init__(self): + super().__init__() + if version.parse(torch.__version__) < version.parse("1.12.0"): + raise ImportError( + f"You are using torch=={torch.__version__}, but torch>=1.12.0 is required to use " + "PytorchGELUTanh. Please upgrade torch." + ) + + def forward(self, input: Tensor) -> Tensor: + return nn.functional.gelu(input, approximate="tanh") + + class NewGELUActivation(nn.Module): """ Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see @@ -80,10 +100,8 @@ class ClippedGELUActivation(nn.Module): Clip the range of possible GeLU outputs between [min, max]. This is especially useful for quantization purpose, as it allows mapping negatives values in the GeLU spectrum. For more information on this trick, please refer to https://arxiv.org/abs/2004.09602. - Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when initially created. - For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))). See https://arxiv.org/abs/1606.08415 """ @@ -155,6 +173,7 @@ def __getitem__(self, key): "gelu_fast": FastGELUActivation, "gelu_new": NewGELUActivation, "gelu_python": (GELUActivation, {"use_gelu_python": True}), + "gelu_pytorch_tanh": PytorchGELUTanh, "linear": LinearActivation, "mish": MishActivation, "quick_gelu": QuickGELUActivation, diff --git a/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py index 8fcf554ded..6546b47253 100644 --- a/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py @@ -15,6 +15,7 @@ # limitations under the License. """ OpenAI GPT-2 configuration""" from collections import OrderedDict +from enum import Enum from typing import Any, List, Mapping, Optional from transformers import PreTrainedTokenizer, TensorType, is_torch_available @@ -31,6 +32,12 @@ } +class AttentionType(Enum): + MULTI_HEAD = 1 + MULTI_QUERY_1 = 2 + MULTI_QUERY_2 = 3 + + class GPTBigCodeConfig(PretrainedConfig): """ # TODO: Update doc @@ -143,7 +150,7 @@ def __init__( n_layer=12, n_head=12, n_inner=None, - activation_function="gelu_new", + activation_function="gelu_pytorch_tanh", resid_pdrop=0.1, embd_pdrop=0.1, attn_pdrop=0.1, @@ -160,6 +167,7 @@ def __init__( eos_token_id=50256, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False, + attention_type=AttentionType.MULTI_HEAD, **kwargs, ): self.vocab_size = vocab_size @@ -187,6 +195,9 @@ def __init__( self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id + # Convert to an int so it's JSON-serializable. + self.attention_type = AttentionType(attention_type).value + super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 942cea744b..82938e65ab 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -44,7 +44,7 @@ replace_return_docstrings, ) from ...utils.model_parallel_utils import assert_device_map, get_device_map -from .configuration_gpt_bigcode import GPTBigCodeConfig +from .configuration_gpt_bigcode import AttentionType, GPTBigCodeConfig logger = logging.get_logger(__name__) @@ -121,16 +121,21 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): max_positions = config.max_position_embeddings self.register_buffer( - "bias", - torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view( - 1, 1, max_positions, max_positions - ), + "bias", torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)), persistent=False ) - self.register_buffer("masked_bias", torch.tensor(-1e4)) + self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False) + # We don't use a buffer because the mask value depends on the dtype, + # And the dtype will be different if upcasting. + self.mask_value = None + + self.attention_type = AttentionType(config.attention_type) + self.is_mqa = self.attention_type != AttentionType.MULTI_HEAD self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads + self.kv_heads = 1 if self.is_mqa else self.head_dim + self.kv_dim = self.kv_heads * self.head_dim self.split_size = self.embed_dim if self.head_dim * self.num_heads != self.embed_dim: raise ValueError( @@ -146,11 +151,27 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): self.layer_idx = layer_idx self.reorder_and_upcast_attn = config.reorder_and_upcast_attn + self.scale_factor = 1.0 + if self.scale_attn_weights: + self.scale_factor /= self.head_dim**0.5 + + if self.scale_attn_by_inverse_layer_idx: + self.scale_factor /= self.layer_idx + 1 + if self.is_cross_attention: + if self.is_mqa: + raise NotImplementedError(f"attention_type {self.attention_type} for cross_attention") + self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim) self.q_attn = Conv1D(self.embed_dim, self.embed_dim) else: - self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim) + if self.attention_type == AttentionType.MULTI_QUERY_2: + self.q_attn = Conv1D(self.embed_dim, self.embed_dim) + # Keys and values are shared across heads + self.kv_attn = Conv1D(2 * self.head_dim, self.embed_dim) + else: + self.c_attn = Conv1D(self.embed_dim + 2 * self.kv_dim, self.embed_dim) + self.c_proj = Conv1D(self.embed_dim, self.embed_dim) self.attn_dropout = nn.Dropout(config.attn_pdrop) @@ -173,27 +194,52 @@ def prune_heads(self, heads): self.num_heads = self.num_heads - len(heads) self.pruned_heads = self.pruned_heads.union(heads) - def _attn(self, query, key, value, attention_mask=None, head_mask=None): - attn_weights = torch.matmul(query, key.transpose(-1, -2)) - - if self.scale_attn_weights: - attn_weights = attn_weights / torch.full( - [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device + def _matmul(self, x, y, dtype=None, scale_factor=1.0): + output_shape = (*x.size()[:-1], y.size(-1)) + if self.is_mqa: + # Q x K: (b, sq, nh, hs) x (b, hs, sk) -> (b, sq, nh, sk) + # A X V: (b, sq, nh, sk) x (b, sk, hs) -> (b, sq, nh, hs) + output_view = (x.size(0), x.size(1) * x.size(2), y.size(-1)) + # No copy needed for MQA 2, or when layer_past is provided. + x = x.reshape(*output_view[:-1], x.size(-1)) + else: + # Q x K: (b, nh, sq, hs) x (b, nh, hs, sk) -> (b, nh, sq, sk) + # A X V: (b, nh, sq, sk) x (b, nh, sk, hs) -> (b, nh, sq, hs) + output_view = (x.size(0) * x.size(1), x.size(2), y.size(-1)) + # Always copies + x = x.reshape(output_view[0], *x.size()[2:]) + # No copy when layer_past is provided. + y = y.reshape(output_view[0], *y.size()[2:]) + # This is identical to matmul when scale_factor==1 + z = torch.empty(output_view, dtype=x.dtype if dtype is None else dtype, device=x.device) + z = torch.baddbmm(z, x, y, beta=0, alpha=scale_factor) + return z.view(output_shape) + + def _attn(self, query, key, value, attention_mask=None, head_mask=None, upcast=False): + with autocast(enabled=False): + attn_weights = self._matmul( + query, key.transpose(-1, -2), dtype=torch.float32 if upcast else None, scale_factor=self.scale_factor ) - # Layer-wise attention scaling - if self.scale_attn_by_inverse_layer_idx: - attn_weights = attn_weights / float(self.layer_idx + 1) - if not self.is_cross_attention: # if only "normal" attention layer implements causal mask - query_length, key_length = query.size(-2), key.size(-2) - causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool) - mask_value = torch.finfo(attn_weights.dtype).min - # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. - # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` - mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(attn_weights.device) - attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value) + key_length = key.size(-2) + if self.is_mqa: + # (b, sq, nh, sk) + causal_mask = self.bias[None, key_length - query.size(1) : key_length, None, :key_length] + else: + # (b, nh, sq, sk) + causal_mask = self.bias[None, None, key_length - query.size(-2) : key_length, :key_length] + # torch.where expects a tensor. We use a cache to avoid recreating it every time. + if ( + self.mask_value is None + or self.mask_value.dtype != attn_weights.dtype + or self.mask_value.device != attn_weights.device + ): + self.mask_value = torch.full( + [], torch.finfo(attn_weights.dtype).min, dtype=attn_weights.dtype, device=attn_weights.device + ) + attn_weights = torch.where(causal_mask, attn_weights, self.mask_value) if attention_mask is not None: # Apply the attention mask @@ -202,57 +248,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): attn_weights = nn.functional.softmax(attn_weights, dim=-1) # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise - attn_weights = attn_weights.type(value.dtype) - attn_weights = self.attn_dropout(attn_weights) - - # Mask heads if we want to - if head_mask is not None: - attn_weights = attn_weights * head_mask - - attn_output = torch.matmul(attn_weights, value) - - return attn_output, attn_weights - - def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None): - # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM) - bsz, num_heads, q_seq_len, dk = query.size() - _, _, k_seq_len, _ = key.size() - - # Preallocate attn_weights for `baddbmm` - attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device) - - # Compute Scale Factor - scale_factor = 1.0 - if self.scale_attn_weights: - scale_factor /= float(value.size(-1)) ** 0.5 - - if self.scale_attn_by_inverse_layer_idx: - scale_factor /= float(self.layer_idx + 1) - - # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk)) - with autocast(enabled=False): - q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) - attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) - attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) - - if not self.is_cross_attention: - # if only "normal" attention layer implements causal mask - query_length, key_length = query.size(-2), key.size(-2) - causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool() - mask_value = torch.finfo(attn_weights.dtype).min - # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. - # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` - mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device) - attn_weights = torch.where(causal_mask, attn_weights, mask_value) - - if attention_mask is not None: - # Apply the attention mask - attn_weights = attn_weights + attention_mask - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise - if attn_weights.dtype != torch.float32: + if upcast and attn_weights.dtype != torch.float32: raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32") attn_weights = attn_weights.type(value.dtype) attn_weights = self.attn_dropout(attn_weights) @@ -261,39 +257,42 @@ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, hea if head_mask is not None: attn_weights = attn_weights * head_mask - attn_output = torch.matmul(attn_weights, value) + attn_output = self._matmul(attn_weights, value) return attn_output, attn_weights - def _split_heads(self, tensor, num_heads, attn_head_size): + def _split_heads(self, tensor, num_heads, attn_head_size, permute=True): """ Splits hidden_size dim into attn_head_size and num_heads """ new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) tensor = tensor.view(new_shape) - return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) + if permute: + tensor = tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) + return tensor - def _merge_heads(self, tensor, num_heads, attn_head_size): + def _merge_heads(self, tensor, num_heads, attn_head_size, permute=True): """ Merges attn_head_size dim and num_attn_heads dim into hidden_size """ - tensor = tensor.permute(0, 2, 1, 3).contiguous() + if permute: + tensor = tensor.permute(0, 2, 1, 3).contiguous() new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) return tensor.view(new_shape) def forward( self, hidden_states: Optional[Tuple[torch.FloatTensor]], - layer_past: Optional[Tuple[torch.Tensor]] = None, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, - ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], ...]: if encoder_hidden_states is not None: - if not hasattr(self, "q_attn"): + if not hasattr(self, "q_attn") or not self.is_cross_attention: raise ValueError( "If class is used as cross attention, the weights `q_attn` have to be defined. " "Please make sure to instantiate class with `GPTBigCodeAttention(..., is_cross_attention=True)`." @@ -303,11 +302,16 @@ def forward( key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) attention_mask = encoder_attention_mask else: - query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) + if self.attention_type == AttentionType.MULTI_QUERY_2: + query = self.q_attn(hidden_states) + key, value = self.kv_attn(hidden_states).split((self.kv_dim, self.kv_dim), dim=2) + else: + query, key, value = self.c_attn(hidden_states).split((self.embed_dim, self.kv_dim, self.kv_dim), dim=2) - query = self._split_heads(query, self.num_heads, self.head_dim) - key = self._split_heads(key, self.num_heads, self.head_dim) - value = self._split_heads(value, self.num_heads, self.head_dim) + query = self._split_heads(query, self.num_heads, self.head_dim, permute=not self.is_mqa) + if not self.is_mqa: + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) if layer_past is not None: past_key, past_value = layer_past @@ -319,12 +323,11 @@ def forward( else: present = None - if self.reorder_and_upcast_attn: - attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask) - else: - attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + attn_output, attn_weights = self._attn( + query, key, value, attention_mask, head_mask, upcast=self.reorder_and_upcast_attn + ) - attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) + attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim, permute=not self.is_mqa) attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) @@ -363,6 +366,8 @@ def __init__(self, config, layer_idx=None): self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) if config.add_cross_attention: + if config.attention_type != AttentionType.MULTI_HEAD: + raise NotImplementedError("Cross-attention not implemented for MQA") self.crossattention = GPTBigCodeAttention(config, is_cross_attention=True, layer_idx=layer_idx) self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) From 878025e10e183f068c7e3f8b8767452f73378127 Mon Sep 17 00:00:00 2001 From: minimario Date: Sat, 11 Feb 2023 16:49:07 +0000 Subject: [PATCH 12/27] add test to ensure mqa and mha have the same behaviour --- .../gpt_bigcode/modeling_gpt_bigcode.py | 2 +- .../gpt_bigcode/test_modeling_gpt_bigcode.py | 84 +++++++++++++++++++ 2 files changed, 85 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 82938e65ab..7539a82d7c 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -134,7 +134,7 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads - self.kv_heads = 1 if self.is_mqa else self.head_dim + self.kv_heads = 1 if self.is_mqa else self.num_heads self.kv_dim = self.kv_heads * self.head_dim self.split_size = self.embed_dim if self.head_dim * self.num_heads != self.embed_dim: diff --git a/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py b/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py index f5f31ce05f..2fdbceed29 100644 --- a/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py +++ b/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py @@ -32,6 +32,7 @@ from transformers import ( GPT_BIGCODE_PRETRAINED_MODEL_ARCHIVE_LIST, GPT2Tokenizer, + GPTBigCodeConfig, GPTBigCodeDoubleHeadsModel, GPTBigCodeForSequenceClassification, GPTBigCodeForTokenClassification, @@ -39,6 +40,13 @@ GPTBigCodeModel, ) + from transformers.models.gpt_bigcode.configuration_gpt_bigcode import ( + AttentionType + ) + from transformers.models.gpt_bigcode.modeling_gpt_bigcode import ( + GPTBigCodeAttention + ) + class GPTBigCodeModelTester: # TODO: Update the tests to use valid pretrained models. @@ -807,3 +815,79 @@ def test_contrastive_search_gpt_bigcode(self): "but said in a statement to The Associated Press that" ], ) + +@require_torch +class GPTBigCodeAttentionTest(unittest.TestCase): + def get_attention(self, attention_type : AttentionType): + config = GPTBigCodeConfig.from_pretrained("bigcode/santacoder-fast-inference") + config.attention_type = attention_type + return GPTBigCodeAttention(config) + + def test_mqa_correctness(self): + embed_dim = 2048 + head_dim = 128 + random_attn_weight = torch.randn(embed_dim, embed_dim) + random_attn_k_weight = torch.randn(embed_dim, head_dim) + random_attn_v_weight = torch.randn(embed_dim, head_dim) + random_attn_bias = torch.randn(embed_dim) + random_attn_k_bias = torch.randn(head_dim) + random_attn_v_bias = torch.randn(head_dim) + random_proj = torch.randn(embed_dim, embed_dim) + random_proj_bias = torch.randn(embed_dim) + + # MULTI-HEAD ATTENTION + num_heads = 16 + c_attn_weight = torch.hstack( + [random_attn_weight] + + num_heads * [random_attn_k_weight] + + num_heads * [random_attn_v_weight]) + + c_attn_bias = torch.hstack( + [random_attn_bias] + + num_heads * [random_attn_k_bias] + + num_heads * [random_attn_v_bias]) + + attention_mh = self.get_attention(AttentionType.MULTI_HEAD) + state_dict = attention_mh.state_dict() + state_dict["c_attn.weight"] = c_attn_weight + state_dict["c_attn.bias"] = c_attn_bias + state_dict["c_proj.weight"] = random_proj + state_dict["c_proj.bias"] = random_proj_bias + attention_mh.load_state_dict(state_dict) + + # MULTI-QUERY ATTENTION 1 + attention_mq1 = self.get_attention(AttentionType.MULTI_QUERY_1) + state_dict_mq1 = attention_mq1.state_dict() + c_attn_weight = torch.hstack([random_attn_weight, random_attn_k_weight, random_attn_v_weight]) + c_attn_bias = torch.hstack([random_attn_bias, random_attn_k_bias, random_attn_v_bias]) + state_dict_mq1["c_attn.weight"] = c_attn_weight + state_dict_mq1["c_attn.bias"] = c_attn_bias + state_dict_mq1["c_proj.weight"] = random_proj + state_dict_mq1["c_proj.bias"] = random_proj_bias + attention_mq1.load_state_dict(state_dict_mq1) + + # MULTI-QUERY ATTENTION 2 + attention_mq2 = self.get_attention(AttentionType.MULTI_QUERY_2) + state_dict_mq2 = attention_mq2.state_dict() + state_dict_mq2["q_attn.weight"] = random_attn_weight + state_dict_mq2["q_attn.bias"] = random_attn_bias + state_dict_mq2["kv_attn.weight"] = torch.hstack([random_attn_k_weight, random_attn_v_weight]) + state_dict_mq2["kv_attn.bias"] = torch.hstack([random_attn_k_bias, random_attn_v_bias]) + state_dict_mq2["c_proj.weight"] = random_proj + state_dict_mq2["c_proj.bias"] = random_proj_bias + attention_mq2.load_state_dict(state_dict_mq2) + + # Run correctness test + attention_mh.eval() + attention_mq1.eval() + attention_mq2.eval() + + num_tokens = 5 + hidden_states = torch.randn(1, num_tokens, embed_dim) + attention_mh_result = attention_mh(hidden_states)[0] + attention_mq1_result = attention_mq1(hidden_states)[0] + attention_mq2_result = attention_mq2(hidden_states)[0] + + self.assertTrue(torch.allclose(attention_mh_result, attention_mq1_result)) + self.assertTrue(torch.allclose(attention_mh_result, attention_mq2_result)) + self.assertTrue(torch.allclose(attention_mq1_result, attention_mq2_result)) \ No newline at end of file From 92d98c176efe5c6dee96c3c55197025e84860e22 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 14 Feb 2023 15:20:29 -0500 Subject: [PATCH 13/27] Megatron conversion script (#8) --- .../gpt_bigcode/configuration_gpt_bigcode.py | 7 +- .../gpt_bigcode/modeling_gpt_bigcode.py | 13 +- ...convert_megatron_gpt_bigcode_checkpoint.py | 159 +++++++++++++----- .../megatron_gpt_bigcode/push_checkpoints.py | 74 ++++++++ 4 files changed, 198 insertions(+), 55 deletions(-) create mode 100644 src/transformers/models/megatron_gpt_bigcode/push_checkpoints.py diff --git a/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py index 6546b47253..2522c240d7 100644 --- a/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py @@ -19,10 +19,9 @@ from typing import Any, List, Mapping, Optional from transformers import PreTrainedTokenizer, TensorType, is_torch_available - -from ...configuration_utils import PretrainedConfig -from ...onnx import OnnxConfigWithPast, PatchingSpec -from ...utils import logging +from transformers.configuration_utils import PretrainedConfig +from transformers.onnx import OnnxConfigWithPast, PatchingSpec +from transformers.utils import logging logger = logging.get_logger(__name__) diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 82938e65ab..1642301963 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -26,16 +26,16 @@ from torch.cuda.amp import autocast from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from ...activations import ACT2FN -from ...modeling_outputs import ( +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, SequenceClassifierOutputWithPast, TokenClassifierOutput, ) -from ...modeling_utils import PreTrainedModel, SequenceSummary -from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer -from ...utils import ( +from transformers.modeling_utils import PreTrainedModel, SequenceSummary +from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer +from transformers.utils import ( ModelOutput, add_code_sample_docstrings, add_start_docstrings, @@ -43,7 +43,8 @@ logging, replace_return_docstrings, ) -from ...utils.model_parallel_utils import assert_device_map, get_device_map +from transformers.utils.model_parallel_utils import assert_device_map, get_device_map + from .configuration_gpt_bigcode import AttentionType, GPTBigCodeConfig diff --git a/src/transformers/models/megatron_gpt_bigcode/convert_megatron_gpt_bigcode_checkpoint.py b/src/transformers/models/megatron_gpt_bigcode/convert_megatron_gpt_bigcode_checkpoint.py index 0484f8994f..9de9f4603b 100644 --- a/src/transformers/models/megatron_gpt_bigcode/convert_megatron_gpt_bigcode_checkpoint.py +++ b/src/transformers/models/megatron_gpt_bigcode/convert_megatron_gpt_bigcode_checkpoint.py @@ -41,7 +41,7 @@ import torch -from transformers import AutoTokenizer, GPT2Config +from transformers import GPTBigCodeConfig, GPTBigCodeLMHeadModel, GPTBigCodeModel #################################################################################################### @@ -110,6 +110,16 @@ def convert_megatron_checkpoint(args, input_state_dict, config): config.n_layer = ds_args.num_layers config.n_head = ds_args.num_attention_heads config.n_inner = ds_args.ffn_hidden_size + + if ds_args.attention_head_type == "multihead": + config.attention_type = 1 + else: + assert ds_args.attention_head_type == "multiquery" + config.attention_type = 2 if args.merge_qkv else 3 + + # also set `scale_attn_weights` and `scale_attn_by_inverse_layer_idx` ? + # Uncommenting the next line makes the converted model output different logits. + # config.scale_attn_by_inverse_layer_idx = ds_args.apply_query_key_layer_scaling # pprint(config) # The number of heads. @@ -189,23 +199,36 @@ def convert_megatron_checkpoint(args, input_state_dict, config): elif ( op_name == "attention.query_key_value" or op_name == "self_attention.query_key_value" ) and weight_or_bias == "weight": - - # Insert a tensor of 1x1xDxD bias. - causal_mask = torch.tril(torch.ones((n_positions, n_positions), dtype=torch.float16)).view( - 1, 1, n_positions, n_positions - ) - output_state_dict[layer_name + ".attn.bias"] = causal_mask - - # Insert a "dummy" tensor for masked_bias. - masked_bias = torch.tensor(-1e4, dtype=torch.float16) - output_state_dict[layer_name + ".attn.masked_bias"] = masked_bias - out_val = fix_query_key_value_ordering(val, checkpoint_version, 3, heads, hidden_size_per_head) # Megatron stores (3*D) x D but transformers-GPT2 expects D x 3*D. out_val = out_val.transpose(0, 1).contiguous() # Store. output_state_dict[layer_name + ".attn.c_attn.weight"] = out_val + # Tranpose the Q matrix (for MQA) + elif (op_name == "self_attention.query") and weight_or_bias == "weight": + out_val = fix_query_key_value_ordering(val, checkpoint_version, 1, heads, hidden_size_per_head) + # Megatron stores (out x in) but transformers-GPT2 expects (in x out). + out_val = out_val.transpose(0, 1).contiguous() + # Store. + output_state_dict[layer_name + ".attn.q_attn.weight"] = out_val + + # Tranpose the KV matrix (for MQA) + elif (op_name == "self_attention.key_value") and weight_or_bias == "weight": + # Key-values are shared across heads + out_val = fix_query_key_value_ordering(val, checkpoint_version, 2, 1, hidden_size_per_head) + # Megatron stores (out x in) but transformers-GPT2 expects (in x out). + out_val = out_val.transpose(0, 1).contiguous() + if args.merge_qkv: + # Concatenate the tensors. + # Query is before key_value in the dict. + query = output_state_dict.pop(layer_name + ".attn.q_attn.weight") + out_val = torch.cat([query, out_val], dim=1) + output_state_dict[layer_name + ".attn.c_attn.weight"] = out_val + else: + # Store. + output_state_dict[layer_name + ".attn.kv_attn.weight"] = out_val + # Transpose the bias. elif ( op_name == "attention.query_key_value" or op_name == "self_attention.query_key_value" @@ -215,6 +238,27 @@ def convert_megatron_checkpoint(args, input_state_dict, config): # Store. No change of shape. output_state_dict[layer_name + ".attn.c_attn.bias"] = out_val + # Transpose the Q bias (MQA) + elif (op_name == "self_attention.query") and weight_or_bias == "bias": + + out_val = fix_query_key_value_ordering(val, checkpoint_version, 1, heads, hidden_size_per_head) + # Store. No change of shape. + output_state_dict[layer_name + ".attn.q_attn.bias"] = out_val + + # Transpose the KV bias (MQA) + elif (op_name == "self_attention.key_value") and weight_or_bias == "bias": + + out_val = fix_query_key_value_ordering(val, checkpoint_version, 2, 1, hidden_size_per_head) + if args.merge_qkv: + # Concatenate the tensors. + # Query is before key_value in the dict. + query = output_state_dict.pop(layer_name + ".attn.q_attn.bias") + out_val = torch.cat([query, out_val], dim=0) + output_state_dict[layer_name + ".attn.c_attn.bias"] = out_val + else: + # Store. No change of shape. + output_state_dict[layer_name + ".attn.kv_attn.bias"] = out_val + # Transpose the weights. elif weight_or_bias == "weight": @@ -244,7 +288,7 @@ def convert_megatron_checkpoint(args, input_state_dict, config): #################################################################################################### -def main(): +def main(argv=None): # Create the argument parser. parser = argparse.ArgumentParser() parser.add_argument("--print-checkpoint-structure", action="store_true") @@ -259,10 +303,24 @@ def main(): type=str, help="An optional config json file describing the pre-trained model.", ) - args = parser.parse_args() + parser.add_argument( + "--no_merge_qkv", + dest="merge_qkv", + action="store_false", + help="Do not merge the query and key_value tensors (MQA).", + ) + parser.add_argument( + "--custom_model", + action="store_true", + help="Save as custom model so it can be used with huggingface transformers.", + ) + parser.add_argument( + "--save_dir", help="Path where the converted model is saved. Will use the checkpoint directory if not provided" + ) + args = parser.parse_args(argv) # Extract the basename. - basename = os.path.dirname(args.path_to_checkpoint) + basename = args.save_dir or os.path.dirname(args.path_to_checkpoint) # Load the model. # the .zip is very optional, let's keep it for backward compatibility @@ -278,10 +336,12 @@ def main(): # Read the config, or default to the model released by NVIDIA. if args.config_file == "": + # TODO: FP32 softmax if ds_args is not None: if ds_args.bias_gelu_fusion: - activation_function = "gelu_fast" + # TODO: This will be in the next release of transformers (4.27). + activation_function = "gelu_pytorch_tanh" elif ds_args.openai_gelu: activation_function = "gelu_new" else: @@ -291,7 +351,7 @@ def main(): activation_function = "gelu_new" # Spell out all parameters in case the defaults change. - config = GPT2Config( + config = GPTBigCodeConfig( vocab_size=50257, n_positions=1024, n_embd=1024, @@ -315,9 +375,9 @@ def main(): eos_token_id=50256, ) else: - config = GPT2Config.from_json_file(args.config_file) + config = GPTBigCodeConfig.from_json_file(args.config_file) - config.architectures = ["GPT2LMHeadModel"] + config.architectures = ["GPTBigCodeLMHeadModel"] # Convert. print("Converting") @@ -329,33 +389,42 @@ def main(): # Add tokenizer class info to config # see https://github.com/huggingface/transformers/issues/13906) - if ds_args is not None: - tokenizer_type = ds_args.tokenizer_type - if tokenizer_type == "GPT2BPETokenizer": - tokenizer_model_name = "gpt2" - elif tokenizer_type == "PretrainedFromHF": - tokenizer_model_name = ds_args.tokenizer_name_or_path - else: - raise ValueError(f"Unrecognized tokenizer_type {tokenizer_type}") - else: - tokenizer_model_name = "gpt2" - - tokenizer = AutoTokenizer.from_pretrained(tokenizer_model_name) - tokenizer_class = type(tokenizer).__name__ - config.tokenizer_class = tokenizer_class + # if ds_args is not None: + # tokenizer_type = ds_args.tokenizer_type + # if tokenizer_type == "GPT2BPETokenizer": + # tokenizer_model_name = "gpt2" + # elif tokenizer_type == "PretrainedFromHF": + # tokenizer_model_name = ds_args.tokenizer_name_or_path + # else: + # raise ValueError(f"Unrecognized tokenizer_type {tokenizer_type}") + # else: + # tokenizer_model_name = "gpt2" + + # tokenizer = AutoTokenizer.from_pretrained(tokenizer_model_name) + # tokenizer_class = type(tokenizer).__name__ + # config.tokenizer_class = tokenizer_class + + if args.custom_model: + # Save custom model + GPTBigCodeConfig.register_for_auto_class() + GPTBigCodeModel.register_for_auto_class("AutoModelForCausalLM") + hf_model = GPTBigCodeLMHeadModel(config) + hf_model.load_state_dict(output_state_dict) + hf_model.save_pretrained(basename) - # Store the config to file. - print("Saving config") - config.save_pretrained(basename) - - # Save tokenizer based on args - print(f"Adding {tokenizer_class} tokenizer files") - tokenizer.save_pretrained(basename) - - # Store the state_dict to file. - output_checkpoint_file = os.path.join(basename, "pytorch_model.bin") - print(f'Saving checkpoint to "{output_checkpoint_file}"') - torch.save(output_state_dict, output_checkpoint_file) + else: + # Store the config to file. + print("Saving config") + config.save_pretrained(basename) + + # Save tokenizer based on args + # print(f"Adding {tokenizer_class} tokenizer files") + # tokenizer.save_pretrained(basename) + + # Store the state_dict to file. + output_checkpoint_file = os.path.join(basename, "pytorch_model.bin") + print(f'Saving checkpoint to "{output_checkpoint_file}"') + torch.save(output_state_dict, output_checkpoint_file) #################################################################################################### diff --git a/src/transformers/models/megatron_gpt_bigcode/push_checkpoints.py b/src/transformers/models/megatron_gpt_bigcode/push_checkpoints.py new file mode 100644 index 0000000000..27610525d9 --- /dev/null +++ b/src/transformers/models/megatron_gpt_bigcode/push_checkpoints.py @@ -0,0 +1,74 @@ +import argparse +import re +import subprocess +from pathlib import Path + +from huggingface_hub import Repository + +from .convert_megatron_gpt_bigcode_checkpoint import main as convert + + +""" +Script to upload Megatron checkpoints to a HF repo on the Hub. +The script clones/creates a repo on the Hub, checks out a branch `--branch_name`, +and converts each `iter_` checkpoint and saves it as a commit on that branch. +""" + + +def get_iter_number(iter_dir: str): + m = re.match(r"iter_(\d+)", iter_dir) + if m is not None: + return int(m.group(1)) + else: + raise ValueError(f"Invalid directory name: {iter_dir}") + + +def main(argv=None): + parser = argparse.ArgumentParser() + parser.add_argument("--exp_dir", type=Path, required=True, help="Path to experiment folder.") + parser.add_argument("--repo_name", required=True, help="Name of repository on the Hub in 'ORG/NAME' format.") + parser.add_argument("--branch_name", required=True, help="Name of branch in repository to save experiments.") + parser.add_argument( + "--save_dir", + type=Path, + help="Path where repository is cloned to locally. Will use {exp_dir}/hf_checkpoints if not provided", + ) + parser.add_argument( + "--iter_interval", + type=int, + default=1, + help="Iteration number must be divisble by iter_interval in order to be pushed", + ) + args, argv = parser.parse_known_args(argv) + + save_dir = args.save_dir or args.exp_dir / "hf_checkpoints" + + hf_repo = Repository(save_dir, clone_from=args.repo_name) + hf_repo.git_checkout(args.branch_name, create_branch_ok=True) + # Find last checkpoint that was uploaded + head_hash = hf_repo.git_head_hash() + commit_msg = subprocess.check_output(["git", "show", "-s", "--format=%B", head_hash], cwd=save_dir).decode() + try: + last_commit_iter = get_iter_number(commit_msg.strip()) + print(f"Last commit iteration: {last_commit_iter}") + except ValueError: + last_commit_iter = -1 + + # The checkpoint dirs should be in ascending iteration order, so that the last commit corresponds to the latest checkpoint + ckpt_dirs = sorted([x for x in args.exp_dir.iterdir() if x.name.startswith("iter_") and x.is_dir()]) + + for ckpt_dir in ckpt_dirs: + iter_number = get_iter_number(ckpt_dir.name) + if iter_number <= last_commit_iter: + continue + if iter_number % args.iter_interval == 0: + print(f"Converting iteration {iter_number}") + # TODO: this only works for 1-way tensor/pipeline parallelism + file_path = next((ckpt_dir / "mp_rank_00").glob("*.pt")) + convert(argv + [f"--save_dir={str(save_dir)}", str(file_path)]) + print(f"Pushing iteration {iter_number}") + hf_repo.push_to_hub(commit_message=f"{ckpt_dir.name}") + + +if __name__ == "__main__": + main() From 03716fae4a240724e79cf4de53451f94107ad90b Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 15 Feb 2023 16:59:52 -0500 Subject: [PATCH 14/27] Upcasting, scaling, masking and fused kernels to match Megatron-LM (#10) --- .../gpt_bigcode/configuration_gpt_bigcode.py | 8 +- .../gpt_bigcode/modeling_gpt_bigcode.py | 171 ++++++++++-------- 2 files changed, 96 insertions(+), 83 deletions(-) diff --git a/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py index 2522c240d7..ec0fe0bdd9 100644 --- a/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py @@ -164,8 +164,8 @@ def __init__( use_cache=True, bos_token_id=50256, eos_token_id=50256, - scale_attn_by_inverse_layer_idx=False, - reorder_and_upcast_attn=False, + attention_softmax_in_fp32=True, + scale_attention_softmax_in_fp32=True, attention_type=AttentionType.MULTI_HEAD, **kwargs, ): @@ -188,8 +188,8 @@ def __init__( self.summary_proj_to_labels = summary_proj_to_labels self.scale_attn_weights = scale_attn_weights self.use_cache = use_cache - self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx - self.reorder_and_upcast_attn = reorder_and_upcast_attn + self.attention_softmax_in_fp32 = attention_softmax_in_fp32 + self.scale_attention_softmax_in_fp32 = scale_attention_softmax_in_fp32 self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 1642301963..63021a8f90 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -23,7 +23,6 @@ import torch import torch.utils.checkpoint from torch import nn -from torch.cuda.amp import autocast from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.activations import ACT2FN @@ -57,6 +56,37 @@ # TODO: Add support for santa models. ] +# Fused kernels +# Use separate functions for each case because conditionals prevent kernel fusion. +# TODO: Could have better fused kernels depending on scaling, dropout and head mask. +# Is it doable without writing 32 functions? + + +@torch.jit.script +def upcast_masked_softmax( + x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor, scale: float, softmax_dtype: torch.dtype +): + input_dtype = x.dtype + x = x.to(softmax_dtype) * scale + x = torch.where(mask, x, mask_value) + x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype) + return x + + +@torch.jit.script +def upcast_softmax(x: torch.Tensor, scale: float, softmax_dtype: torch.dtype): + input_dtype = x.dtype + x = x.to(softmax_dtype) * scale + x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype) + return x + + +@torch.jit.script +def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor): + x = torch.where(mask, x, mask_value) + x = torch.nn.functional.softmax(x, dim=-1) + return x + def load_tf_weights_in_gpt_bigcode(model, config, gpt_bigcode_checkpoint_path): """Load tf checkpoints in a pytorch model""" @@ -119,14 +149,6 @@ def load_tf_weights_in_gpt_bigcode(model, config, gpt_bigcode_checkpoint_path): class GPTBigCodeAttention(nn.Module): def __init__(self, config, is_cross_attention=False, layer_idx=None): super().__init__() - - max_positions = config.max_position_embeddings - self.register_buffer( - "bias", torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)), persistent=False - ) - self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False) - # We don't use a buffer because the mask value depends on the dtype, - # And the dtype will be different if upcasting. self.mask_value = None self.attention_type = AttentionType(config.attention_type) @@ -147,17 +169,11 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): self.scale_attn_weights = config.scale_attn_weights self.is_cross_attention = is_cross_attention - # Layer-wise attention scaling, reordering, and upcasting - self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx self.layer_idx = layer_idx - self.reorder_and_upcast_attn = config.reorder_and_upcast_attn - - self.scale_factor = 1.0 - if self.scale_attn_weights: - self.scale_factor /= self.head_dim**0.5 - - if self.scale_attn_by_inverse_layer_idx: - self.scale_factor /= self.layer_idx + 1 + self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32 + self.scale_attention_softmax_in_fp32 = ( + config.scale_attention_softmax_in_fp32 and config.attention_softmax_in_fp32 + ) if self.is_cross_attention: if self.is_mqa: @@ -216,42 +232,38 @@ def _matmul(self, x, y, dtype=None, scale_factor=1.0): z = torch.baddbmm(z, x, y, beta=0, alpha=scale_factor) return z.view(output_shape) - def _attn(self, query, key, value, attention_mask=None, head_mask=None, upcast=False): - with autocast(enabled=False): - attn_weights = self._matmul( - query, key.transpose(-1, -2), dtype=torch.float32 if upcast else None, scale_factor=self.scale_factor - ) + def _get_mask_value(self, device, dtype): + # torch.where expects a tensor. We use a cache to avoid recreating it every time. + if self.mask_value is None or self.mask_value.dtype != dtype or self.mask_value.device != device: + self.mask_value = torch.full([], torch.finfo(dtype).min, dtype=dtype, device=device) + return self.mask_value - if not self.is_cross_attention: - # if only "normal" attention layer implements causal mask - key_length = key.size(-2) - if self.is_mqa: - # (b, sq, nh, sk) - causal_mask = self.bias[None, key_length - query.size(1) : key_length, None, :key_length] - else: - # (b, nh, sq, sk) - causal_mask = self.bias[None, None, key_length - query.size(-2) : key_length, :key_length] - # torch.where expects a tensor. We use a cache to avoid recreating it every time. - if ( - self.mask_value is None - or self.mask_value.dtype != attn_weights.dtype - or self.mask_value.device != attn_weights.device - ): - self.mask_value = torch.full( - [], torch.finfo(attn_weights.dtype).min, dtype=attn_weights.dtype, device=attn_weights.device - ) - attn_weights = torch.where(causal_mask, attn_weights, self.mask_value) + def _attn(self, query, key, value, attention_mask=None, head_mask=None): + dtype = query.dtype + softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else dtype + upcast = dtype != softmax_dtype - if attention_mask is not None: - # Apply the attention mask - attn_weights = attn_weights + attention_mask + unscale = self.layer_idx + 1 if self.scale_attention_softmax_in_fp32 and upcast else 1 + scale_factor = unscale**-1 + if self.scale_attn_weights: + scale_factor /= self.head_dim**0.5 - attn_weights = nn.functional.softmax(attn_weights, dim=-1) + attn_weights = self._matmul(query, key.transpose(-1, -2), scale_factor=scale_factor) + + if upcast: + # Use a fused kernel to prevent a large overhead from casting and scaling. + if attention_mask is None: + attn_weights = upcast_softmax(attn_weights, unscale, softmax_dtype) + else: + mask_value = self._get_mask_value(attn_weights.device, softmax_dtype) + attn_weights = upcast_masked_softmax(attn_weights, attention_mask, mask_value, unscale, softmax_dtype) + else: + if attention_mask is not None: + mask_value = self._get_mask_value(attn_weights.device, softmax_dtype) + # This can be fused with the softmax, but the fused kernel seems slower. + attn_weights = torch.where(attention_mask, attn_weights, mask_value) + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) - # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise - if upcast and attn_weights.dtype != torch.float32: - raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32") - attn_weights = attn_weights.type(value.dtype) attn_weights = self.attn_dropout(attn_weights) # Mask heads if we want to @@ -324,9 +336,7 @@ def forward( else: present = None - attn_output, attn_weights = self._attn( - query, key, value, attention_mask, head_mask, upcast=self.reorder_and_upcast_attn - ) + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim, permute=not self.is_mqa) attn_output = self.c_proj(attn_output) @@ -670,7 +680,8 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel): def __init__(self, config): super().__init__(config) - + self.attention_type = AttentionType(config.attention_type) + self.is_mqa = self.attention_type != AttentionType.MULTI_HEAD self.embed_dim = config.hidden_size self.wte = nn.Embedding(config.vocab_size, self.embed_dim) @@ -680,6 +691,11 @@ def __init__(self, config): self.h = nn.ModuleList([GPTBigCodeBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)]) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + max_positions = config.max_position_embeddings + self.register_buffer( + "bias", torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)), persistent=False + ) + # Model parallel self.model_parallel = False self.device_map = None @@ -775,6 +791,9 @@ def forward( else: raise ValueError("You have to specify either input_ids or inputs_embeds") + if batch_size <= 0: + raise ValueError("batch_size has to be defined and > 0") + device = input_ids.device if input_ids is not None else inputs_embeds.device if token_type_ids is not None: @@ -791,34 +810,28 @@ def forward( position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) - # GPTBigCodeAttention mask. + # Self-attention mask. + query_length = input_shape[-1] + key_length = past_length + query_length + self_attention_mask = self.bias[None, key_length - query_length : key_length, :key_length] + if attention_mask is not None: - if batch_size <= 0: - raise ValueError("batch_size has to be defined and > 0") - attention_mask = attention_mask.view(batch_size, -1) - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - attention_mask = attention_mask[:, None, None, :] - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and the dtype's smallest value for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility - attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + self_attention_mask = self_attention_mask * attention_mask.view(batch_size, 1, -1).bool() + # MQA: (b, sq, nh, sk) + # MHA: (b, nh, sq, sk) + attention_mask = self_attention_mask.unsqueeze(2 if self.is_mqa else 1) # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - if self.config.add_cross_attention and encoder_hidden_states is not None: - encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() - encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) - if encoder_attention_mask is None: - encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) - encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) + if ( + self.config.add_cross_attention + and encoder_hidden_states is not None + and encoder_attention_mask is not None + ): + if encoder_attention_mask.dim() == 2: + encoder_attention_mask.unsqueeze(1) + assert encoder_attention_mask.dim() == 3 + encoder_attention_mask = encoder_attention_mask.bool().unsqueeze(2 if self.is_mqa else 1) else: encoder_attention_mask = None From 97f734cac7f65af12e360f06ac5687d2ff4af56d Mon Sep 17 00:00:00 2001 From: minimario Date: Mon, 20 Feb 2023 19:12:48 +0000 Subject: [PATCH 15/27] change test to use santacoder, add seed for the random inputs, increase tolerance --- .../gpt_bigcode/test_modeling_gpt_bigcode.py | 94 +++++++++---------- 1 file changed, 42 insertions(+), 52 deletions(-) diff --git a/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py b/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py index 2fdbceed29..575527cc52 100644 --- a/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py +++ b/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py @@ -40,12 +40,8 @@ GPTBigCodeModel, ) - from transformers.models.gpt_bigcode.configuration_gpt_bigcode import ( - AttentionType - ) - from transformers.models.gpt_bigcode.modeling_gpt_bigcode import ( - GPTBigCodeAttention - ) + from transformers.models.gpt_bigcode.configuration_gpt_bigcode import AttentionType + from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeAttention class GPTBigCodeModelTester: @@ -816,78 +812,72 @@ def test_contrastive_search_gpt_bigcode(self): ], ) + @require_torch class GPTBigCodeAttentionTest(unittest.TestCase): - def get_attention(self, attention_type : AttentionType): + def get_attention(self, attention_type: AttentionType): config = GPTBigCodeConfig.from_pretrained("bigcode/santacoder-fast-inference") config.attention_type = attention_type return GPTBigCodeAttention(config) def test_mqa_correctness(self): + torch.manual_seed(0) embed_dim = 2048 head_dim = 128 - random_attn_weight = torch.randn(embed_dim, embed_dim) - random_attn_k_weight = torch.randn(embed_dim, head_dim) - random_attn_v_weight = torch.randn(embed_dim, head_dim) - random_attn_bias = torch.randn(embed_dim) - random_attn_k_bias = torch.randn(head_dim) - random_attn_v_bias = torch.randn(head_dim) - random_proj = torch.randn(embed_dim, embed_dim) - random_proj_bias = torch.randn(embed_dim) - - # MULTI-HEAD ATTENTION + + # GET THE WEIGHTS FROM MULTI-QUERY ATTENTION 1 + attention_mq1 = self.get_attention(AttentionType.MULTI_QUERY_1) + state_dict_mq1 = attention_mq1.state_dict() + attn_weight, attn_k_weight, attn_v_weight = torch.split(state_dict_mq1["c_attn.weight"], [embed_dim, head_dim, head_dim], dim=1) + attn_bias, attn_k_bias, attn_v_bias = torch.split(state_dict_mq1["c_attn.bias"], [embed_dim, head_dim, head_dim], dim=0) + proj = state_dict_mq1["c_proj.weight"] + proj_bias = state_dict_mq1["c_proj.bias"] + + # PUT THEM INTO THE MULTI-HEAD ATTENTION num_heads = 16 c_attn_weight = torch.hstack( - [random_attn_weight] + - num_heads * [random_attn_k_weight] + - num_heads * [random_attn_v_weight]) + [attn_weight] + + num_heads * [attn_k_weight] + + num_heads * [attn_v_weight]) c_attn_bias = torch.hstack( - [random_attn_bias] + - num_heads * [random_attn_k_bias] + - num_heads * [random_attn_v_bias]) + [attn_bias] + + num_heads * [attn_k_bias] + + num_heads * [attn_v_bias]) attention_mh = self.get_attention(AttentionType.MULTI_HEAD) state_dict = attention_mh.state_dict() state_dict["c_attn.weight"] = c_attn_weight state_dict["c_attn.bias"] = c_attn_bias - state_dict["c_proj.weight"] = random_proj - state_dict["c_proj.bias"] = random_proj_bias + state_dict["c_proj.weight"] = proj + state_dict["c_proj.bias"] = proj_bias attention_mh.load_state_dict(state_dict) - # MULTI-QUERY ATTENTION 1 - attention_mq1 = self.get_attention(AttentionType.MULTI_QUERY_1) - state_dict_mq1 = attention_mq1.state_dict() - c_attn_weight = torch.hstack([random_attn_weight, random_attn_k_weight, random_attn_v_weight]) - c_attn_bias = torch.hstack([random_attn_bias, random_attn_k_bias, random_attn_v_bias]) - state_dict_mq1["c_attn.weight"] = c_attn_weight - state_dict_mq1["c_attn.bias"] = c_attn_bias - state_dict_mq1["c_proj.weight"] = random_proj - state_dict_mq1["c_proj.bias"] = random_proj_bias - attention_mq1.load_state_dict(state_dict_mq1) - - # MULTI-QUERY ATTENTION 2 + # PUT THEM INTO THE MULTI-QUERY ATTENTION 2 attention_mq2 = self.get_attention(AttentionType.MULTI_QUERY_2) state_dict_mq2 = attention_mq2.state_dict() - state_dict_mq2["q_attn.weight"] = random_attn_weight - state_dict_mq2["q_attn.bias"] = random_attn_bias - state_dict_mq2["kv_attn.weight"] = torch.hstack([random_attn_k_weight, random_attn_v_weight]) - state_dict_mq2["kv_attn.bias"] = torch.hstack([random_attn_k_bias, random_attn_v_bias]) - state_dict_mq2["c_proj.weight"] = random_proj - state_dict_mq2["c_proj.bias"] = random_proj_bias + state_dict_mq2["q_attn.weight"] = attn_weight + state_dict_mq2["q_attn.bias"] = attn_bias + state_dict_mq2["kv_attn.weight"] = torch.hstack([attn_k_weight, attn_v_weight]) + state_dict_mq2["kv_attn.bias"] = torch.hstack([attn_k_bias, attn_v_bias]) + state_dict_mq2["c_proj.weight"] = proj + state_dict_mq2["c_proj.bias"] = proj_bias attention_mq2.load_state_dict(state_dict_mq2) - # Run correctness test + # RUN CORRECTNESS TEST attention_mh.eval() attention_mq1.eval() attention_mq2.eval() num_tokens = 5 - hidden_states = torch.randn(1, num_tokens, embed_dim) - attention_mh_result = attention_mh(hidden_states)[0] - attention_mq1_result = attention_mq1(hidden_states)[0] - attention_mq2_result = attention_mq2(hidden_states)[0] - - self.assertTrue(torch.allclose(attention_mh_result, attention_mq1_result)) - self.assertTrue(torch.allclose(attention_mh_result, attention_mq2_result)) - self.assertTrue(torch.allclose(attention_mq1_result, attention_mq2_result)) \ No newline at end of file + + for i in range(5): + hidden_states = torch.randn(1, num_tokens, embed_dim) + attention_mh_result = attention_mh(hidden_states)[0] + attention_mq1_result = attention_mq1(hidden_states)[0] + attention_mq2_result = attention_mq2(hidden_states)[0] + + tolerance = 1e-6 + self.assertTrue(torch.allclose(attention_mh_result, attention_mq1_result, atol=tolerance)) + self.assertTrue(torch.allclose(attention_mh_result, attention_mq2_result, atol=tolerance)) + self.assertTrue(torch.allclose(attention_mq1_result, attention_mq2_result, atol=tolerance)) \ No newline at end of file From 8ea03f475617ca94424e25fa87c6cabb27966d56 Mon Sep 17 00:00:00 2001 From: minimario Date: Mon, 20 Feb 2023 19:37:55 +0000 Subject: [PATCH 16/27] add train mode test --- .../gpt_bigcode/test_modeling_gpt_bigcode.py | 31 ++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py b/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py index 575527cc52..886b458e93 100644 --- a/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py +++ b/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py @@ -864,13 +864,42 @@ def test_mqa_correctness(self): state_dict_mq2["c_proj.bias"] = proj_bias attention_mq2.load_state_dict(state_dict_mq2) - # RUN CORRECTNESS TEST + # RUN CORRECTNESS TEST IN EVAL MODE attention_mh.eval() attention_mq1.eval() attention_mq2.eval() num_tokens = 5 + for i in range(5): + hidden_states = torch.randn(1, num_tokens, embed_dim) + attention_mh_result = attention_mh(hidden_states)[0] + attention_mq1_result = attention_mq1(hidden_states)[0] + attention_mq2_result = attention_mq2(hidden_states)[0] + + tolerance = 1e-6 + self.assertTrue(torch.allclose(attention_mh_result, attention_mq1_result, atol=tolerance)) + self.assertTrue(torch.allclose(attention_mh_result, attention_mq2_result, atol=tolerance)) + self.assertTrue(torch.allclose(attention_mq1_result, attention_mq2_result, atol=tolerance)) + + # RUN CORRECTNESS TEST IN TRAIN MODE + attention_mh.train() + attention_mq1.train() + attention_mq2.train() + + # disable dropouts + for module in attention_mh.modules(): + if isinstance(module, torch.nn.Dropout): + module.eval() + for module in attention_mq1.modules(): + if isinstance(module, torch.nn.Dropout): + module.eval() + for module in attention_mq2.modules(): + if isinstance(module, torch.nn.Dropout): + module.eval() + + num_tokens = 5 + for i in range(5): hidden_states = torch.randn(1, num_tokens, embed_dim) attention_mh_result = attention_mh(hidden_states)[0] From 400adc3b12cd19a9826ab76f4709f55557da7469 Mon Sep 17 00:00:00 2001 From: minimario Date: Tue, 21 Feb 2023 21:15:15 +0000 Subject: [PATCH 17/27] add attention parameters to initialization, parameterize test --- .../gpt_bigcode/configuration_gpt_bigcode.py | 4 +- .../gpt_bigcode/test_modeling_gpt_bigcode.py | 106 ++++++++---------- 2 files changed, 48 insertions(+), 62 deletions(-) diff --git a/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py index 6546b47253..b50d31baa8 100644 --- a/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py @@ -15,7 +15,7 @@ # limitations under the License. """ OpenAI GPT-2 configuration""" from collections import OrderedDict -from enum import Enum +from enum import IntEnum from typing import Any, List, Mapping, Optional from transformers import PreTrainedTokenizer, TensorType, is_torch_available @@ -32,7 +32,7 @@ } -class AttentionType(Enum): +class AttentionType(IntEnum): MULTI_HEAD = 1 MULTI_QUERY_1 = 2 MULTI_QUERY_2 = 3 diff --git a/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py b/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py index 886b458e93..4b33ee99b0 100644 --- a/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py +++ b/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py @@ -816,35 +816,35 @@ def test_contrastive_search_gpt_bigcode(self): @require_torch class GPTBigCodeAttentionTest(unittest.TestCase): def get_attention(self, attention_type: AttentionType): - config = GPTBigCodeConfig.from_pretrained("bigcode/santacoder-fast-inference") - config.attention_type = attention_type + config = GPTBigCodeConfig.from_pretrained( + "bigcode/santacoder-fast-inference", + attention_type=attention_type, + attn_pdrop=0, + resid_pdrop=0, + ) return GPTBigCodeAttention(config) - def test_mqa_correctness(self): - torch.manual_seed(0) + def prepare_mqa_correctness_test(self, seed, test_mode="train"): + torch.manual_seed(seed) embed_dim = 2048 head_dim = 128 # GET THE WEIGHTS FROM MULTI-QUERY ATTENTION 1 attention_mq1 = self.get_attention(AttentionType.MULTI_QUERY_1) state_dict_mq1 = attention_mq1.state_dict() - attn_weight, attn_k_weight, attn_v_weight = torch.split(state_dict_mq1["c_attn.weight"], [embed_dim, head_dim, head_dim], dim=1) - attn_bias, attn_k_bias, attn_v_bias = torch.split(state_dict_mq1["c_attn.bias"], [embed_dim, head_dim, head_dim], dim=0) + attn_weight, attn_k_weight, attn_v_weight = torch.split( + state_dict_mq1["c_attn.weight"], [embed_dim, head_dim, head_dim], dim=1 + ) + attn_bias, attn_k_bias, attn_v_bias = torch.split( + state_dict_mq1["c_attn.bias"], [embed_dim, head_dim, head_dim], dim=0 + ) proj = state_dict_mq1["c_proj.weight"] proj_bias = state_dict_mq1["c_proj.bias"] # PUT THEM INTO THE MULTI-HEAD ATTENTION num_heads = 16 - c_attn_weight = torch.hstack( - [attn_weight] + - num_heads * [attn_k_weight] + - num_heads * [attn_v_weight]) - - c_attn_bias = torch.hstack( - [attn_bias] + - num_heads * [attn_k_bias] + - num_heads * [attn_v_bias]) - + c_attn_weight = torch.hstack([attn_weight] + num_heads * [attn_k_weight] + num_heads * [attn_v_weight]) + c_attn_bias = torch.hstack([attn_bias] + num_heads * [attn_k_bias] + num_heads * [attn_v_bias]) attention_mh = self.get_attention(AttentionType.MULTI_HEAD) state_dict = attention_mh.state_dict() state_dict["c_attn.weight"] = c_attn_weight @@ -864,49 +864,35 @@ def test_mqa_correctness(self): state_dict_mq2["c_proj.bias"] = proj_bias attention_mq2.load_state_dict(state_dict_mq2) - # RUN CORRECTNESS TEST IN EVAL MODE - attention_mh.eval() - attention_mq1.eval() - attention_mq2.eval() - - num_tokens = 5 - - for i in range(5): - hidden_states = torch.randn(1, num_tokens, embed_dim) - attention_mh_result = attention_mh(hidden_states)[0] - attention_mq1_result = attention_mq1(hidden_states)[0] - attention_mq2_result = attention_mq2(hidden_states)[0] - - tolerance = 1e-6 - self.assertTrue(torch.allclose(attention_mh_result, attention_mq1_result, atol=tolerance)) - self.assertTrue(torch.allclose(attention_mh_result, attention_mq2_result, atol=tolerance)) - self.assertTrue(torch.allclose(attention_mq1_result, attention_mq2_result, atol=tolerance)) - - # RUN CORRECTNESS TEST IN TRAIN MODE - attention_mh.train() - attention_mq1.train() - attention_mq2.train() - - # disable dropouts - for module in attention_mh.modules(): - if isinstance(module, torch.nn.Dropout): - module.eval() - for module in attention_mq1.modules(): - if isinstance(module, torch.nn.Dropout): - module.eval() - for module in attention_mq2.modules(): - if isinstance(module, torch.nn.Dropout): - module.eval() + # PUT THE MODEL INTO THE CORRECT MODE + if test_mode == "eval": + attention_mh.eval() + attention_mq1.eval() + attention_mq2.eval() + elif test_mode == "train": + attention_mh.train() + attention_mq1.train() + attention_mq2.train() + else: + raise ValueError(f"test_mode must be train or eval, but found: {test_mode}") + # RUN AN INPUT THROUGH THE MODELS num_tokens = 5 - - for i in range(5): - hidden_states = torch.randn(1, num_tokens, embed_dim) - attention_mh_result = attention_mh(hidden_states)[0] - attention_mq1_result = attention_mq1(hidden_states)[0] - attention_mq2_result = attention_mq2(hidden_states)[0] - - tolerance = 1e-6 - self.assertTrue(torch.allclose(attention_mh_result, attention_mq1_result, atol=tolerance)) - self.assertTrue(torch.allclose(attention_mh_result, attention_mq2_result, atol=tolerance)) - self.assertTrue(torch.allclose(attention_mq1_result, attention_mq2_result, atol=tolerance)) \ No newline at end of file + hidden_states = torch.randn(1, num_tokens, embed_dim) + attention_mh_result = attention_mh(hidden_states)[0] + attention_mq1_result = attention_mq1(hidden_states)[0] + attention_mq2_result = attention_mq2(hidden_states)[0] + + # CHECK THAT ALL OUTPUTS ARE THE SAME + tolerance = 1e-5 + self.assertTrue(torch.allclose(attention_mh_result, attention_mq1_result, atol=tolerance)) + self.assertTrue(torch.allclose(attention_mh_result, attention_mq2_result, atol=tolerance)) + self.assertTrue(torch.allclose(attention_mq1_result, attention_mq2_result, atol=tolerance)) + + def test_mqa_correctness_train(self): + for seed in range(5): + self.prepare_mqa_correctness_test(seed=seed, test_mode="train") + + def test_mqa_correctness_eval(self): + for seed in range(5): + self.prepare_mqa_correctness_test(seed=seed, test_mode="eval") From edd1622d8db26ecb0d41cd86f5ebc66c383d8345 Mon Sep 17 00:00:00 2001 From: minimario Date: Fri, 24 Feb 2023 05:32:11 +0000 Subject: [PATCH 18/27] use parameterized test --- .../gpt_bigcode/test_modeling_gpt_bigcode.py | 25 +++++-------------- 1 file changed, 6 insertions(+), 19 deletions(-) diff --git a/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py b/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py index 4b33ee99b0..4877064472 100644 --- a/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py +++ b/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py @@ -17,6 +17,7 @@ import datetime import math import unittest +from parameterized import parameterized from transformers import GPTBigCodeConfig, is_torch_available from transformers.testing_utils import require_torch, slow, torch_device @@ -824,7 +825,8 @@ def get_attention(self, attention_type: AttentionType): ) return GPTBigCodeAttention(config) - def prepare_mqa_correctness_test(self, seed, test_mode="train"): + @parameterized.expand([(seed, is_train_mode) for seed in range(5) for is_train_mode in [True, False]]) + def test_mqa_correctness(self, seed, is_train_mode=True): torch.manual_seed(seed) embed_dim = 2048 head_dim = 128 @@ -865,16 +867,9 @@ def prepare_mqa_correctness_test(self, seed, test_mode="train"): attention_mq2.load_state_dict(state_dict_mq2) # PUT THE MODEL INTO THE CORRECT MODE - if test_mode == "eval": - attention_mh.eval() - attention_mq1.eval() - attention_mq2.eval() - elif test_mode == "train": - attention_mh.train() - attention_mq1.train() - attention_mq2.train() - else: - raise ValueError(f"test_mode must be train or eval, but found: {test_mode}") + attention_mh.train(is_train_mode) + attention_mq1.train(is_train_mode) + attention_mq2.train(is_train_mode) # RUN AN INPUT THROUGH THE MODELS num_tokens = 5 @@ -888,11 +883,3 @@ def prepare_mqa_correctness_test(self, seed, test_mode="train"): self.assertTrue(torch.allclose(attention_mh_result, attention_mq1_result, atol=tolerance)) self.assertTrue(torch.allclose(attention_mh_result, attention_mq2_result, atol=tolerance)) self.assertTrue(torch.allclose(attention_mq1_result, attention_mq2_result, atol=tolerance)) - - def test_mqa_correctness_train(self): - for seed in range(5): - self.prepare_mqa_correctness_test(seed=seed, test_mode="train") - - def test_mqa_correctness_eval(self): - for seed in range(5): - self.prepare_mqa_correctness_test(seed=seed, test_mode="eval") From 7f00e1cc10756c0f7aca2b8f879687c4d35d0479 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 28 Feb 2023 09:28:31 -0500 Subject: [PATCH 19/27] Add santacoder (#9) --- .../gpt_bigcode/configuration_gpt_bigcode.py | 4 +++- .../gpt_bigcode/modeling_gpt_bigcode.py | 4 +--- .../gpt_bigcode/test_modeling_gpt_bigcode.py | 24 +++++++++---------- 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py index 092dab617b..c2fd6125c6 100644 --- a/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py @@ -27,7 +27,9 @@ logger = logging.get_logger(__name__) GPT_BIGCODE_PRETRAINED_CONFIG_ARCHIVE_MAP = { - # TODO: Add support for santa models. + "bigcode/santacoder-fast-inference": ( + "https://huggingface.co/bigcode/santacoder-fast-inference/resolve/main/config.json" + ), } diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 3d9f9f4220..d2ed08e4ba 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -52,9 +52,7 @@ _CHECKPOINT_FOR_DOC = "gpt_bigcode" _CONFIG_FOR_DOC = "GPTBigCodeConfig" -GPT_BIGCODE_PRETRAINED_MODEL_ARCHIVE_LIST = [ - # TODO: Add support for santa models. -] +GPT_BIGCODE_PRETRAINED_MODEL_ARCHIVE_LIST = ["bigcode/santacoder-fast-inference"] # Fused kernels # Use separate functions for each case because conditionals prevent kernel fusion. diff --git a/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py b/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py index 4877064472..98790c67f6 100644 --- a/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py +++ b/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py @@ -101,7 +101,7 @@ def __init__( self.pad_token_id = vocab_size - 1 def get_large_model_config(self): - return GPTBigCodeConfig.from_pretrained("gpt2") + return GPTBigCodeConfig.from_pretrained("bigcode/santacoder-fast-inference") def prepare_config_and_inputs( self, gradient_checkpointing=False, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False @@ -537,9 +537,9 @@ def test_gpt_bigcode_weight_initialization(self): @slow def test_batch_generation(self): - model = GPTBigCodeLMHeadModel.from_pretrained("gpt2") + model = GPTBigCodeLMHeadModel.from_pretrained("bigcode/santacoder-fast-inference") model.to(torch_device) - tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + tokenizer = GPT2Tokenizer.from_pretrained("bigcode/santacoder") tokenizer.padding_side = "left" @@ -596,9 +596,9 @@ def test_batch_generation(self): @slow def test_batch_generation_2heads(self): - model = GPTBigCodeDoubleHeadsModel.from_pretrained("gpt2") + model = GPTBigCodeDoubleHeadsModel.from_pretrained("bigcode/santacoder-fast-inference") model.to(torch_device) - tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + tokenizer = GPT2Tokenizer.from_pretrained("bigcode/santacoder") tokenizer.padding_side = "left" @@ -671,7 +671,7 @@ def _test_lm_generate_gpt_bigcode_helper( verify_outputs=True, ): model = GPTBigCodeLMHeadModel.from_pretrained( - "gpt2", + "bigcode/santacoder-fast-inference", reorder_and_upcast_attn=reorder_and_upcast_attn, scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx, ) @@ -712,8 +712,8 @@ def test_lm_generate_gpt_bigcode_with_scale_attn_by_inverse_layer_idx(self): @slow def test_gpt_bigcode_sample(self): - tokenizer = GPT2Tokenizer.from_pretrained("gpt2") - model = GPTBigCodeLMHeadModel.from_pretrained("gpt2") + tokenizer = GPT2Tokenizer.from_pretrained("bigcode/santacoder") + model = GPTBigCodeLMHeadModel.from_pretrained("bigcode/santacoder-fast-inference") model.to(torch_device) torch.manual_seed(0) @@ -740,8 +740,8 @@ def test_gpt_bigcode_sample(self): @slow def test_gpt_bigcode_sample_max_time(self): - tokenizer = GPTBigCodeTokenizer.from_pretrained("gpt2") - model = GPTBigCodeLMHeadModel.from_pretrained("gpt2") + tokenizer = GPT2Tokenizer.from_pretrained("bigcode/santacoder") + model = GPTBigCodeLMHeadModel.from_pretrained("bigcode/santacoder-fast-inference") model.to(torch_device) torch.manual_seed(0) @@ -786,8 +786,8 @@ def test_contrastive_search_gpt_bigcode(self): "laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based" ) - gpt_bigcode_tokenizer = GPT2Tokenizer.from_pretrained("gpt2-large") - gpt_bigcode_model = GPTBigCodeLMHeadModel.from_pretrained("gpt2-large").to(torch_device) + gpt_bigcode_tokenizer = GPT2Tokenizer.from_pretrained("bigcode/santacoder") + gpt_bigcode_model = GPTBigCodeLMHeadModel.from_pretrained("bigcode/santacoder-fast-inference").to(torch_device) input_ids = gpt_bigcode_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) outputs = gpt_bigcode_model.generate(input_ids, penalty_alpha=0.6, top_k=4, max_length=256) From b989169f812f50e7be0fee3c7fb4d9b5e247f419 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 2 Mar 2023 14:23:11 -0500 Subject: [PATCH 20/27] More optimizations (#13) --- .../gpt_bigcode/configuration_gpt_bigcode.py | 3 +- .../gpt_bigcode/modeling_gpt_bigcode.py | 161 ++++----- ...convert_megatron_gpt_bigcode_checkpoint.py | 329 +++++------------- .../megatron_gpt_bigcode/push_checkpoints.py | 3 +- .../gpt_bigcode/test_modeling_gpt_bigcode.py | 3 +- 5 files changed, 164 insertions(+), 335 deletions(-) diff --git a/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py index c2fd6125c6..7f05ffa9c2 100644 --- a/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py @@ -196,8 +196,7 @@ def __init__( self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id - # Convert to an int so it's JSON-serializable. - self.attention_type = AttentionType(attention_type).value + self.attention_type = AttentionType(attention_type) super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index d2ed08e4ba..d30ec4186a 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -18,7 +18,7 @@ import math import os from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -33,7 +33,7 @@ TokenClassifierOutput, ) from transformers.modeling_utils import PreTrainedModel, SequenceSummary -from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer +from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from transformers.utils import ( ModelOutput, add_code_sample_docstrings, @@ -177,17 +177,17 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): if self.is_mqa: raise NotImplementedError(f"attention_type {self.attention_type} for cross_attention") - self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim) - self.q_attn = Conv1D(self.embed_dim, self.embed_dim) + self.c_attn = torch.nn.Linear(self.embed_dim, 2 * self.embed_dim) + self.q_attn = torch.nn.Linear(self.embed_dim, self.embed_dim) else: if self.attention_type == AttentionType.MULTI_QUERY_2: - self.q_attn = Conv1D(self.embed_dim, self.embed_dim) + self.q_attn = torch.nn.Linear(self.embed_dim, self.embed_dim) # Keys and values are shared across heads - self.kv_attn = Conv1D(2 * self.head_dim, self.embed_dim) + self.kv_attn = torch.nn.Linear(self.embed_dim, 2 * self.head_dim) else: - self.c_attn = Conv1D(self.embed_dim + 2 * self.kv_dim, self.embed_dim) + self.c_attn = torch.nn.Linear(self.embed_dim, self.embed_dim + 2 * self.kv_dim) - self.c_proj = Conv1D(self.embed_dim, self.embed_dim) + self.c_proj = torch.nn.Linear(self.embed_dim, self.embed_dim) self.attn_dropout = nn.Dropout(config.attn_pdrop) self.resid_dropout = nn.Dropout(config.resid_pdrop) @@ -195,41 +195,22 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): self.pruned_heads = set() def prune_heads(self, heads): + if self.is_mqa: + raise NotImplementedError("prune_heads not implemented for MQA") if len(heads) == 0: return heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads) index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)]) - # Prune conv1d layers - self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1) - self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0) + # Prune linear layers + self.c_attn = prune_linear_layer(self.c_attn, index_attn, dim=1) + self.c_proj = prune_linear_layer(self.c_proj, index, dim=0) # Update hyper params self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads)) self.num_heads = self.num_heads - len(heads) self.pruned_heads = self.pruned_heads.union(heads) - def _matmul(self, x, y, dtype=None, scale_factor=1.0): - output_shape = (*x.size()[:-1], y.size(-1)) - if self.is_mqa: - # Q x K: (b, sq, nh, hs) x (b, hs, sk) -> (b, sq, nh, sk) - # A X V: (b, sq, nh, sk) x (b, sk, hs) -> (b, sq, nh, hs) - output_view = (x.size(0), x.size(1) * x.size(2), y.size(-1)) - # No copy needed for MQA 2, or when layer_past is provided. - x = x.reshape(*output_view[:-1], x.size(-1)) - else: - # Q x K: (b, nh, sq, hs) x (b, nh, hs, sk) -> (b, nh, sq, sk) - # A X V: (b, nh, sq, sk) x (b, nh, sk, hs) -> (b, nh, sq, hs) - output_view = (x.size(0) * x.size(1), x.size(2), y.size(-1)) - # Always copies - x = x.reshape(output_view[0], *x.size()[2:]) - # No copy when layer_past is provided. - y = y.reshape(output_view[0], *y.size()[2:]) - # This is identical to matmul when scale_factor==1 - z = torch.empty(output_view, dtype=x.dtype if dtype is None else dtype, device=x.device) - z = torch.baddbmm(z, x, y, beta=0, alpha=scale_factor) - return z.view(output_shape) - def _get_mask_value(self, device, dtype): # torch.where expects a tensor. We use a cache to avoid recreating it every time. if self.mask_value is None or self.mask_value.dtype != dtype or self.mask_value.device != device: @@ -246,7 +227,32 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): if self.scale_attn_weights: scale_factor /= self.head_dim**0.5 - attn_weights = self._matmul(query, key.transpose(-1, -2), scale_factor=scale_factor) + # MQA: (b, sq, nh * hs) + # MHA: (b, nh, sq, hs) + query_shape = query.shape + batch_size = query_shape[0] + key_length = key.size(-1) + if self.is_mqa: + # (b, sq, nh, hs) x (b, hs, sk) -> (b, sq, nh, sk) + query_length = query_shape[1] + attn_shape = (batch_size, query_length, self.num_heads, key_length) + attn_view = (batch_size, query_length * self.num_heads, key_length) + # No copy needed for MQA 2, or when layer_past is provided. + query = query.reshape(batch_size, query_length * self.num_heads, self.head_dim) + else: + # (b, nh, sq, hs) x (b, nh, hs, sk) -> (b, nh, sq, sk) + query_length = query_shape[2] + attn_shape = (batch_size, self.num_heads, query_length, key_length) + attn_view = (batch_size * self.num_heads, query_length, key_length) + # Always copies + query = query.reshape(batch_size * self.num_heads, query_length, self.head_dim) + # No copy when layer_past is provided. + key = key.reshape(batch_size * self.num_heads, self.head_dim, key_length) + attn_weights = torch.baddbmm( + torch.empty(attn_view, dtype=query.dtype, device=query.device), query, key, beta=0, alpha=scale_factor + ).view(attn_shape) + + # attn_weights = self._matmul(query, key, scale_factor=scale_factor) if upcast: # Use a fused kernel to prevent a large overhead from casting and scaling. @@ -268,33 +274,17 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): if head_mask is not None: attn_weights = attn_weights * head_mask - attn_output = self._matmul(attn_weights, value) + if self.is_mqa: + attn_output = torch.bmm(attn_weights.view(attn_view), value).view(query_shape) + else: + attn_output = torch.matmul(attn_weights, value) return attn_output, attn_weights - def _split_heads(self, tensor, num_heads, attn_head_size, permute=True): - """ - Splits hidden_size dim into attn_head_size and num_heads - """ - new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) - tensor = tensor.view(new_shape) - if permute: - tensor = tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) - return tensor - - def _merge_heads(self, tensor, num_heads, attn_head_size, permute=True): - """ - Merges attn_head_size dim and num_attn_heads dim into hidden_size - """ - if permute: - tensor = tensor.permute(0, 2, 1, 3).contiguous() - new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) - return tensor.view(new_shape) - def forward( self, - hidden_states: Optional[Tuple[torch.FloatTensor]], - layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + hidden_states: torch.FloatTensor, + layer_past: Optional[torch.Tensor] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, @@ -310,37 +300,38 @@ def forward( ) query = self.q_attn(hidden_states) - key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) + key_value = self.c_attn(encoder_hidden_states) attention_mask = encoder_attention_mask + elif self.attention_type == AttentionType.MULTI_QUERY_2: + query = self.q_attn(hidden_states) + key_value = self.kv_attn(hidden_states) + elif self.attention_type == AttentionType.MULTI_QUERY_1: + query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2) else: - if self.attention_type == AttentionType.MULTI_QUERY_2: - query = self.q_attn(hidden_states) - key, value = self.kv_attn(hidden_states).split((self.kv_dim, self.kv_dim), dim=2) - else: - query, key, value = self.c_attn(hidden_states).split((self.embed_dim, self.kv_dim, self.kv_dim), dim=2) - - query = self._split_heads(query, self.num_heads, self.head_dim, permute=not self.is_mqa) - if not self.is_mqa: - key = self._split_heads(key, self.num_heads, self.head_dim) - value = self._split_heads(value, self.num_heads, self.head_dim) + # Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim), + # i.e., the memory layout is not the same as GPT2. + # This makes the concatenation with past_key_value more efficient. + query, key_value = ( + self.c_attn(hidden_states) + .view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim) + .transpose(1, 2) + .split((self.head_dim, 2 * self.head_dim), dim=3) + ) if layer_past is not None: - past_key, past_value = layer_past - key = torch.cat((past_key, key), dim=-2) - value = torch.cat((past_value, value), dim=-2) + key_value = torch.cat((layer_past, key_value), dim=-2) - if use_cache is True: - present = (key, value) - else: - present = None + key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) - attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + attn_output, attn_weights = self._attn(query, key.transpose(-1, -2), value, attention_mask, head_mask) - attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim, permute=not self.is_mqa) + if not self.is_mqa: + attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape) attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) - outputs = (attn_output, present) + # TODO: Is it ok to send unwrapped present? + outputs = (attn_output, key_value if use_cache else None) if output_attentions: outputs += (attn_weights,) @@ -351,8 +342,8 @@ class GPTBigCodeMLP(nn.Module): def __init__(self, intermediate_size, config): super().__init__() embed_dim = config.hidden_size - self.c_fc = Conv1D(intermediate_size, embed_dim) - self.c_proj = Conv1D(embed_dim, intermediate_size) + self.c_fc = torch.nn.Linear(embed_dim, intermediate_size) + self.c_proj = torch.nn.Linear(intermediate_size, embed_dim) self.act = ACT2FN[config.activation_function] self.dropout = nn.Dropout(config.resid_pdrop) @@ -385,14 +376,14 @@ def __init__(self, config, layer_idx=None): def forward( self, hidden_states: Optional[Tuple[torch.FloatTensor]], - layer_past: Optional[Tuple[torch.Tensor]] = None, + layer_past: Optional[torch.Tensor] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, - ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: + ) -> Tuple: residual = hidden_states hidden_states = self.ln_1(hidden_states) attn_outputs = self.attn( @@ -462,7 +453,7 @@ def __init__(self, *inputs, **kwargs): def _init_weights(self, module): """Initialize the weights.""" - if isinstance(module, (nn.Linear, Conv1D)): + if isinstance(module, nn.Linear): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) @@ -757,7 +748,7 @@ def _prune_heads(self, heads_to_prune): def forward( self, input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + past_key_values: Optional[List[torch.Tensor]] = None, attention_mask: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -852,7 +843,7 @@ def forward( output_shape = input_shape + (hidden_states.size(-1),) - presents = () if use_cache else None + presents = [] if use_cache else None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None all_hidden_states = () if output_hidden_states else None @@ -909,8 +900,8 @@ def custom_forward(*inputs): ) hidden_states = outputs[0] - if use_cache is True: - presents = presents + (outputs[1],) + if use_cache: + presents.append(outputs[1]) if output_attentions: all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) diff --git a/src/transformers/models/megatron_gpt_bigcode/convert_megatron_gpt_bigcode_checkpoint.py b/src/transformers/models/megatron_gpt_bigcode/convert_megatron_gpt_bigcode_checkpoint.py index 9de9f4603b..6ca2e79671 100644 --- a/src/transformers/models/megatron_gpt_bigcode/convert_megatron_gpt_bigcode_checkpoint.py +++ b/src/transformers/models/megatron_gpt_bigcode/convert_megatron_gpt_bigcode_checkpoint.py @@ -32,12 +32,9 @@ # in your path, i.e., /path/to/Megatron-DeepSpeed/ # -# TODO: Update this file - import argparse import os import re -import zipfile import torch @@ -67,109 +64,98 @@ def recursive_print(name, val, spaces=0): print(msg, ":", val) -def fix_query_key_value_ordering(param, checkpoint_version, num_splits, num_heads, hidden_size): - # Permutes layout of param tensor to [num_splits * num_heads * hidden_size, :] - # for compatibility with later versions of NVIDIA Megatron-LM. - # The inverse operation is performed inside Megatron-LM to read checkpoints: - # https://github.com/NVIDIA/Megatron-LM/blob/v2.4/megatron/checkpointing.py#L209 - # If param is the weight tensor of the self-attention block, the returned tensor - # will have to be transposed one more time to be read by HuggingFace GPT2. - input_shape = param.size() - if checkpoint_version == 1.0: - # version 1.0 stores [num_heads * hidden_size * num_splits, :] - saved_shape = (num_heads, hidden_size, num_splits) + input_shape[1:] - param = param.view(*saved_shape) - param = param.transpose(0, 2) - param = param.transpose(1, 2).contiguous() - elif checkpoint_version >= 2.0: - # other versions store [num_heads * num_splits * hidden_size, :] - saved_shape = (num_heads, num_splits, hidden_size) + input_shape[1:] - param = param.view(*saved_shape) - param = param.transpose(0, 1).contiguous() - param = param.view(*input_shape) - return param - - #################################################################################################### +# The simple map of names for "automated" rules. +NAME_MAP = { + "attention.dense": ".attn.c_proj.", + "self_attention.dense": ".attn.c_proj.", + "mlp.dense_h_to_4h": ".mlp.c_fc.", + "mlp.dense_4h_to_h": ".mlp.c_proj.", + "self_attention.query_key_value": ".attn.c_attn.", + "self_attention.query": ".attn.q_attn.", + "self_attention.key_value": ".attn.kv_attn.", +} + -def convert_megatron_checkpoint(args, input_state_dict, config): +def convert_megatron_checkpoint(input_state_dict, merge_qkv): # The converted output model. output_state_dict = {} + ds_args = input_state_dict["args"] - # old versions did not store training args - ds_args = input_state_dict.get("args", None) if ds_args is not None: - # do not make the user write a config file when the exact dimensions/sizes are already in the checkpoint - # from pprint import pprint - # pprint(vars(ds_args)) - - config.vocab_size = ds_args.padded_vocab_size - config.n_positions = ds_args.max_position_embeddings - config.n_embd = ds_args.hidden_size - config.n_layer = ds_args.num_layers - config.n_head = ds_args.num_attention_heads - config.n_inner = ds_args.ffn_hidden_size - - if ds_args.attention_head_type == "multihead": - config.attention_type = 1 + if ds_args.bias_gelu_fusion: + activation_function = "gelu_pytorch_tanh" + elif ds_args.openai_gelu: + activation_function = "gelu_new" else: - assert ds_args.attention_head_type == "multiquery" - config.attention_type = 2 if args.merge_qkv else 3 - - # also set `scale_attn_weights` and `scale_attn_by_inverse_layer_idx` ? - # Uncommenting the next line makes the converted model output different logits. - # config.scale_attn_by_inverse_layer_idx = ds_args.apply_query_key_layer_scaling - # pprint(config) - - # The number of heads. - heads = config.n_head - # The hidden_size per head. - hidden_size_per_head = config.n_embd // config.n_head - # Megatron-LM checkpoint version - if "checkpoint_version" in input_state_dict.keys(): - checkpoint_version = input_state_dict["checkpoint_version"] + activation_function = "gelu" + else: + # in the very early days this used to be "gelu_new" + activation_function = "gelu_new" + + if ds_args.attention_head_type == "multihead": + attention_type = 1 else: - checkpoint_version = 0.0 + assert ds_args.attention_head_type == "multiquery" + attention_type = 2 if merge_qkv else 3 + + attention_softmax_in_fp32 = ds_args.attention_softmax_in_fp32 or ds_args.apply_query_key_layer_scaling + + # Spell out all parameters in case the defaults change. + config = GPTBigCodeConfig( + architectures=["GPTBigCodeLMHeadModel"], + vocab_size=ds_args.padded_vocab_size, + n_positions=ds_args.max_position_embeddings, + n_embd=ds_args.hidden_size, + n_layer=ds_args.num_layers, + n_head=ds_args.num_attention_heads, + n_inner=ds_args.ffn_hidden_size, + activation_function=activation_function, + attention_type=attention_type, + resid_pdrop=0.1, + embd_pdrop=0.1, + attn_pdrop=0.1, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + summary_type="cls_index", + summary_use_proj=True, + summary_activation=None, + summary_proj_to_labels=True, + summary_first_dropout=0.1, + scale_attn_weights=True, + use_cache=True, + bos_token_id=50256, + eos_token_id=50256, + attention_softmax_in_fp32=attention_softmax_in_fp32, + scale_attention_softmax_in_fp32=True, + ) + + # from pprint import pprint + # pprint(vars(ds_args)) + # pprint(config) + + # Megatron-LM checkpoint version + checkpoint_version = input_state_dict["checkpoint_version"] + if checkpoint_version < 2.0: + raise NotImplementedError(f"Checkpoint version {checkpoint_version} not supported.") # The model. - model = input_state_dict["model"] - # The language model. - lm = model["language_model"] - # The embeddings. - embeddings = lm["embedding"] - - # The word embeddings. - word_embeddings = embeddings["word_embeddings"]["weight"] - # Truncate the embedding table to vocab_size rows. - word_embeddings = word_embeddings[: config.vocab_size, :] + model = input_state_dict["model"]["language_model"] + + # The word embeddings, truncated to to vocab_size rows. + word_embeddings = model["embedding"]["word_embeddings"]["weight"][: config.vocab_size, :] output_state_dict["transformer.wte.weight"] = word_embeddings # The position embeddings. - pos_embeddings = embeddings["position_embeddings"]["weight"] - # Read the causal mask dimension (seqlen). [max_sequence_length, hidden_size] - n_positions = pos_embeddings.size(0) - if n_positions != config.n_positions: - raise ValueError( - f"pos_embeddings.max_sequence_length={n_positions} and config.n_positions={config.n_positions} don't match" - ) - # Store the position embeddings. - output_state_dict["transformer.wpe.weight"] = pos_embeddings + output_state_dict["transformer.wpe.weight"] = model["embedding"]["position_embeddings"]["weight"] # The transformer. - transformer = lm["transformer"] if "transformer" in lm.keys() else lm["encoder"] + transformer = model["transformer"] if "transformer" in model else model["encoder"] # The regex to extract layer names. layer_re = re.compile("layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)") - # The simple map of names for "automated" rules. - megatron_to_transformers = { - "attention.dense": ".attn.c_proj.", - "self_attention.dense": ".attn.c_proj.", - "mlp.dense_h_to_4h": ".mlp.c_fc.", - "mlp.dense_4h_to_h": ".mlp.c_proj.", - } - # Extract the layers. for key, val in transformer.items(): # Match the name. @@ -195,81 +181,16 @@ def convert_megatron_checkpoint(args, input_state_dict, config): ln_name = "ln_1" if op_name.startswith("input") else "ln_2" output_state_dict[layer_name + "." + ln_name + "." + weight_or_bias] = val - # Transpose the QKV matrix. - elif ( - op_name == "attention.query_key_value" or op_name == "self_attention.query_key_value" - ) and weight_or_bias == "weight": - out_val = fix_query_key_value_ordering(val, checkpoint_version, 3, heads, hidden_size_per_head) - # Megatron stores (3*D) x D but transformers-GPT2 expects D x 3*D. - out_val = out_val.transpose(0, 1).contiguous() - # Store. - output_state_dict[layer_name + ".attn.c_attn.weight"] = out_val - - # Tranpose the Q matrix (for MQA) - elif (op_name == "self_attention.query") and weight_or_bias == "weight": - out_val = fix_query_key_value_ordering(val, checkpoint_version, 1, heads, hidden_size_per_head) - # Megatron stores (out x in) but transformers-GPT2 expects (in x out). - out_val = out_val.transpose(0, 1).contiguous() - # Store. - output_state_dict[layer_name + ".attn.q_attn.weight"] = out_val - - # Tranpose the KV matrix (for MQA) - elif (op_name == "self_attention.key_value") and weight_or_bias == "weight": - # Key-values are shared across heads - out_val = fix_query_key_value_ordering(val, checkpoint_version, 2, 1, hidden_size_per_head) - # Megatron stores (out x in) but transformers-GPT2 expects (in x out). - out_val = out_val.transpose(0, 1).contiguous() - if args.merge_qkv: - # Concatenate the tensors. - # Query is before key_value in the dict. - query = output_state_dict.pop(layer_name + ".attn.q_attn.weight") - out_val = torch.cat([query, out_val], dim=1) - output_state_dict[layer_name + ".attn.c_attn.weight"] = out_val - else: - # Store. - output_state_dict[layer_name + ".attn.kv_attn.weight"] = out_val - - # Transpose the bias. - elif ( - op_name == "attention.query_key_value" or op_name == "self_attention.query_key_value" - ) and weight_or_bias == "bias": - - out_val = fix_query_key_value_ordering(val, checkpoint_version, 3, heads, hidden_size_per_head) - # Store. No change of shape. - output_state_dict[layer_name + ".attn.c_attn.bias"] = out_val - - # Transpose the Q bias (MQA) - elif (op_name == "self_attention.query") and weight_or_bias == "bias": - - out_val = fix_query_key_value_ordering(val, checkpoint_version, 1, heads, hidden_size_per_head) - # Store. No change of shape. - output_state_dict[layer_name + ".attn.q_attn.bias"] = out_val - - # Transpose the KV bias (MQA) - elif (op_name == "self_attention.key_value") and weight_or_bias == "bias": - - out_val = fix_query_key_value_ordering(val, checkpoint_version, 2, 1, hidden_size_per_head) - if args.merge_qkv: - # Concatenate the tensors. - # Query is before key_value in the dict. - query = output_state_dict.pop(layer_name + ".attn.q_attn.bias") - out_val = torch.cat([query, out_val], dim=0) - output_state_dict[layer_name + ".attn.c_attn.bias"] = out_val - else: - # Store. No change of shape. - output_state_dict[layer_name + ".attn.kv_attn.bias"] = out_val - - # Transpose the weights. - elif weight_or_bias == "weight": - - out_name = megatron_to_transformers[op_name] - output_state_dict[layer_name + out_name + "weight"] = val.transpose(0, 1) - - # Copy the bias. - elif weight_or_bias == "bias": - - out_name = megatron_to_transformers[op_name] - output_state_dict[layer_name + out_name + "bias"] = val + # Concatenate QKV matrix. + elif merge_qkv and (op_name == "self_attention.key_value"): + # Query is before key_value in the dict. + query = output_state_dict.pop(layer_name + ".attn.q_attn." + weight_or_bias) + out_val = torch.cat([query, val], dim=0) + output_state_dict[layer_name + ".attn.c_attn." + weight_or_bias] = out_val + + # Copy the parameters. + else: + output_state_dict[layer_name + NAME_MAP[op_name] + weight_or_bias] = val # DEBUG. assert config.n_layer == layer_idx + 1 @@ -282,7 +203,7 @@ def convert_megatron_checkpoint(args, input_state_dict, config): output_state_dict["lm_head.weight"] = word_embeddings # It should be done! - return output_state_dict + return config, output_state_dict #################################################################################################### @@ -297,12 +218,6 @@ def main(argv=None): type=str, help="Path to the checkpoint file (.zip archive or direct .pt file)", ) - parser.add_argument( - "--config_file", - default="", - type=str, - help="An optional config json file describing the pre-trained model.", - ) parser.add_argument( "--no_merge_qkv", dest="merge_qkv", @@ -323,87 +238,17 @@ def main(argv=None): basename = args.save_dir or os.path.dirname(args.path_to_checkpoint) # Load the model. - # the .zip is very optional, let's keep it for backward compatibility print(f"Extracting PyTorch state dictionary from {args.path_to_checkpoint}") - if args.path_to_checkpoint.endswith(".zip"): - with zipfile.ZipFile(args.path_to_checkpoint, "r") as checkpoint: - with checkpoint.open("release/mp_rank_00/model_optim_rng.pt") as pytorch_dict: - input_state_dict = torch.load(pytorch_dict, map_location="cpu") - else: - input_state_dict = torch.load(args.path_to_checkpoint, map_location="cpu") - - ds_args = input_state_dict.get("args", None) - - # Read the config, or default to the model released by NVIDIA. - if args.config_file == "": - # TODO: FP32 softmax - - if ds_args is not None: - if ds_args.bias_gelu_fusion: - # TODO: This will be in the next release of transformers (4.27). - activation_function = "gelu_pytorch_tanh" - elif ds_args.openai_gelu: - activation_function = "gelu_new" - else: - activation_function = "gelu" - else: - # in the very early days this used to be "gelu_new" - activation_function = "gelu_new" - - # Spell out all parameters in case the defaults change. - config = GPTBigCodeConfig( - vocab_size=50257, - n_positions=1024, - n_embd=1024, - n_layer=24, - n_head=16, - n_inner=4096, - activation_function=activation_function, - resid_pdrop=0.1, - embd_pdrop=0.1, - attn_pdrop=0.1, - layer_norm_epsilon=1e-5, - initializer_range=0.02, - summary_type="cls_index", - summary_use_proj=True, - summary_activation=None, - summary_proj_to_labels=True, - summary_first_dropout=0.1, - scale_attn_weights=True, - use_cache=True, - bos_token_id=50256, - eos_token_id=50256, - ) - else: - config = GPTBigCodeConfig.from_json_file(args.config_file) - - config.architectures = ["GPTBigCodeLMHeadModel"] + input_state_dict = torch.load(args.path_to_checkpoint, map_location="cpu") # Convert. print("Converting") - output_state_dict = convert_megatron_checkpoint(args, input_state_dict, config) + config, output_state_dict = convert_megatron_checkpoint(input_state_dict, args.merge_qkv) # Print the structure of converted state dict. if args.print_checkpoint_structure: recursive_print(None, output_state_dict) - # Add tokenizer class info to config - # see https://github.com/huggingface/transformers/issues/13906) - # if ds_args is not None: - # tokenizer_type = ds_args.tokenizer_type - # if tokenizer_type == "GPT2BPETokenizer": - # tokenizer_model_name = "gpt2" - # elif tokenizer_type == "PretrainedFromHF": - # tokenizer_model_name = ds_args.tokenizer_name_or_path - # else: - # raise ValueError(f"Unrecognized tokenizer_type {tokenizer_type}") - # else: - # tokenizer_model_name = "gpt2" - - # tokenizer = AutoTokenizer.from_pretrained(tokenizer_model_name) - # tokenizer_class = type(tokenizer).__name__ - # config.tokenizer_class = tokenizer_class - if args.custom_model: # Save custom model GPTBigCodeConfig.register_for_auto_class() @@ -417,10 +262,6 @@ def main(argv=None): print("Saving config") config.save_pretrained(basename) - # Save tokenizer based on args - # print(f"Adding {tokenizer_class} tokenizer files") - # tokenizer.save_pretrained(basename) - # Store the state_dict to file. output_checkpoint_file = os.path.join(basename, "pytorch_model.bin") print(f'Saving checkpoint to "{output_checkpoint_file}"') diff --git a/src/transformers/models/megatron_gpt_bigcode/push_checkpoints.py b/src/transformers/models/megatron_gpt_bigcode/push_checkpoints.py index 27610525d9..ececb9e5be 100644 --- a/src/transformers/models/megatron_gpt_bigcode/push_checkpoints.py +++ b/src/transformers/models/megatron_gpt_bigcode/push_checkpoints.py @@ -4,8 +4,7 @@ from pathlib import Path from huggingface_hub import Repository - -from .convert_megatron_gpt_bigcode_checkpoint import main as convert +from transformers.models.megatron_gpt_bigcode.convert_megatron_gpt_bigcode_checkpoint import main as convert """ diff --git a/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py b/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py index 98790c67f6..8c61088699 100644 --- a/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py +++ b/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py @@ -17,8 +17,8 @@ import datetime import math import unittest -from parameterized import parameterized +from parameterized import parameterized from transformers import GPTBigCodeConfig, is_torch_available from transformers.testing_utils import require_torch, slow, torch_device @@ -40,7 +40,6 @@ GPTBigCodeLMHeadModel, GPTBigCodeModel, ) - from transformers.models.gpt_bigcode.configuration_gpt_bigcode import AttentionType from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeAttention From d4451b4657a9eee6beffcad147ec55c2dfd6b0c0 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 2 Mar 2023 14:26:50 -0500 Subject: [PATCH 21/27] Fast inference (#7) --- .../gpt_bigcode/configuration_gpt_bigcode.py | 25 ++ .../models/gpt_bigcode/inference_runner.py | 334 ++++++++++++++++++ .../gpt_bigcode/modeling_gpt_bigcode.py | 44 ++- 3 files changed, 398 insertions(+), 5 deletions(-) create mode 100644 src/transformers/models/gpt_bigcode/inference_runner.py diff --git a/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py index 7f05ffa9c2..ece1fb5437 100644 --- a/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py @@ -39,6 +39,19 @@ class AttentionType(IntEnum): MULTI_QUERY_2 = 3 +class InferenceRunnerType(IntEnum): + NO_RUNNER = 0 + # Use the inference runner without cuda graphs. + BASE_RUNNER = 1 + # Use cuda graphs in the inference runner. Leave out the attention which has a variable shape. + # This significantly lowers the cpu time and prevent a cpu bottleneck for smaller batches and models. + PARTIAL_GRAPH = 2 + # Turn the whole model into a cuda graph. One graph for each sequence length. + # Note: only useful for small batches and models, graphs take some time to generate, flaky. + # Crashes with jit on A100 but seems to work without jit (PYTORCH_JIT=0) and on V100. + FULL_GRAPH = 3 + + class GPTBigCodeConfig(PretrainedConfig): """ # TODO: Update doc @@ -169,6 +182,10 @@ def __init__( attention_softmax_in_fp32=True, scale_attention_softmax_in_fp32=True, attention_type=AttentionType.MULTI_HEAD, + inference_runner=InferenceRunnerType.NO_RUNNER, + validate_runner_input=True, + runner_max_sequence_length=None, + pad_key_length=True, **kwargs, ): self.vocab_size = vocab_size @@ -198,6 +215,14 @@ def __init__( self.attention_type = AttentionType(attention_type) + self.inference_runner = InferenceRunnerType(inference_runner) + # Set to False to disable input validation of safe inputs, for a small speedup. + self.validate_runner_input = validate_runner_input + # Set if `n_positions` uses too much memory. + self.runner_max_sequence_length = runner_max_sequence_length + # Pad key length to a multiple of 8. + self.pad_key_length = pad_key_length + super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) diff --git a/src/transformers/models/gpt_bigcode/inference_runner.py b/src/transformers/models/gpt_bigcode/inference_runner.py new file mode 100644 index 0000000000..7f6e64b3f5 --- /dev/null +++ b/src/transformers/models/gpt_bigcode/inference_runner.py @@ -0,0 +1,334 @@ +import math +from typing import List, Union + +import torch + +from transformers import GPTBigCodeConfig +from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions +from transformers.models.gpt_bigcode.configuration_gpt_bigcode import AttentionType, InferenceRunnerType +from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeBlock, masked_softmax, upcast_masked_softmax + + +def _align_tensor(x): + return x + -x % 128 + + +class GPTBigCodeInferenceRunner: + def __init__(self, config: GPTBigCodeConfig, model): + self.batch_size = None + self.model = model + self.n_layer = len(self.model.h) + + self.inference_runner_type = InferenceRunnerType(config.inference_runner) + assert self.inference_runner_type != InferenceRunnerType.NO_RUNNER + self.validate_input = config.validate_runner_input + self.pad_key_length = 8 if config.pad_key_length else 1 + + # TODO: Support other attention types? + assert model.attention_type == AttentionType.MULTI_QUERY_1 + + self.max_sequence_length = config.runner_max_sequence_length or config.n_positions + + def _allocate(self, batch_size, device, dtype): + block: GPTBigCodeBlock = self.model.h[0] + attn = block.attn + self.batch_size = batch_size + self.dtype = dtype + self.device = device + self.softmax_dtype = torch.float32 if attn.attention_softmax_in_fp32 else self.dtype + self.upcast = self.softmax_dtype != self.dtype + + do_unscale = attn.scale_attention_softmax_in_fp32 and self.upcast + self.unscale = [i + 1.0 if do_unscale else 1.0 for i in range(self.n_layer)] + scale = attn.head_dim**-0.5 if attn.scale_attn_weights else 1 + self.scale = [scale / unscale for unscale in self.unscale] + + factory_kwargs = {"device": self.device, "dtype": self.dtype} + + hidden_end = self.batch_size * attn.embed_dim + # Query: (bs, embed_dim), also used for attn outputs (no overlap with value). + query_begin = _align_tensor(hidden_end) + query_end = query_begin + self.batch_size * attn.embed_dim + # KV: (bs, 2 * kv_dim), combines with query into c_attn. + kv_end = query_end + 2 * self.batch_size * attn.kv_dim + # Attn weights: (batch_size, num_heads, key_length), no overlap with value + attn_weights_begin = _align_tensor(kv_end) + attn_weights_end = kv_end + self.batch_size * attn.num_heads * self.max_sequence_length + # Projection: (batch_size, embed_dim), no overlap with attn outputs ~ query. + # Also used for MLP projection () + c_proj_begin = _align_tensor(query_end) + c_proj_end = c_proj_begin + self.batch_size * attn.embed_dim + c_fc_begin = query_begin + c_fc_end = c_fc_begin + self.batch_size * block.inner_dim + pool_size = max(attn_weights_end, c_proj_end, c_fc_end) + kv_cache_shape = (self.n_layer, self.batch_size, attn.kv_heads, self.max_sequence_length, 2 * attn.head_dim) + kv_cache_size = math.prod(kv_cache_shape) + + print( + f"Allocating inference buffers (batch size = {self.batch_size}, max sequence length =" + f" {self.max_sequence_length})..." + ) + print(f" Activation pool size: {pool_size:,}") + print(f" KV cache size: {kv_cache_size:,}") + buffer_memory = (pool_size + kv_cache_size) * torch.finfo( + self.dtype + ).bits / 8 + self.batch_size * self.max_sequence_length + print(f" Memory usage: {buffer_memory/2**20:.0f} MiB") + + activation_pool = torch.empty(pool_size, **factory_kwargs) + kv_cache = torch.empty( + kv_cache_shape, + **factory_kwargs, + ) + self.mask_value = torch.full( + [], torch.finfo(self.softmax_dtype).min, dtype=self.softmax_dtype, device=self.device + ) + # We ensure mask tensors are contiguous to enable more efficient kernels. + attn_mask = torch.empty(self.batch_size * self.max_sequence_length, dtype=torch.bool, device=self.device) + + if self.device.type == "cuda": + print(f" Memory allocated {torch.cuda.memory_allocated()/2**20:.0f} MiB") + # Max stats give some insight on the prefill memory usage. + print(f" Max memory allocated {torch.cuda.max_memory_allocated()/2**20:.0f} MiB") + print(f" Max memory reserved {torch.cuda.max_memory_reserved()/2**20:.0f} MiB") + + key_lengths = range(self.max_sequence_length + 1) + padded_key_lengths = [key_length + -key_length % self.pad_key_length for key_length in key_lengths] + + self.padded_attn_masks = [ + attn_mask[: self.batch_size * key_length].view(self.batch_size, 1, key_length) + for key_length in padded_key_lengths + ] + self.attn_masks = [ + padded_attn_mask[:, :, :key_length].squeeze(1) + for key_length, padded_attn_mask in enumerate(self.padded_attn_masks) + ] + self.attn_mask_pads = [ + padded_attn_mask[:, :, key_length:].squeeze(1) + for key_length, padded_attn_mask in enumerate(self.padded_attn_masks) + ] + + # Hidden: (batch_size, 1, embed_dim), no overlap allowed. + self.hidden_states_squeezed = activation_pool[:hidden_end].view(self.batch_size, -1) + self.hidden_states = self.hidden_states_squeezed.unsqueeze(1) + # QKV: (bs, embed_dim + 2 * kv_dim). + self.c_attn = activation_pool[query_begin:kv_end].view(self.batch_size, -1) + self.query = self.c_attn[:, : attn.embed_dim].view(self.batch_size, attn.num_heads, attn.head_dim) + self.kv_attn = self.c_attn[:, attn.embed_dim :] + + key, value = kv_cache.split((attn.head_dim, attn.head_dim), dim=-1) + key = key.transpose(-1, -2) + head_slice = 0 if attn.is_mqa else slice(None) + + self.padded_keys = [key[:, :, head_slice, :, :key_length].unbind(0) for key_length in padded_key_lengths] + self.padded_values = [value[:, :, head_slice, :key_length, :].unbind(0) for key_length in padded_key_lengths] + + # This is nonsense for key_length == 0, but we never need the value. + self.current_key_values = [ + kv_cache[:, :, head_slice, key_length - 1, :].unbind(0) for key_length in key_lengths + ] + self.past_key_values = [ + kv_cache[:, :, head_slice, : key_length - 1, :].unbind(0) for key_length in key_lengths + ] + + # Attn weights: (batch_size, num_heads, key_length), no overlap with value. + attn_weights = activation_pool[attn_weights_begin:attn_weights_end].view( + self.batch_size, attn.num_heads, self.max_sequence_length + ) + self.padded_attn_weights = [attn_weights[:, :, :key_length] for key_length in padded_key_lengths] + + # Attn outputs: (batch_size, embed_dim), no overlap with value. + self.attn_output = activation_pool[query_begin:query_end].view(self.batch_size, -1) + self.attn_output_expanded = self.attn_output.view(self.batch_size, attn.num_heads, attn.head_dim) + # Attn projection: (batch_size, embed_dim), no overlap with attn outputs. + self.c_proj = activation_pool[c_proj_begin:c_proj_end].view(self.batch_size, -1) + + # MLP first layer: (batch_size, embed_dim) + self.mlp_c_fc = activation_pool[c_fc_begin:c_fc_end].view(self.batch_size, -1) + # MLP projection: (batch_size, inner_dim) + self.mlp_c_proj = activation_pool[query_begin:query_end].view(self.batch_size, -1) + + if self.inference_runner_type != InferenceRunnerType.BASE_RUNNER: + print(f"Generating cuda graphs") + self.memory_pool = None + if self.inference_runner_type == InferenceRunnerType.FULL_GRAPH: + self.cuda_graphs = {} + # The output may not always be at the same memory location. + self.output_hidden_states = {} + # Generate the largest one first to warm up the memory pool. + # The other ones are generated lazily. + self._generate_full_cuda_graph(self.max_sequence_length) + else: + self._generate_cuda_graphs() + + def _generate_cuda_graphs(self): + self.cuda_graphs = {} + for layer_idx in range(self.n_layer + 1): + graph = torch.cuda.CUDAGraph() + + with torch.cuda.graph(graph, pool=self.memory_pool): + if layer_idx > 0: + self._forward_post_attn(self.model.h[layer_idx - 1]) + if layer_idx < self.n_layer: + self._forward_qkv(self.model.h[layer_idx]) + else: + self.output_hidden_states = self._forward_end() + if self.memory_pool is None: + self.memory_pool = graph.pool() + self.cuda_graphs[layer_idx] = graph + + def _generate_full_cuda_graph(self, key_length): + # We need to warmup the jit function before creating the graph, otherwise it will crash. + # Warmup needs to be done for every input shape (key length), and for both scale == 1 and scale != 1 + if self.upcast: + for scale in (1.0, 2.0): + upcast_masked_softmax( + self.padded_attn_weights[key_length], + self.padded_attn_masks[key_length], + self.mask_value, + scale, + self.softmax_dtype, + ) + else: + masked_softmax( + self.padded_attn_weights[key_length], + self.padded_attn_masks[key_length], + self.mask_value, + ) + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, pool=self.memory_pool): + self.output_hidden_states[key_length] = self._forward(key_length) + if self.memory_pool is None: + self.memory_pool = graph.pool() + self.cuda_graphs[key_length] = graph + + def _forward_embed(self, input_ids, position_ids): + # Embedding doesn't support out argument. + inputs_embeds = self.model.wte(input_ids) + position_embeds = self.model.wpe(position_ids) + torch.add(inputs_embeds, position_embeds, out=self.hidden_states) + + def _forward_qkv(self, block): + # LN doesn't support out argument. + hidden_states = block.ln_1(self.hidden_states_squeezed) + torch.nn.functional.linear( + hidden_states, + block.attn.c_attn.weight, + block.attn.c_attn.bias, + out=self.c_attn, + ) + + def _forward_attn(self, block, key_length): + + layer_idx = block.attn.layer_idx + self.current_key_values[key_length][layer_idx].copy_(self.kv_attn) + attn_weights = self.padded_attn_weights[key_length] + + torch.baddbmm( + attn_weights, + self.query, + self.padded_keys[key_length][layer_idx], + beta=0, + alpha=self.scale[layer_idx], + out=attn_weights, + ) + # Use a fused kernel to prevent a large overhead from casting and scaling. + # Jit doesn't allow inplace kernel. + if self.upcast: + attn_weights = upcast_masked_softmax( + attn_weights, + self.padded_attn_masks[key_length], + self.mask_value, + self.unscale[layer_idx], + self.softmax_dtype, + ) + else: + attn_weights = masked_softmax(attn_weights, self.padded_attn_masks[key_length], self.mask_value) + + torch.bmm(attn_weights, self.padded_values[key_length][layer_idx], out=self.attn_output_expanded) + + def _forward_post_attn(self, block): + torch.nn.functional.linear( + self.attn_output, + block.attn.c_proj.weight, + block.attn.c_proj.bias, + out=self.c_proj, + ) + self.hidden_states_squeezed.add_(self.c_proj) + # LN doesn't support out argument. + hidden_states = block.ln_2(self.hidden_states_squeezed) + torch.nn.functional.linear(hidden_states, block.mlp.c_fc.weight, block.mlp.c_fc.bias, out=self.mlp_c_fc) + # Most activations don't support out argument. + feed_forward_hidden_states = block.mlp.act(self.mlp_c_fc) + torch.nn.functional.linear( + feed_forward_hidden_states, block.mlp.c_proj.weight, block.mlp.c_proj.bias, out=self.mlp_c_proj + ) + self.hidden_states_squeezed.add_(self.mlp_c_proj) + + def _forward_end(self): + # LN doesn't support out argument. + return self.model.ln_f(self.hidden_states) + + def _forward(self, key_length): + for block in self.model.h: + self._forward_qkv(block) + self._forward_attn(block, key_length) + self._forward_post_attn(block) + return self._forward_end() + + def forward( + self, + input_ids: torch.LongTensor, + attention_mask: torch.FloatTensor, + position_ids: torch.LongTensor, + past_key_values: Union[List[torch.Tensor], int], + ) -> BaseModelOutputWithPastAndCrossAttentions: + batch_size, query_length = input_ids.shape + assert query_length == 1 + if self.batch_size is None: + self._allocate(batch_size, device=input_ids.device, dtype=self.model.dtype) + elif self.validate_input: + assert batch_size == self.batch_size + assert self.dtype == self.model.dtype + assert self.device == input_ids.device + + if self.validate_input: + assert attention_mask.dim() == 2 + assert attention_mask.shape[0] == batch_size + key_length = attention_mask.shape[1] + assert key_length <= self.max_sequence_length + if isinstance(past_key_values, int): + assert key_length == past_key_values + 1 + else: + key_length = attention_mask.shape[1] + + if not isinstance(past_key_values, int): + for buffer, past_key_value in zip(self.past_key_values[key_length], past_key_values): + buffer.copy_(past_key_value) + + self._forward_embed(input_ids, position_ids) + + self.attn_masks[key_length].copy_(attention_mask) + + attn_mask_pad = self.attn_mask_pads[key_length] + if attn_mask_pad.size(1) > 0: + attn_mask_pad.fill_(False) + + if self.inference_runner_type == InferenceRunnerType.FULL_GRAPH: + if key_length not in self.cuda_graphs: + self._generate_full_cuda_graph(key_length) + self.cuda_graphs[key_length].replay() + hidden_states = self.output_hidden_states[key_length] + elif self.inference_runner_type == InferenceRunnerType.PARTIAL_GRAPH: + for i, block in enumerate(self.model.h): + self.cuda_graphs[i].replay() + self._forward_attn(block, key_length) + self.cuda_graphs[self.n_layer].replay() + hidden_states = self.output_hidden_states + else: + hidden_states = self._forward(key_length) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=key_length, + ) diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index d30ec4186a..e9121c3aee 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -44,7 +44,7 @@ ) from transformers.utils.model_parallel_utils import assert_device_map, get_device_map -from .configuration_gpt_bigcode import AttentionType, GPTBigCodeConfig +from .configuration_gpt_bigcode import AttentionType, GPTBigCodeConfig, InferenceRunnerType logger = logging.get_logger(__name__) @@ -256,6 +256,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): if upcast: # Use a fused kernel to prevent a large overhead from casting and scaling. + # Sub-optimal when the key length is not a multiple of 8. if attention_mask is None: attn_weights = upcast_softmax(attn_weights, unscale, softmax_dtype) else: @@ -264,7 +265,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): else: if attention_mask is not None: mask_value = self._get_mask_value(attn_weights.device, softmax_dtype) - # This can be fused with the softmax, but the fused kernel seems slower. + # The fused kernel is very slow when the key length is not a multiple of 8, so we skip fusion. attn_weights = torch.where(attention_mask, attn_weights, mask_value) attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) @@ -359,7 +360,7 @@ class GPTBigCodeBlock(nn.Module): def __init__(self, config, layer_idx=None): super().__init__() hidden_size = config.hidden_size - inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size + self.inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.attn = GPTBigCodeAttention(config, layer_idx=layer_idx) @@ -371,7 +372,7 @@ def __init__(self, config, layer_idx=None): self.crossattention = GPTBigCodeAttention(config, is_cross_attention=True, layer_idx=layer_idx) self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.mlp = GPTBigCodeMLP(inner_dim, config) + self.mlp = GPTBigCodeMLP(self.inner_dim, config) def forward( self, @@ -680,6 +681,15 @@ def __init__(self, config): self.h = nn.ModuleList([GPTBigCodeBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)]) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + self.inference_runner_type = InferenceRunnerType(config.inference_runner) + + if self.inference_runner_type == InferenceRunnerType.NO_RUNNER: + self.inference_runner = None + else: + from .inference_runner import GPTBigCodeInferenceRunner + + self.inference_runner = GPTBigCodeInferenceRunner(config, self) + max_positions = config.max_position_embeddings self.register_buffer( "bias", torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)), persistent=False @@ -748,7 +758,7 @@ def _prune_heads(self, heads_to_prune): def forward( self, input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.Tensor]] = None, + past_key_values: Optional[Union[List[torch.Tensor], int]] = None, attention_mask: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -761,6 +771,30 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + if self.inference_runner is not None and past_key_values is not None: + if self.config.validate_runner_input: + assert input_ids is not None + assert past_key_values is not None + assert attention_mask is not None + assert token_type_ids is None + assert position_ids is not None + assert head_mask is None + assert inputs_embeds is None + assert encoder_hidden_states is None + assert encoder_attention_mask is None + use_cache = use_cache if use_cache is not None else self.config.use_cache + assert use_cache is True + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) + assert output_attentions is False + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + assert output_hidden_states is False + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + assert return_dict is True + return self.inference_runner.forward(input_ids, attention_mask, position_ids, past_key_values) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states From 9c3c5484d831484f96e2bcd2961cfac100e52d0b Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 2 Mar 2023 14:28:44 -0500 Subject: [PATCH 22/27] Add gpu optimizations to base model (#14) --- .../gpt_bigcode/configuration_gpt_bigcode.py | 14 +++- .../models/gpt_bigcode/inference_runner.py | 41 +++++----- .../gpt_bigcode/modeling_gpt_bigcode.py | 75 +++++++++++++++++-- 3 files changed, 102 insertions(+), 28 deletions(-) diff --git a/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py index ece1fb5437..1150f01cb5 100644 --- a/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py @@ -184,7 +184,9 @@ def __init__( attention_type=AttentionType.MULTI_HEAD, inference_runner=InferenceRunnerType.NO_RUNNER, validate_runner_input=True, - runner_max_sequence_length=None, + pre_allocate_kv_cache=False, + max_sequence_length=None, + max_batch_size=None, pad_key_length=True, **kwargs, ): @@ -218,9 +220,13 @@ def __init__( self.inference_runner = InferenceRunnerType(inference_runner) # Set to False to disable input validation of safe inputs, for a small speedup. self.validate_runner_input = validate_runner_input - # Set if `n_positions` uses too much memory. - self.runner_max_sequence_length = runner_max_sequence_length - # Pad key length to a multiple of 8. + + self.pre_allocate_kv_cache = pre_allocate_kv_cache + # The max sequence length for the pre-allocated KV cache (`n_positions` if not provided). + self.max_sequence_length = max_sequence_length + # The max batch size for the pre-allocated KV cache, (deduce from input if not provided). + self.max_batch_size = max_batch_size + # Pad key length to a multiple of 8 (requires pre_allocate_kv_cache). self.pad_key_length = pad_key_length super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) diff --git a/src/transformers/models/gpt_bigcode/inference_runner.py b/src/transformers/models/gpt_bigcode/inference_runner.py index 7f6e64b3f5..c75be65357 100644 --- a/src/transformers/models/gpt_bigcode/inference_runner.py +++ b/src/transformers/models/gpt_bigcode/inference_runner.py @@ -21,13 +21,14 @@ def __init__(self, config: GPTBigCodeConfig, model): self.inference_runner_type = InferenceRunnerType(config.inference_runner) assert self.inference_runner_type != InferenceRunnerType.NO_RUNNER + assert config.pre_allocate_kv_cache self.validate_input = config.validate_runner_input self.pad_key_length = 8 if config.pad_key_length else 1 # TODO: Support other attention types? assert model.attention_type == AttentionType.MULTI_QUERY_1 - self.max_sequence_length = config.runner_max_sequence_length or config.n_positions + self.max_sequence_length = config.max_sequence_length or config.n_positions def _allocate(self, batch_size, device, dtype): block: GPTBigCodeBlock = self.model.h[0] @@ -55,19 +56,28 @@ def _allocate(self, batch_size, device, dtype): attn_weights_begin = _align_tensor(kv_end) attn_weights_end = kv_end + self.batch_size * attn.num_heads * self.max_sequence_length # Projection: (batch_size, embed_dim), no overlap with attn outputs ~ query. - # Also used for MLP projection () + # Also used for MLP projection c_proj_begin = _align_tensor(query_end) c_proj_end = c_proj_begin + self.batch_size * attn.embed_dim c_fc_begin = query_begin c_fc_end = c_fc_begin + self.batch_size * block.inner_dim pool_size = max(attn_weights_end, c_proj_end, c_fc_end) - kv_cache_shape = (self.n_layer, self.batch_size, attn.kv_heads, self.max_sequence_length, 2 * attn.head_dim) - kv_cache_size = math.prod(kv_cache_shape) print( f"Allocating inference buffers (batch size = {self.batch_size}, max sequence length =" f" {self.max_sequence_length})..." ) + + kv_caches = [] + for block in self.model.h: + block.attn.freeze_kv_cache() + kv_cache=block.attn.get_kv_cache(self.batch_size, self.max_sequence_length, self.device, self.dtype) + if attn.is_mqa: + kv_cache=kv_cache.unsqueeze(1) + kv_caches.append(kv_cache) + + kv_cache_size = sum(kv_cache.numel() for kv_cache in kv_caches) + print(f" Activation pool size: {pool_size:,}") print(f" KV cache size: {kv_cache_size:,}") buffer_memory = (pool_size + kv_cache_size) * torch.finfo( @@ -76,10 +86,6 @@ def _allocate(self, batch_size, device, dtype): print(f" Memory usage: {buffer_memory/2**20:.0f} MiB") activation_pool = torch.empty(pool_size, **factory_kwargs) - kv_cache = torch.empty( - kv_cache_shape, - **factory_kwargs, - ) self.mask_value = torch.full( [], torch.finfo(self.softmax_dtype).min, dtype=self.softmax_dtype, device=self.device ) @@ -116,19 +122,22 @@ def _allocate(self, batch_size, device, dtype): self.query = self.c_attn[:, : attn.embed_dim].view(self.batch_size, attn.num_heads, attn.head_dim) self.kv_attn = self.c_attn[:, attn.embed_dim :] - key, value = kv_cache.split((attn.head_dim, attn.head_dim), dim=-1) - key = key.transpose(-1, -2) + keys, values = zip(*(kv_cache.split((attn.head_dim, attn.head_dim), dim=-1) for kv_cache in kv_caches)) head_slice = 0 if attn.is_mqa else slice(None) - self.padded_keys = [key[:, :, head_slice, :, :key_length].unbind(0) for key_length in padded_key_lengths] - self.padded_values = [value[:, :, head_slice, :key_length, :].unbind(0) for key_length in padded_key_lengths] + self.padded_keys = [ + [key[:, head_slice, :key_length, :].transpose(-1, -2) for key in keys] for key_length in padded_key_lengths + ] + self.padded_values = [ + [value[:, head_slice, :key_length, :] for value in values] for key_length in padded_key_lengths + ] # This is nonsense for key_length == 0, but we never need the value. self.current_key_values = [ - kv_cache[:, :, head_slice, key_length - 1, :].unbind(0) for key_length in key_lengths + [kv_cache[:, head_slice, key_length - 1, :] for kv_cache in kv_caches] for key_length in key_lengths ] self.past_key_values = [ - kv_cache[:, :, head_slice, : key_length - 1, :].unbind(0) for key_length in key_lengths + [kv_cache[:, head_slice, : key_length - 1, :] for kv_cache in kv_caches] for key_length in key_lengths ] # Attn weights: (batch_size, num_heads, key_length), no overlap with value. @@ -302,10 +311,6 @@ def forward( else: key_length = attention_mask.shape[1] - if not isinstance(past_key_values, int): - for buffer, past_key_value in zip(self.past_key_values[key_length], past_key_values): - buffer.copy_(past_key_value) - self._forward_embed(input_ids, position_ids) self.attn_masks[key_length].copy_(attention_mask) diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index e9121c3aee..4a682a843c 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -148,10 +148,17 @@ class GPTBigCodeAttention(nn.Module): def __init__(self, config, is_cross_attention=False, layer_idx=None): super().__init__() self.mask_value = None + self.kv_cache = None + self.kv_cache_max_batch_size = config.max_batch_size or 0 + self.kv_cache_max_sequence_length = config.max_sequence_length or config.n_positions self.attention_type = AttentionType(config.attention_type) self.is_mqa = self.attention_type != AttentionType.MULTI_HEAD + self.pre_allocate_kv_cache = config.pre_allocate_kv_cache + self.pad_key_length = config.pad_key_length and config.pre_allocate_kv_cache + self._frozen_kv_cache = False + self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads @@ -252,8 +259,6 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): torch.empty(attn_view, dtype=query.dtype, device=query.device), query, key, beta=0, alpha=scale_factor ).view(attn_shape) - # attn_weights = self._matmul(query, key, scale_factor=scale_factor) - if upcast: # Use a fused kernel to prevent a large overhead from casting and scaling. # Sub-optimal when the key length is not a multiple of 8. @@ -282,6 +287,35 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): return attn_output, attn_weights + def freeze_kv_cache(self, enable=True): + if self.kv_cache is None: + raise RuntimeError("KV cache not found.") + # Prevent re-allocation of the KV cache. + self._frozen_kv_cache = enable + + def get_kv_cache(self, batch_size, sequence_length, device, dtype, allocate=True): + if ( + self.kv_cache is None + or self.kv_cache.dtype != dtype + or self.kv_cache.device != device + or batch_size > self.kv_cache_max_batch_size + or sequence_length > self.kv_cache_max_sequence_length + ): + if self._frozen_kv_cache or not allocate: + # TODO: Improve error message + raise RuntimeError("KV cache not found." if self.kv_cache is None else "Invalid KV cache.") + # Free memory first. + self.kv_cache = None + self.kv_cache_max_sequence_length = max(sequence_length, self.kv_cache_max_sequence_length) + self.kv_cache_max_batch_size = max(batch_size, self.kv_cache_max_batch_size) + kv_cache_size = 2 * self.kv_cache_max_batch_size * self.kv_cache_max_sequence_length * self.kv_dim + self.kv_cache = torch.empty([kv_cache_size], device=device, dtype=dtype) + # This view ensures the cache is contiguous for all batch sizes. + kv_cache = self.kv_cache[: 2 * batch_size * self.kv_cache_max_sequence_length * self.kv_dim].view( + batch_size, self.kv_heads, self.kv_cache_max_sequence_length, 2 * self.head_dim + ) + return kv_cache[:, 0, :sequence_length, :] if self.is_mqa else kv_cache[:, :, :sequence_length, :] + def forward( self, hidden_states: torch.FloatTensor, @@ -319,8 +353,27 @@ def forward( .split((self.head_dim, 2 * self.head_dim), dim=3) ) - if layer_past is not None: - key_value = torch.cat((layer_past, key_value), dim=-2) + present = None + + if self.pre_allocate_kv_cache: + if use_cache or layer_past is not None: + last_key_length = layer_past or 0 + batch_size = key_value.size(0) + key_length = last_key_length + key_value.size(-2) + padded_key_length = key_length + -key_length % (8 if self.pad_key_length else 1) + kv_cache = self.get_kv_cache( + batch_size, padded_key_length, key_value.device, key_value.dtype, allocate=last_key_length == 0 + ) + if self.is_mqa: + kv_cache[:, last_key_length:key_length, :].copy_(key_value) + key_value = kv_cache + if use_cache: + present = key_length + else: + if layer_past is not None: + key_value = torch.cat((layer_past, key_value), dim=-2) + if use_cache: + present = key_value key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) @@ -332,7 +385,7 @@ def forward( attn_output = self.resid_dropout(attn_output) # TODO: Is it ok to send unwrapped present? - outputs = (attn_output, key_value if use_cache else None) + outputs = (attn_output, present) if output_attentions: outputs += (attn_weights,) @@ -681,6 +734,9 @@ def __init__(self, config): self.h = nn.ModuleList([GPTBigCodeBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)]) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + self.pre_allocate_kv_cache = config.pre_allocate_kv_cache + self.pad_key_length = config.pad_key_length and self.pre_allocate_kv_cache + self.inference_runner_type = InferenceRunnerType(config.inference_runner) if self.inference_runner_type == InferenceRunnerType.NO_RUNNER: @@ -827,8 +883,10 @@ def forward( if past_key_values is None: past_length = 0 past_key_values = tuple([None] * len(self.h)) + elif self.pre_allocate_kv_cache: + past_length = past_key_values[0] else: - past_length = past_key_values[0][0].size(-2) + past_length = past_key_values[0].size(-2) if position_ids is None: position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) @@ -844,6 +902,11 @@ def forward( # MHA: (b, nh, sq, sk) attention_mask = self_attention_mask.unsqueeze(2 if self.is_mqa else 1) + if self.pad_key_length: + pad = -key_length % 8 + if pad > 0: + attention_mask = torch.nn.functional.pad(attention_mask, (0, pad), mode="constant", value=False) + # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] if ( From 7671e1e81025f25fcc9eece6b1e31c0140a1d7c2 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 12 Apr 2023 20:26:32 -0400 Subject: [PATCH 23/27] Fix merge --- .../gpt_bigcode/test_modeling_gpt_bigcode.py | 492 ------------------ 1 file changed, 492 deletions(-) diff --git a/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py b/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py index 04f18dd7bf..e35e931d4f 100644 --- a/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py +++ b/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py @@ -1,9 +1,5 @@ # coding=utf-8 -<<<<<<< HEAD -# Copyright 2020 The HuggingFace Team. All rights reserved. -======= # Copyright 2023 The HuggingFace Team. All rights reserved. ->>>>>>> e0921c6b53310a47b10f01633809b2b9f785a465 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,63 +12,34 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -<<<<<<< HEAD - - -import datetime -======= ->>>>>>> e0921c6b53310a47b10f01633809b2b9f785a465 import math import unittest from parameterized import parameterized -<<<<<<< HEAD -======= ->>>>>>> e0921c6b53310a47b10f01633809b2b9f785a465 from transformers import GPTBigCodeConfig, is_torch_available from transformers.testing_utils import require_torch, slow, torch_device from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask -<<<<<<< HEAD -======= from ...test_pipeline_mixin import PipelineTesterMixin ->>>>>>> e0921c6b53310a47b10f01633809b2b9f785a465 if is_torch_available(): import torch from transformers import ( -<<<<<<< HEAD - GPT_BIGCODE_PRETRAINED_MODEL_ARCHIVE_LIST, - GPT2Tokenizer, - GPTBigCodeConfig, - GPTBigCodeDoubleHeadsModel, - GPTBigCodeForSequenceClassification, - GPTBigCodeForTokenClassification, - GPTBigCodeLMHeadModel, - GPTBigCodeModel, - ) - from transformers.models.gpt_bigcode.configuration_gpt_bigcode import AttentionType -======= GPT2TokenizerFast, GPTBigCodeForCausalLM, GPTBigCodeForSequenceClassification, GPTBigCodeForTokenClassification, GPTBigCodeModel, ) ->>>>>>> e0921c6b53310a47b10f01633809b2b9f785a465 from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeAttention class GPTBigCodeModelTester: -<<<<<<< HEAD - # TODO: Update the tests to use valid pretrained models. -======= ->>>>>>> e0921c6b53310a47b10f01633809b2b9f785a465 def __init__( self, parent, @@ -88,11 +55,7 @@ def __init__( num_hidden_layers=5, num_attention_heads=4, intermediate_size=37, -<<<<<<< HEAD - hidden_act="gelu", -======= hidden_act="relu", ->>>>>>> e0921c6b53310a47b10f01633809b2b9f785a465 hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, @@ -101,10 +64,7 @@ def __init__( initializer_range=0.02, num_labels=3, num_choices=4, -<<<<<<< HEAD -======= multi_query=True, ->>>>>>> e0921c6b53310a47b10f01633809b2b9f785a465 scope=None, ): self.parent = parent @@ -131,20 +91,12 @@ def __init__( self.num_choices = num_choices self.scope = None self.bos_token_id = vocab_size - 1 -<<<<<<< HEAD - self.eos_token_id = vocab_size - 1 - self.pad_token_id = vocab_size - 1 - - def get_large_model_config(self): - return GPTBigCodeConfig.from_pretrained("bigcode/santacoder-fast-inference") -======= self.eos_token_id = vocab_size - 2 self.pad_token_id = vocab_size - 3 self.multi_query = multi_query def get_large_model_config(self): return GPTBigCodeConfig.from_pretrained("bigcode/gpt_bigcode-santacoder") ->>>>>>> e0921c6b53310a47b10f01633809b2b9f785a465 def prepare_config_and_inputs( self, gradient_checkpointing=False, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False @@ -213,12 +165,9 @@ def get_config( gradient_checkpointing=gradient_checkpointing, scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx, reorder_and_upcast_attn=reorder_and_upcast_attn, -<<<<<<< HEAD -======= attention_softmax_in_fp32=False, scale_attention_softmax_in_fp32=False, multi_query=self.multi_query, ->>>>>>> e0921c6b53310a47b10f01633809b2b9f785a465 ) def get_pipeline_config(self): @@ -384,11 +333,7 @@ def create_and_check_gpt_bigcode_model_past_large_inputs( self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) def create_and_check_lm_head_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): -<<<<<<< HEAD - model = GPTBigCodeLMHeadModel(config) -======= model = GPTBigCodeForCausalLM(config) ->>>>>>> e0921c6b53310a47b10f01633809b2b9f785a465 model.to(torch_device) model.eval() @@ -399,11 +344,7 @@ def create_and_check_lm_head_model(self, config, input_ids, input_mask, head_mas def create_and_check_forward_and_backwards( self, config, input_ids, input_mask, head_mask, token_type_ids, *args, gradient_checkpointing=False ): -<<<<<<< HEAD - model = GPTBigCodeLMHeadModel(config) -======= model = GPTBigCodeForCausalLM(config) ->>>>>>> e0921c6b53310a47b10f01633809b2b9f785a465 model.to(torch_device) if gradient_checkpointing: model.gradient_checkpointing_enable() @@ -413,35 +354,6 @@ def create_and_check_forward_and_backwards( self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) result.loss.backward() -<<<<<<< HEAD - def create_and_check_double_lm_head_model( - self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, *args - ): - model = GPTBigCodeDoubleHeadsModel(config) - model.to(torch_device) - model.eval() - - multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() - multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() - multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() - - inputs = { - "input_ids": multiple_choice_inputs_ids, - "mc_token_ids": mc_token_ids, - "attention_mask": multiple_choice_input_mask, - "token_type_ids": multiple_choice_token_type_ids, - "labels": multiple_choice_inputs_ids, - } - - result = model(**inputs) - self.parent.assertEqual(result.loss.shape, ()) - self.parent.assertEqual( - result.logits.shape, (self.batch_size, self.num_choices, self.seq_length, self.vocab_size) - ) - self.parent.assertEqual(result.mc_logits.shape, (self.batch_size, self.num_choices)) - -======= ->>>>>>> e0921c6b53310a47b10f01633809b2b9f785a465 def create_and_check_gpt_bigcode_for_sequence_classification( self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, sequence_labels, *args ): @@ -495,38 +407,18 @@ def prepare_config_and_inputs_for_common(self): @require_torch -<<<<<<< HEAD -class GPTBigCodeModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): - # TODO: Update the tests to use valid pretrained models. - - all_model_classes = ( - ( - GPTBigCodeModel, - GPTBigCodeLMHeadModel, - GPTBigCodeDoubleHeadsModel, -======= class GPTBigCodeMQAModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): # TODO: Update the tests to use valid pretrained models. all_model_classes = ( ( GPTBigCodeModel, GPTBigCodeForCausalLM, ->>>>>>> e0921c6b53310a47b10f01633809b2b9f785a465 GPTBigCodeForSequenceClassification, GPTBigCodeForTokenClassification, ) if is_torch_available() else () ) -<<<<<<< HEAD - all_generative_model_classes = (GPTBigCodeLMHeadModel, GPTBigCodeDoubleHeadsModel) if is_torch_available() else () - all_parallelizable_model_classes = ( - (GPTBigCodeLMHeadModel, GPTBigCodeDoubleHeadsModel) if is_torch_available() else () - ) - fx_compatible = True - test_missing_keys = False - test_model_parallel = True -======= all_generative_model_classes = (GPTBigCodeForCausalLM,) if is_torch_available() else () fx_compatible = False test_missing_keys = False @@ -544,40 +436,11 @@ class GPTBigCodeMQAModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTe if is_torch_available() else {} ) ->>>>>>> e0921c6b53310a47b10f01633809b2b9f785a465 # special case for DoubleHeads model def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) -<<<<<<< HEAD - if return_labels: - if model_class.__name__ == "GPTBigCodeDoubleHeadsModel": - inputs_dict["labels"] = torch.zeros( - (self.model_tester.batch_size, self.model_tester.num_choices, self.model_tester.seq_length), - dtype=torch.long, - device=torch_device, - ) - inputs_dict["input_ids"] = inputs_dict["labels"] - inputs_dict["token_type_ids"] = inputs_dict["labels"] - inputs_dict["mc_token_ids"] = torch.zeros( - (self.model_tester.batch_size, self.model_tester.num_choices), - dtype=torch.long, - device=torch_device, - ) - inputs_dict["mc_labels"] = torch.zeros( - self.model_tester.batch_size, dtype=torch.long, device=torch_device - ) - return inputs_dict - - def setUp(self): - self.model_tester = GPTBigCodeModelTester(self) - self.config_tester = ConfigTester(self, config_class=GPTBigCodeConfig, n_embd=37) - - def test_config(self): - self.config_tester.run_common_tests() - -======= return inputs_dict def setUp(self): @@ -612,7 +475,6 @@ def test_cpu_offload(self): def test_disk_offload(self): pass ->>>>>>> e0921c6b53310a47b10f01633809b2b9f785a465 def test_gpt_bigcode_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_gpt_bigcode_model(*config_and_inputs) @@ -633,13 +495,6 @@ def test_gpt_bigcode_lm_head_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_lm_head_model(*config_and_inputs) -<<<<<<< HEAD - def test_gpt_bigcode_double_lm_head_model(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_double_lm_head_model(*config_and_inputs) - -======= ->>>>>>> e0921c6b53310a47b10f01633809b2b9f785a465 def test_gpt_bigcode_sequence_classification_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_gpt_bigcode_for_sequence_classification(*config_and_inputs) @@ -664,293 +519,6 @@ def test_gpt_bigcode_weight_initialization(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_gpt_bigcode_weight_initialization(*config_and_inputs) -<<<<<<< HEAD - @slow - def test_batch_generation(self): - model = GPTBigCodeLMHeadModel.from_pretrained("bigcode/santacoder-fast-inference") - model.to(torch_device) - tokenizer = GPT2Tokenizer.from_pretrained("bigcode/santacoder") - - tokenizer.padding_side = "left" - - # Define PAD Token = EOS Token = 50256 - tokenizer.pad_token = tokenizer.eos_token - model.config.pad_token_id = model.config.eos_token_id - - # use different length sentences to test batching - sentences = [ - "Hello, my dog is a little", - "Today, I", - ] - - inputs = tokenizer(sentences, return_tensors="pt", padding=True) - input_ids = inputs["input_ids"].to(torch_device) - token_type_ids = torch.cat( - [ - input_ids.new_full((input_ids.shape[0], input_ids.shape[1] - 1), 0), - input_ids.new_full((input_ids.shape[0], 1), 500), - ], - dim=-1, - ) - - outputs = model.generate( - input_ids=input_ids, - attention_mask=inputs["attention_mask"].to(torch_device), - ) - - outputs_tt = model.generate( - input_ids=input_ids, - attention_mask=inputs["attention_mask"].to(torch_device), - token_type_ids=token_type_ids, - ) - - inputs_non_padded = tokenizer(sentences[0], return_tensors="pt").input_ids.to(torch_device) - output_non_padded = model.generate(input_ids=inputs_non_padded) - - num_paddings = inputs_non_padded.shape[-1] - inputs["attention_mask"][-1].long().sum().cpu().item() - inputs_padded = tokenizer(sentences[1], return_tensors="pt").input_ids.to(torch_device) - output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings) - - batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True) - batch_out_sentence_tt = tokenizer.batch_decode(outputs_tt, skip_special_tokens=True) - non_padded_sentence = tokenizer.decode(output_non_padded[0], skip_special_tokens=True) - padded_sentence = tokenizer.decode(output_padded[0], skip_special_tokens=True) - - expected_output_sentence = [ - "Hello, my dog is a little bit of a mess. I'm not sure if he's going", - "Today, I'm going to be doing a lot of research on this. I", - ] - self.assertListEqual(expected_output_sentence, batch_out_sentence) - self.assertTrue(batch_out_sentence_tt != batch_out_sentence) # token_type_ids should change output - self.assertListEqual(expected_output_sentence, [non_padded_sentence, padded_sentence]) - - @slow - def test_batch_generation_2heads(self): - model = GPTBigCodeDoubleHeadsModel.from_pretrained("bigcode/santacoder-fast-inference") - model.to(torch_device) - tokenizer = GPT2Tokenizer.from_pretrained("bigcode/santacoder") - - tokenizer.padding_side = "left" - - # This tokenizer has no pad token, so we have to set it in some way - # Define PAD Token = EOS Token = 50256 - tokenizer.pad_token = tokenizer.eos_token - model.config.pad_token_id = model.config.eos_token_id - - # use different length sentences to test batching - sentences = [ - "Hello, my dog is a little", - "Today, I", - ] - - inputs = tokenizer(sentences, return_tensors="pt", padding=True) - input_ids = inputs["input_ids"].to(torch_device) - token_type_ids = torch.cat( - [ - input_ids.new_full((input_ids.shape[0], input_ids.shape[1] - 1), 0), - input_ids.new_full((input_ids.shape[0], 1), 500), - ], - dim=-1, - ) - - outputs = model.generate( - input_ids=input_ids, - attention_mask=inputs["attention_mask"].to(torch_device), - ) - - outputs_tt = model.generate( - input_ids=input_ids, - attention_mask=inputs["attention_mask"].to(torch_device), - token_type_ids=token_type_ids, - ) - - inputs_non_padded = tokenizer(sentences[0], return_tensors="pt").input_ids.to(torch_device) - output_non_padded = model.generate(input_ids=inputs_non_padded) - - num_paddings = inputs_non_padded.shape[-1] - inputs["attention_mask"][-1].long().sum().cpu().item() - inputs_padded = tokenizer(sentences[1], return_tensors="pt").input_ids.to(torch_device) - output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings) - - batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True) - batch_out_sentence_tt = tokenizer.batch_decode(outputs_tt, skip_special_tokens=True) - non_padded_sentence = tokenizer.decode(output_non_padded[0], skip_special_tokens=True) - padded_sentence = tokenizer.decode(output_padded[0], skip_special_tokens=True) - - expected_output_sentence = [ - "Hello, my dog is a little bit of a mess. I'm not sure if he's going", - "Today, I'm going to be doing a lot of research on this. I", - ] - self.assertListEqual(expected_output_sentence, batch_out_sentence) - self.assertTrue(batch_out_sentence_tt != batch_out_sentence) # token_type_ids should change output - self.assertListEqual(expected_output_sentence, [non_padded_sentence, padded_sentence]) - - @slow - def test_model_from_pretrained(self): - for model_name in GPT_BIGCODE_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: - model = GPTBigCodeModel.from_pretrained(model_name) - self.assertIsNotNone(model) - - -@require_torch -class GPTBigCodeModelLanguageGenerationTest(unittest.TestCase): - def _test_lm_generate_gpt_bigcode_helper( - self, - gradient_checkpointing=False, - reorder_and_upcast_attn=False, - scale_attn_by_inverse_layer_idx=False, - verify_outputs=True, - ): - model = GPTBigCodeLMHeadModel.from_pretrained( - "bigcode/santacoder-fast-inference", - reorder_and_upcast_attn=reorder_and_upcast_attn, - scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx, - ) - if gradient_checkpointing: - model.gradient_checkpointing_enable() - else: - model.gradient_checkpointing_disable() - model.to(torch_device) - - # The dog - input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) - - # The dog was found in a field near the intersection of West and West Streets.\n\nThe dog - # fmt: off - expected_output_ids = [ - 464, 3290, 373, 1043, 287, 257, 2214, 1474, 262, 16246, 286, 2688, 290, 2688, 27262, 13, 198, 198, 464, 3290, - ] - # fmt: on - output_ids = model.generate(input_ids, do_sample=False) - if verify_outputs: - self.assertListEqual(output_ids[0].tolist(), expected_output_ids) - - @slow - def test_lm_generate_gpt_bigcode(self): - self._test_lm_generate_gpt_bigcode_helper() - - @slow - def test_lm_generate_gpt_bigcode_with_gradient_checkpointing(self): - self._test_lm_generate_gpt_bigcode_helper(gradient_checkpointing=True) - - @slow - def test_lm_generate_gpt_bigcode_with_reorder_and_upcast_attn(self): - self._test_lm_generate_gpt_bigcode_helper(reorder_and_upcast_attn=True) - - @slow - def test_lm_generate_gpt_bigcode_with_scale_attn_by_inverse_layer_idx(self): - self._test_lm_generate_gpt_bigcode_helper(scale_attn_by_inverse_layer_idx=True, verify_outputs=False) - - @slow - def test_gpt_bigcode_sample(self): - tokenizer = GPT2Tokenizer.from_pretrained("bigcode/santacoder") - model = GPTBigCodeLMHeadModel.from_pretrained("bigcode/santacoder-fast-inference") - model.to(torch_device) - - torch.manual_seed(0) - tokenized = tokenizer("Today is a nice day and", return_tensors="pt", return_token_type_ids=True) - input_ids = tokenized.input_ids.to(torch_device) - output_ids = model.generate(input_ids, do_sample=True) - output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True) - - token_type_ids = tokenized.token_type_ids.to(torch_device) - output_seq = model.generate(input_ids=input_ids, do_sample=True, num_return_sequences=5) - output_seq_tt = model.generate( - input_ids=input_ids, token_type_ids=token_type_ids, do_sample=True, num_return_sequences=5 - ) - output_seq_strs = tokenizer.batch_decode(output_seq, skip_special_tokens=True) - output_seq_tt_strs = tokenizer.batch_decode(output_seq_tt, skip_special_tokens=True) - - EXPECTED_OUTPUT_STR = ( - "Today is a nice day and if you don't know anything about the state of play during your holiday" - ) - self.assertEqual(output_str, EXPECTED_OUTPUT_STR) - self.assertTrue( - all([output_seq_strs[idx] != output_seq_tt_strs[idx] for idx in range(len(output_seq_tt_strs))]) - ) # token_type_ids should change output - - @slow - def test_gpt_bigcode_sample_max_time(self): - tokenizer = GPT2Tokenizer.from_pretrained("bigcode/santacoder") - model = GPTBigCodeLMHeadModel.from_pretrained("bigcode/santacoder-fast-inference") - model.to(torch_device) - - torch.manual_seed(0) - tokenized = tokenizer("Today is a nice day and", return_tensors="pt", return_token_type_ids=True) - input_ids = tokenized.input_ids.to(torch_device) - - MAX_TIME = 0.5 - - start = datetime.datetime.now() - model.generate(input_ids, do_sample=True, max_time=MAX_TIME, max_length=256) - duration = datetime.datetime.now() - start - self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME)) - self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME)) - - start = datetime.datetime.now() - model.generate(input_ids, do_sample=False, max_time=MAX_TIME, max_length=256) - duration = datetime.datetime.now() - start - self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME)) - self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME)) - - start = datetime.datetime.now() - model.generate(input_ids, do_sample=False, num_beams=2, max_time=MAX_TIME, max_length=256) - duration = datetime.datetime.now() - start - self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME)) - self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME)) - - start = datetime.datetime.now() - model.generate(input_ids, do_sample=True, num_beams=2, max_time=MAX_TIME, max_length=256) - duration = datetime.datetime.now() - start - self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME)) - self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME)) - - start = datetime.datetime.now() - model.generate(input_ids, do_sample=False, max_time=None, max_length=256) - duration = datetime.datetime.now() - start - self.assertGreater(duration, datetime.timedelta(seconds=1.5 * MAX_TIME)) - - @slow - def test_contrastive_search_gpt_bigcode(self): - article = ( - "DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research " - "laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based" - ) - - gpt_bigcode_tokenizer = GPT2Tokenizer.from_pretrained("bigcode/santacoder") - gpt_bigcode_model = GPTBigCodeLMHeadModel.from_pretrained("bigcode/santacoder-fast-inference").to(torch_device) - input_ids = gpt_bigcode_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) - - outputs = gpt_bigcode_model.generate(input_ids, penalty_alpha=0.6, top_k=4, max_length=256) - - generated_text = gpt_bigcode_tokenizer.batch_decode(outputs, skip_special_tokens=True) - - self.assertListEqual( - generated_text, - [ - "DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research " - "laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based in London, " - "United Kingdom\n\nGoogle has a lot of data on its users and uses it to improve its products, such as " - "Google Now, which helps users find the information they're looking for on the web. But the company " - "is not the only one to collect data on its users. Facebook, for example, has its own facial " - "recognition technology, as well as a database of millions of photos that it uses to personalize its " - "News Feed.\n\nFacebook's use of data is a hot topic in the tech industry, with privacy advocates " - "concerned about the company's ability to keep users' information private. In a blog post last " - 'year, Facebook CEO Mark Zuckerberg said his company would "do our best to be transparent about our ' - 'data use and how we use it."\n\n"We have made it clear that we do not sell or share your data with ' - 'third parties," Zuckerberg wrote. "If you have questions or concerns, please reach out to us at ' - 'privacy@facebook.com."\n\nGoogle declined to comment on the privacy implications of its use of data, ' - "but said in a statement to The Associated Press that" - ], - ) - - -@require_torch -class GPTBigCodeAttentionTest(unittest.TestCase): - def get_attention(self, attention_type: AttentionType): - config = GPTBigCodeConfig.from_pretrained( - "bigcode/santacoder-fast-inference", - attention_type=attention_type, -======= @require_torch class GPTBigCodeMHAModelTest(GPTBigCodeMQAModelTest): @@ -999,59 +567,12 @@ def get_attention(self, multi_query): config = GPTBigCodeConfig.from_pretrained( "bigcode/gpt_bigcode-santacoder", multi_query=multi_query, ->>>>>>> e0921c6b53310a47b10f01633809b2b9f785a465 attn_pdrop=0, resid_pdrop=0, ) return GPTBigCodeAttention(config) @parameterized.expand([(seed, is_train_mode) for seed in range(5) for is_train_mode in [True, False]]) -<<<<<<< HEAD - def test_mqa_correctness(self, seed, is_train_mode=True): - torch.manual_seed(seed) - embed_dim = 2048 - head_dim = 128 - - # GET THE WEIGHTS FROM MULTI-QUERY ATTENTION 1 - attention_mq1 = self.get_attention(AttentionType.MULTI_QUERY_1) - state_dict_mq1 = attention_mq1.state_dict() - attn_weight, attn_k_weight, attn_v_weight = torch.split( - state_dict_mq1["c_attn.weight"], [embed_dim, head_dim, head_dim], dim=1 - ) - attn_bias, attn_k_bias, attn_v_bias = torch.split( - state_dict_mq1["c_attn.bias"], [embed_dim, head_dim, head_dim], dim=0 - ) - proj = state_dict_mq1["c_proj.weight"] - proj_bias = state_dict_mq1["c_proj.bias"] - - # PUT THEM INTO THE MULTI-HEAD ATTENTION - num_heads = 16 - c_attn_weight = torch.hstack([attn_weight] + num_heads * [attn_k_weight] + num_heads * [attn_v_weight]) - c_attn_bias = torch.hstack([attn_bias] + num_heads * [attn_k_bias] + num_heads * [attn_v_bias]) - attention_mh = self.get_attention(AttentionType.MULTI_HEAD) - state_dict = attention_mh.state_dict() - state_dict["c_attn.weight"] = c_attn_weight - state_dict["c_attn.bias"] = c_attn_bias - state_dict["c_proj.weight"] = proj - state_dict["c_proj.bias"] = proj_bias - attention_mh.load_state_dict(state_dict) - - # PUT THEM INTO THE MULTI-QUERY ATTENTION 2 - attention_mq2 = self.get_attention(AttentionType.MULTI_QUERY_2) - state_dict_mq2 = attention_mq2.state_dict() - state_dict_mq2["q_attn.weight"] = attn_weight - state_dict_mq2["q_attn.bias"] = attn_bias - state_dict_mq2["kv_attn.weight"] = torch.hstack([attn_k_weight, attn_v_weight]) - state_dict_mq2["kv_attn.bias"] = torch.hstack([attn_k_bias, attn_v_bias]) - state_dict_mq2["c_proj.weight"] = proj - state_dict_mq2["c_proj.bias"] = proj_bias - attention_mq2.load_state_dict(state_dict_mq2) - - # PUT THE MODEL INTO THE CORRECT MODE - attention_mh.train(is_train_mode) - attention_mq1.train(is_train_mode) - attention_mq2.train(is_train_mode) -======= def test_mqa_reduces_to_mha(self, seed, is_train_mode=True): torch.manual_seed(seed) @@ -1085,25 +606,12 @@ def test_mqa_reduces_to_mha(self, seed, is_train_mode=True): # PUT THE MODEL INTO THE CORRECT MODE attention_mha.train(is_train_mode) attention_mqa.train(is_train_mode) ->>>>>>> e0921c6b53310a47b10f01633809b2b9f785a465 # RUN AN INPUT THROUGH THE MODELS num_tokens = 5 hidden_states = torch.randn(1, num_tokens, embed_dim) -<<<<<<< HEAD - attention_mh_result = attention_mh(hidden_states)[0] - attention_mq1_result = attention_mq1(hidden_states)[0] - attention_mq2_result = attention_mq2(hidden_states)[0] - - # CHECK THAT ALL OUTPUTS ARE THE SAME - tolerance = 1e-5 - self.assertTrue(torch.allclose(attention_mh_result, attention_mq1_result, atol=tolerance)) - self.assertTrue(torch.allclose(attention_mh_result, attention_mq2_result, atol=tolerance)) - self.assertTrue(torch.allclose(attention_mq1_result, attention_mq2_result, atol=tolerance)) -======= attention_mha_result = attention_mha(hidden_states)[0] attention_mqa_result = attention_mqa(hidden_states)[0] # CHECK THAT ALL OUTPUTS ARE THE SAME self.assertTrue(torch.allclose(attention_mha_result, attention_mqa_result, atol=1e-5)) ->>>>>>> e0921c6b53310a47b10f01633809b2b9f785a465 From 2fe9ae32fc930aca8ede3e333a25b44a494a63e3 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 12 Apr 2023 20:30:48 -0400 Subject: [PATCH 24/27] Update conversion script --- ...convert_megatron_gpt_bigcode_checkpoint.py | 20 +++++++------------ 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/src/transformers/models/megatron_gpt_bigcode/convert_megatron_gpt_bigcode_checkpoint.py b/src/transformers/models/megatron_gpt_bigcode/convert_megatron_gpt_bigcode_checkpoint.py index 6ca2e79671..72b3b6910b 100644 --- a/src/transformers/models/megatron_gpt_bigcode/convert_megatron_gpt_bigcode_checkpoint.py +++ b/src/transformers/models/megatron_gpt_bigcode/convert_megatron_gpt_bigcode_checkpoint.py @@ -38,7 +38,7 @@ import torch -from transformers import GPTBigCodeConfig, GPTBigCodeLMHeadModel, GPTBigCodeModel +from transformers import GPTBigCodeConfig, GPTBigCodeForCausalLM, GPTBigCodeModel #################################################################################################### @@ -78,7 +78,7 @@ def recursive_print(name, val, spaces=0): } -def convert_megatron_checkpoint(input_state_dict, merge_qkv): +def convert_megatron_checkpoint(input_state_dict): # The converted output model. output_state_dict = {} ds_args = input_state_dict["args"] @@ -95,10 +95,10 @@ def convert_megatron_checkpoint(input_state_dict, merge_qkv): activation_function = "gelu_new" if ds_args.attention_head_type == "multihead": - attention_type = 1 + multi_query = False else: assert ds_args.attention_head_type == "multiquery" - attention_type = 2 if merge_qkv else 3 + multi_query = True attention_softmax_in_fp32 = ds_args.attention_softmax_in_fp32 or ds_args.apply_query_key_layer_scaling @@ -112,7 +112,7 @@ def convert_megatron_checkpoint(input_state_dict, merge_qkv): n_head=ds_args.num_attention_heads, n_inner=ds_args.ffn_hidden_size, activation_function=activation_function, - attention_type=attention_type, + multi_query=multi_query, resid_pdrop=0.1, embd_pdrop=0.1, attn_pdrop=0.1, @@ -218,12 +218,6 @@ def main(argv=None): type=str, help="Path to the checkpoint file (.zip archive or direct .pt file)", ) - parser.add_argument( - "--no_merge_qkv", - dest="merge_qkv", - action="store_false", - help="Do not merge the query and key_value tensors (MQA).", - ) parser.add_argument( "--custom_model", action="store_true", @@ -243,7 +237,7 @@ def main(argv=None): # Convert. print("Converting") - config, output_state_dict = convert_megatron_checkpoint(input_state_dict, args.merge_qkv) + config, output_state_dict = convert_megatron_checkpoint(input_state_dict) # Print the structure of converted state dict. if args.print_checkpoint_structure: @@ -253,7 +247,7 @@ def main(argv=None): # Save custom model GPTBigCodeConfig.register_for_auto_class() GPTBigCodeModel.register_for_auto_class("AutoModelForCausalLM") - hf_model = GPTBigCodeLMHeadModel(config) + hf_model = GPTBigCodeForCausalLM(config) hf_model.load_state_dict(output_state_dict) hf_model.save_pretrained(basename) From 7c45f0c3c1901fb44af0d664547bc52a335c11c7 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 12 Apr 2023 20:35:09 -0400 Subject: [PATCH 25/27] Fix and style --- src/transformers/activations.py | 14 ++++++-------- .../models/gpt_bigcode/inference_runner.py | 8 +++----- .../checkpoint_reshaping_and_interoperability.py | 5 ----- .../convert_megatron_gpt_bigcode_checkpoint.py | 3 +-- .../megatron_gpt_bigcode/push_checkpoints.py | 6 +++--- 5 files changed, 13 insertions(+), 23 deletions(-) diff --git a/src/transformers/activations.py b/src/transformers/activations.py index d2c3768926..cb41fee9e9 100644 --- a/src/transformers/activations.py +++ b/src/transformers/activations.py @@ -28,9 +28,8 @@ class PytorchGELUTanh(nn.Module): """ A fast C implementation of the tanh approximation of the GeLU activation function. See - https://arxiv.org/abs/1606.08415. - This implementation is equivalent to NewGELU and FastGELU but much faster. However, it is not an exact numerical - match due to rounding errors. + https://arxiv.org/abs/1606.08415. This implementation is equivalent to NewGELU and FastGELU but much faster. + However, it is not an exact numerical match due to rounding errors. """ def __init__(self): @@ -99,11 +98,10 @@ class ClippedGELUActivation(nn.Module): """ Clip the range of possible GeLU outputs between [min, max]. This is especially useful for quantization purpose, as it allows mapping negatives values in the GeLU spectrum. For more information on this trick, please refer to - https://arxiv.org/abs/2004.09602. - Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when - initially created. - For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 + - torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))). See https://arxiv.org/abs/1606.08415 + https://arxiv.org/abs/2004.09602. Gaussian Error Linear Unit. Original Implementation of the gelu activation + function in Google Bert repo when initially created. For information: OpenAI GPT's gelu is slightly different (and + gives slightly different results): 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, + 3)))). See https://arxiv.org/abs/1606.08415 """ def __init__(self, min: float, max: float): diff --git a/src/transformers/models/gpt_bigcode/inference_runner.py b/src/transformers/models/gpt_bigcode/inference_runner.py index c75be65357..0cba0cc7e5 100644 --- a/src/transformers/models/gpt_bigcode/inference_runner.py +++ b/src/transformers/models/gpt_bigcode/inference_runner.py @@ -1,4 +1,3 @@ -import math from typing import List, Union import torch @@ -71,9 +70,9 @@ def _allocate(self, batch_size, device, dtype): kv_caches = [] for block in self.model.h: block.attn.freeze_kv_cache() - kv_cache=block.attn.get_kv_cache(self.batch_size, self.max_sequence_length, self.device, self.dtype) + kv_cache = block.attn.get_kv_cache(self.batch_size, self.max_sequence_length, self.device, self.dtype) if attn.is_mqa: - kv_cache=kv_cache.unsqueeze(1) + kv_cache = kv_cache.unsqueeze(1) kv_caches.append(kv_cache) kv_cache_size = sum(kv_cache.numel() for kv_cache in kv_caches) @@ -158,7 +157,7 @@ def _allocate(self, batch_size, device, dtype): self.mlp_c_proj = activation_pool[query_begin:query_end].view(self.batch_size, -1) if self.inference_runner_type != InferenceRunnerType.BASE_RUNNER: - print(f"Generating cuda graphs") + print("Generating cuda graphs") self.memory_pool = None if self.inference_runner_type == InferenceRunnerType.FULL_GRAPH: self.cuda_graphs = {} @@ -228,7 +227,6 @@ def _forward_qkv(self, block): ) def _forward_attn(self, block, key_length): - layer_idx = block.attn.layer_idx self.current_key_values[key_length][layer_idx].copy_(self.kv_attn) attn_weights = self.padded_attn_weights[key_length] diff --git a/src/transformers/models/megatron_gpt_bigcode/checkpoint_reshaping_and_interoperability.py b/src/transformers/models/megatron_gpt_bigcode/checkpoint_reshaping_and_interoperability.py index 9292cb505c..e854d7df19 100644 --- a/src/transformers/models/megatron_gpt_bigcode/checkpoint_reshaping_and_interoperability.py +++ b/src/transformers/models/megatron_gpt_bigcode/checkpoint_reshaping_and_interoperability.py @@ -477,7 +477,6 @@ def convert_checkpoint_from_megatron_to_transformers(args): # For layernorm(s), simply store the layer norm. if op_name.endswith("layernorm"): - ln_name = "ln_1" if op_name.startswith("input") else "ln_2" output_state_dict[layer_name + "." + ln_name + "." + weight_or_bias] = params @@ -485,7 +484,6 @@ def convert_checkpoint_from_megatron_to_transformers(args): elif ( op_name == "attention.query_key_value" or op_name == "self_attention.query_key_value" ) and weight_or_bias == "weight": - # Insert a tensor of 1x1xDxD bias. causal_mask = torch.tril(torch.ones((n_positions, n_positions), dtype=dtype)).view( 1, 1, n_positions, n_positions @@ -512,7 +510,6 @@ def convert_checkpoint_from_megatron_to_transformers(args): elif ( op_name == "attention.query_key_value" or op_name == "self_attention.query_key_value" ) and weight_or_bias == "bias": - out_val = megatron_to_transformers_fix_query_key_value_ordering( params, checkpoint_version, 3, heads, hidden_size_per_head ) @@ -521,13 +518,11 @@ def convert_checkpoint_from_megatron_to_transformers(args): # Transpose the weights. elif weight_or_bias == "weight": - out_name = megatron_to_transformers[op_name] output_state_dict[layer_name + out_name + "weight"] = params.transpose(0, 1) # Copy the bias. elif weight_or_bias == "bias": - out_name = megatron_to_transformers[op_name] output_state_dict[layer_name + out_name + "bias"] = params diff --git a/src/transformers/models/megatron_gpt_bigcode/convert_megatron_gpt_bigcode_checkpoint.py b/src/transformers/models/megatron_gpt_bigcode/convert_megatron_gpt_bigcode_checkpoint.py index 72b3b6910b..e3d1d85d58 100644 --- a/src/transformers/models/megatron_gpt_bigcode/convert_megatron_gpt_bigcode_checkpoint.py +++ b/src/transformers/models/megatron_gpt_bigcode/convert_megatron_gpt_bigcode_checkpoint.py @@ -177,12 +177,11 @@ def convert_megatron_checkpoint(input_state_dict): # For layernorm(s), simply store the layer norm. if op_name.endswith("layernorm"): - ln_name = "ln_1" if op_name.startswith("input") else "ln_2" output_state_dict[layer_name + "." + ln_name + "." + weight_or_bias] = val # Concatenate QKV matrix. - elif merge_qkv and (op_name == "self_attention.key_value"): + elif op_name == "self_attention.key_value": # Query is before key_value in the dict. query = output_state_dict.pop(layer_name + ".attn.q_attn." + weight_or_bias) out_val = torch.cat([query, val], dim=0) diff --git a/src/transformers/models/megatron_gpt_bigcode/push_checkpoints.py b/src/transformers/models/megatron_gpt_bigcode/push_checkpoints.py index ececb9e5be..e2cbd38a8f 100644 --- a/src/transformers/models/megatron_gpt_bigcode/push_checkpoints.py +++ b/src/transformers/models/megatron_gpt_bigcode/push_checkpoints.py @@ -4,13 +4,13 @@ from pathlib import Path from huggingface_hub import Repository + from transformers.models.megatron_gpt_bigcode.convert_megatron_gpt_bigcode_checkpoint import main as convert """ -Script to upload Megatron checkpoints to a HF repo on the Hub. -The script clones/creates a repo on the Hub, checks out a branch `--branch_name`, -and converts each `iter_` checkpoint and saves it as a commit on that branch. +Script to upload Megatron checkpoints to a HF repo on the Hub. The script clones/creates a repo on the Hub, checks out +a branch `--branch_name`, and converts each `iter_` checkpoint and saves it as a commit on that branch. """ From 05a6225f0d428037f790154fd4cbe914e6e80165 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 12 Apr 2023 20:38:49 -0400 Subject: [PATCH 26/27] Reduce diff --- src/transformers/activations.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/transformers/activations.py b/src/transformers/activations.py index cb41fee9e9..587dc2e599 100644 --- a/src/transformers/activations.py +++ b/src/transformers/activations.py @@ -28,8 +28,10 @@ class PytorchGELUTanh(nn.Module): """ A fast C implementation of the tanh approximation of the GeLU activation function. See - https://arxiv.org/abs/1606.08415. This implementation is equivalent to NewGELU and FastGELU but much faster. - However, it is not an exact numerical match due to rounding errors. + https://arxiv.org/abs/1606.08415. + + This implementation is equivalent to NewGELU and FastGELU but much faster. However, it is not an exact numerical + match due to rounding errors. """ def __init__(self): @@ -98,10 +100,13 @@ class ClippedGELUActivation(nn.Module): """ Clip the range of possible GeLU outputs between [min, max]. This is especially useful for quantization purpose, as it allows mapping negatives values in the GeLU spectrum. For more information on this trick, please refer to - https://arxiv.org/abs/2004.09602. Gaussian Error Linear Unit. Original Implementation of the gelu activation - function in Google Bert repo when initially created. For information: OpenAI GPT's gelu is slightly different (and - gives slightly different results): 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, - 3)))). See https://arxiv.org/abs/1606.08415 + https://arxiv.org/abs/2004.09602. + + Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when + initially created. + + For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 + + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))). See https://arxiv.org/abs/1606.08415 """ def __init__(self, min: float, max: float): From 8b0cb2c6261e65d4d852d6813f071772c1b32665 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 24 Apr 2023 13:04:38 -0400 Subject: [PATCH 27/27] Add back experimental features (#17) --- .../gpt_bigcode/configuration_gpt_bigcode.py | 33 +++++ .../models/gpt_bigcode/inference_runner.py | 14 +-- .../gpt_bigcode/modeling_gpt_bigcode.py | 114 +++++++++++++++++- 3 files changed, 149 insertions(+), 12 deletions(-) diff --git a/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py index 9cbaf3e184..1cfba93a71 100644 --- a/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py @@ -14,6 +14,8 @@ # limitations under the License. """ GPTBigCode configuration""" +from enum import IntEnum + from ...configuration_utils import PretrainedConfig from ...utils import logging @@ -25,6 +27,19 @@ } +class InferenceRunnerType(IntEnum): + NO_RUNNER = 0 + # Use the inference runner without cuda graphs. + BASE_RUNNER = 1 + # Use cuda graphs in the inference runner. Leave out the attention which has a variable shape. + # This significantly lowers the cpu time and prevent a cpu bottleneck for smaller batches and models. + PARTIAL_GRAPH = 2 + # Turn the whole model into a cuda graph. One graph for each sequence length. + # Note: only useful for small batches and models, graphs take some time to generate, flaky. + # Crashes with jit on A100 but seems to work without jit (PYTORCH_JIT=0) and on V100. + FULL_GRAPH = 3 + + class GPTBigCodeConfig(PretrainedConfig): """ This is the configuration class to store the configuration of a [`GPTBigCodeModel`]. It is used to instantiate a @@ -119,6 +134,12 @@ def __init__( attention_softmax_in_fp32=True, scale_attention_softmax_in_fp32=True, multi_query=True, + inference_runner=InferenceRunnerType.NO_RUNNER, + validate_runner_input=True, + pre_allocate_kv_cache=False, + max_sequence_length=None, + max_batch_size=None, + pad_key_length=True, **kwargs, ): self.vocab_size = vocab_size @@ -142,4 +163,16 @@ def __init__( self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id + self.inference_runner = InferenceRunnerType(inference_runner) + # Set to False to disable input validation of safe inputs, for a small speedup. + self.validate_runner_input = validate_runner_input + + self.pre_allocate_kv_cache = pre_allocate_kv_cache + # The max sequence length for the pre-allocated KV cache (`n_positions` if not provided). + self.max_sequence_length = max_sequence_length + # The max batch size for the pre-allocated KV cache, (deduce from input if not provided). + self.max_batch_size = max_batch_size + # Pad key length to a multiple of 8 (requires pre_allocate_kv_cache). + self.pad_key_length = pad_key_length + super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) diff --git a/src/transformers/models/gpt_bigcode/inference_runner.py b/src/transformers/models/gpt_bigcode/inference_runner.py index 0cba0cc7e5..1767bf9642 100644 --- a/src/transformers/models/gpt_bigcode/inference_runner.py +++ b/src/transformers/models/gpt_bigcode/inference_runner.py @@ -4,7 +4,7 @@ from transformers import GPTBigCodeConfig from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions -from transformers.models.gpt_bigcode.configuration_gpt_bigcode import AttentionType, InferenceRunnerType +from transformers.models.gpt_bigcode.configuration_gpt_bigcode import InferenceRunnerType from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeBlock, masked_softmax, upcast_masked_softmax @@ -25,7 +25,7 @@ def __init__(self, config: GPTBigCodeConfig, model): self.pad_key_length = 8 if config.pad_key_length else 1 # TODO: Support other attention types? - assert model.attention_type == AttentionType.MULTI_QUERY_1 + assert model.multi_query self.max_sequence_length = config.max_sequence_length or config.n_positions @@ -71,7 +71,7 @@ def _allocate(self, batch_size, device, dtype): for block in self.model.h: block.attn.freeze_kv_cache() kv_cache = block.attn.get_kv_cache(self.batch_size, self.max_sequence_length, self.device, self.dtype) - if attn.is_mqa: + if attn.multi_query: kv_cache = kv_cache.unsqueeze(1) kv_caches.append(kv_cache) @@ -122,7 +122,7 @@ def _allocate(self, batch_size, device, dtype): self.kv_attn = self.c_attn[:, attn.embed_dim :] keys, values = zip(*(kv_cache.split((attn.head_dim, attn.head_dim), dim=-1) for kv_cache in kv_caches)) - head_slice = 0 if attn.is_mqa else slice(None) + head_slice = 0 if attn.multi_query else slice(None) self.padded_keys = [ [key[:, head_slice, :key_length, :].transpose(-1, -2) for key in keys] for key_length in padded_key_lengths @@ -285,9 +285,9 @@ def _forward(self, key_length): def forward( self, - input_ids: torch.LongTensor, - attention_mask: torch.FloatTensor, - position_ids: torch.LongTensor, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + position_ids: torch.Tensor, past_key_values: Union[List[torch.Tensor], int], ) -> BaseModelOutputWithPastAndCrossAttentions: batch_size, query_length = input_ids.shape diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 72858532bf..76ba07b73e 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -34,7 +34,7 @@ add_start_docstrings_to_model_forward, logging, ) -from .configuration_gpt_bigcode import GPTBigCodeConfig +from .configuration_gpt_bigcode import GPTBigCodeConfig, InferenceRunnerType logger = logging.get_logger(__name__) @@ -105,6 +105,14 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): config.scale_attention_softmax_in_fp32 and config.attention_softmax_in_fp32 ) + # KV caching and padding + self.kv_cache = None + self.kv_cache_max_batch_size = config.max_batch_size or 0 + self.kv_cache_max_sequence_length = config.max_sequence_length or config.n_positions + self.pre_allocate_kv_cache = config.pre_allocate_kv_cache + self.pad_key_length = config.pad_key_length and config.pre_allocate_kv_cache + self._frozen_kv_cache = False + if self.is_cross_attention: if self.multi_query: raise NotImplementedError("Multi-Query Attention not supported for cross_attention") @@ -202,6 +210,41 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): return attn_output, attn_weights + def freeze_kv_cache(self, enable=True): + if self.kv_cache is None: + raise RuntimeError("KV cache not found.") + # Prevent re-allocation of the KV cache. + self._frozen_kv_cache = enable + + def get_kv_cache(self, batch_size, sequence_length, device, dtype, allocate=True): + if ( + self.kv_cache is None + or self.kv_cache.dtype != dtype + or self.kv_cache.device != device + or batch_size > self.kv_cache_max_batch_size + or sequence_length > self.kv_cache_max_sequence_length + ): + if self._frozen_kv_cache or not allocate: + if self.kv_cache is None: + raise RuntimeError("KV cache not found.") + else: + raise RuntimeError( + f"Invalid KV cache: " + f"existing = {(self.kv_cache.dtype,self.kv_cache.device,self.kv_cache_max_batch_size,self.kv_cache_max_sequence_length)}, " + f"requested = {(dtype,device,batch_size,sequence_length)}" + ) + # Free memory first. + self.kv_cache = None + self.kv_cache_max_sequence_length = max(sequence_length, self.kv_cache_max_sequence_length) + self.kv_cache_max_batch_size = max(batch_size, self.kv_cache_max_batch_size) + kv_cache_size = 2 * self.kv_cache_max_batch_size * self.kv_cache_max_sequence_length * self.kv_dim + self.kv_cache = torch.empty([kv_cache_size], device=device, dtype=dtype) + # This view ensures the cache is contiguous for all batch sizes. + kv_cache = self.kv_cache[: 2 * batch_size * self.kv_cache_max_sequence_length * self.kv_dim].view( + batch_size, self.kv_heads, self.kv_cache_max_sequence_length, 2 * self.head_dim + ) + return kv_cache[:, 0, :sequence_length, :] if self.multi_query else kv_cache[:, :, :sequence_length, :] + def forward( self, hidden_states: torch.Tensor, @@ -239,9 +282,27 @@ def forward( .split((self.head_dim, 2 * self.head_dim), dim=3) ) - if layer_past is not None: - key_value = torch.cat((layer_past, key_value), dim=-2) - present = key_value if use_cache else None + present = None + + if self.pre_allocate_kv_cache: + if use_cache or layer_past is not None: + last_key_length = layer_past or 0 + batch_size = key_value.size(0) + key_length = last_key_length + key_value.size(-2) + padded_key_length = key_length + -key_length % (8 if self.pad_key_length else 1) + kv_cache = self.get_kv_cache( + batch_size, padded_key_length, key_value.device, key_value.dtype, allocate=last_key_length == 0 + ) + if self.multi_query: + kv_cache[:, last_key_length:key_length, :].copy_(key_value) + key_value = kv_cache + if use_cache: + present = key_length + else: + if layer_past is not None: + key_value = torch.cat((layer_past, key_value), dim=-2) + if use_cache: + present = key_value key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) @@ -513,6 +574,17 @@ def __init__(self, config): self.h = nn.ModuleList([GPTBigCodeBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)]) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + self.pre_allocate_kv_cache = config.pre_allocate_kv_cache + self.pad_key_length = config.pad_key_length and self.pre_allocate_kv_cache + self.inference_runner_type = InferenceRunnerType(config.inference_runner) + + if self.inference_runner_type == InferenceRunnerType.NO_RUNNER: + self.inference_runner = None + else: + from .inference_runner import GPTBigCodeInferenceRunner + + self.inference_runner = GPTBigCodeInferenceRunner(config, self) + max_positions = config.max_position_embeddings self.register_buffer( "bias", torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)), persistent=False @@ -551,6 +623,31 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + if self.inference_runner is not None and past_key_values is not None: + if self.config.validate_runner_input: + assert input_ids is not None + assert past_key_values is not None + assert attention_mask is not None + assert token_type_ids is None + assert position_ids is not None + assert head_mask is None + assert inputs_embeds is None + assert encoder_hidden_states is None + assert encoder_attention_mask is None + use_cache = use_cache if use_cache is not None else self.config.use_cache + assert use_cache is True + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) + assert output_attentions is False + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + assert output_hidden_states is False + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + assert return_dict is True + return self.inference_runner.forward(input_ids, attention_mask, position_ids, past_key_values) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -583,8 +680,10 @@ def forward( if past_key_values is None: past_length = 0 past_key_values = tuple([None] * len(self.h)) + elif self.pre_allocate_kv_cache: + past_length = past_key_values[0] else: - past_length = past_key_values[0][0].size(-2) + past_length = past_key_values[0].size(-2) if attention_mask is not None and len(attention_mask.shape) == 2 and position_ids is None: # create position_ids on the fly for batch generation @@ -610,6 +709,11 @@ def forward( # MHA models: (batch_size, n_heads, query_length, key_length) attention_mask = self_attention_mask.unsqueeze(2 if self.multi_query else 1) + if self.pad_key_length: + pad = -key_length % 8 + if pad > 0: + attention_mask = torch.nn.functional.pad(attention_mask, (0, pad), mode="constant", value=False) + # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] if (