Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,24 @@ def _prefill(
model_input: ModelInput,
):
infer_state = self._create_inferstate(model_input)

# Capture old indexer_ks positions before they are overwritten
# This is needed for DeepSeek v3.2 to copy cached tokens' indexer_ks
old_indexer_ks_positions = []
for i in range(infer_state.b_req_idx.shape[0]):
req_idx = infer_state.b_req_idx[i].item()
ready_cache_len = infer_state.b_ready_cache_len[i].item()

if ready_cache_len > 0:
# Capture old positions for cached tokens
old_pos = self.req_manager.req_to_token_indexs[
req_idx, 0:ready_cache_len
].clone() # Clone to avoid view issues
old_indexer_ks_positions.append(old_pos)
else:
# No cached tokens for this request
old_indexer_ks_positions.append(None)

init_req_to_token_indexes(
req_to_token_indexs=self.req_manager.req_to_token_indexs,
b_req_idx=infer_state.b_req_idx,
Expand All @@ -356,6 +374,8 @@ def _prefill(
b_start_loc=model_input.b_prefill_start_loc,
alloc_mem_index=infer_state.mem_index,
max_q_seq_len=infer_state.max_q_seq_len,
mem_manager=self.req_manager.mem_manager,
old_indexer_ks_positions=old_indexer_ks_positions,
)
prefill_mem_indexes_ready_event = torch.cuda.Event()
prefill_mem_indexes_ready_event.record()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,6 @@ def __init__(
self.global_rank_ = get_global_rank()
self.redundancy_expert_num = get_redundancy_expert_num()
self.redundancy_expert_ids = get_redundancy_expert_ids(layer_num)
logger.info(
f"global_rank {self.global_rank_} layerindex {layer_num} redundancy_expertids: {self.redundancy_expert_ids}"
)
self.redundancy_expert_ids_tensor = torch.tensor(self.redundancy_expert_ids, dtype=torch.int64, device="cuda")
self.routed_expert_counter_tensor = torch.zeros((self.n_routed_experts,), dtype=torch.int64, device="cuda")
self.total_expert_num_contain_redundancy = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,17 @@ def load_hf_weights(self, weights):
"""
for attr_name in dir(self):
attr = getattr(self, attr_name, None)
if isinstance(attr, MultiMMWeightTpl):
if isinstance(attr, TransformerLayerWeight):
attr.load_hf_weights(weights)
elif isinstance(attr, MultiMMWeightTpl):
with self.lock:
attr.load_hf_weights(weights)
elif isinstance(attr, BaseWeight):
attr.load_hf_weights(weights)

def verify_load(self):
for attr_name in dir(self):
attr = getattr(self, attr_name, None)
if isinstance(attr, TransformerLayerWeight):
attr.verify_load()
super().verify_load()
4 changes: 2 additions & 2 deletions lightllm/common/deepseek2_fp8kv_mem_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@


class Deepseek2FP8KVMemoryManager(Deepseek2MemoryManager):
def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9):
def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9, is_sub_mem_manager=False):
# scale被追加到kv_buffer末尾, 因此加2, dtype统一改成uint8
super().__init__(size, torch.uint8, head_num, head_dim + 2, layer_num, always_copy, mem_fraction)
super().__init__(size, torch.uint8, head_num, head_dim + 2, layer_num, always_copy, mem_fraction, is_sub_mem_manager)
4 changes: 2 additions & 2 deletions lightllm/common/deepseek2_mem_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@


class Deepseek2MemoryManager(MemoryManager):
def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9):
super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction)
def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9, is_sub_mem_manager=False):
super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction, is_sub_mem_manager)

def get_cell_size(self):
return self.head_num * self.head_dim * self.layer_num * torch._utils._element_size(self.dtype)
Expand Down
70 changes: 69 additions & 1 deletion lightllm/common/infer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,17 @@


def init_req_to_token_indexes(
req_to_token_indexs, b_req_idx, b_seq_len, b_ready_cache_len, b_start_loc, alloc_mem_index, max_q_seq_len
req_to_token_indexs,
b_req_idx,
b_seq_len,
b_ready_cache_len,
b_start_loc,
alloc_mem_index,
max_q_seq_len,
mem_manager=None,
old_indexer_ks_positions=None,
):
# Step 1: Copy KV cache for NEW tokens (existing logic)
copy_kv_index_to_req_prefill(
req_to_token_indexs=req_to_token_indexs,
b_req_idx=b_req_idx,
Expand All @@ -13,3 +22,62 @@ def init_req_to_token_indexes(
memindex=alloc_mem_index,
max_q_seq_len=max_q_seq_len,
)

# Step 2: Copy indexer_ks for CACHED tokens (DeepSeek v3.2 specific)
# This ensures consistency between KV cache and indexer_ks buffers
# when prefix cache is hit
if (
mem_manager is not None
and hasattr(mem_manager, "indexer_ks_mem_manager")
and old_indexer_ks_positions is not None
):

_copy_cached_indexer_ks_to_new_positions(
req_to_token_indexs=req_to_token_indexs,
b_req_idx=b_req_idx,
b_ready_cache_len=b_ready_cache_len,
mem_manager=mem_manager,
old_indexer_ks_positions=old_indexer_ks_positions,
)


def _copy_cached_indexer_ks_to_new_positions(
req_to_token_indexs,
b_req_idx,
b_ready_cache_len,
mem_manager,
old_indexer_ks_positions,
):
"""
Copy cached tokens' indexer_ks from old positions to new positions.

This function is called after copy_kv_index_to_req_prefill() has updated
req_to_token_indexs to point to new contiguous positions. We need to copy
indexer_ks data to match the KV cache layout.

For each layer and each request with cached tokens:
- Copy indexer_ks data from old positions to new positions
- This ensures consistency when using extract_indexer_ks later
"""
from lightllm.models.deepseek3_2.triton_kernel.copy_indexer_ks import copy_indexer_ks

# Get number of layers from indexer_ks_mem_manager
num_layers = len(mem_manager.indexer_ks_mem_manager.kv_buffer)
indexer_buffer = mem_manager.indexer_ks_mem_manager.kv_buffer

for layer_idx in range(num_layers):
for i in range(b_req_idx.shape[0]):
req_idx = b_req_idx[i].item()
ready_cache_len = b_ready_cache_len[i].item()
old_positions = old_indexer_ks_positions[i]

if ready_cache_len > 0 and old_positions is not None:
# New positions after copy_kv_index_to_req_prefill
new_positions = req_to_token_indexs[req_idx, 0:ready_cache_len]

# Copy indexer_ks: old_positions -> new_positions
copy_indexer_ks(
buffer=indexer_buffer[layer_idx],
src_loc=old_positions,
dest_loc=new_positions,
)
34 changes: 24 additions & 10 deletions lightllm/common/mem_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@


class MemoryManager:
def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9):
def __init__(
self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9, is_sub_mem_manager=False
):
self.size = size
self.head_num = head_num
self.head_dim = head_dim
Expand All @@ -41,15 +43,16 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False

self.can_use_mem_size = self.size

# 用共享内存进行共享,router 模块读取进行精确的调度估计, nccl port 作为一个单机中单实列的标记。防止冲突。
from lightllm.utils.envs_utils import get_unique_server_name
if not is_sub_mem_manager:
# 用共享内存进行共享,router 模块读取进行精确的调度估计, nccl port 作为一个单机中单实列的标记。防止冲突。
from lightllm.utils.envs_utils import get_unique_server_name

rank_in_node = get_current_rank_in_node()
self.shared_can_use_token_num = SharedInt(
f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}"
)
rank_in_node = get_current_rank_in_node()
self.shared_can_use_token_num = SharedInt(
f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}"
)

self.shared_can_use_token_num.set_value(self.can_use_mem_size)
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
self._init_buffers(
self.size,
dtype,
Expand All @@ -71,6 +74,17 @@ def profile_size(self, mem_fraction):
available_memory = get_available_gpu_memory(world_size) - total_memory * (1 - mem_fraction)
cell_size = self.get_cell_size()
self.size = int(available_memory * 1024 ** 3 / cell_size)

# Ensure size is at least a minimum positive value to avoid torch.arange errors
MIN_SIZE = 1024 # Minimum 1024 tokens
if self.size < MIN_SIZE:
logger.warning(
f"Insufficient memory for KV cache. Available: {available_memory:.2f} GB, "
f"but calculated size is {self.size} tokens. Using minimum size {MIN_SIZE} tokens instead. "
f"Consider reducing model size, using fewer GPUs, or increasing mem_fraction."
)
self.size = MIN_SIZE

if world_size > 1:
tensor = torch.tensor(self.size, dtype=torch.int64, device=f"cuda:{get_current_device_id()}")
dist.all_reduce(tensor, op=dist.ReduceOp.MIN)
Expand All @@ -93,7 +107,7 @@ def alloc_kv_move_buffer(self, max_req_total_len):
"""
pd 分离模式使用的特殊接口
"""
if isinstance(self, MemoryManager) and type(self) != MemoryManager:
if isinstance(self, MemoryManager) and type(self) is not MemoryManager:
raise NotImplementedError("subclass need reimpl this method")
self.kv_move_buffer = torch.empty(
(1, max_req_total_len + 8, 2 * self.head_num, self.head_dim), dtype=self.dtype, device="cuda"
Expand All @@ -103,7 +117,7 @@ def alloc_kv_move_buffer(self, max_req_total_len):
return

def alloc_paged_kv_move_buffer(self, page_num, page_size) -> torch.Tensor:
if isinstance(self, MemoryManager) and type(self) != MemoryManager:
if isinstance(self, MemoryManager) and type(self) is not MemoryManager:
raise NotImplementedError("subclass need reimpl this method")

num_kv_head = get_num_key_value_heads(get_env_start_args().model_dir)
Expand Down
2 changes: 1 addition & 1 deletion lightllm/common/triton_utils/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def autotune(
as needed before invocation.
"""

def decorator(fn):
def decorator(fn: Callable) -> Callable:
return Autotuner(
fn=fn,
kernel_name=kernel_name,
Expand Down
1 change: 1 addition & 0 deletions lightllm/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from lightllm.models.gemma_2b.model import Gemma_2bTpPartModel
from lightllm.models.phi3.model import Phi3TpPartModel
from lightllm.models.deepseek2.model import Deepseek2TpPartModel
from lightllm.models.deepseek3_2.model import Deepseek3_2TpPartModel
from lightllm.models.internvl.model import (
InternVLLlamaTpPartModel,
InternVLPhi3TpPartModel,
Expand Down
Empty file.
Loading