Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -133,13 +133,16 @@ def __init__(
eos_token_id=50256,
attention_softmax_in_fp32=True,
scale_attention_softmax_in_fp32=True,
fused_softmax=None,
multi_query=True,
flash_attention=False,
inference_runner=InferenceRunnerType.NO_RUNNER,
validate_runner_input=True,
pre_allocate_kv_cache=False,
max_sequence_length=None,
max_batch_size=None,
pad_key_length=True,
predict_last_token: bool = False,
**kwargs,
):
self.vocab_size = vocab_size
Expand All @@ -158,7 +161,9 @@ def __init__(
self.use_cache = use_cache
self.attention_softmax_in_fp32 = attention_softmax_in_fp32
self.scale_attention_softmax_in_fp32 = scale_attention_softmax_in_fp32
self.fused_softmax = fused_softmax
self.multi_query = multi_query
self.flash_attention = flash_attention

self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
Expand All @@ -175,4 +180,7 @@ def __init__(
# Pad key length to a multiple of 8 (requires pre_allocate_kv_cache).
self.pad_key_length = pad_key_length

# Predict only the last token in inference even if the input is bigger.
self.predict_last_token = predict_last_token

super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
46 changes: 24 additions & 22 deletions src/transformers/models/gpt_bigcode/inference_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@

from transformers import GPTBigCodeConfig
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
from transformers.models.gpt_bigcode.configuration_gpt_bigcode import InferenceRunnerType
from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeBlock, masked_softmax, upcast_masked_softmax
from transformers.models.gpt_bigcode.configuration_gpt_bigcode import (
InferenceRunnerType,
)
from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeBlock, softmax_function


def _align_tensor(x):
Expand All @@ -23,6 +25,7 @@ def __init__(self, config: GPTBigCodeConfig, model):
assert config.pre_allocate_kv_cache
self.validate_input = config.validate_runner_input
self.pad_key_length = 8 if config.pad_key_length else 1
self.fused_softmax = True if config.fused_softmax is None and config.pad_key_length else config.fused_softmax

# TODO: Support other attention types?
assert model.multi_query
Expand Down Expand Up @@ -53,7 +56,7 @@ def _allocate(self, batch_size, device, dtype):
kv_end = query_end + 2 * self.batch_size * attn.kv_dim
# Attn weights: (batch_size, num_heads, key_length), no overlap with value
attn_weights_begin = _align_tensor(kv_end)
attn_weights_end = kv_end + self.batch_size * attn.num_heads * self.max_sequence_length
attn_weights_end = attn_weights_begin + self.batch_size * attn.num_heads * self.max_sequence_length
# Projection: (batch_size, embed_dim), no overlap with attn outputs ~ query.
# Also used for MLP projection
c_proj_begin = _align_tensor(query_end)
Expand Down Expand Up @@ -159,6 +162,10 @@ def _allocate(self, batch_size, device, dtype):
if self.inference_runner_type != InferenceRunnerType.BASE_RUNNER:
print("Generating cuda graphs")
self.memory_pool = None
# This prevents some issue with cublas initialization.
# https://github.com/pytorch/pytorch/issues/99397
dummy_matrix = self.mask_value.view([1, 1])
torch.matmul(dummy_matrix, dummy_matrix)
if self.inference_runner_type == InferenceRunnerType.FULL_GRAPH:
self.cuda_graphs = {}
# The output may not always be at the same memory location.
Expand Down Expand Up @@ -187,22 +194,19 @@ def _generate_cuda_graphs(self):

def _generate_full_cuda_graph(self, key_length):
# We need to warmup the jit function before creating the graph, otherwise it will crash.
# https://github.com/pytorch/pytorch/issues/99397
# Warmup needs to be done for every input shape (key length), and for both scale == 1 and scale != 1
if self.upcast:
if self.fused_softmax or (self.fused_softmax is None and key_length % 8 == 0):
for scale in (1.0, 2.0):
upcast_masked_softmax(
softmax_function(
self.padded_attn_weights[key_length],
self.padded_attn_masks[key_length],
self.mask_value,
scale,
self.softmax_dtype,
self.upcast,
self.fused_softmax,
)
else:
masked_softmax(
self.padded_attn_weights[key_length],
self.padded_attn_masks[key_length],
self.mask_value,
)
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, pool=self.memory_pool):
self.output_hidden_states[key_length] = self._forward(key_length)
Expand Down Expand Up @@ -239,18 +243,16 @@ def _forward_attn(self, block, key_length):
alpha=self.scale[layer_idx],
out=attn_weights,
)
# Use a fused kernel to prevent a large overhead from casting and scaling.
# Jit doesn't allow inplace kernel.
if self.upcast:
attn_weights = upcast_masked_softmax(
attn_weights,
self.padded_attn_masks[key_length],
self.mask_value,
self.unscale[layer_idx],
self.softmax_dtype,
)
else:
attn_weights = masked_softmax(attn_weights, self.padded_attn_masks[key_length], self.mask_value)
attn_weights = softmax_function(
attn_weights,
self.padded_attn_masks[key_length],
self.mask_value,
self.unscale[layer_idx],
self.softmax_dtype,
self.upcast,
self.fused_softmax,
)

torch.bmm(attn_weights, self.padded_values[key_length][layer_idx], out=self.attn_output_expanded)

Expand Down
Loading