Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,10 @@ def _create_padded_decode_model_input(self, model_input: ModelInput, new_batch_s
mode="constant",
value=self.mem_manager.HOLD_TOKEN_MEMINDEX,
)
new_model_input.multimodal_params = new_model_input.multimodal_params + [
{"images": [], "audios": []} for _ in range(padded_batch_size)
]

if enable_diverse_mode_gqa_decode_fast_kernel():
if new_model_input.b_shared_seq_len is not None:
new_model_input.b_shared_seq_len = F.pad(
Expand All @@ -345,6 +349,7 @@ def _create_padded_decode_model_input(self, model_input: ModelInput, new_batch_s
new_batch_size=new_batch_size,
)

new_model_input.check_input()
return new_model_input

def _create_padded_prefill_model_input(self, model_input: ModelInput, new_handle_token_num: int):
Expand Down Expand Up @@ -378,7 +383,9 @@ def _create_padded_prefill_model_input(self, model_input: ModelInput, new_handle
new_model_input.b_prefill_has_output_cpu = [e for e in new_model_input.b_prefill_has_output_cpu] + [False]
new_model_input.prefix_total_token_num = model_input.prefix_total_token_num

# TODO 多模态的参数需要 pad 吗,需要check
new_model_input.multimodal_params = [e for e in new_model_input.multimodal_params] + [
{"images": [], "audios": []}
]

# 特殊模型,特殊模式的特殊变量的特殊 padding
if new_model_input.deepseekv3_mtp_draft_input_hiddens is not None:
Expand All @@ -387,6 +394,7 @@ def _create_padded_prefill_model_input(self, model_input: ModelInput, new_handle
new_batch_size=new_handle_token_num,
)

new_model_input.check_input()
return new_model_input

def _create_unpad_decode_model_output(self, model_output: ModelOutput, origin_batch_size: int):
Expand Down Expand Up @@ -827,6 +835,7 @@ def _check_max_len_infer(self):
is_prefill=True,
b_ready_cache_len=b_ready_cache_len,
b_prefill_start_loc=b_prefill_start_loc,
multimodal_params=[{"images": [], "audios": []}],
)
model_output = self.forward(
model_input,
Expand Down Expand Up @@ -903,7 +912,7 @@ def _autotune_warmup(self):
is_prefill=True,
b_ready_cache_len=b_ready_cache_len,
b_prefill_start_loc=b_prefill_start_loc,
multimodal_params=[],
multimodal_params=[{"images": [], "audios": []}],
**self._gen_special_model_input(total_token_num),
)
model_output = self.forward(
Expand Down Expand Up @@ -966,7 +975,7 @@ def _init_padded_req(self):
b_ready_cache_len=b_ready_cache_len,
b_prefill_start_loc=b_prefill_start_loc,
is_prefill=True,
multimodal_params=[],
multimodal_params=[{"images": [], "audios": []} for _ in range(batch_size)],
**self._gen_special_model_input(total_token_num),
)

Expand Down
9 changes: 7 additions & 2 deletions lightllm/common/basemodel/batch_objs.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ class ModelInput:
is_prefill: bool = False
b_ready_cache_len: torch.Tensor = None
b_prefill_start_loc: torch.Tensor = None
multimodal_params: list = field(default_factory=list)

multimodal_params: list = None
# cpu 变量
mem_indexes_cpu: torch.Tensor = None
# prefill 阶段使用的参数,但是不是推理过程使用的参数,是推理外部进行资源管理
Expand Down Expand Up @@ -74,6 +73,12 @@ def to_cuda(self):
else:
self.b_shared_seq_len = self.b_shared_seq_len.cuda(non_blocking=True)

def __post_init__(self):
self.check_input()

def check_input(self):
assert len(self.multimodal_params) == self.batch_size


@dataclass
class ModelOutput:
Expand Down
2 changes: 2 additions & 0 deletions lightllm/common/basemodel/cuda_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ def warmup(self, model):
b_seq_len=b_seq_len,
b_mtp_index=b_mtp_index,
is_prefill=False,
multimodal_params=[{"images": [], "audios": []} for _ in range(batch_size)],
**model._gen_special_model_input(batch_size),
)
model_output: ModelOutput = model.forward(model_input)
Expand Down Expand Up @@ -274,6 +275,7 @@ def warmup_overlap(self, model):
mem_indexes=mem_indexes,
b_req_idx=b_req_idx,
b_seq_len=b_seq_len,
multimodal_params=[{"images": [], "audios": []} for _ in range(batch_size)],
**model._gen_special_model_input(batch_size),
)
decode_batches.append(micro_batch)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_wei
input1 = None
self._post_cache_kv(cache_kv, infer_state, layer_weight)

