Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 79 additions & 8 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1173,6 +1173,8 @@ def get_vocab_base_pre(self, tokenizer) -> str:
if chkhsh == "877081d19cf6996e2c4ff0e1236341e9b7bde288f5311a56a937f0afbbb3aeb5":
# ref: https://huggingface.co/deepseek-ai/DeepSeek-V3
res = "deepseek-v3"
if chkhsh == "9d70134b369a70e5735009b6de918f7581b5211f7c074d1f89f753aea8248af1":
res = "utu-vl"
if chkhsh == "b3f499bb4255f8ca19fccd664443283318f2fd2414d5e0b040fbdd0cc195d6c5":
# ref: https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
res = "deepseek-r1-qwen"
Expand Down Expand Up @@ -7133,6 +7135,7 @@ def prepare_tensors(self):
"DeepseekV2ForCausalLM",
"DeepseekV3ForCausalLM",
"KimiVLForConditionalGeneration",
"UTUVLForCausalLM",
)
class DeepseekV2Model(TextModel):
model_arch = gguf.MODEL_ARCH.DEEPSEEK2
Expand Down Expand Up @@ -7211,11 +7214,26 @@ def set_gguf_parameters(self):
self.gguf_writer.add_key_length_mla(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"])
self.gguf_writer.add_value_length_mla(hparams["v_head_dim"])

self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
self.gguf_writer.add_expert_count(hparams["n_routed_experts"])
self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"])
self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"])
self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])
if hparams.get("moe_intermediate_size") is not None:
self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
else:
self.gguf_writer.add_expert_feed_forward_length(hparams.get("intermediate_size", 0))

if hparams.get("n_routed_experts") is not None:
self.gguf_writer.add_expert_count(hparams["n_routed_experts"])

if hparams.get("n_shared_experts") is not None:
self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"])
else:
self.gguf_writer.add_expert_shared_count(0)

if hparams.get("routed_scaling_factor") is not None:
self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"])
else:
self.gguf_writer.add_expert_weights_scale(1.0)

if hparams.get("norm_topk_prob") is not None and hparams["norm_topk_prob"]:
self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])

self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])

Expand All @@ -7226,15 +7244,26 @@ def set_gguf_parameters(self):
self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1 * rope_mscale_all)

_experts: list[dict[str, Tensor]] | None = None
_token_embd: Tensor | None = None

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# skip vision tensors and remove "language_model." for Kimi-VL
if "vision_tower" in name or "multi_modal_projector" in name:
return []

if name.startswith("siglip2.") or name.startswith("merger."):
return []
if name.startswith("language_model."):
name = name.replace("language_model.", "")

# skip lm_head.weight if tie_word_embeddings is True
if self.hparams.get("tie_word_embeddings", False):
# Save token_embd for potential duplication as output if tie_word_embeddings is True
if name == "model.embed_tokens.weight":
self._token_embd = data_torch
if name == "lm_head.weight" or name == "model.lm_head.weight":
logger.info("Skipping tied output layer 'lm_head.weight' - will duplicate from token_embd.weight")
return []

# rename e_score_correction_bias tensors
if name.endswith("e_score_correction_bias"):
name = name.replace("e_score_correction_bias", "e_score_correction.bias")
Expand All @@ -7246,7 +7275,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
return []

# process the experts separately
if name.find("mlp.experts") != -1:
if name.find("mlp.experts") != -1 and self.hparams.get("n_routed_experts") is not None:
n_experts = self.hparams["n_routed_experts"]
assert bid is not None

Expand Down Expand Up @@ -7308,7 +7337,10 @@ def prepare_tensors(self):
experts = [k for d in self._experts for k in d.keys()]
if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts}")

if self._token_embd is not None:
logger.info("Model has tie_word_embeddings=True but no lm_head.weight found - adding output.weight from token_embd.weight")
output_name = self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT)
self.gguf_writer.add_tensor(output_name, self._token_embd.numpy())

@ModelBase.register("MiniMaxM2ForCausalLM")
class MiniMaxM2Model(TextModel):
Expand Down Expand Up @@ -10466,7 +10498,46 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter

return []

