-
Notifications
You must be signed in to change notification settings - Fork 10
Open
Description
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# Prepare full_position_ids for the keys (from the cache)
full_position_ids = torch.arange(
0, past_key_value.seen_tokens, dtype=torch.long, device=query_states.device
)
full_position_ids = full_position_ids.unsqueeze(0)
key_states = apply_single_rotary_pos_emb(key_states, cos, sin, full_position_ids)Based on the code provided in your repository, the position IDs for key_states have been reassigned and rotated. However, each document block has already been assigned corresponding positions when the KV cache is prefilled. After concatenating these KV caches, the original positions are [0, 1, 2, ... , l, 0, 1, 2, ..., l, 0, 1, 2, ..., l]. If the rotation in the aforementioned code is applied, shouldn't the result be in the form of [0, 2, 4, ..., 2l, l+1, l+3]? This doesn't seem to be in line with expectations.
Metadata
Metadata
Assignees
Labels
No labels