o = self.__context_attention_wrapper_run(
o = self._context_attention_wrapper_run(
q=q, cache_kv=cache_kv, infer_state=infer_state, layer_weight=layer_weight
)

Expand Down Expand Up @@ -116,7 +116,7 @@ def tpsp_context_forward(self, input_embdings: torch.Tensor, infer_state: InferS
input1 = None
self._post_cache_kv(cache_kv, infer_state, layer_weight)

o = self.__context_attention_wrapper_run(
o = self._context_attention_wrapper_run(
q=q, cache_kv=cache_kv, infer_state=infer_state, layer_weight=layer_weight
)

Expand Down Expand Up @@ -148,7 +148,7 @@ def tpsp_token_forward(self, input_embdings: torch.Tensor, infer_state: InferSta
input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
return input_embdings

def __context_attention_wrapper_run(
def _context_attention_wrapper_run(
self, q: torch.Tensor, cache_kv: torch.Tensor, infer_state: InferStateInfo, layer_weight
) -> torch.Tensor:
if torch.cuda.is_current_stream_capturing():
Expand Down
2 changes: 2 additions & 0 deletions lightllm/common/basemodel/prefill_cuda_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def warmup(self, model):
is_prefill=True,
b_prefill_has_output_cpu=[False],
prefix_total_token_num=0,
multimodal_params=[{"images": [], "audios": []}],
**model._gen_special_model_input(token_num=total_token_num),
)
model_output: ModelOutput = model.forward(model_input)
Expand Down Expand Up @@ -242,6 +243,7 @@ def warmup_overlap(self, model):
is_prefill=True,
b_prefill_has_output_cpu=[False],
prefix_total_token_num=0,
multimodal_params=[{"images": [], "audios": []}],
**model._gen_special_model_input(token_num=total_token_num),
)

Expand Down
8 changes: 8 additions & 0 deletions lightllm/models/qwen2_vl/infer_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@ def get_mrope_position(self, multimodal_params: List[dict]) -> torch.Tensor:
b_image_len = []
image_start_num = 0
b_image_thwd = []

# pad multimodal_params to batch size.
batch_size = self.b_q_seq_len.shape[0]
multimodal_params = multimodal_params + [
{"images": [], "audios": []} for _ in range(batch_size - len(multimodal_params))
]

for _, p in enumerate(multimodal_params):
images = p.get("images", [])
for img in images:
Expand All @@ -59,6 +66,7 @@ def get_mrope_position(self, multimodal_params: List[dict]) -> torch.Tensor:
b_image_nums.append(len(images))
b_image_start_num.append(image_start_num)
image_start_num += len(images)

# 没有任何图片
if image_start_num == 0:
return self.position_ids.unsqueeze(0).expand(3, -1).contiguous()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def get_mrope_position_triton(
) -> torch.Tensor:

batch_size = b_q_seq_len.shape[0]
assert batch_size == b_image_nums.shape[0]
grid = (batch_size,)
BLOCK_SIZE = 64
_get_mrope_position_triton[grid](
Expand Down
2 changes: 1 addition & 1 deletion lightllm/models/qwen3_vl/infer_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ def __init__(self):
self.input_ids = None
self.img_start_token_ids = None
self.img_token_lens = None
self.img_start_locs = None
self.img_start_locs_in_cache = None
38 changes: 36 additions & 2 deletions lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from lightllm.models.qwen3_vl.triton_kernel.deepstack_multimodal_emb import apply_deepstack_features
from lightllm.models.qwen2_vl.layer_infer.transformer_layer_infer import Qwen2VLTransformerLayerInfer
from lightllm.models.qwen3.triton_kernel.qk_norm import qk_rmsnorm_forward
from lightllm.utils.tensor_utils import tensor_to_no_ref_tensor


class Qwen3VLTransformerLayerInfer(Qwen2VLTransformerLayerInfer):
Expand Down Expand Up @@ -63,7 +64,7 @@ def context_forward(self, input_embdings, infer_state: Qwen3VLInferStateInfo, la
q, cache_kv = self._get_qkv(input1, infer_state, layer_weight)
input1 = None
self._post_cache_kv(cache_kv, infer_state, layer_weight)
o = self._context_attention_kernel(q, cache_kv, infer_state, layer_weight)
o = self._context_attention_wrapper_run(q, cache_kv, infer_state, layer_weight)
q = None
o = self._get_o(o, infer_state, layer_weight)
if self.tp_world_size_ > 1:
Expand All @@ -77,9 +78,42 @@ def context_forward(self, input_embdings, infer_state: Qwen3VLInferStateInfo, la
if self.tp_world_size_ > 1:
all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False)
input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
apply_deepstack_features(
self._apply_deepstack_features_wrapper_run(
input_embeddings=input_embdings,
infer_state=infer_state,
layer_num=self.layer_num_,
)
return input_embdings

def _apply_deepstack_features_wrapper_run(
self,
input_embeddings: torch.Tensor,
infer_state: InferStateInfo,
layer_num: int,
):
if torch.cuda.is_current_stream_capturing():
input_embeddings = input_embeddings.contiguous()
_input_embeddings = tensor_to_no_ref_tensor(input_embeddings)
pre_capture_graph = infer_state.prefill_cuda_graph_get_current_capture_graph()
pre_capture_graph.__exit__(None, None, None)

infer_state.prefill_cuda_graph_create_graph_obj()
infer_state.prefill_cuda_graph_get_current_capture_graph().__enter__()

def apply_func(new_infer_state: InferStateInfo):
apply_deepstack_features(
input_embeddings=_input_embeddings,
infer_state=new_infer_state,
layer_num=layer_num,
)
return

infer_state.prefill_cuda_graph_add_cpu_runnning_func(func=apply_func, after_graph=pre_capture_graph)
else:
apply_deepstack_features(
input_embeddings=input_embeddings,
infer_state=infer_state,
layer_num=layer_num,
)

return
16 changes: 7 additions & 9 deletions lightllm/server/core/objs/req.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,15 +234,13 @@ def get_prompt_ids_numpy(self):
return self.shm_prompt_ids.arr[: self.input_len]

def to_router_rpc_obj(self):
if hasattr(self, "multimodal_params"):
return (
self.request_id,
self.index_in_shm_mem,
self.multimodal_params,
self.sample_params.suggested_dp_index,
)
else:
return (self.request_id, self.index_in_shm_mem, None, self.sample_params.suggested_dp_index)
assert hasattr(self, "multimodal_params")
return (
self.request_id,
self.index_in_shm_mem,
self.multimodal_params,
self.sample_params.suggested_dp_index,
)

def can_release(self):
# 只有管理节点有一个引用
Expand Down
2 changes: 1 addition & 1 deletion lightllm/server/router/model_infer/infer_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def __init__(
req_id: int,
req_idx: int,
shm_index: int,
multimodal_params=None,
multimodal_params: MultimodalParams,
vocab_size: int = -1,
init_prefix_cache: bool = True,
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,7 @@ def prefill_normal(
prefill_reqs: List[InferReq],
):
# 第一阶段: 模型推理
model_input, run_reqs = prepare_prefill_inputs(
prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill, is_multimodal=self.is_multimodal
)
model_input, run_reqs = prepare_prefill_inputs(prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill)
with torch.cuda.stream(g_infer_context.get_overlap_stream()):
model_output = self.model.forward(model_input)
_, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token(
Expand Down Expand Up @@ -185,9 +183,7 @@ def prefill_mtp(
event_pack: OverlapEventPack,
prefill_reqs: List[InferReq],
):
model_input, run_reqs = prepare_prefill_inputs(
prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill, is_multimodal=self.is_multimodal
)
model_input, run_reqs = prepare_prefill_inputs(prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill)
with torch.cuda.stream(g_infer_context.get_overlap_stream()):
model_output = self.model.forward(model_input)
next_token_ids, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@ def return_all_prompt_logprobs_prefill(self, event_pack: OverlapEventPack, prefi
assert self.radix_cache is None
assert self.disable_chunked_prefill is True

model_input, run_reqs = prepare_prefill_inputs(
prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill, is_multimodal=self.is_multimodal
)
model_input, run_reqs = prepare_prefill_inputs(prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill)

model_output = self.model.forward(model_input)
prompt_all_logits = model_output.logits
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@ def __init__(self) -> None:
def reward_prefill(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq]):

assert self.disable_chunked_prefill is True
model_input, run_reqs = prepare_prefill_inputs(
prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill, is_multimodal=self.is_multimodal
)
model_input, run_reqs = prepare_prefill_inputs(prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill)

model_output = self.model.forward(model_input)
scores: torch.Tensor = model_output.logits
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def beam_prefill(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq
group_reqs = [g_infer_context.requests_mapping[req.req_id] for req in prefill_reqs if req.is_master_req()]

model_input, group_run_reqs = prepare_prefill_inputs(
group_reqs, is_chuncked_mode=not self.disable_chunked_prefill, is_multimodal=self.is_multimodal
group_reqs, is_chuncked_mode=not self.disable_chunked_prefill
)

with torch.cuda.stream(g_infer_context.get_overlap_stream()):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def prefill_normal(
event_pack: OverlapEventPack,
prefill_reqs: List[InferReq],
):
model_input, run_reqs, _ = padded_prepare_prefill_inputs(prefill_reqs, is_multimodal=self.is_multimodal)
model_input, run_reqs, _ = padded_prepare_prefill_inputs(prefill_reqs)
run_reqs_num = len(run_reqs)
with torch.cuda.stream(g_infer_context.get_overlap_stream()):
model_output = self.model.forward(model_input)
Expand Down Expand Up @@ -232,7 +232,7 @@ def prefill_overlap(self, event_pack: OverlapEventPack, prefill_reqs: List[Infer
model_input1,
run_reqs1,
_,
) = padded_overlap_prepare_prefill_inputs(prefill_reqs, is_multimodal=self.is_multimodal)
) = padded_overlap_prepare_prefill_inputs(prefill_reqs)

with torch.cuda.stream(g_infer_context.get_overlap_stream()):
model_output0, model_output1 = self.model.microbatch_overlap_prefill(model_input0, model_input1)
Expand Down Expand Up @@ -355,7 +355,7 @@ def decode_overlap(self, event_pack: OverlapEventPack, decode_reqs: List[InferRe

def prefill_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq]):
# main model prefill
model_input, run_reqs, _ = padded_prepare_prefill_inputs(prefill_reqs, is_multimodal=self.is_multimodal)
model_input, run_reqs, _ = padded_prepare_prefill_inputs(prefill_reqs)
req_num = len(run_reqs)
with torch.cuda.stream(g_infer_context.get_overlap_stream()):
model_output: ModelOutput = self.model.forward(model_input)
Expand Down Expand Up @@ -626,7 +626,7 @@ def prefill_overlap_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[I
model_input1,
run_reqs1,
_,
) = padded_overlap_prepare_prefill_inputs(prefill_reqs, is_multimodal=self.is_multimodal)
) = padded_overlap_prepare_prefill_inputs(prefill_reqs)
with torch.cuda.stream(g_infer_context.get_overlap_stream()):
model_output0, model_output1 = self.model.microbatch_overlap_prefill(model_input0, model_input1)
logits0 = model_output0.logits
Expand Down
Loading