-
Notifications
You must be signed in to change notification settings - Fork 292
add-glm4v model #1157
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
add-glm4v model #1157
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,104 @@ | ||
| import torch | ||
| import torch.functional as F | ||
| import torch.distributed as dist | ||
| import numpy as np | ||
| from typing import Tuple | ||
| from functools import partial | ||
|
|
||
| from lightllm.distributed import all_reduce | ||
| from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward | ||
| from lightllm.models.qwen2_vl.triton_kernel.mrope import mrope_triton_fused | ||
| from lightllm.models.qwen2_vl.infer_struct import Qwen2VLInferStateInfo | ||
| from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer | ||
| from lightllm.models.glm4v.layer_weight.transformer_layer_weight import Glm4VTransformerLayerWeight | ||
|
|
||
|
|
||
| class Glm4VTransformerLayerInfer(LlamaTransformerLayerInfer): | ||
| def __init__(self, layer_num, network_config, mode=[]): | ||
| super().__init__(layer_num, network_config, mode) | ||
| mrope_section = network_config["rope_parameters"]["mrope_section"] | ||
| self.mrope_section = torch.tensor(mrope_section, dtype=torch.int32, device="cuda") | ||
| self.partial_rotary_factor = network_config["rope_parameters"]["partial_rotary_factor"] | ||
|
|
||
| def _post_self_att_norm( | ||
| self, input, infer_state: Qwen2VLInferStateInfo, layer_weight: Glm4VTransformerLayerWeight | ||
| ) -> torch.Tensor: | ||
| out = self.alloc_tensor(input.shape, input.dtype) | ||
| rmsnorm_forward(input, weight=layer_weight._post_self_att_norm_weight_.weight, eps=self.eps_, out=out) | ||
| return out | ||
|
|
||
| def _post_mlp_norm( | ||
| self, input, infer_state: Qwen2VLInferStateInfo, layer_weight: Glm4VTransformerLayerWeight | ||
| ) -> torch.Tensor: | ||
| out = self.alloc_tensor(input.shape, input.dtype) | ||
| rmsnorm_forward(input, weight=layer_weight._post_mlp_norm_weight_.weight, eps=self.eps_, out=out) | ||
| return out | ||
|
|
||
| def _get_qkv(self, input, infer_state, layer_weight): | ||
| q = layer_weight.q_proj.mm(input) | ||
| cache_kv = layer_weight.kv_proj.mm(input).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) | ||
| mrope_triton_fused( | ||
| q.view(-1, self.tp_q_head_num_, self.head_dim_), | ||
| cache_kv[:, : self.tp_k_head_num_, :], | ||
| infer_state.position_cos, | ||
| infer_state.position_sin, | ||
| self.mrope_section, | ||
| partial_rotary_factor=self.partial_rotary_factor, | ||
| is_interleaved=False, | ||
| is_glm4v=True, | ||
| ) | ||
| return q, cache_kv | ||
|
|
||
| def context_forward(self, input_embdings, infer_state: Qwen2VLInferStateInfo, layer_weight): | ||
| input1 = self._att_norm(input_embdings, infer_state, layer_weight) | ||
| q, cache_kv = self._get_qkv(input1, infer_state, layer_weight) | ||
| input1 = None | ||
| self._post_cache_kv(cache_kv, infer_state, layer_weight) | ||
|
|
||
| o = self._TransformerLayerInferTpl__context_attention_wrapper_run( | ||
| q=q, cache_kv=cache_kv, infer_state=infer_state, layer_weight=layer_weight | ||
| ) | ||
|
|
||
| q = None | ||
| o = self._get_o(o, infer_state, layer_weight) | ||
| if self.tp_world_size_ > 1: | ||
| all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) | ||
| o = self._post_self_att_norm(o, infer_state, layer_weight) # add前多一次norm | ||
| input_embdings.add_(o.view(-1, self.embed_dim_)) | ||
| o = None | ||
|
|
||
| input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) | ||
| ffn_out = self._ffn(input1, infer_state, layer_weight) | ||
| ffn_out = self._post_mlp_norm(ffn_out, infer_state, layer_weight) # mlp之后多一次norm | ||
| input1 = None | ||
| if self.tp_world_size_ > 1: | ||
| all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) | ||
| input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) | ||
| return input_embdings | ||
|
|
||
| def token_forward(self, input_embdings, infer_state: Qwen2VLInferStateInfo, layer_weight): | ||
| input1 = self._att_norm(input_embdings, infer_state, layer_weight) | ||
| q, cache_kv = self._get_qkv(input1, infer_state, layer_weight) | ||
| input1 = None | ||
| self._post_cache_kv(cache_kv, infer_state, layer_weight) | ||
| o = self._token_attention_kernel(q, infer_state, layer_weight) | ||
| q = None | ||
| o = self._get_o(o, infer_state, layer_weight) | ||
| if self.tp_world_size_ > 1: | ||
| all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) | ||
| o = self._post_self_att_norm(o, infer_state, layer_weight) # add前多一次norm | ||
| input_embdings.add_(o.view(-1, self.embed_dim_)) | ||
| o = None | ||
|
|
||
| input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) | ||
| ffn_out = self._ffn(input1, infer_state, layer_weight) | ||
| ffn_out = self._post_mlp_norm(ffn_out, infer_state, layer_weight) # mlp之后多一次norm | ||
| input1 = None | ||
| if self.tp_world_size_ > 1: | ||
| all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) | ||
| input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) | ||
| return input_embdings | ||
|
|
||
| def _tpsp_get_qkv(self, input, infer_state, layer_weight) -> Tuple[torch.Tensor, torch.Tensor]: | ||
| # TODO | ||
| raise Exception("not impl") | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,14 @@ | ||
| import numpy as np | ||
| from lightllm.models.qwen2.layer_weights.pre_and_post_layer_weight import Qwen2PreAndPostLayerWeight | ||
| from lightllm.models.qwen3_vl.layer_weights.pre_and_post_layer_weight import rename_weight_keys | ||
|
|
||
|
|
||
| class Glm4VPreAndPostLayerWeight(Qwen2PreAndPostLayerWeight): | ||
| def __init__(self, data_type, network_config, mode): | ||
| super().__init__(data_type, network_config, mode) | ||
| return | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| def load_hf_weights(self, weights): | ||
| rename_weight_keys(weights) | ||
| super().load_hf_weights(weights) | ||
| return | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,35 @@ | ||
| from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, COLMMWeight, NormWeight | ||
| from lightllm.models.qwen2.layer_weights.transformer_layer_weight import Qwen2TransformerLayerWeight | ||
|
|
||
|
|
||
| class Glm4VTransformerLayerWeight(Qwen2TransformerLayerWeight): | ||
| def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): | ||
| super().__init__(layer_num, data_type, network_config, mode, quant_cfg) | ||
|
|
||
| def _init_weight_names(self): | ||
| self._post_self_att_norm_weight_name = f"model.layers.{self.layer_num_}.post_self_attn_layernorm.weight" | ||
| self._post_self_att_norm_bias_name = None | ||
| self._post_mlp_norm_weight_name = f"model.layers.{self.layer_num_}.post_mlp_layernorm.weight" | ||
| self._post_mlp_norm_bias_name = None | ||
| super()._init_weight_names() | ||
|
|
||
| def load_hf_weights(self, weights): | ||
| gate_up_weight_name = f"model.layers.{self.layer_num_}.mlp.gate_up_proj.weight" | ||
| if gate_up_weight_name in weights: | ||
| intermediate_size = self.network_config_["intermediate_size"] | ||
| gate_up_proj = weights[gate_up_weight_name] | ||
| gate_weight_ = gate_up_proj[0:intermediate_size, :] | ||
| up_weight_ = gate_up_proj[intermediate_size:, :] | ||
| weights[self._gate_weight_name] = gate_weight_ | ||
| weights[self._up_weight_name] = up_weight_ | ||
| del weights[gate_up_weight_name] | ||
| super().load_hf_weights(weights) | ||
|
|
||
| def _init_norm(self): | ||
| self._post_self_att_norm_weight_ = NormWeight( | ||
| self._post_self_att_norm_weight_name, self.data_type_, bias_name=self._post_self_att_norm_bias_name | ||
| ) | ||
| self._post_mlp_norm_weight_ = NormWeight( | ||
| self._post_mlp_norm_weight_name, self.data_type_, bias_name=self._post_mlp_norm_bias_name | ||
| ) | ||
| super()._init_norm() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,87 @@ | ||
| import os | ||
| import json | ||
| import numpy as np | ||
| from lightllm.common.build_utils import repair_config | ||
| from lightllm.models.registry import ModelRegistry | ||
| from lightllm.models.qwen2_vl.infer_struct import Qwen2VLInferStateInfo | ||
| from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer | ||
| from lightllm.models.qwen2_vl.layer_infer.transformer_layer_infer import Qwen2VLTransformerLayerInfer | ||
| from lightllm.models.glm4v.layer_infer.transformer_layer_infer import Glm4VTransformerLayerInfer | ||
| from lightllm.models.glm4v.layer_weight.pre_and_post_layer_weight import Glm4VPreAndPostLayerWeight | ||
| from lightllm.models.glm4v.layer_weight.transformer_layer_weight import Glm4VTransformerLayerWeight | ||
| from lightllm.server.multimodal_params import MultimodalParams | ||
| from lightllm.models.qwen2_vl.model import QWen2VLTokenizer | ||
| from lightllm.models.qwen2.model import Qwen2TpPartModel | ||
|
|
||
|
|
||
| class GLM4VTokenizer(QWen2VLTokenizer): | ||
| def __init__(self, tokenizer=None, image_processor=None, **kwargs): | ||
| self.tokenizer = tokenizer | ||
| self.image_processor = image_processor | ||
| self.min_pixel = self.image_processor.size["shortest_edge"] | ||
| self.max_pixel = self.image_processor.size["longest_edge"] | ||
| self.patch_size = self.image_processor.patch_size | ||
| self.merge_size = self.image_processor.merge_size | ||
| self.image_start_id = kwargs["model_cfg"]["image_start_token_id"] | ||
| self.image_end_id = kwargs["model_cfg"]["image_end_token_id"] | ||
| self.image_token_id = kwargs["model_cfg"]["image_token_id"] | ||
|
|
||
| def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs): | ||
| origin_ids = self.tokenizer.encode(prompt) | ||
|
|
||
| # <img><image_pad></img> -> <img></img> | ||
| origin_ids = [token for token in origin_ids if token != self.image_token_id] | ||
| # <img></img> --> <img>id,id+1...id+num</img> | ||
| input_ids = [] | ||
| image_id = 0 | ||
| while True: | ||
| try: | ||
| start_idx = origin_ids.index(self.image_start_id) | ||
| if start_idx + 1 >= len(origin_ids): | ||
| break | ||
| if origin_ids[start_idx + 1] == self.image_end_id: | ||
| input_ids.extend(origin_ids[: start_idx + 1]) | ||
| token_id = multimodal_params.images[image_id].token_id | ||
| token_num = multimodal_params.images[image_id].token_num | ||
| multimodal_params.images[image_id].start_idx = len(input_ids) | ||
| input_ids.extend(range(token_id, token_id + token_num)) | ||
| input_ids.append(self.image_end_id) | ||
| origin_ids = origin_ids[start_idx + 2 :] | ||
| image_id += 1 | ||
| else: | ||
| raise ValueError("image token error") | ||
| except ValueError: | ||
| break | ||
| input_ids.extend(origin_ids) | ||
| return input_ids | ||
|
|
||
|
|
||
| @ModelRegistry(["glm4v"], is_multimodal=True) | ||
| class GLM4VTpPartModel(Qwen2TpPartModel): | ||
|
|
||
| pre_layer_infer_class = LlamaMultimodalPreLayerInfer | ||
| transformer_layer_infer_class = Glm4VTransformerLayerInfer | ||
|
|
||
| pre_and_post_weight_class = Glm4VPreAndPostLayerWeight | ||
| transformer_weight_class = Glm4VTransformerLayerWeight | ||
|
|
||
| infer_state_class = Qwen2VLInferStateInfo | ||
|
|
||
| def __init__(self, kvargs): | ||
| super().__init__(kvargs) | ||
| return | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| def _init_inferstate_cls(self): | ||
| pass | ||
|
|
||
| def _init_config(self): | ||
| with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file: | ||
| all_config = json.load(json_file) | ||
| self.config = all_config["text_config"] | ||
| # rename keys | ||
| repair_config(self.config, same_names=["num_attention_heads", "n_head"]) | ||
| repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"]) | ||
| repair_config(self.config, same_names=["num_hidden_layers", "n_layer"]) | ||
| if self.finetune_config: | ||
| self.config["vocab_size"] = self.finetune_config.vocab_size | ||
| return | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The import
torch.functional as Fis unused in this file and can be removed. Note thattorch.functionalis also a deprecated alias fortorch.nn.functional.