diff --git a/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py index 3c5cff3a7e..cd1978e4d9 100644 --- a/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py @@ -50,6 +50,14 @@ class GPTBigCodeConfig(PretrainedConfig): Number of hidden layers in the Transformer encoder. n_head (`int`, *optional*, defaults to 12): Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. n_inner (`int`, *optional*, defaults to None): Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd activation_function (`str`, *optional*, defaults to `"gelu_pytorch_tanh"`): @@ -63,6 +71,8 @@ class GPTBigCodeConfig(PretrainedConfig): The dropout ratio for the attention. layer_norm_epsilon (`float`, *optional*, defaults to 1e-5): The epsilon to use in the layer normalization layers. + rope_theta (`int`, *optional*, defaults to 10000): + The theta value to use in the rotary position embeddings. initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. scale_attn_weights (`bool`, *optional*, defaults to `True`): @@ -106,12 +116,14 @@ def __init__( n_embd=768, n_layer=12, n_head=12, + num_key_value_heads=None, n_inner=None, activation_function="gelu_pytorch_tanh", resid_pdrop=0.1, embd_pdrop=0.1, attn_pdrop=0.1, layer_norm_epsilon=1e-5, + rope_theta=10000, initializer_range=0.02, scale_attn_weights=True, use_cache=True, @@ -131,12 +143,19 @@ def __init__( self.n_embd = n_embd self.n_layer = n_layer self.n_head = n_head + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = 1 if multi_query else n_head + self.num_key_value_heads = num_key_value_heads + self.n_inner = n_inner self.activation_function = activation_function self.resid_pdrop = resid_pdrop self.embd_pdrop = embd_pdrop self.attn_pdrop = attn_pdrop self.layer_norm_epsilon = layer_norm_epsilon + self.rope_theta = rope_theta self.initializer_range = initializer_range self.scale_attn_weights = scale_attn_weights self.use_cache = use_cache diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 387e78eb46..6103d31b3c 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -15,6 +15,15 @@ import math from typing import List, Optional, Tuple, Union +try: + from flash_attn import bert_padding + from flash_attn.flash_attn_interface import ( + flash_attn_varlen_func, + flash_attn_with_kvcache, + ) + FLASHATTN_IS_AVAILABLE = True +except ImportError: + FLASHATTN_IS_AVAILABLE = False import torch import torch.utils.checkpoint from torch import nn @@ -35,7 +44,12 @@ logging, ) from .configuration_gpt_bigcode import GPTBigCodeConfig - +import warnings +try: + from apex.normalization import FusedLayerNorm as LayerNorm +except ImportError: + from torch.nn import LayerNorm + warnings.warn("Install NVIDIA Apex for apex.normalization.FusedLayerNorm to be available") logger = logging.get_logger(__name__) @@ -93,17 +107,137 @@ def _apply_rotary_embeddings( return torch.view_as_real(complex_tensor * rope_frequencies).view_as(tensor).type_as(tensor) +# rotary pos emb helpers (torch.jit.script does not seem to support staticmethod...) +@torch.jit.script +def rotate_half(x): + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +class StarcoderRotaryEmbedding(nn.Module): + """Implementation of RotaryEmbedding from GPT-NeoX.""" + + def __init__(self, head_dim: int, base=10000): + super().__init__() + self.base = base + self.head_dim = head_dim + self.seq_len_cached = -1 + # TODO @nouamane: Figure out why we can't set `DTypeInvariantTensor` ... + self.inv_freq: torch.Tensor + self.register_buffer( + "inv_freq", + torch.empty(head_dim // 2, dtype=torch.float), + persistent=False, + ) + self.cos_cached: Optional[torch.Tensor] = None + self.sin_cached: Optional[torch.Tensor] = None + self._initialized_buffer = False + + def init_rotary_embeddings(self): + if self._initialized_buffer is True: + # Buffer if already initialized + return + + assert self.inv_freq.device.type == "cuda" + # TODO @nouamane: One we figure out how to do the DTypeInvariantTensor, this can be removed and changed to an assert + if self.inv_freq.dtype != torch.float: + self.inv_freq = self.inv_freq.to(torch.float) + assert self.inv_freq.dtype == torch.float + + self.inv_freq = 1.0 / ( + self.base ** (torch.arange(0, self.head_dim, 2, dtype=torch.float, device="cuda") / self.head_dim) + ) + + self._initialized_buffer = True + + def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype=torch.bfloat16) -> torch.Tensor: + total_length = seq_len + past_key_values_length + if total_length > self.seq_len_cached: + self.seq_len_cached = total_length + assert self.inv_freq.dtype == torch.float + t = torch.arange(total_length, device=device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) # [seq_len, head_dim] + + if dtype in [torch.float16, torch.bfloat16]: + emb = emb.float() + + self.cos_cached = emb.cos()[None, :, None, :] # [1, seq_len, 1, head_dim] + self.sin_cached = emb.sin()[None, :, None, :] + + self.cos_cached = self.cos_cached.type(dtype) + self.sin_cached = self.sin_cached.type(dtype) + + return ( + self.cos_cached[:, past_key_values_length : seq_len + past_key_values_length], + self.sin_cached[:, past_key_values_length : seq_len + past_key_values_length], + ) + + def forward(self, query, key, past_key_values_length=0): + """ + Args: + query: [batch_size, seq_len, num_heads, head_dim] + key: [batch_size, seq_len, num_heads_k, head_dim] + past_key_values_length: int + + Returns: + query: [batch_size, seq_len, num_heads, head_dim] + key: [batch_size, seq_len, num_heads_k, head_dim] + """ + if self._initialized_buffer is False: + self.init_rotary_embeddings() + seq_len = query.shape[1] + cos, sin = self.cos_sin( + seq_len, past_key_values_length, query.device, query.dtype + ) # [1, seq_len, 1, head_dim] + query = (query * cos) + (rotate_half(query) * sin) + key = (key * cos) + (rotate_half(key) * sin) + if past_key_values_length > 0: + assert ( + query.shape[1] == 1 + ), f"past_key_values_length={past_key_values_length} but query.shape[1]={query.shape[1]}" + return query, key + + +def pad_to_right(tensor, mask, new_tensor=None): + """Transform a left-padded tensor into a right-padded tensor. (Useful for prefilling key/value states) + Args: + tensor: (batch_size, seqlen, d1, d2) + mask: (batch_size, seqlen) + new_tensor: (batch_size, new_tensor_seqlen, d1, d2) + Returns: + new_tensor: (batch_size, new_tensor_seqlen, d1, d2) + right_padded_mask: (batch_size, seqlen) + """ + # First, we need to find the number of padding for each row + unpad_seqlens = mask.sum(1) + # Then, we need to find the maximum length of the tensor + max_seqlen = mask.shape[1] + # We can then create the indices to select the padded values + # The indices are the same for each row + indices = torch.arange(max_seqlen, device=mask.device) + # We can then create the mask for the padded values + right_padded_mask = indices < unpad_seqlens[:, None] + # We select the useful values + useful_values = tensor[mask] + # We create the new tensor (if not provided) + new_tensor = torch.zeros_like(tensor) if new_tensor is None else new_tensor + # We fill the new tensor with the useful values + new_tensor[:, : right_padded_mask.shape[1], :, :][right_padded_mask] = useful_values + return new_tensor, right_padded_mask + class GPTBigCodeAttention(nn.Module): def __init__(self, config, is_cross_attention=False, layer_idx=None): super().__init__() self.mask_value = None + # self.multi_query = config.num_key_value_heads == 1 self.multi_query = config.multi_query + assert not config.multi_query or (config.multi_query and config.num_key_value_heads == 1), f"Got MQA with num_key_value_heads = {config.num_key_value_heads}" self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads - self.kv_heads = 1 if self.multi_query else self.num_heads - self.kv_dim = self.kv_heads * self.head_dim + self.kv_heads = config.num_key_value_heads self.split_size = self.embed_dim if self.head_dim * self.num_heads != self.embed_dim: raise ValueError( @@ -124,17 +258,35 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): if self.is_cross_attention: if self.multi_query: raise NotImplementedError("Multi-Query Attention not supported for cross_attention") + if self.kv_heads != self.num_heads: + raise NotImplementedError("Cross-Attention not supported for num_key_value_heads != num_attention_heads") self.c_attn = nn.Linear(self.embed_dim, 2 * self.embed_dim) self.q_attn = nn.Linear(self.embed_dim, self.embed_dim) else: - self.c_attn = nn.Linear(self.embed_dim, self.embed_dim + 2 * self.kv_dim) + self.c_attn = nn.Linear(self.embed_dim, self.embed_dim + 2 * self.kv_heads * self.head_dim) + + if self.use_rotary_embeddings: + self.maybe_rotary = ( + StarcoderRotaryEmbedding(head_dim=self.head_dim, base=config.rope_theta) + if config.use_rotary_embeddings + else lambda q, k, t: (q, k) + ) self.c_proj = nn.Linear(self.embed_dim, self.embed_dim) self.attn_dropout = nn.Dropout(config.attn_pdrop) + self.attn_pdrop = config.attn_pdrop self.resid_dropout = nn.Dropout(config.resid_pdrop) + self.prefill_kv_len = ( + config.max_position_embeddings + ) # TODO @nouamane: compute based on free memory, because in rope we can surpass max_position_embeddings + self.n_local_kv_heads = self.kv_heads + self.n_local_q_heads = self.num_heads + self.n_repeats = self.num_heads // self.kv_heads + + def _get_mask_value(self, device, dtype): # torch.where expects a tensor. We use a cache to avoid recreating it every time. if self.mask_value is None or self.mask_value.dtype != dtype or self.mask_value.device != device: @@ -142,6 +294,13 @@ def _get_mask_value(self, device, dtype): return self.mask_value def _attn(self, query, key, value, attention_mask=None, head_mask=None): + """ + Args: + query (batch_size, query_length, num_heads * head_dim) + key (batch_size, kv_heads*head_dim, key_length) + value (batch_size, key_length, kv_heads, head_dim) + + """ dtype = query.dtype softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else dtype upcast = dtype != softmax_dtype @@ -164,6 +323,29 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): attn_view = (batch_size, query_length * self.num_heads, key_length) # No copy needed for MQA 2, or when layer_past is provided. query = query.reshape(batch_size, query_length * self.num_heads, self.head_dim) + # value shape (batch_size, key_length, 1, head_dim) + elif self.kv_heads < self.num_heads: # GQA + # (batch_size, query_length, num_heads, head_dim) x (batch_size, kv_heads, head_dim, key_length) + # -> (batch_size, num_heads, query_length, key_length) + query_length = query_shape[1] + attn_shape = (batch_size, self.num_heads, query_length, key_length) + attn_view = (batch_size * self.num_heads, query_length, key_length) + # Always copies + query = query.view(batch_size, query_length, self.num_heads, self.head_dim) + query = query.transpose(1,2).reshape(batch_size * self.num_heads, query_length, self.head_dim) + # query = query.reshape(batch_size, query_length * self.num_heads, self.head_dim) + # Need to repeat key/value for each query -> we need kv_heads dim to precede key_length dim + # key shape: (batch_size, kv_heads, head_dim, key_length) + n_repeats = self.num_heads // self.kv_heads + # TODO @nouamane: refactor this + key = (key.view(batch_size, self.kv_heads, 1, self.head_dim, key_length) + .expand(batch_size, self.kv_heads, n_repeats, self.head_dim, key_length) + .reshape(batch_size * self.num_heads, self.head_dim, key_length)) + value = value.view(batch_size, key_length, self.kv_heads, self.head_dim) + value = value.transpose(1,2) + value = (value.view(batch_size, self.kv_heads, 1, key_length, self.head_dim) + .expand(batch_size, self.kv_heads, n_repeats, key_length, self.head_dim) + .reshape(batch_size, self.num_heads, key_length, self.head_dim)) else: # (batch_size, num_heads, query_length, head_dim) x (batch_size, num_heads, head_dim, key_length) # -> (batch_size, num_heads, query_length, key_length) @@ -185,6 +367,8 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): else: beta = 0 attn_weights = torch.baddbmm(attn_weights, query, key, beta=beta, alpha=scale_factor).view(attn_shape) + # if not self.multi_query: + # attn_weights = attn_weights.transpose(1, 2) # (batch_size, query_length, num_heads, key_length) if upcast: # Use a fused kernel to prevent a large overhead from casting and scaling. @@ -212,24 +396,151 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): attn_weights = attn_weights * head_mask if self.multi_query: + # value shape: (batch_size, key_length, 1, head_dim) attn_output = torch.bmm(attn_weights.view(attn_view), value).view(query_shape) else: attn_output = torch.matmul(attn_weights, value) return attn_output, attn_weights + def _flash_attn(self, hidden_states, layer_past, sequence_mask, position_ids): + # hidden_states (batch_size, query_length, embed_dim) + # layer_past (batch_size, past_key_values_length, embed_dim) + # sequence_mask (batch_size, query_length) + + fused_qkv = self.c_attn( + hidden_states + ) # [batch_size, query_length, n_local_q_heads * head_dim + 2 * n_local_kv_heads * head_dim] + batch_size, q_length, _ = fused_qkv.size() + + qkv = fused_qkv.view(batch_size, q_length, self.n_local_kv_heads, self.n_repeats + 2, self.head_dim) + + query, key, value = torch.split(qkv, [self.n_repeats, 1, 1], dim=3) + query_states = query.reshape( + batch_size, q_length, self.n_local_q_heads, self.head_dim + ) + key_states = key.reshape(batch_size, q_length, self.n_local_kv_heads, self.head_dim) + value_states = value.reshape(batch_size, q_length, self.n_local_kv_heads, self.head_dim) + + # Compute rotary embeddings + if layer_past is None: + past_key_values_length = 0 + else: + past_key_values_length = layer_past.shape[1] + query_states, key_states = self.maybe_rotary( + query_states, key_states, past_key_values_length=past_key_values_length + ) + + if layer_past is None: + # First inference iteration (Prefill) + # TODO @nouamane: support custom masking + # assert that [ False, False, False, False, True, True, True, True, True, True] is accepted + # but [ False, False, False, False, True, True, False, False, True, True] is not (can't mask in the middle of sequence) + assert ~( + sequence_mask[:, :-1] & (~sequence_mask[:, 1:]) # True is never followed by False + ).any(), f"Can't mask in the middle of sequence, please use USE_FAST=0 instead.\nGot sequence_mask: {sequence_mask}" + + # preallocate k_cache, v_cache to self.prefill_kv_len + k_cache = torch.zeros( + ( + batch_size, + self.prefill_kv_len, + self.n_local_kv_heads, + self.head_dim, + ), + dtype=query_states.dtype, + device=query_states.device, + ) + v_cache = torch.zeros( + (batch_size, self.prefill_kv_len, self.n_local_kv_heads, self.head_dim), + dtype=query_states.dtype, + device=query_states.device, + ) + self.register_buffer("k_cache", k_cache) + self.register_buffer("v_cache", v_cache) + + # Remove pad tokens from key_states and concatenate samples in key_unpad + # cu_seqlens_k is the cumulative sequence lengths of key_states + (query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input( + query_states, + sequence_mask, + ) + (key_unpad, indices_k, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input( + key_states, sequence_mask + ) + (value_unpad, _, _, _) = bert_padding.unpad_input(value_states, sequence_mask) + + output_unpad = flash_attn_varlen_func( + q=query_unpad, # (total_q, self.n_local_q_heads, d_qk) + k=key_unpad, # (total_kv, self.n_local_kv_heads, d_qk) + v=value_unpad, # (total_kv, self.n_local_kv_heads, d_v) + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + dropout_p=self.attn_pdrop, + softmax_scale=None, + causal=True, # True in prefill phase, False in subsequent phases + return_attn_probs=False, + ) # (total_unpadded, n_local_q_heads, d_v) + + attention_output = bert_padding.pad_input( + output_unpad, indices_q, batch_size, q_length + ) # (batch_size, q_length, n_local_q_heads, d_v) + pad_to_right(key_states, sequence_mask, new_tensor=k_cache) + pad_to_right(value_states, sequence_mask, new_tensor=v_cache) + + + else: + # Pull pre-computed key/value states + # Subsequent inference iterations (q_length=1) + k_cache = self.k_cache + v_cache = self.v_cache + # TODO: remove layer_past as it's redundant with k_cache, v_cache + + # [batch_size, seq_length, num_heads, d_qk] + query_states = query_states.view( + batch_size, q_length, self.n_local_q_heads, self.head_dim + ) # [batch_size, q_length, self.n_local_q_heads, self.head_dim] + kv_length = key_states.shape[1] + key_states = key_states.view( + batch_size, kv_length, self.n_local_kv_heads, self.head_dim + ) # [batch_size, kv_length, self.n_local_kv_heads, self.head_dim] + value_states = value_states.view( + batch_size, kv_length, self.n_local_kv_heads, self.head_dim + ) # [batch_size, kv_length, self.n_local_kv_heads, self.head_dim] + + position_offsets = position_ids[:, -1] + attention_output = flash_attn_with_kvcache( + query_states, + k_cache, + v_cache, + key_states, + value_states, + rotary_cos=None, + rotary_sin=None, + # TODO @nouamane: seems like this doesnt help to indicate padding in (for first iteration it's just 0) + cache_seqlens=position_offsets.contiguous(), + softmax_scale=None, + causal=True, + rotary_interleaved=False, # GPT-NeoX style + ) + + return key_states, value_states, attention_output + def forward( self, - hidden_states: torch.Tensor, - layer_past: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, + hidden_states: torch.Tensor, # (batch_size, query_length, embed_dim) + layer_past: Optional[torch.Tensor] = None, # (batch_size, past_key_values_length, embed_dim) + attention_mask: Optional[torch.Tensor] = None, # (batch_size, query_length) head_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, rotary_embedding_frequencies_q: Optional[torch.Tensor] = None, - rotary_embedding_frequencies_k: Optional[torch.Tensor] = None + rotary_embedding_frequencies_k: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, ) -> Union[ Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]], @@ -245,32 +556,66 @@ def forward( key_value = self.c_attn(encoder_hidden_states) attention_mask = encoder_attention_mask elif self.multi_query: - query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2) - else: - # Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim), - # i.e., the memory layout is not the same as GPT2. - # This makes the concatenation with past_key_value more efficient. - query, key_value = ( - self.c_attn(hidden_states) - .view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim) - .transpose(1, 2) - .split((self.head_dim, 2 * self.head_dim), dim=3) - ) + query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_heads * self.head_dim), dim=2) + # query shape: (batch_size, query_length, num_heads * head_dim) + # key_value shape: (batch_size, query_length, 2 * 1 * head_dim) + key, value = key_value.split( + (self.head_dim, self.head_dim), dim=-1 + ) # (batch_size, query_length, 1 * head_dim) + else: # GQA + if FLASHATTN_IS_AVAILABLE: + key, value, attn_output = self._flash_attn(hidden_states, layer_past, attention_mask, position_ids) + attn_output = attn_output.view(hidden_states.shape) + + else: + # Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim), + # i.e., the memory layout is not the same as GPT2. + # This makes the concatenation with past_key_value more efficient. + + n_repeats = self.num_heads // self.kv_heads + + query, key, value = ( + self.c_attn(hidden_states) + .view(*hidden_states.shape[:2], self.kv_heads, n_repeats + 2, self.head_dim) + .split((n_repeats, 1, 1), dim=3) + ) + # query shape: (batch_size, query_length, kv_heads, n_repeats, head_dim) + # key, value shape: (batch_size, query_length, kv_heads, 1, head_dim) if layer_past is not None: - key_value = torch.cat((layer_past, key_value), dim=-2) - present = key_value if use_cache else None + past_kv_length = layer_past.shape[-2] + else: + past_kv_length = 0 - key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) - if self.use_rotary_embeddings: - query = _apply_rotary_embeddings(query, rotary_embedding_frequencies_q) - key = _apply_rotary_embeddings(key, rotary_embedding_frequencies_k) + if self.use_rotary_embeddings and not FLASHATTN_IS_AVAILABLE: + # query = _apply_rotary_embeddings(query, rotary_embedding_frequencies_q) + # key = _apply_rotary_embeddings(key, rotary_embedding_frequencies_k) + query = query.reshape(*query.shape[:2], self.num_heads, self.head_dim) + key = key.view(*key.shape[:2], self.kv_heads, self.head_dim) + query, key = self.maybe_rotary(query, key, past_kv_length) + query = query.view(*query.shape[:2], self.num_heads * self.head_dim) + key = key.view(*key.shape[:2], self.kv_heads * self.head_dim) + value = value.reshape(*value.shape[:2], self.kv_heads*self.head_dim) - attn_output, attn_weights = self._attn(query, key.transpose(-1, -2), value, attention_mask, head_mask) + if use_cache: + key_value = torch.cat((key, value), dim=-1) + key_value = key_value.view(*key_value.shape[:2], 2 * self.kv_heads * self.head_dim) + # TODO @nouamane: do we need to concat? + + if layer_past is not None: + # Concatenate past key/values with new key/values. + key_value = torch.cat((layer_past, key_value), dim=1) + key, value = key_value.split( + (self.kv_heads * self.head_dim, self.kv_heads * self.head_dim), dim=-1 + ) # (batch_size, key_length, kv_heads * head_dim) - if not self.multi_query: - attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape) + present = key_value if use_cache else None + if not FLASHATTN_IS_AVAILABLE: + attn_output, attn_weights = self._attn(query, key.transpose(-1, -2), value, attention_mask, head_mask) + + if not self.multi_query: + attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape) attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) @@ -279,7 +624,7 @@ def forward( if self.multi_query: # Transpose to return weights in the usual format (batch_size, num_heads, query_length, key_length) attn_weights = attn_weights.transpose(1, 2) - outputs += (attn_weights,) + outputs += (attn_weights,) #TODO: fix for GQA return outputs # a, present, (attentions) @@ -308,15 +653,15 @@ def __init__(self, config, layer_idx=None): hidden_size = config.hidden_size self.inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size - self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.ln_1 = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.attn = GPTBigCodeAttention(config, layer_idx=layer_idx) - self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.ln_2 = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) if config.add_cross_attention: if config.multi_query: raise NotImplementedError("Cross-attention not implemented for MQA") self.crossattention = GPTBigCodeAttention(config, is_cross_attention=True, layer_idx=layer_idx) - self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.ln_cross_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.mlp = GPTBigCodeMLP(self.inner_dim, config) @@ -331,7 +676,8 @@ def forward( use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, rotary_embedding_frequencies_q: Optional[torch.Tensor] = None, - rotary_embedding_frequencies_k: Optional[torch.Tensor] = None + rotary_embedding_frequencies_k: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, ) -> Union[ Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor] ]: @@ -345,7 +691,8 @@ def forward( use_cache=use_cache, output_attentions=output_attentions, rotary_embedding_frequencies_q=rotary_embedding_frequencies_q, - rotary_embedding_frequencies_k=rotary_embedding_frequencies_k + rotary_embedding_frequencies_k=rotary_embedding_frequencies_k, + position_ids=position_ids, ) attn_output = attn_outputs[0] # output_attn: a, present, (attentions) outputs = attn_outputs[1:] @@ -369,7 +716,8 @@ def forward( encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, rotary_embedding_frequencies_q=rotary_embedding_frequencies_q, - rotary_embedding_frequencies_k=rotary_embedding_frequencies_k + rotary_embedding_frequencies_k=rotary_embedding_frequencies_k, + position_ids=position_ids, ) attn_output = cross_attn_outputs[0] # residual connection @@ -428,7 +776,7 @@ def _init_weights(self, module): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): + elif isinstance(module, LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) @@ -536,24 +884,21 @@ def __init__(self, config): self.wte = nn.Embedding(config.vocab_size, self.embed_dim) if config.use_position_embeddings: self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) - if config.use_rotary_embeddings: - # Calculate the complex frequencies (https://blog.eleuther.ai/rotary-embeddings/) - # `exp(i * n * a) = cos(n * a) + i sin(n * a)`, - # `a = theta ** - (2 * (channel // 2) / kv_channels)`, - # where n is the position in the sequence. - kv_channels = config.n_embd / config.n_head - angles = torch.outer( - torch.arange(config.max_position_embeddings, dtype=torch.float32), - torch.exp( - config.rotary_embedding_scale - * torch.arange(0, 1, 2 / kv_channels, dtype=torch.float32) - ), - ) - self._rotary_embedding_frequencies = torch.polar(torch.ones_like(angles), angles)[None, :, None, :] + # if config.use_rotary_embeddings: + # # Calculate the complex frequencies (https://blog.eleuther.ai/rotary-embeddings/) + # # `exp(i * n * a) = cos(n * a) + i sin(n * a)`, + # # `a = theta ** - (2 * (channel // 2) / kv_channels)`, + # # where n is the position in the sequence. + # kv_channels = config.n_embd / config.n_head + # angles = torch.outer( + # torch.arange(config.max_position_embeddings, dtype=torch.float32), + # torch.exp(config.rotary_embedding_scale * torch.arange(0, 1, 2 / kv_channels, dtype=torch.float32)), + # ) + # self._rotary_embedding_frequencies = torch.polar(torch.ones_like(angles), angles)[None, :, None, :] self.drop = nn.Dropout(config.embd_pdrop) self.h = nn.ModuleList([GPTBigCodeBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)]) - self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) max_positions = config.max_position_embeddings self.register_buffer( @@ -638,30 +983,36 @@ def forward( elif position_ids is None: position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) - + position_ids = position_ids.to(torch.int32) # For flash-attn with kv cache + # Rotary frequencies rotary_embedding_frequencies_q = None rotary_embedding_frequencies_k = None - if self.config.use_rotary_embeddings: - rotary_embedding_frequencies_q = self._rotary_embedding_frequencies[:, past_length : past_length + input_shape[-1]].to(device=device) - rotary_embedding_frequencies_k = self._rotary_embedding_frequencies[:, :past_length + input_shape[-1], :, :].to(device=device) + # if self.config.use_rotary_embeddings: + # rotary_embedding_frequencies_q = self._rotary_embedding_frequencies[ + # :, past_length : past_length + input_shape[-1] + # ].to(device=device) + # rotary_embedding_frequencies_k = self._rotary_embedding_frequencies[ + # :, : past_length + input_shape[-1], :, : + # ].to(device=device) # Self-attention mask. query_length = input_shape[-1] key_length = past_length + query_length - self_attention_mask = self.bias[None, key_length - query_length : key_length, :key_length] + # self_attention_mask = self.bias[None, key_length - query_length : key_length, :key_length] # Sliding window attention - if self.config.attention_window_size is not None: - self_attention_mask.triu_(-self.config.attention_window_size + 1) + # if self.config.attention_window_size is not None: + # self_attention_mask.triu_(-self.config.attention_window_size + 1) - if attention_mask is not None: - self_attention_mask = self_attention_mask * attention_mask.view(batch_size, 1, -1).to( - dtype=torch.bool, device=self_attention_mask.device - ) + # if attention_mask is not None: + # self_attention_mask = self_attention_mask * attention_mask.view(batch_size, 1, -1).to( + # dtype=torch.bool, device=self_attention_mask.device + # ) - # MQA models: (batch_size, query_length, n_heads, key_length) - # MHA models: (batch_size, n_heads, query_length, key_length) - attention_mask = self_attention_mask.unsqueeze(2 if self.multi_query else 1) + # # MQA models: (batch_size, query_length, n_heads, key_length) + # # MHA models: (batch_size, n_heads, query_length, key_length) + # attention_mask = self_attention_mask.unsqueeze(2 if self.multi_query else 1) + attention_mask = attention_mask.to(dtype=torch.bool) # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] @@ -712,7 +1063,14 @@ def forward( def create_custom_forward(module): def custom_forward(*inputs): # None for past_key_value - return module(*inputs, use_cache, output_attentions, rotary_embedding_frequencies_q, rotary_embedding_frequencies_k) + return module( + *inputs, + use_cache, + output_attentions, + rotary_embedding_frequencies_q, + rotary_embedding_frequencies_k, + position_ids, + ) return custom_forward @@ -736,7 +1094,8 @@ def custom_forward(*inputs): use_cache=use_cache, output_attentions=output_attentions, rotary_embedding_frequencies_q=rotary_embedding_frequencies_q, - rotary_embedding_frequencies_k=rotary_embedding_frequencies_k + rotary_embedding_frequencies_k=rotary_embedding_frequencies_k, + position_ids=position_ids, ) hidden_states = outputs[0]