From 9352fe80286902160cf452ff97170e5d7a7977e8 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Mon, 22 Dec 2025 10:34:37 +0800 Subject: [PATCH 01/11] qwen3_vl support prefill cuda graph feature --- lightllm/common/basemodel/basemodel.py | 27 ++++++++++++++++--- .../transformer_layer_infer_template.py | 6 ++--- .../layer_infer/transformer_layer_infer.py | 2 +- lightllm/utils/config_utils.py | 13 +++++++++ 4 files changed, 40 insertions(+), 8 deletions(-) diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index cde2b5039..e3017e0f8 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -515,9 +515,27 @@ def _context_forward(self, input_ids, infer_state: InferStateInfo): input_embs = pre_method(cuda_input_ids, infer_state, self.pre_post_weight) input_tensors = [input_embs] - def prefill_func(input_tensors, infer_state): + # prefill cuda graph 在 qwen3 vl 上的前几层由于特殊的处理,导致目前无法支持cuda graph + from lightllm.utils.config_utils import is_qwen3_vl + + if is_qwen3_vl(): + no_graph_layer_num = 3 + else: + no_graph_layer_num = 0 + + def no_graph_prefill_func(input_tensors, infer_state): + _input_embs = input_tensors[0] + for i in range(no_graph_layer_num): + layer = self.layers_infer[i] + layer_method = (layer.context_forward, layer.tpsp_context_forward)[run_mode_index] + _input_embs = layer_method(_input_embs, infer_state, self.trans_layers_weight[i]) + return [_input_embs] + + input_tensors = no_graph_prefill_func(input_tensors=input_tensors, infer_state=infer_state) + + def graph_prefill_func(input_tensors, infer_state): _input_embs = input_tensors[0] - for i in range(self.layers_num): + for i in range(no_graph_layer_num, self.layers_num): layer = self.layers_infer[i] layer_method = (layer.context_forward, layer.tpsp_context_forward)[run_mode_index] _input_embs = layer_method(_input_embs, infer_state, self.trans_layers_weight[i]) @@ -531,7 +549,7 @@ def prefill_func(input_tensors, infer_state): ) if self.prefill_graph.need_capture(handle_token_num=finded_handle_token_num): output_tensors: List[torch.Tensor] = self.prefill_graph.capture_prefill( - prefill_func=prefill_func, + prefill_func=graph_prefill_func, input_tensors=input_tensors, infer_state=infer_state, ) @@ -542,7 +560,8 @@ def prefill_func(input_tensors, infer_state): else: g_cache_manager.cache_env_in() - output_tensors: List[torch.Tensor] = prefill_func(input_tensors, infer_state) + input_tensors = no_graph_prefill_func(input_tensors=input_tensors, infer_state=infer_state) + output_tensors: List[torch.Tensor] = graph_prefill_func(input_tensors, infer_state) g_cache_manager.cache_env_out() input_embs = output_tensors[0] diff --git a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py index 06f2251ca..436ca77d8 100755 --- a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py +++ b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py @@ -70,7 +70,7 @@ def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_wei input1 = None self._post_cache_kv(cache_kv, infer_state, layer_weight) - o = self.__context_attention_wrapper_run( + o = self._context_attention_wrapper_run( q=q, cache_kv=cache_kv, infer_state=infer_state, layer_weight=layer_weight ) @@ -116,7 +116,7 @@ def tpsp_context_forward(self, input_embdings: torch.Tensor, infer_state: InferS input1 = None self._post_cache_kv(cache_kv, infer_state, layer_weight) - o = self.__context_attention_wrapper_run( + o = self._context_attention_wrapper_run( q=q, cache_kv=cache_kv, infer_state=infer_state, layer_weight=layer_weight ) @@ -148,7 +148,7 @@ def tpsp_token_forward(self, input_embdings: torch.Tensor, infer_state: InferSta input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) return input_embdings - def __context_attention_wrapper_run( + def _context_attention_wrapper_run( self, q: torch.Tensor, cache_kv: torch.Tensor, infer_state: InferStateInfo, layer_weight ) -> torch.Tensor: if torch.cuda.is_current_stream_capturing(): diff --git a/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py index 30b72a5ca..18e1302f5 100644 --- a/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py @@ -63,7 +63,7 @@ def context_forward(self, input_embdings, infer_state: Qwen3VLInferStateInfo, la q, cache_kv = self._get_qkv(input1, infer_state, layer_weight) input1 = None self._post_cache_kv(cache_kv, infer_state, layer_weight) - o = self._context_attention_kernel(q, cache_kv, infer_state, layer_weight) + o = self._context_attention_wrapper_run(q, cache_kv, infer_state, layer_weight) q = None o = self._get_o(o, infer_state, layer_weight) if self.tp_world_size_ > 1: diff --git a/lightllm/utils/config_utils.py b/lightllm/utils/config_utils.py index b06309f96..dec6240d4 100644 --- a/lightllm/utils/config_utils.py +++ b/lightllm/utils/config_utils.py @@ -132,3 +132,16 @@ def get_fixed_kv_len(): return len(model_cfg["prompt_cache_token_ids"]) else: return 0 + + +@lru_cache(maxsize=None) +def is_qwen3_vl(): + from lightllm.utils.llm_utils import get_llm_model_class + from lightllm.models import Qwen3VLTpPartModel, Qwen3VLMOETpPartModel + + model_class = get_llm_model_class() + + if model_class in [Qwen3VLTpPartModel, Qwen3VLMOETpPartModel]: + return True + else: + return False From 1ebdd7bb8bb4cc1f94c203cd7e2d872a304ecc88 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Mon, 22 Dec 2025 10:40:00 +0800 Subject: [PATCH 02/11] fix --- lightllm/common/basemodel/basemodel.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index e3017e0f8..cc8b1ac19 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -560,7 +560,6 @@ def graph_prefill_func(input_tensors, infer_state): else: g_cache_manager.cache_env_in() - input_tensors = no_graph_prefill_func(input_tensors=input_tensors, infer_state=infer_state) output_tensors: List[torch.Tensor] = graph_prefill_func(input_tensors, infer_state) g_cache_manager.cache_env_out() From 13883912f5fac7ee5e7adbd36dc7386f92735393 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Mon, 22 Dec 2025 10:43:28 +0800 Subject: [PATCH 03/11] fix --- lightllm/common/basemodel/basemodel.py | 26 ++++---------------------- lightllm/utils/config_utils.py | 13 ------------- 2 files changed, 4 insertions(+), 35 deletions(-) diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index cc8b1ac19..cde2b5039 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -515,27 +515,9 @@ def _context_forward(self, input_ids, infer_state: InferStateInfo): input_embs = pre_method(cuda_input_ids, infer_state, self.pre_post_weight) input_tensors = [input_embs] - # prefill cuda graph 在 qwen3 vl 上的前几层由于特殊的处理,导致目前无法支持cuda graph - from lightllm.utils.config_utils import is_qwen3_vl - - if is_qwen3_vl(): - no_graph_layer_num = 3 - else: - no_graph_layer_num = 0 - - def no_graph_prefill_func(input_tensors, infer_state): - _input_embs = input_tensors[0] - for i in range(no_graph_layer_num): - layer = self.layers_infer[i] - layer_method = (layer.context_forward, layer.tpsp_context_forward)[run_mode_index] - _input_embs = layer_method(_input_embs, infer_state, self.trans_layers_weight[i]) - return [_input_embs] - - input_tensors = no_graph_prefill_func(input_tensors=input_tensors, infer_state=infer_state) - - def graph_prefill_func(input_tensors, infer_state): + def prefill_func(input_tensors, infer_state): _input_embs = input_tensors[0] - for i in range(no_graph_layer_num, self.layers_num): + for i in range(self.layers_num): layer = self.layers_infer[i] layer_method = (layer.context_forward, layer.tpsp_context_forward)[run_mode_index] _input_embs = layer_method(_input_embs, infer_state, self.trans_layers_weight[i]) @@ -549,7 +531,7 @@ def graph_prefill_func(input_tensors, infer_state): ) if self.prefill_graph.need_capture(handle_token_num=finded_handle_token_num): output_tensors: List[torch.Tensor] = self.prefill_graph.capture_prefill( - prefill_func=graph_prefill_func, + prefill_func=prefill_func, input_tensors=input_tensors, infer_state=infer_state, ) @@ -560,7 +542,7 @@ def graph_prefill_func(input_tensors, infer_state): else: g_cache_manager.cache_env_in() - output_tensors: List[torch.Tensor] = graph_prefill_func(input_tensors, infer_state) + output_tensors: List[torch.Tensor] = prefill_func(input_tensors, infer_state) g_cache_manager.cache_env_out() input_embs = output_tensors[0] diff --git a/lightllm/utils/config_utils.py b/lightllm/utils/config_utils.py index dec6240d4..b06309f96 100644 --- a/lightllm/utils/config_utils.py +++ b/lightllm/utils/config_utils.py @@ -132,16 +132,3 @@ def get_fixed_kv_len(): return len(model_cfg["prompt_cache_token_ids"]) else: return 0 - - -@lru_cache(maxsize=None) -def is_qwen3_vl(): - from lightllm.utils.llm_utils import get_llm_model_class - from lightllm.models import Qwen3VLTpPartModel, Qwen3VLMOETpPartModel - - model_class = get_llm_model_class() - - if model_class in [Qwen3VLTpPartModel, Qwen3VLMOETpPartModel]: - return True - else: - return False From 1bbf90b212730b95077ecb610a1184ce9ed66b08 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Mon, 22 Dec 2025 11:00:13 +0800 Subject: [PATCH 04/11] fix --- .../layer_infer/transformer_layer_infer.py | 36 ++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py index 18e1302f5..a9a6954d6 100644 --- a/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py @@ -19,6 +19,7 @@ from lightllm.models.qwen3_vl.triton_kernel.deepstack_multimodal_emb import apply_deepstack_features from lightllm.models.qwen2_vl.layer_infer.transformer_layer_infer import Qwen2VLTransformerLayerInfer from lightllm.models.qwen3.triton_kernel.qk_norm import qk_rmsnorm_forward +from lightllm.utils.tensor_utils import tensor_to_no_ref_tensor class Qwen3VLTransformerLayerInfer(Qwen2VLTransformerLayerInfer): @@ -77,9 +78,42 @@ def context_forward(self, input_embdings, infer_state: Qwen3VLInferStateInfo, la if self.tp_world_size_ > 1: all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) - apply_deepstack_features( + self._apply_deepstack_features_wrapper_run( input_embeddings=input_embdings, infer_state=infer_state, layer_num=self.layer_num_, ) return input_embdings + + def _apply_deepstack_features_wrapper_run( + self, + input_embeddings: torch.Tensor, + infer_state: InferStateInfo, + layer_num: int, + ): + if torch.cuda.is_current_stream_capturing(): + input_embeddings = input_embeddings.contiguous() + _input_embeddings = tensor_to_no_ref_tensor(input_embeddings) + pre_capture_graph = infer_state.prefill_cuda_graph_get_current_capture_graph() + pre_capture_graph.__exit__(None, None, None) + + infer_state.prefill_cuda_graph_create_graph_obj() + infer_state.prefill_cuda_graph_get_current_capture_graph().__enter__() + + def apply_func(new_infer_state: InferStateInfo): + apply_deepstack_features( + input_embeddings=_input_embeddings, + infer_state=new_infer_state, + layer_num=layer_num, + ) + return + + infer_state.prefill_cuda_graph_add_cpu_runnning_func(func=apply_func, after_graph=pre_capture_graph) + else: + apply_deepstack_features( + input_embeddings=input_embeddings, + infer_state=infer_state, + layer_num=layer_num, + ) + + return From a17615cc5e0f363cfdc20ac4e321b5ac0557c907 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 23 Dec 2025 02:12:10 +0000 Subject: [PATCH 05/11] fix name --- lightllm/models/qwen3_vl/infer_struct.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/models/qwen3_vl/infer_struct.py b/lightllm/models/qwen3_vl/infer_struct.py index 097424fa1..a5769e3ff 100644 --- a/lightllm/models/qwen3_vl/infer_struct.py +++ b/lightllm/models/qwen3_vl/infer_struct.py @@ -7,4 +7,4 @@ def __init__(self): self.input_ids = None self.img_start_token_ids = None self.img_token_lens = None - self.img_start_locs = None + self.img_start_locs_in_cache = None From 83c1479ad89ab60c22c9e31a2d11d90672a32106 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 23 Dec 2025 02:48:32 +0000 Subject: [PATCH 06/11] fix get_mrope input batch_size check --- lightllm/models/qwen2_vl/infer_struct.py | 8 ++++++++ .../qwen2_vl/triton_kernel/get_mrope_position_ids.py | 1 + 2 files changed, 9 insertions(+) diff --git a/lightllm/models/qwen2_vl/infer_struct.py b/lightllm/models/qwen2_vl/infer_struct.py index da49f2841..ce7938b6a 100644 --- a/lightllm/models/qwen2_vl/infer_struct.py +++ b/lightllm/models/qwen2_vl/infer_struct.py @@ -50,6 +50,13 @@ def get_mrope_position(self, multimodal_params: List[dict]) -> torch.Tensor: b_image_len = [] image_start_num = 0 b_image_thwd = [] + + # pad multimodal_params to batch size. + batch_size = self.b_q_seq_len.shape[0] + multimodal_params = multimodal_params + [ + {"images": [], "audios": []} for _ in range(batch_size - len(multimodal_params)) + ] + for _, p in enumerate(multimodal_params): images = p.get("images", []) for img in images: @@ -59,6 +66,7 @@ def get_mrope_position(self, multimodal_params: List[dict]) -> torch.Tensor: b_image_nums.append(len(images)) b_image_start_num.append(image_start_num) image_start_num += len(images) + # 没有任何图片 if image_start_num == 0: return self.position_ids.unsqueeze(0).expand(3, -1).contiguous() diff --git a/lightllm/models/qwen2_vl/triton_kernel/get_mrope_position_ids.py b/lightllm/models/qwen2_vl/triton_kernel/get_mrope_position_ids.py index 4c7aa30b8..756198e89 100644 --- a/lightllm/models/qwen2_vl/triton_kernel/get_mrope_position_ids.py +++ b/lightllm/models/qwen2_vl/triton_kernel/get_mrope_position_ids.py @@ -94,6 +94,7 @@ def get_mrope_position_triton( ) -> torch.Tensor: batch_size = b_q_seq_len.shape[0] + assert batch_size == b_image_nums.shape[0] grid = (batch_size,) BLOCK_SIZE = 64 _get_mrope_position_triton[grid]( From 20d927b89939c7d95e3e1f1b45991204e82ee2b9 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 23 Dec 2025 03:09:16 +0000 Subject: [PATCH 07/11] fix multi_modal --- lightllm/common/basemodel/basemodel.py | 4 +++- lightllm/common/basemodel/batch_objs.py | 3 +-- .../mode_backend/chunked_prefill/impl.py | 8 ++------ .../impl_for_return_all_prompt_logprobs.py | 4 +--- .../chunked_prefill/impl_for_reward_model.py | 4 +--- .../mode_backend/diverse_backend/impl.py | 2 +- .../model_infer/mode_backend/dp_backend/impl.py | 4 ++-- .../mode_backend/generic_padded_pre_process.py | 15 +++++---------- .../mode_backend/generic_pre_process.py | 9 +++------ 9 files changed, 19 insertions(+), 34 deletions(-) diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index cde2b5039..0d01aabf4 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -378,7 +378,9 @@ def _create_padded_prefill_model_input(self, model_input: ModelInput, new_handle new_model_input.b_prefill_has_output_cpu = [e for e in new_model_input.b_prefill_has_output_cpu] + [False] new_model_input.prefix_total_token_num = model_input.prefix_total_token_num - # TODO 多模态的参数需要 pad 吗,需要check + new_model_input.multimodal_params = [e for e in new_model_input.multimodal_params] + [ + {"images": [], "audios": []} + ] # 特殊模型,特殊模式的特殊变量的特殊 padding if new_model_input.deepseekv3_mtp_draft_input_hiddens is not None: diff --git a/lightllm/common/basemodel/batch_objs.py b/lightllm/common/basemodel/batch_objs.py index 06b742634..002703415 100644 --- a/lightllm/common/basemodel/batch_objs.py +++ b/lightllm/common/basemodel/batch_objs.py @@ -36,8 +36,7 @@ class ModelInput: is_prefill: bool = False b_ready_cache_len: torch.Tensor = None b_prefill_start_loc: torch.Tensor = None - multimodal_params: list = field(default_factory=list) - + multimodal_params: list = None # cpu 变量 mem_indexes_cpu: torch.Tensor = None # prefill 阶段使用的参数,但是不是推理过程使用的参数,是推理外部进行资源管理 diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index 2c3cfaf11..858f4713a 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -106,9 +106,7 @@ def prefill_normal( prefill_reqs: List[InferReq], ): # 第一阶段: 模型推理 - model_input, run_reqs = prepare_prefill_inputs( - prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill, is_multimodal=self.is_multimodal - ) + model_input, run_reqs = prepare_prefill_inputs(prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) _, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( @@ -185,9 +183,7 @@ def prefill_mtp( event_pack: OverlapEventPack, prefill_reqs: List[InferReq], ): - model_input, run_reqs = prepare_prefill_inputs( - prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill, is_multimodal=self.is_multimodal - ) + model_input, run_reqs = prepare_prefill_inputs(prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) next_token_ids, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_return_all_prompt_logprobs.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_return_all_prompt_logprobs.py index c11040278..a9fbb41e7 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_return_all_prompt_logprobs.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_return_all_prompt_logprobs.py @@ -19,9 +19,7 @@ def return_all_prompt_logprobs_prefill(self, event_pack: OverlapEventPack, prefi assert self.radix_cache is None assert self.disable_chunked_prefill is True - model_input, run_reqs = prepare_prefill_inputs( - prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill, is_multimodal=self.is_multimodal - ) + model_input, run_reqs = prepare_prefill_inputs(prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill) model_output = self.model.forward(model_input) prompt_all_logits = model_output.logits diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_reward_model.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_reward_model.py index 0889a8ea1..dfb102082 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_reward_model.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_reward_model.py @@ -16,9 +16,7 @@ def __init__(self) -> None: def reward_prefill(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq]): assert self.disable_chunked_prefill is True - model_input, run_reqs = prepare_prefill_inputs( - prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill, is_multimodal=self.is_multimodal - ) + model_input, run_reqs = prepare_prefill_inputs(prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill) model_output = self.model.forward(model_input) scores: torch.Tensor = model_output.logits diff --git a/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py index c973c4e9c..5b0bf3335 100644 --- a/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py @@ -28,7 +28,7 @@ def beam_prefill(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq group_reqs = [g_infer_context.requests_mapping[req.req_id] for req in prefill_reqs if req.is_master_req()] model_input, group_run_reqs = prepare_prefill_inputs( - group_reqs, is_chuncked_mode=not self.disable_chunked_prefill, is_multimodal=self.is_multimodal + group_reqs, is_chuncked_mode=not self.disable_chunked_prefill ) with torch.cuda.stream(g_infer_context.get_overlap_stream()): diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index 734dc9998..4f69f04b9 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -141,7 +141,7 @@ def prefill_normal( event_pack: OverlapEventPack, prefill_reqs: List[InferReq], ): - model_input, run_reqs, _ = padded_prepare_prefill_inputs(prefill_reqs, is_multimodal=self.is_multimodal) + model_input, run_reqs, _ = padded_prepare_prefill_inputs(prefill_reqs) run_reqs_num = len(run_reqs) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) @@ -355,7 +355,7 @@ def decode_overlap(self, event_pack: OverlapEventPack, decode_reqs: List[InferRe def prefill_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq]): # main model prefill - model_input, run_reqs, _ = padded_prepare_prefill_inputs(prefill_reqs, is_multimodal=self.is_multimodal) + model_input, run_reqs, _ = padded_prepare_prefill_inputs(prefill_reqs) req_num = len(run_reqs) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output: ModelOutput = self.model.forward(model_input) diff --git a/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py index f6edf5893..679ce2472 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py @@ -12,7 +12,7 @@ def padded_prepare_prefill_inputs( - req_objs: List[InferReq], dest_batch_size: Optional[int] = None, is_multimodal=False + req_objs: List[InferReq], dest_batch_size: Optional[int] = None ) -> Tuple[ModelInput, List[InferReq], int]: if dest_batch_size is None: @@ -116,9 +116,8 @@ def padded_prepare_prefill_inputs( b_prefill_start_loc=b_prefill_start_loc, is_prefill=True, b_prefill_has_output_cpu=b_prefill_has_output, + multimodal_params=batch_multimodal_params, ) - if is_multimodal: - model_input.multimodal_params = batch_multimodal_params return model_input, run_reqs, padded_req_num @@ -220,8 +219,8 @@ def padded_prepare_decode_inputs( b_mtp_index=b_mtp_index, b_seq_len=b_seq_len, is_prefill=False, + multimodal_params=batch_multimodal_params, ) - model_input.multimodal_params = batch_multimodal_params return model_input, run_reqs, padded_req_num @@ -245,12 +244,8 @@ def padded_overlap_prepare_decode_inputs( def padded_overlap_prepare_prefill_inputs(req_objs: List[InferReq], is_multimodal=False): micro_batch1_req_num = triton.cdiv(len(req_objs), 2) - micro_input, run_reqs, padded_req_num = padded_prepare_prefill_inputs( - req_objs[0:micro_batch1_req_num], is_multimodal=is_multimodal - ) + micro_input, run_reqs, padded_req_num = padded_prepare_prefill_inputs(req_objs[0:micro_batch1_req_num]) - micro_input1, run_reqs1, padded_req_num1 = padded_prepare_prefill_inputs( - req_objs[micro_batch1_req_num:], is_multimodal=is_multimodal - ) + micro_input1, run_reqs1, padded_req_num1 = padded_prepare_prefill_inputs(req_objs[micro_batch1_req_num:]) return micro_input, run_reqs, padded_req_num, micro_input1, run_reqs1, padded_req_num1 diff --git a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py index a122d4756..bdb36054b 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py @@ -10,9 +10,7 @@ ) -def prepare_prefill_inputs( - req_objs: List[InferReq], is_chuncked_mode: bool, is_multimodal: bool = False -) -> Tuple[ModelInput, List[InferReq]]: +def prepare_prefill_inputs(req_objs: List[InferReq], is_chuncked_mode: bool) -> Tuple[ModelInput, List[InferReq]]: run_reqs = [] total_token_num = 0 prefix_total_token_num = 0 @@ -88,9 +86,8 @@ def prepare_prefill_inputs( is_prefill=True, b_prefill_has_output_cpu=b_prefill_has_output, prefix_total_token_num=prefix_total_token_num, + multimodal_params=batch_multimodal_params, ) - if is_multimodal: - model_input.multimodal_params = batch_multimodal_params return model_input, run_reqs @@ -160,8 +157,8 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In b_shared_seq_len=b_shared_seq_len, b_mark_shared_group=b_mark_shared_group, is_prefill=False, + multimodal_params=multimodal_params, ) - model_input.multimodal_params = multimodal_params return model_input, run_reqs From 6b80d8aa48bd95bea7699ad237088482415e7b7b Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 23 Dec 2025 03:21:32 +0000 Subject: [PATCH 08/11] fix multi_modal --- lightllm/common/basemodel/basemodel.py | 5 +++-- lightllm/common/basemodel/batch_objs.py | 3 +++ lightllm/common/basemodel/cuda_graph.py | 2 ++ lightllm/common/basemodel/prefill_cuda_graph.py | 2 ++ test/benchmark/static_inference/model_infer.py | 8 ++++++-- test/benchmark/static_inference/model_infer_mtp.py | 2 ++ 6 files changed, 18 insertions(+), 4 deletions(-) diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 0d01aabf4..1418d1a1c 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -829,6 +829,7 @@ def _check_max_len_infer(self): is_prefill=True, b_ready_cache_len=b_ready_cache_len, b_prefill_start_loc=b_prefill_start_loc, + multimodal_params=[{"images": [], "audios": []}], ) model_output = self.forward( model_input, @@ -905,7 +906,7 @@ def _autotune_warmup(self): is_prefill=True, b_ready_cache_len=b_ready_cache_len, b_prefill_start_loc=b_prefill_start_loc, - multimodal_params=[], + multimodal_params=[{"images": [], "audios": []}], **self._gen_special_model_input(total_token_num), ) model_output = self.forward( @@ -968,7 +969,7 @@ def _init_padded_req(self): b_ready_cache_len=b_ready_cache_len, b_prefill_start_loc=b_prefill_start_loc, is_prefill=True, - multimodal_params=[], + multimodal_params=[{"images": [], "audios": []} for _ in range(batch_size)], **self._gen_special_model_input(total_token_num), ) diff --git a/lightllm/common/basemodel/batch_objs.py b/lightllm/common/basemodel/batch_objs.py index 002703415..6574d64bd 100644 --- a/lightllm/common/basemodel/batch_objs.py +++ b/lightllm/common/basemodel/batch_objs.py @@ -73,6 +73,9 @@ def to_cuda(self): else: self.b_shared_seq_len = self.b_shared_seq_len.cuda(non_blocking=True) + def __post_init__(self): + assert len(self.multimodal_params) == self.batch_size + @dataclass class ModelOutput: diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py index c754fabce..15c55e91c 100644 --- a/lightllm/common/basemodel/cuda_graph.py +++ b/lightllm/common/basemodel/cuda_graph.py @@ -216,6 +216,7 @@ def warmup(self, model): b_seq_len=b_seq_len, b_mtp_index=b_mtp_index, is_prefill=False, + multimodal_params=[{"images": [], "audios": []} for _ in range(batch_size)], **model._gen_special_model_input(batch_size), ) model_output: ModelOutput = model.forward(model_input) @@ -274,6 +275,7 @@ def warmup_overlap(self, model): mem_indexes=mem_indexes, b_req_idx=b_req_idx, b_seq_len=b_seq_len, + multimodal_params=[{"images": [], "audios": []} for _ in range(batch_size)], **model._gen_special_model_input(batch_size), ) decode_batches.append(micro_batch) diff --git a/lightllm/common/basemodel/prefill_cuda_graph.py b/lightllm/common/basemodel/prefill_cuda_graph.py index d7412156a..3d77a3ae4 100644 --- a/lightllm/common/basemodel/prefill_cuda_graph.py +++ b/lightllm/common/basemodel/prefill_cuda_graph.py @@ -182,6 +182,7 @@ def warmup(self, model): is_prefill=True, b_prefill_has_output_cpu=[False], prefix_total_token_num=0, + multimodal_params=[{"images": [], "audios": []}], **model._gen_special_model_input(token_num=total_token_num), ) model_output: ModelOutput = model.forward(model_input) @@ -242,6 +243,7 @@ def warmup_overlap(self, model): is_prefill=True, b_prefill_has_output_cpu=[False], prefix_total_token_num=0, + multimodal_params=[{"images": [], "audios": []}], **model._gen_special_model_input(token_num=total_token_num), ) diff --git a/test/benchmark/static_inference/model_infer.py b/test/benchmark/static_inference/model_infer.py index b8b535647..3fc7ee4b4 100644 --- a/test/benchmark/static_inference/model_infer.py +++ b/test/benchmark/static_inference/model_infer.py @@ -90,7 +90,7 @@ def overlap_prefill( b_seq_len=_0_b_seq_len, is_prefill=True, b_ready_cache_len=_o_b_ready_cache_len, - multimodal_params={}, + multimodal_params=[{"images": [], "audios": []} for _ in range(_0_batch_size)], mem_indexes_cpu=_0_mem_indexes, ) @@ -114,7 +114,7 @@ def overlap_prefill( b_seq_len=_1_b_seq_len, is_prefill=True, b_ready_cache_len=_1_b_ready_cache_len, - multimodal_params={}, + multimodal_params=[{"images": [], "audios": []} for _ in range(_1_batch_size)], mem_indexes_cpu=_1_mem_indexes, ) @@ -144,6 +144,7 @@ def overlap_decode( b_mtp_index=_0_b_mtp_index, b_seq_len=_0_b_seq_len, mem_indexes_cpu=_0_mem_indexes, + multimodal_params=[{"images": [], "audios": []} for _ in range(_0_batch_size)], ) _1_batch_size = batch_size - batch_size // 2 @@ -164,6 +165,7 @@ def overlap_decode( b_mtp_index=_1_b_mtp_index, b_seq_len=_1_b_seq_len, mem_indexes_cpu=_1_mem_indexes, + multimodal_params=[{"images": [], "audios": []} for _ in range(_1_batch_size)], ) output, output1 = model_part.microbatch_overlap_decode(micro_batch1, micro_batch2) @@ -202,6 +204,7 @@ def prefill( b_ready_cache_len=b_ready_cache_len, # b_ready_cache_len b_prefill_start_loc=b_prefill_start_loc, prefix_total_token_num=0, # the default kvcache len is zero. + multimodal_params=[{"images": [], "audios": []} for _ in range(batch_size)], ) model_output = model_part.forward(model_input) @@ -223,6 +226,7 @@ def decode( b_mtp_index=b_mtp_index, mem_indexes_cpu=mem_indexes, is_prefill=False, + multimodal_params=[{"images": [], "audios": []} for _ in range(batch_size)], ) model_output = model_part.forward(model_input) return model_output.logits diff --git a/test/benchmark/static_inference/model_infer_mtp.py b/test/benchmark/static_inference/model_infer_mtp.py index ba90e709b..ef2ada64c 100644 --- a/test/benchmark/static_inference/model_infer_mtp.py +++ b/test/benchmark/static_inference/model_infer_mtp.py @@ -136,6 +136,7 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_ b_seq_len=b_seq_len, is_prefill=True, b_ready_cache_len=b_ready_cache_len, + multimodal_params=[{"images": [], "audios": []} for _ in range(batch_size)], ) model_output: ModelOutput = main_model.forward(model_input) @@ -202,6 +203,7 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_ b_req_idx=nopad_b_seq_idx, b_seq_len=nopad_b_seq_len, is_prefill=False, + multimodal_params=[{"images": [], "audios": []} for _ in range(batch_size * (len(draft_models) + 1))], ) # Main decode From e2d4ed4a6867e0667967135debe4f2d662b2e841 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 23 Dec 2025 04:30:39 +0000 Subject: [PATCH 09/11] fix --- lightllm/common/basemodel/basemodel.py | 6 ++++++ lightllm/common/basemodel/batch_objs.py | 3 +++ 2 files changed, 9 insertions(+) diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 1418d1a1c..84d53f3b1 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -328,6 +328,10 @@ def _create_padded_decode_model_input(self, model_input: ModelInput, new_batch_s mode="constant", value=self.mem_manager.HOLD_TOKEN_MEMINDEX, ) + new_model_input.multimodal_params = new_model_input.multimodal_params + [ + {"images": [], "audios": []} for _ in range(padded_batch_size) + ] + if enable_diverse_mode_gqa_decode_fast_kernel(): if new_model_input.b_shared_seq_len is not None: new_model_input.b_shared_seq_len = F.pad( @@ -345,6 +349,7 @@ def _create_padded_decode_model_input(self, model_input: ModelInput, new_batch_s new_batch_size=new_batch_size, ) + new_model_input.check_input() return new_model_input def _create_padded_prefill_model_input(self, model_input: ModelInput, new_handle_token_num: int): @@ -389,6 +394,7 @@ def _create_padded_prefill_model_input(self, model_input: ModelInput, new_handle new_batch_size=new_handle_token_num, ) + new_model_input.check_input() return new_model_input def _create_unpad_decode_model_output(self, model_output: ModelOutput, origin_batch_size: int): diff --git a/lightllm/common/basemodel/batch_objs.py b/lightllm/common/basemodel/batch_objs.py index 6574d64bd..5a98a13df 100644 --- a/lightllm/common/basemodel/batch_objs.py +++ b/lightllm/common/basemodel/batch_objs.py @@ -74,6 +74,9 @@ def to_cuda(self): self.b_shared_seq_len = self.b_shared_seq_len.cuda(non_blocking=True) def __post_init__(self): + self.check_input() + + def check_input(self): assert len(self.multimodal_params) == self.batch_size From b091ccf0c4226b1f5975e6de70dd4150f515dabd Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 23 Dec 2025 04:59:06 +0000 Subject: [PATCH 10/11] fix --- .../server/router/model_infer/mode_backend/dp_backend/impl.py | 4 ++-- .../model_infer/mode_backend/generic_padded_pre_process.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index 4f69f04b9..a1414b8b2 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -232,7 +232,7 @@ def prefill_overlap(self, event_pack: OverlapEventPack, prefill_reqs: List[Infer model_input1, run_reqs1, _, - ) = padded_overlap_prepare_prefill_inputs(prefill_reqs, is_multimodal=self.is_multimodal) + ) = padded_overlap_prepare_prefill_inputs(prefill_reqs) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output0, model_output1 = self.model.microbatch_overlap_prefill(model_input0, model_input1) @@ -626,7 +626,7 @@ def prefill_overlap_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[I model_input1, run_reqs1, _, - ) = padded_overlap_prepare_prefill_inputs(prefill_reqs, is_multimodal=self.is_multimodal) + ) = padded_overlap_prepare_prefill_inputs(prefill_reqs) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output0, model_output1 = self.model.microbatch_overlap_prefill(model_input0, model_input1) logits0 = model_output0.logits diff --git a/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py index 679ce2472..6465995c4 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py @@ -241,7 +241,7 @@ def padded_overlap_prepare_decode_inputs( return micro_input, run_reqs, padded_req_num, micro_input1, run_reqs1, padded_req_num1 -def padded_overlap_prepare_prefill_inputs(req_objs: List[InferReq], is_multimodal=False): +def padded_overlap_prepare_prefill_inputs(req_objs: List[InferReq]): micro_batch1_req_num = triton.cdiv(len(req_objs), 2) micro_input, run_reqs, padded_req_num = padded_prepare_prefill_inputs(req_objs[0:micro_batch1_req_num]) From 5ea82c7781286a3558b875670379b576889ccba8 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 23 Dec 2025 05:06:05 +0000 Subject: [PATCH 11/11] fix --- lightllm/server/core/objs/req.py | 16 +++++++--------- .../server/router/model_infer/infer_batch.py | 2 +- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/lightllm/server/core/objs/req.py b/lightllm/server/core/objs/req.py index 0d2e7ae38..1e1796335 100644 --- a/lightllm/server/core/objs/req.py +++ b/lightllm/server/core/objs/req.py @@ -234,15 +234,13 @@ def get_prompt_ids_numpy(self): return self.shm_prompt_ids.arr[: self.input_len] def to_router_rpc_obj(self): - if hasattr(self, "multimodal_params"): - return ( - self.request_id, - self.index_in_shm_mem, - self.multimodal_params, - self.sample_params.suggested_dp_index, - ) - else: - return (self.request_id, self.index_in_shm_mem, None, self.sample_params.suggested_dp_index) + assert hasattr(self, "multimodal_params") + return ( + self.request_id, + self.index_in_shm_mem, + self.multimodal_params, + self.sample_params.suggested_dp_index, + ) def can_release(self): # 只有管理节点有一个引用 diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index c0ba9303e..4b8b3c538 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -320,7 +320,7 @@ def __init__( req_id: int, req_idx: int, shm_index: int, - multimodal_params=None, + multimodal_params: MultimodalParams, vocab_size: int = -1, init_prefix_cache: bool = True, ):