diff --git a/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py index 9cbaf3e184..1cfba93a71 100644 --- a/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py @@ -14,6 +14,8 @@ # limitations under the License. """ GPTBigCode configuration""" +from enum import IntEnum + from ...configuration_utils import PretrainedConfig from ...utils import logging @@ -25,6 +27,19 @@ } +class InferenceRunnerType(IntEnum): + NO_RUNNER = 0 + # Use the inference runner without cuda graphs. + BASE_RUNNER = 1 + # Use cuda graphs in the inference runner. Leave out the attention which has a variable shape. + # This significantly lowers the cpu time and prevent a cpu bottleneck for smaller batches and models. + PARTIAL_GRAPH = 2 + # Turn the whole model into a cuda graph. One graph for each sequence length. + # Note: only useful for small batches and models, graphs take some time to generate, flaky. + # Crashes with jit on A100 but seems to work without jit (PYTORCH_JIT=0) and on V100. + FULL_GRAPH = 3 + + class GPTBigCodeConfig(PretrainedConfig): """ This is the configuration class to store the configuration of a [`GPTBigCodeModel`]. It is used to instantiate a @@ -119,6 +134,12 @@ def __init__( attention_softmax_in_fp32=True, scale_attention_softmax_in_fp32=True, multi_query=True, + inference_runner=InferenceRunnerType.NO_RUNNER, + validate_runner_input=True, + pre_allocate_kv_cache=False, + max_sequence_length=None, + max_batch_size=None, + pad_key_length=True, **kwargs, ): self.vocab_size = vocab_size @@ -142,4 +163,16 @@ def __init__( self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id + self.inference_runner = InferenceRunnerType(inference_runner) + # Set to False to disable input validation of safe inputs, for a small speedup. + self.validate_runner_input = validate_runner_input + + self.pre_allocate_kv_cache = pre_allocate_kv_cache + # The max sequence length for the pre-allocated KV cache (`n_positions` if not provided). + self.max_sequence_length = max_sequence_length + # The max batch size for the pre-allocated KV cache, (deduce from input if not provided). + self.max_batch_size = max_batch_size + # Pad key length to a multiple of 8 (requires pre_allocate_kv_cache). + self.pad_key_length = pad_key_length + super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) diff --git a/src/transformers/models/gpt_bigcode/inference_runner.py b/src/transformers/models/gpt_bigcode/inference_runner.py new file mode 100644 index 0000000000..1767bf9642 --- /dev/null +++ b/src/transformers/models/gpt_bigcode/inference_runner.py @@ -0,0 +1,337 @@ +from typing import List, Union + +import torch + +from transformers import GPTBigCodeConfig +from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions +from transformers.models.gpt_bigcode.configuration_gpt_bigcode import InferenceRunnerType +from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeBlock, masked_softmax, upcast_masked_softmax + + +def _align_tensor(x): + return x + -x % 128 + + +class GPTBigCodeInferenceRunner: + def __init__(self, config: GPTBigCodeConfig, model): + self.batch_size = None + self.model = model + self.n_layer = len(self.model.h) + + self.inference_runner_type = InferenceRunnerType(config.inference_runner) + assert self.inference_runner_type != InferenceRunnerType.NO_RUNNER + assert config.pre_allocate_kv_cache + self.validate_input = config.validate_runner_input + self.pad_key_length = 8 if config.pad_key_length else 1 + + # TODO: Support other attention types? + assert model.multi_query + + self.max_sequence_length = config.max_sequence_length or config.n_positions + + def _allocate(self, batch_size, device, dtype): + block: GPTBigCodeBlock = self.model.h[0] + attn = block.attn + self.batch_size = batch_size + self.dtype = dtype + self.device = device + self.softmax_dtype = torch.float32 if attn.attention_softmax_in_fp32 else self.dtype + self.upcast = self.softmax_dtype != self.dtype + + do_unscale = attn.scale_attention_softmax_in_fp32 and self.upcast + self.unscale = [i + 1.0 if do_unscale else 1.0 for i in range(self.n_layer)] + scale = attn.head_dim**-0.5 if attn.scale_attn_weights else 1 + self.scale = [scale / unscale for unscale in self.unscale] + + factory_kwargs = {"device": self.device, "dtype": self.dtype} + + hidden_end = self.batch_size * attn.embed_dim + # Query: (bs, embed_dim), also used for attn outputs (no overlap with value). + query_begin = _align_tensor(hidden_end) + query_end = query_begin + self.batch_size * attn.embed_dim + # KV: (bs, 2 * kv_dim), combines with query into c_attn. + kv_end = query_end + 2 * self.batch_size * attn.kv_dim + # Attn weights: (batch_size, num_heads, key_length), no overlap with value + attn_weights_begin = _align_tensor(kv_end) + attn_weights_end = kv_end + self.batch_size * attn.num_heads * self.max_sequence_length + # Projection: (batch_size, embed_dim), no overlap with attn outputs ~ query. + # Also used for MLP projection + c_proj_begin = _align_tensor(query_end) + c_proj_end = c_proj_begin + self.batch_size * attn.embed_dim + c_fc_begin = query_begin + c_fc_end = c_fc_begin + self.batch_size * block.inner_dim + pool_size = max(attn_weights_end, c_proj_end, c_fc_end) + + print( + f"Allocating inference buffers (batch size = {self.batch_size}, max sequence length =" + f" {self.max_sequence_length})..." + ) + + kv_caches = [] + for block in self.model.h: + block.attn.freeze_kv_cache() + kv_cache = block.attn.get_kv_cache(self.batch_size, self.max_sequence_length, self.device, self.dtype) + if attn.multi_query: + kv_cache = kv_cache.unsqueeze(1) + kv_caches.append(kv_cache) + + kv_cache_size = sum(kv_cache.numel() for kv_cache in kv_caches) + + print(f" Activation pool size: {pool_size:,}") + print(f" KV cache size: {kv_cache_size:,}") + buffer_memory = (pool_size + kv_cache_size) * torch.finfo( + self.dtype + ).bits / 8 + self.batch_size * self.max_sequence_length + print(f" Memory usage: {buffer_memory/2**20:.0f} MiB") + + activation_pool = torch.empty(pool_size, **factory_kwargs) + self.mask_value = torch.full( + [], torch.finfo(self.softmax_dtype).min, dtype=self.softmax_dtype, device=self.device + ) + # We ensure mask tensors are contiguous to enable more efficient kernels. + attn_mask = torch.empty(self.batch_size * self.max_sequence_length, dtype=torch.bool, device=self.device) + + if self.device.type == "cuda": + print(f" Memory allocated {torch.cuda.memory_allocated()/2**20:.0f} MiB") + # Max stats give some insight on the prefill memory usage. + print(f" Max memory allocated {torch.cuda.max_memory_allocated()/2**20:.0f} MiB") + print(f" Max memory reserved {torch.cuda.max_memory_reserved()/2**20:.0f} MiB") + + key_lengths = range(self.max_sequence_length + 1) + padded_key_lengths = [key_length + -key_length % self.pad_key_length for key_length in key_lengths] + + self.padded_attn_masks = [ + attn_mask[: self.batch_size * key_length].view(self.batch_size, 1, key_length) + for key_length in padded_key_lengths + ] + self.attn_masks = [ + padded_attn_mask[:, :, :key_length].squeeze(1) + for key_length, padded_attn_mask in enumerate(self.padded_attn_masks) + ] + self.attn_mask_pads = [ + padded_attn_mask[:, :, key_length:].squeeze(1) + for key_length, padded_attn_mask in enumerate(self.padded_attn_masks) + ] + + # Hidden: (batch_size, 1, embed_dim), no overlap allowed. + self.hidden_states_squeezed = activation_pool[:hidden_end].view(self.batch_size, -1) + self.hidden_states = self.hidden_states_squeezed.unsqueeze(1) + # QKV: (bs, embed_dim + 2 * kv_dim). + self.c_attn = activation_pool[query_begin:kv_end].view(self.batch_size, -1) + self.query = self.c_attn[:, : attn.embed_dim].view(self.batch_size, attn.num_heads, attn.head_dim) + self.kv_attn = self.c_attn[:, attn.embed_dim :] + + keys, values = zip(*(kv_cache.split((attn.head_dim, attn.head_dim), dim=-1) for kv_cache in kv_caches)) + head_slice = 0 if attn.multi_query else slice(None) + + self.padded_keys = [ + [key[:, head_slice, :key_length, :].transpose(-1, -2) for key in keys] for key_length in padded_key_lengths + ] + self.padded_values = [ + [value[:, head_slice, :key_length, :] for value in values] for key_length in padded_key_lengths + ] + + # This is nonsense for key_length == 0, but we never need the value. + self.current_key_values = [ + [kv_cache[:, head_slice, key_length - 1, :] for kv_cache in kv_caches] for key_length in key_lengths + ] + self.past_key_values = [ + [kv_cache[:, head_slice, : key_length - 1, :] for kv_cache in kv_caches] for key_length in key_lengths + ] + + # Attn weights: (batch_size, num_heads, key_length), no overlap with value. + attn_weights = activation_pool[attn_weights_begin:attn_weights_end].view( + self.batch_size, attn.num_heads, self.max_sequence_length + ) + self.padded_attn_weights = [attn_weights[:, :, :key_length] for key_length in padded_key_lengths] + + # Attn outputs: (batch_size, embed_dim), no overlap with value. + self.attn_output = activation_pool[query_begin:query_end].view(self.batch_size, -1) + self.attn_output_expanded = self.attn_output.view(self.batch_size, attn.num_heads, attn.head_dim) + # Attn projection: (batch_size, embed_dim), no overlap with attn outputs. + self.c_proj = activation_pool[c_proj_begin:c_proj_end].view(self.batch_size, -1) + + # MLP first layer: (batch_size, embed_dim) + self.mlp_c_fc = activation_pool[c_fc_begin:c_fc_end].view(self.batch_size, -1) + # MLP projection: (batch_size, inner_dim) + self.mlp_c_proj = activation_pool[query_begin:query_end].view(self.batch_size, -1) + + if self.inference_runner_type != InferenceRunnerType.BASE_RUNNER: + print("Generating cuda graphs") + self.memory_pool = None + if self.inference_runner_type == InferenceRunnerType.FULL_GRAPH: + self.cuda_graphs = {} + # The output may not always be at the same memory location. + self.output_hidden_states = {} + # Generate the largest one first to warm up the memory pool. + # The other ones are generated lazily. + self._generate_full_cuda_graph(self.max_sequence_length) + else: + self._generate_cuda_graphs() + + def _generate_cuda_graphs(self): + self.cuda_graphs = {} + for layer_idx in range(self.n_layer + 1): + graph = torch.cuda.CUDAGraph() + + with torch.cuda.graph(graph, pool=self.memory_pool): + if layer_idx > 0: + self._forward_post_attn(self.model.h[layer_idx - 1]) + if layer_idx < self.n_layer: + self._forward_qkv(self.model.h[layer_idx]) + else: + self.output_hidden_states = self._forward_end() + if self.memory_pool is None: + self.memory_pool = graph.pool() + self.cuda_graphs[layer_idx] = graph + + def _generate_full_cuda_graph(self, key_length): + # We need to warmup the jit function before creating the graph, otherwise it will crash. + # Warmup needs to be done for every input shape (key length), and for both scale == 1 and scale != 1 + if self.upcast: + for scale in (1.0, 2.0): + upcast_masked_softmax( + self.padded_attn_weights[key_length], + self.padded_attn_masks[key_length], + self.mask_value, + scale, + self.softmax_dtype, + ) + else: + masked_softmax( + self.padded_attn_weights[key_length], + self.padded_attn_masks[key_length], + self.mask_value, + ) + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, pool=self.memory_pool): + self.output_hidden_states[key_length] = self._forward(key_length) + if self.memory_pool is None: + self.memory_pool = graph.pool() + self.cuda_graphs[key_length] = graph + + def _forward_embed(self, input_ids, position_ids): + # Embedding doesn't support out argument. + inputs_embeds = self.model.wte(input_ids) + position_embeds = self.model.wpe(position_ids) + torch.add(inputs_embeds, position_embeds, out=self.hidden_states) + + def _forward_qkv(self, block): + # LN doesn't support out argument. + hidden_states = block.ln_1(self.hidden_states_squeezed) + torch.nn.functional.linear( + hidden_states, + block.attn.c_attn.weight, + block.attn.c_attn.bias, + out=self.c_attn, + ) + + def _forward_attn(self, block, key_length): + layer_idx = block.attn.layer_idx + self.current_key_values[key_length][layer_idx].copy_(self.kv_attn) + attn_weights = self.padded_attn_weights[key_length] + + torch.baddbmm( + attn_weights, + self.query, + self.padded_keys[key_length][layer_idx], + beta=0, + alpha=self.scale[layer_idx], + out=attn_weights, + ) + # Use a fused kernel to prevent a large overhead from casting and scaling. + # Jit doesn't allow inplace kernel. + if self.upcast: + attn_weights = upcast_masked_softmax( + attn_weights, + self.padded_attn_masks[key_length], + self.mask_value, + self.unscale[layer_idx], + self.softmax_dtype, + ) + else: + attn_weights = masked_softmax(attn_weights, self.padded_attn_masks[key_length], self.mask_value) + + torch.bmm(attn_weights, self.padded_values[key_length][layer_idx], out=self.attn_output_expanded) + + def _forward_post_attn(self, block): + torch.nn.functional.linear( + self.attn_output, + block.attn.c_proj.weight, + block.attn.c_proj.bias, + out=self.c_proj, + ) + self.hidden_states_squeezed.add_(self.c_proj) + # LN doesn't support out argument. + hidden_states = block.ln_2(self.hidden_states_squeezed) + torch.nn.functional.linear(hidden_states, block.mlp.c_fc.weight, block.mlp.c_fc.bias, out=self.mlp_c_fc) + # Most activations don't support out argument. + feed_forward_hidden_states = block.mlp.act(self.mlp_c_fc) + torch.nn.functional.linear( + feed_forward_hidden_states, block.mlp.c_proj.weight, block.mlp.c_proj.bias, out=self.mlp_c_proj + ) + self.hidden_states_squeezed.add_(self.mlp_c_proj) + + def _forward_end(self): + # LN doesn't support out argument. + return self.model.ln_f(self.hidden_states) + + def _forward(self, key_length): + for block in self.model.h: + self._forward_qkv(block) + self._forward_attn(block, key_length) + self._forward_post_attn(block) + return self._forward_end() + + def forward( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: Union[List[torch.Tensor], int], + ) -> BaseModelOutputWithPastAndCrossAttentions: + batch_size, query_length = input_ids.shape + assert query_length == 1 + if self.batch_size is None: + self._allocate(batch_size, device=input_ids.device, dtype=self.model.dtype) + elif self.validate_input: + assert batch_size == self.batch_size + assert self.dtype == self.model.dtype + assert self.device == input_ids.device + + if self.validate_input: + assert attention_mask.dim() == 2 + assert attention_mask.shape[0] == batch_size + key_length = attention_mask.shape[1] + assert key_length <= self.max_sequence_length + if isinstance(past_key_values, int): + assert key_length == past_key_values + 1 + else: + key_length = attention_mask.shape[1] + + self._forward_embed(input_ids, position_ids) + + self.attn_masks[key_length].copy_(attention_mask) + + attn_mask_pad = self.attn_mask_pads[key_length] + if attn_mask_pad.size(1) > 0: + attn_mask_pad.fill_(False) + + if self.inference_runner_type == InferenceRunnerType.FULL_GRAPH: + if key_length not in self.cuda_graphs: + self._generate_full_cuda_graph(key_length) + self.cuda_graphs[key_length].replay() + hidden_states = self.output_hidden_states[key_length] + elif self.inference_runner_type == InferenceRunnerType.PARTIAL_GRAPH: + for i, block in enumerate(self.model.h): + self.cuda_graphs[i].replay() + self._forward_attn(block, key_length) + self.cuda_graphs[self.n_layer].replay() + hidden_states = self.output_hidden_states + else: + hidden_states = self._forward(key_length) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=key_length, + ) diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 72858532bf..76ba07b73e 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -34,7 +34,7 @@ add_start_docstrings_to_model_forward, logging, ) -from .configuration_gpt_bigcode import GPTBigCodeConfig +from .configuration_gpt_bigcode import GPTBigCodeConfig, InferenceRunnerType logger = logging.get_logger(__name__) @@ -105,6 +105,14 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): config.scale_attention_softmax_in_fp32 and config.attention_softmax_in_fp32 ) + # KV caching and padding + self.kv_cache = None + self.kv_cache_max_batch_size = config.max_batch_size or 0 + self.kv_cache_max_sequence_length = config.max_sequence_length or config.n_positions + self.pre_allocate_kv_cache = config.pre_allocate_kv_cache + self.pad_key_length = config.pad_key_length and config.pre_allocate_kv_cache + self._frozen_kv_cache = False + if self.is_cross_attention: if self.multi_query: raise NotImplementedError("Multi-Query Attention not supported for cross_attention") @@ -202,6 +210,41 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): return attn_output, attn_weights + def freeze_kv_cache(self, enable=True): + if self.kv_cache is None: + raise RuntimeError("KV cache not found.") + # Prevent re-allocation of the KV cache. + self._frozen_kv_cache = enable + + def get_kv_cache(self, batch_size, sequence_length, device, dtype, allocate=True): + if ( + self.kv_cache is None + or self.kv_cache.dtype != dtype + or self.kv_cache.device != device + or batch_size > self.kv_cache_max_batch_size + or sequence_length > self.kv_cache_max_sequence_length + ): + if self._frozen_kv_cache or not allocate: + if self.kv_cache is None: + raise RuntimeError("KV cache not found.") + else: + raise RuntimeError( + f"Invalid KV cache: " + f"existing = {(self.kv_cache.dtype,self.kv_cache.device,self.kv_cache_max_batch_size,self.kv_cache_max_sequence_length)}, " + f"requested = {(dtype,device,batch_size,sequence_length)}" + ) + # Free memory first. + self.kv_cache = None + self.kv_cache_max_sequence_length = max(sequence_length, self.kv_cache_max_sequence_length) + self.kv_cache_max_batch_size = max(batch_size, self.kv_cache_max_batch_size) + kv_cache_size = 2 * self.kv_cache_max_batch_size * self.kv_cache_max_sequence_length * self.kv_dim + self.kv_cache = torch.empty([kv_cache_size], device=device, dtype=dtype) + # This view ensures the cache is contiguous for all batch sizes. + kv_cache = self.kv_cache[: 2 * batch_size * self.kv_cache_max_sequence_length * self.kv_dim].view( + batch_size, self.kv_heads, self.kv_cache_max_sequence_length, 2 * self.head_dim + ) + return kv_cache[:, 0, :sequence_length, :] if self.multi_query else kv_cache[:, :, :sequence_length, :] + def forward( self, hidden_states: torch.Tensor, @@ -239,9 +282,27 @@ def forward( .split((self.head_dim, 2 * self.head_dim), dim=3) ) - if layer_past is not None: - key_value = torch.cat((layer_past, key_value), dim=-2) - present = key_value if use_cache else None + present = None + + if self.pre_allocate_kv_cache: + if use_cache or layer_past is not None: + last_key_length = layer_past or 0 + batch_size = key_value.size(0) + key_length = last_key_length + key_value.size(-2) + padded_key_length = key_length + -key_length % (8 if self.pad_key_length else 1) + kv_cache = self.get_kv_cache( + batch_size, padded_key_length, key_value.device, key_value.dtype, allocate=last_key_length == 0 + ) + if self.multi_query: + kv_cache[:, last_key_length:key_length, :].copy_(key_value) + key_value = kv_cache + if use_cache: + present = key_length + else: + if layer_past is not None: + key_value = torch.cat((layer_past, key_value), dim=-2) + if use_cache: + present = key_value key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) @@ -513,6 +574,17 @@ def __init__(self, config): self.h = nn.ModuleList([GPTBigCodeBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)]) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + self.pre_allocate_kv_cache = config.pre_allocate_kv_cache + self.pad_key_length = config.pad_key_length and self.pre_allocate_kv_cache + self.inference_runner_type = InferenceRunnerType(config.inference_runner) + + if self.inference_runner_type == InferenceRunnerType.NO_RUNNER: + self.inference_runner = None + else: + from .inference_runner import GPTBigCodeInferenceRunner + + self.inference_runner = GPTBigCodeInferenceRunner(config, self) + max_positions = config.max_position_embeddings self.register_buffer( "bias", torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)), persistent=False @@ -551,6 +623,31 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + if self.inference_runner is not None and past_key_values is not None: + if self.config.validate_runner_input: + assert input_ids is not None + assert past_key_values is not None + assert attention_mask is not None + assert token_type_ids is None + assert position_ids is not None + assert head_mask is None + assert inputs_embeds is None + assert encoder_hidden_states is None + assert encoder_attention_mask is None + use_cache = use_cache if use_cache is not None else self.config.use_cache + assert use_cache is True + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) + assert output_attentions is False + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + assert output_hidden_states is False + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + assert return_dict is True + return self.inference_runner.forward(input_ids, attention_mask, position_ids, past_key_values) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -583,8 +680,10 @@ def forward( if past_key_values is None: past_length = 0 past_key_values = tuple([None] * len(self.h)) + elif self.pre_allocate_kv_cache: + past_length = past_key_values[0] else: - past_length = past_key_values[0][0].size(-2) + past_length = past_key_values[0].size(-2) if attention_mask is not None and len(attention_mask.shape) == 2 and position_ids is None: # create position_ids on the fly for batch generation @@ -610,6 +709,11 @@ def forward( # MHA models: (batch_size, n_heads, query_length, key_length) attention_mask = self_attention_mask.unsqueeze(2 if self.multi_query else 1) + if self.pad_key_length: + pad = -key_length % 8 + if pad > 0: + attention_mask = torch.nn.functional.pad(attention_mask, (0, pad), mode="constant", value=False) + # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] if ( diff --git a/src/transformers/models/megatron_gpt_bigcode/__init__.py b/src/transformers/models/megatron_gpt_bigcode/__init__.py new file mode 100644 index 0000000000..8228eea536 --- /dev/null +++ b/src/transformers/models/megatron_gpt_bigcode/__init__.py @@ -0,0 +1,17 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2021 NVIDIA Corporation and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/transformers/models/megatron_gpt_bigcode/checkpoint_reshaping_and_interoperability.py b/src/transformers/models/megatron_gpt_bigcode/checkpoint_reshaping_and_interoperability.py new file mode 100644 index 0000000000..e854d7df19 --- /dev/null +++ b/src/transformers/models/megatron_gpt_bigcode/checkpoint_reshaping_and_interoperability.py @@ -0,0 +1,897 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# TODO: Update this file + +import argparse +import json +import os +import re +import sys +import types + +import torch + +from transformers import AutoTokenizer, GPT2Config +from transformers.modeling_utils import WEIGHTS_INDEX_NAME, WEIGHTS_NAME, shard_checkpoint + + +def add_checkpointing_args(parser): + parser.add_argument("--megatron-path", type=str, default=None, help="Base directory of Megatron repository") + parser.add_argument( + "--convert_checkpoint_from_megatron_to_transformers", + action="store_true", + help=( + "If True, convert a Megatron checkpoint to a Transformers checkpoint. " + "If False, convert a Transformers checkpoint to a Megatron checkpoint." + ), + ) + parser.add_argument( + "--load_path", + type=str, + required=True, + help="Path to the checkpoint to convert.", + ) + parser.add_argument( + "--save_path", + type=str, + required=True, + help="Path to the converted checkpoint.", + ) + parser.add_argument("--print-checkpoint-structure", action="store_true") + return parser + + +def add_megatron_checkpoint_args(parser): + parser.add_argument( + "--target_tensor_model_parallel_size", + type=int, + default=1, + help=( + "The tensor model parallel size of the converted checkpoint. " + "Only used when converting a Transformers checkpoint to a Megatron checkpoint." + ), + ) + parser.add_argument( + "--target_pipeline_model_parallel_size", + type=int, + default=1, + help=( + "The pipeline model parallel size of the converted checkpoint. " + "Only used when converting a Transformers checkpoint to a Megatron checkpoint." + ), + ) + parser.add_argument( + "--target_data_parallel_size", + type=int, + default=1, + help=( + "The data parallel size of the converted checkpoint. " + "Only used when converting a Transformers checkpoint to a Megatron checkpoint." + ), + ) + parser.add_argument( + "--target_params_dtype", + type=str, + default="fp32", + help=( + "The dtype of the converted checkpoint. " + "Only used when converting a Transformers checkpoint to a Megatron checkpoint." + ), + ) + parser.add_argument( + "--make_vocab_size_divisible_by", + type=int, + default=128, + help=( + "Pad the vocab size to be divisible by this value. " + "This is added for computational efficieny reasons. " + "Only used when converting a Transformers checkpoint to a Megatron checkpoint." + ), + ) + parser.add_argument( + "--use_distributed_optimizer", + action="store_true", + help=( + "If True, use the distributed optimizer. " + "Only used when converting a Transformers checkpoint to a Megatron checkpoint." + ), + ) + return parser + + +def add_transformers_checkpoint_args(parser): + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help=( + "The name of the pre-trained tokenizer to save. " + "If not None, the tokenizer will be saved. " + "Only used when converting a Megatron checkpoint to a Transformers checkpoint." + ), + ) + parser.add_argument( + "--max_shard_size", + type=str, + default="10GB", + help=( + "The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size " + "lower than this size. If expressed as a string, needs to be digits followed by a unit (like `5MB`). " + "Only used when converting a Megatron checkpoint to a Transformers checkpoint." + ), + ) + + return parser + + +# The simple map of names for "automated" rules. +megatron_to_transformers = { + "attention.dense": ".attn.c_proj.", + "self_attention.dense": ".attn.c_proj.", + "mlp.dense_h_to_4h": ".mlp.c_fc.", + "mlp.dense_4h_to_h": ".mlp.c_proj.", +} +transformers_to_megatron = {v[1:-1]: k for k, v in megatron_to_transformers.items()} + +tensor_parallel_params = [ + # megatron-lm layers to merge across tp ranks + "self_attention.query_key_value.weight", + "self_attention.query_key_value.bias", + "self_attention.dense.weight", + "mlp.dense_h_to_4h.weight", + "mlp.dense_h_to_4h.bias", + "mlp.dense_4h_to_h.weight", + # deprecated + "attention.query_key_value.weight", + "attention.query_key_value.bias", + "attention.dense.weight", + # transformers layers to split across tp ranks + "attn.c_attn.weight", + "attn.c_attn.bias", + "attn.c_proj.weight", + "mlp.c_fc.weight", + "mlp.c_fc.bias", + "mlp.c_proj.weight", +] + + +def recursive_print(name, val, spaces=0): + """ + Recursively print the structure of a checkpoint. This function is taken from `convert_megatron_gpt2_checkpoint.py` + + Args: + name (str): the name of the current tensor parameter + val (Tuple(int)): the shape of the current tensor parameter + spaces (int): the number of spaces to print before the output for a nested structure + """ + # Format the message. + if name is None: + msg = None + else: + fmt = "." * max(0, spaces - 2) + "# {:" + str(50 - spaces) + "s}" + msg = fmt.format(name) + + # Print and recurse (if needed). + if isinstance(val, dict): + if msg is not None: + print(msg) + for k in val.keys(): + recursive_print(k, val[k], spaces + 2) + elif isinstance(val, torch.Tensor): + print(msg, ":", val.size()) + else: + print(msg, ":", val) + + +def megatron_to_transformers_fix_query_key_value_ordering( + param, checkpoint_version, num_splits, num_heads, hidden_size +): + """ + Permutes layout of param tensor to [num_splits * num_heads * hidden_size, :] for compatibility with later versions + of NVIDIA Megatron-LM. The inverse operation is performed inside Megatron-LM to read checkpoints: + https://github.com/NVIDIA/Megatron-LM/blob/v2.4/megatron/checkpointing.py#L209 If param is the weight tensor of the + self-attention block, the returned tensor will have to be transposed one more time to be read by HuggingFace GPT2. + This function is taken from `convert_megatron_gpt2_checkpoint.py` + + Args: + param (torch.Tensor): the tensor to permute + checkpoint_version (int): the version of the checkpoint. + num_splits (int): the number of projections, usually 3 for (Query, Key, Value) + num_heads (int): the number of attention heads + hidden_size (int): the hidden size per head + """ + + input_shape = param.size() + if checkpoint_version == 1.0: + # version 1.0 stores [num_heads * hidden_size * num_splits, :] + saved_shape = (num_heads, hidden_size, num_splits) + input_shape[1:] + param = param.view(*saved_shape) + param = param.transpose(0, 2) + param = param.transpose(1, 2).contiguous() + elif checkpoint_version >= 2.0: + # other versions store [num_heads * num_splits * hidden_size, :] + saved_shape = (num_heads, num_splits, hidden_size) + input_shape[1:] + param = param.view(*saved_shape) + param = param.transpose(0, 1).contiguous() + param = param.view(*input_shape) + return param + + +def transformers_to_megatron_fix_query_key_value_ordering( + param, checkpoint_version, num_splits, num_heads, hidden_size +): + """ + Permutes layout of param tensor to the one compatible with respective NVIDIA Megatron-LM chekpoint versions. Input + is [num_splits * num_heads * hidden_size, :] and output is [num_heads * hidden_size * num_splits, :] for version + 1.0 and [num_heads * num_splits * hidden_size, :] for version 2.0 and later. If param is the weight tensor of the + self-attention block, the param needs to be already transposed before calling this function. + + Args: + param (torch.Tensor): the tensor to permute + checkpoint_version (int): the version of the checkpoint. + num_splits (int): the number of projections, usually 3 for (Query, Key, Value) + num_heads (int): the number of attention heads + hidden_size (int): the hidden size per head + """ + + # Input is [num_splits * num_heads * hidden_size, :] + input_shape = param.size() + if checkpoint_version == 1.0: + # version 1.0 stores [num_heads * hidden_size * num_splits, :] + current_shape = (num_splits, num_heads, hidden_size) + input_shape[1:] + param = param.view(*current_shape) + param = param.transpose(0, 2) + param = param.transpose(1, 2).contiguous() + elif checkpoint_version >= 2.0: + # other versions store [num_heads * num_splits * hidden_size, :] + current_shape = (num_splits, num_heads, hidden_size) + input_shape[1:] + param = param.view(*current_shape) + param = param.transpose(0, 1).contiguous() + param = param.view(*input_shape) + return param + + +def merge_transformers_sharded_states(path, num_checkpoints): + """ + Merge sharded checkpoints from transformers into a single checkpoint. + + Args: + path (str): the path to the sharded checkpoints + num_checkpoints (int): the number of checkpoints to merge + """ + state_dict = {} + for i in range(1, num_checkpoints + 1): + checkpoint_path = os.path.join(path, f"pytorch_model-{i:05d}-of-{num_checkpoints:05d}.bin") + current_chunk = torch.load(checkpoint_path, map_location="cpu") + state_dict.update(current_chunk) + return state_dict + + +def get_megatron_sharded_states(args, tp_size, pp_size, pp_rank): + """ + Get sharded checkpoints from NVIDIA Megatron-LM checkpoint based on the provided tensor parallel size, pipeline + parallel size and pipeline parallel rank. + + Args: + args (argparse.Namespace): the arguments to the script + tp_size (int): the tensor parallel size + pp_size (int): the pipeline parallel size + pp_rank (int): the pipeline parallel rank + """ + tp_state_dicts = [] + for i in range(tp_size): + sub_dir_name = f"mp_rank_{i:02d}" if pp_size == 1 else f"mp_rank_{i:02d}_{pp_rank:03d}" + checkpoint_name = os.listdir(os.path.join(args.load_path, sub_dir_name))[0] + checkpoint_path = os.path.join(args.load_path, sub_dir_name, checkpoint_name) + state_dict = torch.load(checkpoint_path, map_location="cpu") + tp_state_dicts.append(state_dict) + return tp_state_dicts + + +def get_element_from_dict_by_path(d, path): + """ + Get element from dictionary by path. If element is not present, recursively add empty dictionaries. + + Args: + d (dict): the dictionary to get the element from + path (list): the path to the element which is delimited by "." + """ + path = path.split(".") + for k in path: + if k not in d: + d[k] = {} + d = d[k] + return d + + +def convert_checkpoint_from_megatron_to_transformers(args): + """ + Convert NVIDIA Megatron-LM checkpoint to HuggingFace Transformers checkpoint. This handles Megatron checkpoints + with different tensor parallelism and pipeline parallelism sizes. It saves the converted checkpoint into shards + using HuggingFace Transformers checkpoint sharding functionality. This greatly extends the functionality of + `convert_megatron_gpt2_checkpoint.py` + + Args: + args (argparse.Namespace): the arguments to the script + """ + # Load Megatron-LM checkpoint arguments from the state dict + sub_dirs = os.listdir(args.load_path) + possible_sub_dirs = ["mp_rank_00", "mp_rank_00_000"] + for sub_dir in possible_sub_dirs: + if sub_dir in sub_dirs: + rank0_checkpoint_name = os.listdir(os.path.join(args.load_path, sub_dir))[0] + rank0_checkpoint_path = os.path.join(args.load_path, sub_dir, rank0_checkpoint_name) + break + print(f"Loading Megatron-LM checkpoint arguments from: {rank0_checkpoint_path}") + state_dict = torch.load(rank0_checkpoint_path, map_location="cpu") + megatron_args = state_dict.get("args", None) + if megatron_args is None: + raise ValueError( + "Megatron-LM checkpoint does not contain arguments. This utility only supports Megatron-LM checkpoints" + " containing all the megatron arguments. This is because it loads all config related to model" + " architecture, the tensor and pipeline model parallel size from the checkpoint insead of user having to" + " manually specify all the details. Please save Megatron-LM checkpoint along with all the megatron" + " arguments to use this utility." + ) + + # Create Transformers GPT2 config from Megatron-LM arguments + if megatron_args is not None: + if megatron_args.bias_gelu_fusion: + activation_function = "gelu_fast" + elif megatron_args.openai_gelu: + activation_function = "gelu_new" + else: + activation_function = "gelu" + else: + # in the very early days this used to be "gelu_new" + activation_function = "gelu_new" + vocab_size = ( + megatron_args.padded_vocab_size + if getattr(megatron_args, "orig_vocab_size", None) is None + else megatron_args.orig_vocab_size + ) + print(vocab_size) + + config = GPT2Config( + vocab_size=vocab_size, + n_positions=megatron_args.max_position_embeddings, + n_embd=megatron_args.hidden_size, + n_layer=megatron_args.num_layers, + n_head=megatron_args.num_attention_heads, + n_inner=megatron_args.ffn_hidden_size, + activation_function=activation_function, + resid_pdrop=0.1, + embd_pdrop=0.1, + attn_pdrop=0.1, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + summary_type="cls_index", + summary_use_proj=True, + summary_activation=None, + summary_proj_to_labels=True, + summary_first_dropout=0.1, + scale_attn_weights=True, + use_cache=True, + bos_token_id=vocab_size - 1, + eos_token_id=vocab_size - 1, + architectures=["GPT2LMHeadModel"], + ) + + output_state_dict = {} + + checkpoint_version = state_dict.get("checkpoint_version", 0.0) + tp_size = megatron_args.tensor_model_parallel_size + pp_size = megatron_args.pipeline_model_parallel_size + dtype = torch.float32 + # The regex to extract layer names. + layer_re = re.compile("layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)") + + # Convert. + print("Converting") + + # Embeddings + print("Converting embeddings") + tp_state_dicts = get_megatron_sharded_states(args, tp_size, pp_size, 0) + + # Convert and store the position embeddings. + position_embeddings = get_element_from_dict_by_path( + tp_state_dicts[0], "model.language_model.embedding.position_embeddings.weight" + ) + output_state_dict["transformer.wpe.weight"] = position_embeddings.to(dtype) + + # Convert and store the word embeddings. + word_embeddings = torch.cat( + [ + get_element_from_dict_by_path( + tp_state_dicts[tp_rank], "model.language_model.embedding.word_embeddings.weight" + ) + for tp_rank in range(tp_size) + ], + dim=0, + ) + word_embeddings = word_embeddings[:vocab_size].to(dtype) + output_state_dict["transformer.wte.weight"] = word_embeddings + + # Transformer Layers + print("Converting transformer layers") + # The number of heads. + heads = config.n_head + # The hidden_size per head. + hidden_size_per_head = config.n_embd // config.n_head + n_positions = config.n_positions + num_layers = config.num_hidden_layers // pp_size + + for pp_rank in range(pp_size): + if pp_size > 0: + print(f"Converting pipeline parallel rank {pp_rank}") + tp_state_dicts = get_megatron_sharded_states(args, tp_size, pp_size, pp_rank) + + # The transformer. + path = ( + "model.language_model.transformer" + if "transformer" in get_element_from_dict_by_path(tp_state_dicts[0], "model.language_model").keys() + else "model.language_model.encoder" + ) + # Extract the layers. + for key, val in get_element_from_dict_by_path(tp_state_dicts[0], path).items(): + # Match the name. + m = layer_re.match(key) + # Stop if that's not a layer + if m is None: + break + + # The index of the layer. + layer_idx = int(m.group(1)) + pp_rank * num_layers + # The name of the operation. + op_name = m.group(2) + # Is it a weight or a bias? + weight_or_bias = m.group(3) + + # The name of the layer. + layer_name = f"transformer.h.{layer_idx}" + + if op_name + "." + weight_or_bias not in tensor_parallel_params: + params = val.to(dtype) + else: + dim = 1 if op_name in ["self_attention.dense", "mlp.dense_4h_to_h", "attention.dense"] else 0 + params = torch.cat( + [val] + + [ + get_element_from_dict_by_path(tp_state_dicts[tp_rank], f"{path}")[key] + for tp_rank in range(1, tp_size) + ], + dim=dim, + ).to(dtype) + + # For layernorm(s), simply store the layer norm. + if op_name.endswith("layernorm"): + ln_name = "ln_1" if op_name.startswith("input") else "ln_2" + output_state_dict[layer_name + "." + ln_name + "." + weight_or_bias] = params + + # Transpose the QKV matrix. + elif ( + op_name == "attention.query_key_value" or op_name == "self_attention.query_key_value" + ) and weight_or_bias == "weight": + # Insert a tensor of 1x1xDxD bias. + causal_mask = torch.tril(torch.ones((n_positions, n_positions), dtype=dtype)).view( + 1, 1, n_positions, n_positions + ) + output_state_dict[layer_name + ".attn.bias"] = causal_mask + + # Insert a "dummy" tensor for masked_bias. + masked_bias = torch.tensor(-1e4, dtype=dtype) + output_state_dict[layer_name + ".attn.masked_bias"] = masked_bias + + out_val = megatron_to_transformers_fix_query_key_value_ordering( + params, + checkpoint_version, + 3, + heads, + hidden_size_per_head, + ) + # Megatron stores (3*D) x D but transformers-GPT2 expects D x 3*D. + out_val = out_val.transpose(0, 1).contiguous() + # Store. + output_state_dict[layer_name + ".attn.c_attn.weight"] = out_val + + # Transpose the bias. + elif ( + op_name == "attention.query_key_value" or op_name == "self_attention.query_key_value" + ) and weight_or_bias == "bias": + out_val = megatron_to_transformers_fix_query_key_value_ordering( + params, checkpoint_version, 3, heads, hidden_size_per_head + ) + # Store. No change of shape. + output_state_dict[layer_name + ".attn.c_attn.bias"] = out_val + + # Transpose the weights. + elif weight_or_bias == "weight": + out_name = megatron_to_transformers[op_name] + output_state_dict[layer_name + out_name + "weight"] = params.transpose(0, 1) + + # Copy the bias. + elif weight_or_bias == "bias": + out_name = megatron_to_transformers[op_name] + output_state_dict[layer_name + out_name + "bias"] = params + + if config.n_layer != (layer_idx + 1): + raise ValueError(f"Expected {config.n_layer} layers but found {layer_idx + 1}") + + # The final layernorm. + print("Converting final layernorm") + params = get_element_from_dict_by_path(tp_state_dicts[0], str(path)) + output_state_dict["transformer.ln_f.weight"] = params["final_layernorm.weight"].to(dtype) + output_state_dict["transformer.ln_f.bias"] = params["final_layernorm.bias"].to(dtype) + + # For LM head, transformers' wants the matrix to weight embeddings. + print("Converting LM head") + output_state_dict["lm_head.weight"] = word_embeddings.to(dtype) + + # It should be done! + print("Conversion from Megatron-LM to Transformers is done!") + + # Print the structure of converted state dict. + if args.print_checkpoint_structure: + recursive_print(None, output_state_dict) + + # Add tokenizer class info to config + # see https://github.com/huggingface/transformers/issues/13906) + + if args.tokenizer_name is None: + tokenizer_name = "gpt2" + else: + tokenizer_name = args.tokenizer_name + + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + tokenizer_class = type(tokenizer).__name__ + config.tokenizer_class = tokenizer_class + + # Store the config to file. + print("Saving config") + config.save_pretrained(args.save_path) + + # Save tokenizer based on args + if args.tokenizer_name is not None: + print(f"Adding {tokenizer_class} tokenizer files") + tokenizer.save_pretrained(args.save_path) + + # Store the state_dict to file. + max_shard_size = int(args.max_shard_size) if args.max_shard_size.isdigit() else args.max_shard_size + shards, index = shard_checkpoint(output_state_dict, max_shard_size=max_shard_size) + + # Save the model + for shard_file, shard in shards.items(): + torch.save(shard, os.path.join(args.save_path, shard_file)) + + if index is None: + print(f"Model weights saved in {os.path.join(args.save_path, WEIGHTS_NAME)}") + else: + save_index_file = os.path.join(args.save_path, WEIGHTS_INDEX_NAME) + # Save the index as well + with open(save_index_file, "w", encoding="utf-8") as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) + print( + f"The model is bigger than the maximum size per checkpoint ({args.max_shard_size}) and is going to be " + f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) + + +def convert_checkpoint_from_transformers_to_megatron(args): + """ + Convert a checkpoint from HuggingFace Transformers to Megatron-LM. This allows converted checkpoints with variable + tensor parallelism and pipeline parallelism sizes. It takes as input a checkpoint from HuggingFace Transformers + which can have multiple shards. + + Args: + args (argparse.Namespace): the arguments to the script + + """ + os.makedirs(args.save_path, exist_ok=True) + # Search in directory above this + sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) + if args.megatron_path is not None: + sys.path.insert(0, args.megatron_path) + + try: + from megatron.tokenizer.tokenizer import _vocab_size_with_padding + except ModuleNotFoundError: + print("Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting.") + exit(1) + + # load the transformers model state dict and config + sub_dirs = [x for x in os.listdir(args.load_path) if x.startswith("pytorch_model")] + if len(sub_dirs) == 1: + checkpoint_name = "pytorch_model.bin" + state_dict = torch.load(os.path.join(args.load_path, checkpoint_name), map_location="cpu") + else: + num_checkpoints = len(sub_dirs) - 1 + state_dict = merge_transformers_sharded_states(args.load_path, num_checkpoints) + + config = GPT2Config.from_pretrained(args.load_path) + + # Saving the tracker file + tracker_filepath = os.path.join(args.save_path, "latest_checkpointed_iteration.txt") + with open(tracker_filepath, "w") as f: + f.write("release") + + # create `release` dir in args.load_path + release_dir = os.path.join(args.save_path, "release") + os.makedirs(release_dir, exist_ok=True) + + # megatron args + megatron_args = { + "orig_vocab_size": config.vocab_size, + "max_position_embeddings": config.n_positions, + "hidden_size": config.n_embd, + "num_layers": config.n_layer, + "num_attention_heads": config.n_head, + "ffn_hidden_size": config.n_inner, + "tensor_model_parallel_size": args.target_tensor_model_parallel_size, + "pipeline_model_parallel_size": args.target_pipeline_model_parallel_size, + "data_parallel_size": args.target_data_parallel_size, + "make_vocab_size_divisible_by": args.make_vocab_size_divisible_by, + "rank": 0, + "tokenizer_type": "GPT2BPETokenizer", + } + + if config.activation_function == "gelu": + megatron_args["bias_gelu_fusion"] = False + megatron_args["openai_gelu"] = False + elif config.activation_function == "gelu_fast": + megatron_args["bias_gelu_fusion"] = True + megatron_args["openai_gelu"] = False + elif config.activation_function == "gelu_new": + megatron_args["bias_gelu_fusion"] = False + megatron_args["openai_gelu"] = True + + margs = types.SimpleNamespace() + for k, v in megatron_args.items(): + setattr(margs, k, v) + + # params dtype + if args.target_params_dtype == "fp16": + dtype = torch.float16 + elif args.target_params_dtype == "bf16": + dtype = torch.bfloat16 + else: + dtype = torch.float32 + setattr(margs, "params_dtype", dtype) + + # save dummy optim state dict + dummy_optim_state_dict = {} + dummy_optim_state_dict["optimizer"] = { + "step": 0, + "param_groups": [ + { + "lr": 0.0, + "beta1": 0.0, + "beta2": 0.0, + "eps": 0.0, + "weight_decay": 0.0, + "correct_bias": False, + "params": [], + } + ], + } + if args.use_distributed_optimizer: + for i in range(args.target_pipeline_model_parallel_size): + for j in range(args.target_tensor_model_parallel_size): + for k in range(args.target_data_parallel_size): + if args.target_pipeline_model_parallel_size == 1: + checkpoint_dir = f"mp_rank_{i:02d}_{k:03d}" + else: + checkpoint_dir = f"mp_rank_{i:02d}_{j:03d}_{k:03d}" + checkpoint_dir = os.path.join(release_dir, checkpoint_dir) + os.makedirs(checkpoint_dir, exist_ok=True) + torch.save( + dummy_optim_state_dict, + os.path.join(checkpoint_dir, "optim.pt"), + ) + + # Convert. + print("Converting") + output_state_dict = [] + for i in range(args.target_tensor_model_parallel_size): + output_state_dict.append({}) + + # Embedding layer + print("converting embedding layer") + pos_embedding = state_dict["transformer.wpe.weight"].to(dtype) + word_embedding = state_dict["transformer.wte.weight"].to(dtype) + orig_vocab_size = config.vocab_size + padded_vocab_size = _vocab_size_with_padding(orig_vocab_size, margs) + setattr(margs, "padded_vocab_size", padded_vocab_size) + # Cut out extra padding we don't need + if orig_vocab_size > padded_vocab_size: + full_word_embed = word_embedding[0:padded_vocab_size, :] + # Expanding embedding to larger size by replicating final entry + elif orig_vocab_size < padded_vocab_size: + padding_size = padded_vocab_size - orig_vocab_size + full_word_embed = torch.cat((word_embedding, word_embedding[-1].unsqueeze(0).expand(padding_size, -1))) + # Same size! + else: + full_word_embed = word_embedding + + # Split into new tensor model parallel sizes + out_word_embed = torch.chunk(full_word_embed, args.target_tensor_model_parallel_size, dim=0) + for i in range(args.target_tensor_model_parallel_size): + pos_emb_dict = get_element_from_dict_by_path( + output_state_dict[i], "model.language_model.embedding.position_embeddings" + ) + pos_emb_dict["weight"] = pos_embedding + + word_emb_dict = get_element_from_dict_by_path( + output_state_dict[i], "model.language_model.embedding.word_embeddings" + ) + word_emb_dict["weight"] = out_word_embed[i] + + # Transformer layers + print("converting transformer layers") + if config.num_hidden_layers % args.target_tensor_model_parallel_size != 0: + raise ValueError( + f"Number of layers ({config.num_hidden_layers}) must be divisible by number of tensor parallelism" + f" ({args.target_tensor_model_parallel_size})" + ) + num_layers = config.num_hidden_layers // args.target_pipeline_model_parallel_size + + layer_re = re.compile("transformer.h\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)") + # The number of heads. + heads = config.n_head + # The hidden_size per head. + hidden_size_per_head = config.n_embd // config.n_head + for pp_rank in range(args.target_pipeline_model_parallel_size): + layer_offset = pp_rank * num_layers + if pp_rank > 0: + output_state_dict = [] + for i in range(args.target_tensor_model_parallel_size): + output_state_dict.append({}) + + for layer in range(num_layers): + pp_layer_id = layer + layer_offset + layers_to_copy = [ + layer_name + for layer_name in state_dict.keys() + if layer_name.startswith(f"transformer.h.{pp_layer_id}.") + ] + + for layer_name in layers_to_copy: + m = layer_re.match(layer_name) + # Stop if that's not a layer + if m is None: + break + + # The index of the layer. + _ = int(m.group(1)) + # The name of the operation. + op_name = m.group(2) + # Is it a weight or a bias? + weight_or_bias = m.group(3) + + params = state_dict[layer_name].to(dtype) + # handle layernorm + if op_name.startswith("ln"): + out_name = "input_layernorm" if op_name.endswith("1") else "post_attention_layernorm" + layer_name = f"layers.{layer}.{out_name}.{weight_or_bias}" + + # handle attention K, V, Q weights + elif op_name.startswith("attn.c_attn") and weight_or_bias == "weight": + # transformers stores D X (3*D) but Megatron-LM expects (3*D) X D. + params = params.transpose(0, 1).contiguous() + + params = transformers_to_megatron_fix_query_key_value_ordering( + params, + 3.0, + 3, + heads, + hidden_size_per_head, + ) + layer_name = f"layers.{layer}.self_attention.query_key_value.{weight_or_bias}" + + # handle attention K, V, Q bias + elif op_name.startswith("attn.c_attn") and weight_or_bias == "bias": + params = transformers_to_megatron_fix_query_key_value_ordering( + params, + 3.0, + 3, + heads, + hidden_size_per_head, + ) + layer_name = f"layers.{layer}.self_attention.query_key_value.{weight_or_bias}" + + # handle attention and mlp weights + elif weight_or_bias == "weight": + out_name = transformers_to_megatron.get(op_name, None) + if out_name is None: + continue + params = params.transpose(0, 1) + layer_name = f"layers.{layer}.{out_name}.{weight_or_bias}" + + # handle attention and mlp bias + elif weight_or_bias == "bias": + out_name = transformers_to_megatron.get(op_name, None) + if out_name is None: + continue + layer_name = f"layers.{layer}.{out_name}.{weight_or_bias}" + + # skip + else: + continue + + if op_name + "." + weight_or_bias in tensor_parallel_params: + dim = 1 if op_name in ["attn.c_proj", "mlp.c_proj"] else 0 + params = torch.chunk(params, args.target_tensor_model_parallel_size, dim=dim) + + for i in range(args.target_tensor_model_parallel_size): + params_dict = get_element_from_dict_by_path(output_state_dict[i], "model.language_model.encoder") + params_dict[layer_name] = ( + params[i] if (op_name + "." + weight_or_bias in tensor_parallel_params) else params + ) + + if pp_rank == args.target_pipeline_model_parallel_size - 1: + # handle final layernorm + for weight_or_bias in ["weight", "bias"]: + params = state_dict[f"transformer.ln_f.{weight_or_bias}"].to(dtype) + layer_name = f"final_layernorm.{weight_or_bias}" + for i in range(args.target_tensor_model_parallel_size): + params_dict = get_element_from_dict_by_path(output_state_dict[i], "model.language_model.encoder") + params_dict[layer_name] = params + + # add the LM head + for i in range(args.target_tensor_model_parallel_size): + params_dict = get_element_from_dict_by_path(output_state_dict[i], "model.word_embeddings_for_head") + params_dict["weight"] = out_word_embed[i] + + # saving the state dict as per the tp_rank and pp_rank + for tp_rank in range(args.target_tensor_model_parallel_size): + output_state_dict[tp_rank]["checkpoint_version"] = 3.0 + output_state_dict[tp_rank]["args"] = margs + checkpoint_dir = ( + f"mp_rank_{tp_rank:02d}" + if args.target_pipeline_model_parallel_size == 1 + else f"mp_rank_{tp_rank:02d}_{pp_rank:03d}" + ) + if args.use_distributed_optimizer: + checkpoint_name = "model_rng.pt" + else: + checkpoint_name = "model_optim_rng.pt" + output_state_dict[tp_rank]["optimizer"] = dummy_optim_state_dict["optimizer"] + checkpoint_dir = os.path.join(release_dir, checkpoint_dir) + os.makedirs(checkpoint_dir, exist_ok=True) + checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name) + if args.print_checkpoint_structure: + print( + f"Checkpoint structure of model state dict shard belonging to TP rank {tp_rank} and PP rank" + f" {pp_rank}:" + ) + recursive_print(None, output_state_dict[tp_rank]) + torch.save(output_state_dict[tp_rank], checkpoint_path) + + +def main(): + parser = argparse.ArgumentParser() + parser = add_checkpointing_args(parser) + parser = add_megatron_checkpoint_args(parser) + parser = add_transformers_checkpoint_args(parser) + args = parser.parse_args() + if args.convert_checkpoint_from_megatron_to_transformers: + convert_checkpoint_from_megatron_to_transformers(args) + else: + convert_checkpoint_from_transformers_to_megatron(args) + + +if __name__ == "__main__": + main() diff --git a/src/transformers/models/megatron_gpt_bigcode/convert_megatron_gpt_bigcode_checkpoint.py b/src/transformers/models/megatron_gpt_bigcode/convert_megatron_gpt_bigcode_checkpoint.py new file mode 100644 index 0000000000..e3d1d85d58 --- /dev/null +++ b/src/transformers/models/megatron_gpt_bigcode/convert_megatron_gpt_bigcode_checkpoint.py @@ -0,0 +1,269 @@ +#################################################################################################### + +# Copyright (c) 2021-, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#################################################################################################### + +# +# Note: If when running this conversion script you're getting an exception: +# ModuleNotFoundError: No module named 'megatron.model.enums' +# you need to tell python where to find the clone of Megatron-LM, e.g.: +# +# cd /tmp +# git clone https://github.com/NVIDIA/Megatron-LM +# PYTHONPATH=/tmp/Megatron-LM python src/transformers/models/megatron_gpt2/convert_megatron_gpt2_checkpoint.py ... +# +# if you already have it cloned elsewhere, simply adjust the path to the existing path +# +# If the training was done using a Megatron-LM fork, e.g., +# https://github.com/microsoft/Megatron-DeepSpeed/ then chances are that you need to have that one +# in your path, i.e., /path/to/Megatron-DeepSpeed/ +# + +import argparse +import os +import re + +import torch + +from transformers import GPTBigCodeConfig, GPTBigCodeForCausalLM, GPTBigCodeModel + + +#################################################################################################### + + +def recursive_print(name, val, spaces=0): + # Format the message. + if name is None: + msg = None + else: + fmt = "." * max(0, spaces - 2) + "# {:" + str(50 - spaces) + "s}" + msg = fmt.format(name) + + # Print and recurse (if needed). + if isinstance(val, dict): + if msg is not None: + print(msg) + for k in val.keys(): + recursive_print(k, val[k], spaces + 2) + elif isinstance(val, torch.Tensor): + print(msg, ":", val.size()) + else: + print(msg, ":", val) + + +#################################################################################################### + +# The simple map of names for "automated" rules. +NAME_MAP = { + "attention.dense": ".attn.c_proj.", + "self_attention.dense": ".attn.c_proj.", + "mlp.dense_h_to_4h": ".mlp.c_fc.", + "mlp.dense_4h_to_h": ".mlp.c_proj.", + "self_attention.query_key_value": ".attn.c_attn.", + "self_attention.query": ".attn.q_attn.", + "self_attention.key_value": ".attn.kv_attn.", +} + + +def convert_megatron_checkpoint(input_state_dict): + # The converted output model. + output_state_dict = {} + ds_args = input_state_dict["args"] + + if ds_args is not None: + if ds_args.bias_gelu_fusion: + activation_function = "gelu_pytorch_tanh" + elif ds_args.openai_gelu: + activation_function = "gelu_new" + else: + activation_function = "gelu" + else: + # in the very early days this used to be "gelu_new" + activation_function = "gelu_new" + + if ds_args.attention_head_type == "multihead": + multi_query = False + else: + assert ds_args.attention_head_type == "multiquery" + multi_query = True + + attention_softmax_in_fp32 = ds_args.attention_softmax_in_fp32 or ds_args.apply_query_key_layer_scaling + + # Spell out all parameters in case the defaults change. + config = GPTBigCodeConfig( + architectures=["GPTBigCodeLMHeadModel"], + vocab_size=ds_args.padded_vocab_size, + n_positions=ds_args.max_position_embeddings, + n_embd=ds_args.hidden_size, + n_layer=ds_args.num_layers, + n_head=ds_args.num_attention_heads, + n_inner=ds_args.ffn_hidden_size, + activation_function=activation_function, + multi_query=multi_query, + resid_pdrop=0.1, + embd_pdrop=0.1, + attn_pdrop=0.1, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + summary_type="cls_index", + summary_use_proj=True, + summary_activation=None, + summary_proj_to_labels=True, + summary_first_dropout=0.1, + scale_attn_weights=True, + use_cache=True, + bos_token_id=50256, + eos_token_id=50256, + attention_softmax_in_fp32=attention_softmax_in_fp32, + scale_attention_softmax_in_fp32=True, + ) + + # from pprint import pprint + # pprint(vars(ds_args)) + # pprint(config) + + # Megatron-LM checkpoint version + checkpoint_version = input_state_dict["checkpoint_version"] + if checkpoint_version < 2.0: + raise NotImplementedError(f"Checkpoint version {checkpoint_version} not supported.") + + # The model. + model = input_state_dict["model"]["language_model"] + + # The word embeddings, truncated to to vocab_size rows. + word_embeddings = model["embedding"]["word_embeddings"]["weight"][: config.vocab_size, :] + output_state_dict["transformer.wte.weight"] = word_embeddings + + # The position embeddings. + output_state_dict["transformer.wpe.weight"] = model["embedding"]["position_embeddings"]["weight"] + + # The transformer. + transformer = model["transformer"] if "transformer" in model else model["encoder"] + + # The regex to extract layer names. + layer_re = re.compile("layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)") + + # Extract the layers. + for key, val in transformer.items(): + # Match the name. + m = layer_re.match(key) + + # Stop if that's not a layer + if m is None: + break + + # The index of the layer. + layer_idx = int(m.group(1)) + # The name of the operation. + op_name = m.group(2) + # Is it a weight or a bias? + weight_or_bias = m.group(3) + + # The name of the layer. + layer_name = f"transformer.h.{layer_idx}" + + # For layernorm(s), simply store the layer norm. + if op_name.endswith("layernorm"): + ln_name = "ln_1" if op_name.startswith("input") else "ln_2" + output_state_dict[layer_name + "." + ln_name + "." + weight_or_bias] = val + + # Concatenate QKV matrix. + elif op_name == "self_attention.key_value": + # Query is before key_value in the dict. + query = output_state_dict.pop(layer_name + ".attn.q_attn." + weight_or_bias) + out_val = torch.cat([query, val], dim=0) + output_state_dict[layer_name + ".attn.c_attn." + weight_or_bias] = out_val + + # Copy the parameters. + else: + output_state_dict[layer_name + NAME_MAP[op_name] + weight_or_bias] = val + + # DEBUG. + assert config.n_layer == layer_idx + 1 + + # The final layernorm. + output_state_dict["transformer.ln_f.weight"] = transformer["final_layernorm.weight"] + output_state_dict["transformer.ln_f.bias"] = transformer["final_layernorm.bias"] + + # For LM head, transformers' wants the matrix to weight embeddings. + output_state_dict["lm_head.weight"] = word_embeddings + + # It should be done! + return config, output_state_dict + + +#################################################################################################### + + +def main(argv=None): + # Create the argument parser. + parser = argparse.ArgumentParser() + parser.add_argument("--print-checkpoint-structure", action="store_true") + parser.add_argument( + "path_to_checkpoint", + type=str, + help="Path to the checkpoint file (.zip archive or direct .pt file)", + ) + parser.add_argument( + "--custom_model", + action="store_true", + help="Save as custom model so it can be used with huggingface transformers.", + ) + parser.add_argument( + "--save_dir", help="Path where the converted model is saved. Will use the checkpoint directory if not provided" + ) + args = parser.parse_args(argv) + + # Extract the basename. + basename = args.save_dir or os.path.dirname(args.path_to_checkpoint) + + # Load the model. + print(f"Extracting PyTorch state dictionary from {args.path_to_checkpoint}") + input_state_dict = torch.load(args.path_to_checkpoint, map_location="cpu") + + # Convert. + print("Converting") + config, output_state_dict = convert_megatron_checkpoint(input_state_dict) + + # Print the structure of converted state dict. + if args.print_checkpoint_structure: + recursive_print(None, output_state_dict) + + if args.custom_model: + # Save custom model + GPTBigCodeConfig.register_for_auto_class() + GPTBigCodeModel.register_for_auto_class("AutoModelForCausalLM") + hf_model = GPTBigCodeForCausalLM(config) + hf_model.load_state_dict(output_state_dict) + hf_model.save_pretrained(basename) + + else: + # Store the config to file. + print("Saving config") + config.save_pretrained(basename) + + # Store the state_dict to file. + output_checkpoint_file = os.path.join(basename, "pytorch_model.bin") + print(f'Saving checkpoint to "{output_checkpoint_file}"') + torch.save(output_state_dict, output_checkpoint_file) + + +#################################################################################################### + +if __name__ == "__main__": + main() + +#################################################################################################### diff --git a/src/transformers/models/megatron_gpt_bigcode/push_checkpoints.py b/src/transformers/models/megatron_gpt_bigcode/push_checkpoints.py new file mode 100644 index 0000000000..e2cbd38a8f --- /dev/null +++ b/src/transformers/models/megatron_gpt_bigcode/push_checkpoints.py @@ -0,0 +1,73 @@ +import argparse +import re +import subprocess +from pathlib import Path + +from huggingface_hub import Repository + +from transformers.models.megatron_gpt_bigcode.convert_megatron_gpt_bigcode_checkpoint import main as convert + + +""" +Script to upload Megatron checkpoints to a HF repo on the Hub. The script clones/creates a repo on the Hub, checks out +a branch `--branch_name`, and converts each `iter_` checkpoint and saves it as a commit on that branch. +""" + + +def get_iter_number(iter_dir: str): + m = re.match(r"iter_(\d+)", iter_dir) + if m is not None: + return int(m.group(1)) + else: + raise ValueError(f"Invalid directory name: {iter_dir}") + + +def main(argv=None): + parser = argparse.ArgumentParser() + parser.add_argument("--exp_dir", type=Path, required=True, help="Path to experiment folder.") + parser.add_argument("--repo_name", required=True, help="Name of repository on the Hub in 'ORG/NAME' format.") + parser.add_argument("--branch_name", required=True, help="Name of branch in repository to save experiments.") + parser.add_argument( + "--save_dir", + type=Path, + help="Path where repository is cloned to locally. Will use {exp_dir}/hf_checkpoints if not provided", + ) + parser.add_argument( + "--iter_interval", + type=int, + default=1, + help="Iteration number must be divisble by iter_interval in order to be pushed", + ) + args, argv = parser.parse_known_args(argv) + + save_dir = args.save_dir or args.exp_dir / "hf_checkpoints" + + hf_repo = Repository(save_dir, clone_from=args.repo_name) + hf_repo.git_checkout(args.branch_name, create_branch_ok=True) + # Find last checkpoint that was uploaded + head_hash = hf_repo.git_head_hash() + commit_msg = subprocess.check_output(["git", "show", "-s", "--format=%B", head_hash], cwd=save_dir).decode() + try: + last_commit_iter = get_iter_number(commit_msg.strip()) + print(f"Last commit iteration: {last_commit_iter}") + except ValueError: + last_commit_iter = -1 + + # The checkpoint dirs should be in ascending iteration order, so that the last commit corresponds to the latest checkpoint + ckpt_dirs = sorted([x for x in args.exp_dir.iterdir() if x.name.startswith("iter_") and x.is_dir()]) + + for ckpt_dir in ckpt_dirs: + iter_number = get_iter_number(ckpt_dir.name) + if iter_number <= last_commit_iter: + continue + if iter_number % args.iter_interval == 0: + print(f"Converting iteration {iter_number}") + # TODO: this only works for 1-way tensor/pipeline parallelism + file_path = next((ckpt_dir / "mp_rank_00").glob("*.pt")) + convert(argv + [f"--save_dir={str(save_dir)}", str(file_path)]) + print(f"Pushing iteration {iter_number}") + hf_repo.push_to_hub(commit_message=f"{ckpt_dir.name}") + + +if __name__ == "__main__": + main()