Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
bc281e2
qwen3_moe mtp
shihaobai Dec 1, 2025
f4f8415
fix weight name
shihaobai Dec 1, 2025
652dd7d
fix qwen3 fa3 mtp
shihaobai Dec 1, 2025
11dc305
fix
shihaobai Dec 1, 2025
5c2ae24
fix
shihaobai Dec 1, 2025
f09c9bb
fix
shihaobai Dec 1, 2025
c9de6f6
fix rebase
Dec 30, 2025
60b3113
mtp dense
shihaobai Dec 9, 2025
bfa1cfa
mtp dense weight
shihaobai Dec 9, 2025
323761e
fix
shihaobai Dec 9, 2025
e973b5e
fix
shihaobai Dec 9, 2025
aeb4609
fix
shihaobai Dec 9, 2025
224c398
remove mtp norm
shihaobai Dec 30, 2025
71bcd72
mtp dense
shihaobai Dec 12, 2025
5046c53
update
shihaobai Dec 25, 2025
043799b
fix
Dec 30, 2025
72616ef
fix
Dec 30, 2025
7955d77
fix mtp mistral model
Dec 30, 2025
47f768a
mistral mtp pre layer infer
Dec 30, 2025
f63a725
fix pre layer mtp
Dec 30, 2025
acdd94e
fix mistral mtp weight load
Dec 30, 2025
253a60c
fix
Dec 30, 2025
979cd27
fix
Dec 30, 2025
6048371
qwen3next
sufubao Dec 9, 2025
75193df
add radix cache hit rate
sufubao Dec 9, 2025
da00caf
hhh
sufubao Dec 9, 2025
22d4669
reset
sufubao Dec 10, 2025
1d52953
draft
sufubao Dec 10, 2025
fd24683
tmp
sufubao Dec 10, 2025
3680965
done
sufubao Dec 11, 2025
333dca9
fix
sufubao Dec 11, 2025
dc5ce01
fix cudagraph
sufubao Dec 11, 2025
2ae7135
update kernel
sufubao Dec 11, 2025
c27c7c6
use autotuner
sufubao Dec 11, 2025
d1009a5
fix autotuner
sufubao Dec 11, 2025
4cd4727
update_kernel
sufubao Dec 11, 2025
7c18ce8
update kernel
sufubao Dec 12, 2025
8ef741a
try fix
sufubao Dec 12, 2025
e54b8dd
clean code
sufubao Dec 12, 2025
a71a68d
try fix
sufubao Dec 12, 2025
062dfc8
fix prefix cache
sufubao Dec 12, 2025
fee5a78
fix
sufubao Dec 12, 2025
76210f1
refactor
sufubao Dec 15, 2025
280fb26
fix
sufubao Dec 15, 2025
6759e0a
fix
sufubao Dec 16, 2025
4145564
fix
sufubao Dec 16, 2025
4e73244
fix
sufubao Dec 17, 2025
7cf6261
feat: add Qwen3Next MTP (Multi-Token Prediction) support
sufubao Dec 17, 2025
8af273e
fix
sufubao Dec 18, 2025
d6184db
fix: set Triton allocator for TMA kernels in solve_tril
sufubao Dec 17, 2025
e547035
feat: add TTFT/ITL metrics and fix Chinese input handling in test_chat
sufubao Dec 17, 2025
fb15516
feat: optimize hybrid radix cache buffer insertion strategy
sufubao Dec 17, 2025
9cf2579
fix: prefix cache optim
sufubao Dec 18, 2025
e433cd1
fix
sufubao Dec 18, 2025
0bb5b63
feat: add radix prefix hit rate log
sufubao Dec 18, 2025
d552d7d
fix
sufubao Dec 18, 2025
c7799c3
fix: prefix cache
sufubao Dec 19, 2025
f1f5db1
update prefix algo
sufubao Dec 22, 2025
22d32c6
draft
sufubao Dec 30, 2025
1fd4c92
fix mistral support fa3
Dec 31, 2025
3528145
draft
sufubao Dec 31, 2025
810c8b9
Merge branch 'qwen3_mtp_dense' from origin
sufubao Dec 31, 2025
f204a9f
before cc
sufubao Jan 3, 2026
40f82b7
mtp can run
sufubao Jan 4, 2026
4210ec1
refactor
sufubao Jan 5, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ dist
.idea
.vscode
tmp/
.claude
98 changes: 98 additions & 0 deletions lightllm/common/allocator_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from typing import List, Union

