Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
7384efb
implement GPT Neo's rope
NouamaneTazi Oct 25, 2023
7ebc9ea
fix imports
NouamaneTazi Oct 28, 2023
2c28b04
output logits
NouamaneTazi Oct 28, 2023
253da5b
attn mask.all()
NouamaneTazi Oct 30, 2023
d3c15ba
fix caching in rope
NouamaneTazi Oct 31, 2023
1e74664
GQA generation without cache
NouamaneTazi Nov 1, 2023
1f424bb
fix use_cache for GQA
NouamaneTazi Nov 1, 2023
39a3483
reshapes fixes for num_heads=2
NouamaneTazi Nov 1, 2023
1c79ecd
.
NouamaneTazi Nov 2, 2023
19cf153
add flash_attn_with_kvcache to GQA
NouamaneTazi Dec 7, 2023
b493268
add merging word embedding checkpoints
xrsrke Dec 29, 2023
4446fe0
add merging quite a bit
xrsrke Dec 31, 2023
1d949b2
add reference starcoder model
xrsrke Dec 31, 2023
a58a947
merged most of the checkpoints
xrsrke Dec 31, 2023
ac559a1
add merged checkpoints
xrsrke Jan 1, 2024
78114b7
add mapping to target state dict
xrsrke Jan 1, 2024
7d50b80
refactor converting scrip
xrsrke Jan 2, 2024
21ee689
refactor
xrsrke Jan 3, 2024
210311b
add inference script
xrsrke Jan 3, 2024
09c086a
refactor
xrsrke Jan 3, 2024
ae54653
refactor all functions
xrsrke Jan 3, 2024
594099c
save some files before cleaning it all
xrsrke Jan 3, 2024
fb8a86b
delete uncessary files
xrsrke Jan 3, 2024
c26472c
add rope_theta to config
NouamaneTazi Jan 5, 2024
9c9cfbb
fix config.attn_pdrop for flash attn
NouamaneTazi Jan 8, 2024
6bdf78a
Merge pull request #1 from xrsrke/sc2-rope
NouamaneTazi Jan 8, 2024
1507798
Refactor GPTBigCode model conversion code
NouamaneTazi Jan 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@ class GPTBigCodeConfig(PretrainedConfig):
Number of hidden layers in the Transformer encoder.
n_head (`int`, *optional*, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
n_inner (`int`, *optional*, defaults to None):
Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd
activation_function (`str`, *optional*, defaults to `"gelu_pytorch_tanh"`):
Expand All @@ -63,6 +71,8 @@ class GPTBigCodeConfig(PretrainedConfig):
The dropout ratio for the attention.
layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
The epsilon to use in the layer normalization layers.
rope_theta (`int`, *optional*, defaults to 10000):
The theta value to use in the rotary position embeddings.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
scale_attn_weights (`bool`, *optional*, defaults to `True`):
Expand Down Expand Up @@ -106,12 +116,14 @@ def __init__(
n_embd=768,
n_layer=12,
n_head=12,
num_key_value_heads=None,
n_inner=None,
activation_function="gelu_pytorch_tanh",
resid_pdrop=0.1,
embd_pdrop=0.1,
attn_pdrop=0.1,
layer_norm_epsilon=1e-5,
rope_theta=10000,
initializer_range=0.02,
scale_attn_weights=True,
use_cache=True,
Expand All @@ -131,12 +143,19 @@ def __init__(
self.n_embd = n_embd
self.n_layer = n_layer
self.n_head = n_head

# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = 1 if multi_query else n_head
self.num_key_value_heads = num_key_value_heads

self.n_inner = n_inner
self.activation_function = activation_function
self.resid_pdrop = resid_pdrop
self.embd_pdrop = embd_pdrop
self.attn_pdrop = attn_pdrop
self.layer_norm_epsilon = layer_norm_epsilon
self.rope_theta = rope_theta
self.initializer_range = initializer_range
self.scale_attn_weights = scale_attn_weights
self.use_cache = use_cache
Expand Down
Loading