@ModelBase.register("UtuVLForConditionalGeneration", "UTUVLForCausalLM")
class UtuVLVisionModel(MmprojModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert self.hparams_vision is not None
self.hparams_vision["image_size"] = self.hparams_vision.get("image_size", 560)

def set_gguf_parameters(self):
super().set_gguf_parameters()

self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.UTUVL)
self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams.get("layer_norm_eps", 1e-6))

# Handle activation function
hidden_act = str(self.hparams.get("hidden_act", "gelu_pytorch_tanh")).lower()
if hidden_act in ("gelu", "gelu_pytorch_tanh", "gelu_fast", "gelu_new", "gelu_accurate"):
self.gguf_writer.add_vision_use_gelu(True)
elif hidden_act == "silu":
self.gguf_writer.add_vision_use_silu(True)
else:
raise ValueError(f"Unsupported activation function for UTUVL: {hidden_act}")

self.gguf_writer.add_vision_spatial_merge_size(self.hparams.get("spatial_merge_size", 2))

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused

# Skip language model tensors
skip_prefixes = ('lm_head.', 'model.layers.', 'model.embed_tokens.', 'model.norm.')
if name.startswith(skip_prefixes):
return []

# Try to map the tensor using TensorNameMap (handles vision encoder and projector)
try:
new_name = self.map_tensor_name(name)
return [(new_name, data_torch)]
except ValueError:
# If mapping fails, log warning and skip
logger.warning(f"Cannot map tensor: {name}")
return []
###### CONVERSION LOGIC ######


Expand Down
1 change: 1 addition & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3432,6 +3432,7 @@ class VisionProjectorType:
JANUS_PRO = "janus_pro"
LFM2A = "lfm2a" # audio
GLM4V = "glm4v"
UTUVL = "utuvl"


# Items here are (block size, type size)
Expand Down
12 changes: 12 additions & 0 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -1218,6 +1218,7 @@ class TensorNameMap:
MODEL_TENSOR.V_MMPROJ: (
"multi_modal_projector.linear_{bid}",
"visual.merger.mlp.{bid}", # qwen2vl
"merger.mlp.{bid}",
),

MODEL_TENSOR.V_MMPROJ_FC: (
Expand Down Expand Up @@ -1255,6 +1256,7 @@ class TensorNameMap:
"visual.patch_embed.proj", # qwen2vl
"vision_tower.patch_embed.proj", # kimi-vl
"model.vision.patch_embedding.proj", # cogvlm
"siglip2.vision_model.embeddings.patch_embedding",
),

MODEL_TENSOR.V_ENC_EMBD_NORM: (
Expand Down Expand Up @@ -1288,6 +1290,7 @@ class TensorNameMap:
"vision_encoder.transformer.layers.{bid}.attention.wq", # pixtral
"visual.blocks.{bid}.attn.q", # qwen2vl, generated
"vision_tower.encoder.blocks.{bid}.wq", # kimi-vl, generated
"siglip2.vision_model.encoder.layers.{bid}.self_attn.q_proj", # utuvl
),

MODEL_TENSOR.V_ENC_ATTN_Q_NORM: (
Expand All @@ -1305,6 +1308,7 @@ class TensorNameMap:
"vision_encoder.transformer.layers.{bid}.attention.wk", # pixtral
"visual.blocks.{bid}.attn.k", # qwen2vl, generated
"vision_tower.encoder.blocks.{bid}.wk", # kimi-vl, generated
"siglip2.vision_model.encoder.layers.{bid}.self_attn.k_proj",
),

MODEL_TENSOR.V_ENC_ATTN_K_NORM: (
Expand All @@ -1322,6 +1326,7 @@ class TensorNameMap:
"vision_encoder.transformer.layers.{bid}.attention.wv", # pixtral
"visual.blocks.{bid}.attn.v", # qwen2vl, generated
"vision_tower.encoder.blocks.{bid}.wv", # kimi-vl, generated
"siglip2.vision_model.encoder.layers.{bid}.self_attn.v_proj",
),