import torch

from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt
from lightllm.utils.log_utils import init_logger

logger = init_logger(__name__)


class TokenAllocator:
def __init__(self, size, shared_can_use_token_num_name: str):
self.size = size

self.mem_state = torch.arange(
0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True
)
self._mem_state_return = torch.arange(
0, self.size * 3, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True
)
self._return_start = 0
self.mark_start = 0
self.mark_end = self.size

self.can_use_mem_size = self.size

# 用共享内存进行共享,router 模块读取进行精确的调度估计, nccl port 作为一个单机中单实列的标记。防止冲突。
self.shared_can_use_token_num = SharedInt(shared_can_use_token_num_name)

self.shared_can_use_token_num.set_value(self.can_use_mem_size)
self.HOLD_TOKEN_MEMINDEX = self.size

def alloc(self, need_size) -> torch.Tensor:
if need_size > self.mark_end - self.mark_start:
logger.error(f"warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}")
assert False, "error alloc state"

start = self.mark_start
end = self.mark_start + need_size
self.mark_start += need_size

self.can_use_mem_size -= need_size
self.shared_can_use_token_num.set_value(self.can_use_mem_size)

# 利用缓冲区返回,避免异步情况下的内存竞争
if self._return_start + need_size > self._mem_state_return.shape[0]:
self._return_start = 0
ans = self._mem_state_return[self._return_start : self._return_start + need_size]
ans.copy_(self.mem_state[start:end])
self._return_start += need_size
return ans

def free(self, free_index: Union[torch.Tensor, List[int]]):
"""_summary_

Args:
free_index (torch.Tensor): _description_
"""
end = self.mark_start
start = self.mark_start - len(free_index)
assert start >= 0, f"error free state start: {self.mark_start} free len {len(free_index)}"

if isinstance(free_index, list):
free_index_tensor = torch.tensor(free_index, dtype=self.mem_state.dtype, device=self.mem_state.device)
self.mem_state[start:end] = free_index_tensor
else:
# 从 gpu 到 cpu 的拷贝操作是流内阻塞操作
self.mem_state[start:end] = free_index

self.mark_start -= len(free_index)

self.can_use_mem_size += len(free_index)
self.shared_can_use_token_num.set_value(self.can_use_mem_size)

if self.can_use_mem_size == len(self.mem_state):
logger.debug(f"freed all gpu mem size {self.can_use_mem_size}")
return

def free_all(self):
self.can_use_mem_size = len(self.mem_state)
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
self.mem_state.numpy()[:] = list(range(0, len(self.mem_state)))
self.mark_start = 0
self.mark_end = len(self.mem_state)

def resize_mem(self, new_size):
"""
just for test code
"""
self.size = new_size
self.mem_state = torch.arange(
0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True
)
self.mark_start = 0
self.mark_end = self.size
self.can_use_mem_size = self.size
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
return
10 changes: 8 additions & 2 deletions lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0)
infer_state.prefix_total_token_num = model_input.prefix_total_token_num
assert model_input.b_req_idx.shape[0] == model_input.b_seq_len.shape[0]
infer_state.b_req_idx = model_input.b_req_idx
infer_state.b_mtp_index = model_input.b_mtp_index
infer_state.b_seq_len = model_input.b_seq_len
if model_input.is_prefill:
if model_input.b_ready_cache_len is not None:
Expand Down Expand Up @@ -993,8 +994,13 @@ def _init_padded_req(self):
def _gen_special_model_input(self, token_num: int):
special_model_input = {}

is_deepseekv3_mtp_draft_model = "Deepseek3MTPModel" in str(self.__class__)
if is_deepseekv3_mtp_draft_model:
is_mtp_draft_model = (
"Deepseek3MTPModel" in str(self.__class__)
or "Qwen3NextMTPModel" in str(self.__class__)
or "Qwen3MOEMTPModel" in str(self.__class__)
or "MistralMTPModel" in str(self.__class__)
)
if is_mtp_draft_model:
special_model_input["deepseekv3_mtp_draft_input_hiddens"] = torch.randn(
token_num, self.config["hidden_size"], dtype=self.data_type, device="cuda"
)
Expand Down
17 changes: 15 additions & 2 deletions lightllm/common/basemodel/cuda_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import copy
import bisect
from typing import Optional
from tqdm import tqdm
from lightllm.utils.log_utils import init_logger
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.distributed import dist_group_manager, lightllm_capture_graph, CustomProcessGroup
Expand Down Expand Up @@ -191,7 +192,12 @@ def warmup(self, model):
model: TpPartBaseModel = model