MODEL_TENSOR.V_ENC_INPUT_NORM: (
Expand All @@ -1336,6 +1341,7 @@ class TensorNameMap:
"visual.blocks.{bid}.norm1", # qwen2vl
"vision_tower.encoder.blocks.{bid}.norm0", # kimi-vl (norm0/norm1)
"model.vision.transformer.layers.{bid}.input_layernorm", # cogvlm
"siglip2.vision_model.encoder.layers.{bid}.layer_norm1",
),

MODEL_TENSOR.V_ENC_ATTN_O: (
Expand All @@ -1351,6 +1357,7 @@ class TensorNameMap:
"visual.blocks.{bid}.attn.proj", # qwen2vl
"vision_tower.encoder.blocks.{bid}.wo", # kimi-vl
"model.vision.transformer.layers.{bid}.attention.dense", # cogvlm
"siglip2.vision_model.encoder.layers.{bid}.self_attn.out_proj", # utuvl
),

MODEL_TENSOR.V_ENC_POST_ATTN_NORM: (
Expand All @@ -1365,6 +1372,7 @@ class TensorNameMap:
"visual.blocks.{bid}.norm2", # qwen2vl
"vision_tower.encoder.blocks.{bid}.norm1", # kimi-vl (norm0/norm1)
"model.vision.transformer.layers.{bid}.post_attention_layernorm", # cogvlm
"siglip2.vision_model.encoder.layers.{bid}.layer_norm2",
),

MODEL_TENSOR.V_ENC_FFN_UP: (
Expand All @@ -1380,6 +1388,7 @@ class TensorNameMap:
"visual.blocks.{bid}.mlp.linear_fc1", # qwen3vl
"vision_tower.encoder.blocks.{bid}.mlp.fc0", # kimi-vl (fc0/fc1)
"model.vision.transformer.layers.{bid}.mlp.fc1", # cogvlm
"siglip2.vision_model.encoder.layers.{bid}.mlp.fc1",
),

MODEL_TENSOR.V_ENC_FFN_GATE: (
Expand All @@ -1401,6 +1410,7 @@ class TensorNameMap:
"visual.blocks.{bid}.mlp.linear_fc2", # qwen3vl
"vision_tower.encoder.blocks.{bid}.mlp.fc1", # kimi-vl (fc0/fc1)
"model.vision.transformer.layers.{bid}.mlp.fc2", # cogvlm
"siglip2.vision_model.encoder.layers.{bid}.mlp.fc2",
),

MODEL_TENSOR.V_LAYER_SCALE_1: (
Expand All @@ -1427,6 +1437,7 @@ class TensorNameMap:
"visual.merger.ln_q", # qwen2vl
"vision_tower.encoder.final_layernorm", # kimi-vl
"visual.post_layernorm", # glm4v
"siglip2.vision_model.post_layernorm",
),

MODEL_TENSOR.V_MM_POST_NORM: (
Expand All @@ -1443,6 +1454,7 @@ class TensorNameMap:
"multi_modal_projector.pre_norm",
"pre_mm_projector_norm",
"model.vision.linear_proj.norm1", # cogvlm
"merger.ln_q",
),