# decode cuda graph init
for batch_size in self.cuda_graph_batch_sizes[::-1]:
progress_bar = tqdm(self.cuda_graph_batch_sizes[::-1], desc="Capturing CUDA graphs")
for batch_size in progress_bar:
# Get available memory info
avail_mem, total_mem = torch.cuda.mem_get_info()
avail_mem_gb = avail_mem / (1024 ** 3)
progress_bar.set_description(f"Capturing CUDA graphs - Batch: {batch_size}, AvailMem: {avail_mem_gb:.2f}GB")
seq_len = 2
total_token_num = batch_size * seq_len
max_len_in_batch = self.graph_max_len_in_batch
Expand Down Expand Up @@ -247,7 +253,14 @@ def warmup_overlap(self, model):

model: TpPartBaseModel = model

for batch_size in self.cuda_graph_batch_sizes[::-1]:
progress_bar = tqdm(self.cuda_graph_batch_sizes[::-1], desc="Capturing overlap CUDA graphs")
for batch_size in progress_bar:
# Get available memory info
avail_mem, total_mem = torch.cuda.mem_get_info()
avail_mem_gb = avail_mem / (1024 ** 3)
progress_bar.set_description(
f"Capturing overlap CUDA graphs - Batch: {batch_size}, AvailMem: {avail_mem_gb:.2f}GB"
)
decode_batches = []
for micro_batch_index in [0, 1]:
# dummy decoding, capture the cudagraph
Expand Down
2 changes: 2 additions & 0 deletions lightllm/common/basemodel/infer_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(self):
self.batch_size: int = None
self.total_token_num: int = None
self.b_req_idx: torch.Tensor = None
self.b_mtp_index: torch.Tensor = None # MTP index for each batch item (0: main, 1-mtp_step: candidates)
self.b_start_loc: torch.Tensor = None
self.b_ready_cache_len: torch.Tensor = None # only for prefill prompt cache used.

Expand Down Expand Up @@ -112,6 +113,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
) = gen_decode_params(self.b_seq_len)
# TODO: check the correctness
self.max_kv_seq_len = self.max_len_in_batch
self.max_q_seq_len = self.b_q_seq_len.max().item() if self.b_q_seq_len.numel() > 0 else 1
self.b_start_loc = self.b1_cu_kv_seq_len[0:-1]

def copy_for_cuda_graph(self, new_infer_state: "InferStateInfo"):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,20 +64,21 @@ def _ffn(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor
def _tpsp_ffn(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor:
raise Exception("need to impl")

def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight):
input1 = self._att_norm(input_embdings, infer_state, layer_weight)
q, cache_kv = self._get_qkv(input1, infer_state, layer_weight)
input1 = None
def context_attention_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight):
q, cache_kv = self._get_qkv(input_embdings, infer_state, layer_weight)
self._post_cache_kv(cache_kv, infer_state, layer_weight)

o = self._context_attention_wrapper_run(
q=q, cache_kv=cache_kv, infer_state=infer_state, layer_weight=layer_weight
)

q = None
o = self._get_o(o, infer_state, layer_weight)
if self.tp_world_size_ > 1:
all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False)
return o

def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight):
input1 = self._att_norm(input_embdings, infer_state, layer_weight)
o = self.context_attention_forward(input1, infer_state, layer_weight)
input_embdings.add_(o.view(-1, self.embed_dim_))
o = None

Expand All @@ -89,39 +90,42 @@ def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_wei
input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
return input_embdings

def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight):
input1 = self._att_norm(input_embdings, infer_state, layer_weight)
q, cache_kv = self._get_qkv(input1, infer_state, layer_weight)
input1 = None
def token_attention_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight):
q, cache_kv = self._get_qkv(input_embdings, infer_state, layer_weight)
self._post_cache_kv(cache_kv, infer_state, layer_weight)
o = self._token_attention_kernel(q, infer_state, layer_weight)
q = None
o = self._get_o(o, infer_state, layer_weight)
if self.tp_world_size_ > 1:
all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False)
return o