MODEL_TENSOR.V_MM_SOFT_EMB_NORM: (
Expand Down
11 changes: 11 additions & 0 deletions src/llama-vocab.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,12 @@ struct llm_tokenizer_bpe : llm_tokenizer {
"[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+",
};
break;
case LLAMA_VOCAB_PRE_TYPE_UTU_VL:
regex_exprs = {
"[가-힣ㄱ-ㆎ]+|[!…“”‘’—:;,、-〿︰-﹏]+|[ㄅ-ㄯ]+|[一-龥぀-ゟ゠-ヿ]+",
"[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
};
break;
case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER:
regex_exprs = {
"[\r\n]",
Expand Down Expand Up @@ -1860,6 +1866,11 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
tokenizer_pre == "deepseek-v3") {
pre_type = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM;
clean_spaces = false;
} else if (
tokenizer_pre == "utu-vl") {
pre_type = LLAMA_VOCAB_PRE_TYPE_UTU_VL;
clean_spaces = false;
ignore_merges = true;
} else if (
tokenizer_pre == "falcon") {
pre_type = LLAMA_VOCAB_PRE_TYPE_FALCON;
Expand Down
1 change: 1 addition & 0 deletions src/llama-vocab.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ enum llama_vocab_pre_type {
LLAMA_VOCAB_PRE_TYPE_GRANITE_DOCLING = 40,
LLAMA_VOCAB_PRE_TYPE_MINIMAX_M2 = 41,
LLAMA_VOCAB_PRE_TYPE_AFMOE = 42,
LLAMA_VOCAB_PRE_TYPE_UTU_VL = 43,
};

struct LLM_KV;
Expand Down
37 changes: 23 additions & 14 deletions src/unicode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -964,6 +964,11 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
{ "\\p{P}", unicode_cpt_flags::PUNCTUATION },
{ "\\p{M}", unicode_cpt_flags::ACCENT_MARK },
{ "\\p{S}", unicode_cpt_flags::SYMBOL },
{ "\\p{Lu}", unicode_cpt_flags::LETTER }, // Uppercase letter
{ "\\p{Ll}", unicode_cpt_flags::LETTER }, // Lowercase letter
{ "\\p{Lt}", unicode_cpt_flags::LETTER }, // Titlecase letter
{ "\\p{Lm}", unicode_cpt_flags::LETTER }, // Modifier letter
{ "\\p{Lo}", unicode_cpt_flags::LETTER }, // Other letter
};

static const std::map<int, int> k_ucat_cpt = {
Expand Down Expand Up @@ -1074,22 +1079,26 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
continue;
}

if (regex_expr[i + 0] == '\\' && i + 4 < regex_expr.size() &&
// Match \p{...} Unicode properties of varying lengths
if (regex_expr[i + 0] == '\\' && i + 3 < regex_expr.size() &&
regex_expr[i + 1] == 'p' &&
regex_expr[i + 2] == '{' &&
regex_expr[i + 4] == '}') {
const std::string pat = regex_expr.substr(i, 5);
if (k_ucat_enum.find(pat) != k_ucat_enum.end()) {
if (!inside) {
regex_expr_collapsed += '[';
regex_expr[i + 2] == '{') {
// Find the closing brace
size_t closing_brace = regex_expr.find('}', i + 3);
if (closing_brace != std::string::npos && closing_brace <= i + 10) { // reasonable limit
const std::string pat = regex_expr.substr(i, closing_brace - i + 1);
if (k_ucat_enum.find(pat) != k_ucat_enum.end()) {
if (!inside) {
regex_expr_collapsed += '[';
}
regex_expr_collapsed += k_ucat_cpt.at(k_ucat_enum.at(pat));
regex_expr_collapsed += k_ucat_map.at(k_ucat_enum.at(pat));
if (!inside) {
regex_expr_collapsed += ']';
}
i = closing_brace;
continue;
}
regex_expr_collapsed += k_ucat_cpt.at(k_ucat_enum.at(pat));
regex_expr_collapsed += k_ucat_map.at(k_ucat_enum.at(pat));
if (!inside) {
regex_expr_collapsed += ']';
}
i += 4;
continue;
}
}

Expand Down
1 change: 1 addition & 0 deletions tools/mtmd/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ add_library(mtmd
models/qwen3vl.cpp
models/siglip.cpp
models/whisper-enc.cpp
models/utuvl.cpp
)

set_target_properties(mtmd PROPERTIES
Expand Down
2 changes: 2 additions & 0 deletions tools/mtmd/clip-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ enum projector_type {
PROJECTOR_TYPE_JANUS_PRO,
PROJECTOR_TYPE_LFM2A,
PROJECTOR_TYPE_GLM4V,
PROJECTOR_TYPE_UTUVL,
PROJECTOR_TYPE_UNKNOWN,
};

Expand Down Expand Up @@ -216,6 +217,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
{ PROJECTOR_TYPE_JANUS_PRO, "janus_pro"},
{ PROJECTOR_TYPE_LFM2A, "lfm2a"},
{ PROJECTOR_TYPE_GLM4V, "glm4v"},
{ PROJECTOR_TYPE_UTUVL, "utuvl"},
};

static projector_type clip_projector_type_from_string(const std::string & str) {
Expand Down
Loading