def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight):
input1 = self._att_norm(input_embdings, infer_state, layer_weight)
o = self.token_attention_forward(input1, infer_state, layer_weight)
input_embdings.add_(o.view(-1, self.embed_dim_))
o = None

input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
ffn_out = self._ffn(input1, infer_state, layer_weight)
input1 = None
if self.tp_world_size_ > 1:
all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False)
input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
return input_embdings

def tpsp_context_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight):
input1 = self._att_norm(input_embdings, infer_state, layer_weight)
q, cache_kv = self._tpsp_get_qkv(input1, infer_state, layer_weight)
input1 = None
def tpsp_context_attention_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight):
q, cache_kv = self._tpsp_get_qkv(input_embdings, infer_state, layer_weight)
self._post_cache_kv(cache_kv, infer_state, layer_weight)

o = self._context_attention_wrapper_run(
q=q, cache_kv=cache_kv, infer_state=infer_state, layer_weight=layer_weight
)

q = None
o = self._tpsp_get_o(o, infer_state, layer_weight)
return o

def tpsp_context_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight):
input1 = self._att_norm(input_embdings, infer_state, layer_weight)
o = self.tpsp_context_attention_forward(input1, infer_state, layer_weight)
input_embdings.add_(o.view(-1, self.embed_dim_))
o = None

Expand All @@ -131,14 +135,17 @@ def tpsp_context_forward(self, input_embdings: torch.Tensor, infer_state: InferS
input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
return input_embdings

def tpsp_token_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight):
input1 = self._att_norm(input_embdings, infer_state, layer_weight)
q, cache_kv = self._tpsp_get_qkv(input1, infer_state, layer_weight)
input1 = None
def tpsp_token_attention_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight):
q, cache_kv = self._tpsp_get_qkv(input_embdings, infer_state, layer_weight)
self._post_cache_kv(cache_kv, infer_state, layer_weight)
o = self._token_attention_kernel(q, infer_state, layer_weight)
q = None
o = self._tpsp_get_o(o, infer_state, layer_weight)
return o

def tpsp_token_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight):
input1 = self._att_norm(input_embdings, infer_state, layer_weight)
o = self.tpsp_token_attention_forward(input1, infer_state, layer_weight)
input_embdings.add_(o.view(-1, self.embed_dim_))
o = None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@
from .norm_weight import NormWeight, GEMMANormWeight, TpNormWeight
from .fused_moe_weight_tp import create_tp_moe_wegiht_obj
from .fused_moe_weight_ep import FusedMoeWeightEP
from .parameter_weight import ParameterWeight, TpParameterWeight
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import torch
from typing import Dict
from .base_weight import BaseWeightTpl
from lightllm.utils.dist_utils import get_current_device_id


class ParameterWeight(BaseWeightTpl):
def __init__(self, weight_name: str, data_type: torch.dtype, bias_name: str = None):
super().__init__()
self.weight_name = weight_name
self.bias_name = bias_name
self.data_type_ = data_type
self.weight = None
self.bias = None

def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None:
if self.weight_name in weights:
self.weight = weights[self.weight_name].to(self.data_type_).cuda(get_current_device_id())
if self.bias_name in weights:
self.bias = weights[self.bias_name].to(self.data_type_).cuda(get_current_device_id())

def verify_load(self):
load_ok = True
# Verify weight. The weight must be not None.
load_ok = load_ok and self.weight is not None
# Verify bias. If bias_name is set, it must be not None.
if self.bias_name is not None:
load_ok = load_ok and self.bias is not None
return load_ok


class TpParameterWeight(ParameterWeight):
def __init__(self, weight_name: str, data_type: torch.dtype, split_n_embed: int, bias_name: str = None):
super().__init__(weight_name, data_type, bias_name)
self.split_n_embed = split_n_embed

def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None:
start = self.split_n_embed * self.tp_rank_
end = self.split_n_embed * (self.tp_rank_ + 1)

if self.weight_name in weights:
self.weight = weights[self.weight_name][start:end].to(self.data_type_).cuda(get_current_device_id())
if self.bias_name in weights:
self.bias = weights[self.bias_name][start:end].to(self.data_type_).cuda(get_current_device_id())
Loading