Skip to content
153 changes: 151 additions & 2 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,11 @@
return ()

def prepare_tensors(self):
max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,")
# Handle empty tensor_map for models with block_count=0 (like MobileNetV5)
if self.tensor_map.mapping:
max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,")
else:
max_name_len = len("vision_encoder.weight,") # Default reasonable length

for name, data_torch in chain(self.generate_extra_tensors(), self.get_tensors()):
# we don't need these
Expand Down Expand Up @@ -5959,8 +5963,112 @@

return [] # skip other tensors

@ModelBase.register("Gemma3nForConditionalGeneration", "Gemma3nVisionModel")
class Gemma3nVisionModel(MmprojModel):
"""Vision encoder converter for Gemma3n using MobileNetV5 architecture"""
n_block_keys = []

def find_hparam(self, keys: list[str], optional: bool = False) -> Any:
"""Override to return 0 for block count since MobileNetV5 is CNN-based"""
if not keys: # If n_block_keys is empty (our case)
return 0
# Otherwise use parent implementation
return super().find_hparam(keys, optional)

def __init__(self, *args, **kwargs):
# Parent init will call find_hparam which now returns 0 for empty keys

Check failure on line 5979 in convert_hf_to_gguf.py

View workflow job for this annotation

GitHub Actions / pyright type-check

Method "find_hparam" overrides class "ModelBase" in an incompatible manner   Parameter 2 type mismatch: base parameter is type "Iterable[str]", override parameter is type "list[str]"     "Iterable[str]" is not assignable to "list[str]" (reportIncompatibleMethodOverride)
super().__init__(*args, **kwargs)

def find_vparam(self, keys: list[str], optional: bool = False) -> Any:
"""Override to provide hardcoded MobileNetV5 parameters that aren't in config"""
# Handle empty keys list (n_block_keys) - return 0 for CNN architecture
if not keys:
return 0

if "intermediate_size" in keys:
# Typical expansion is 4x the embedding dimension
hidden_size = self.hparams_vision.get("hidden_size", 2048)

Check failure on line 5990 in convert_hf_to_gguf.py

View workflow job for this annotation

GitHub Actions / pyright type-check

Method "find_vparam" overrides class "MmprojModel" in an incompatible manner   Parameter 2 type mismatch: base parameter is type "Iterable[str]", override parameter is type "list[str]"     "Iterable[str]" is not assignable to "list[str]" (reportIncompatibleMethodOverride)
return hidden_size * 4

if "num_attention_heads" in keys or "num_heads" in keys:
# Multi-Query Attention with 8 heads
return 8

# For other parameters, use parent implementation
return super().find_vparam(keys, optional)

Check failure on line 5998 in convert_hf_to_gguf.py

View workflow job for this annotation

GitHub Actions / pyright type-check

"get" is not a known attribute of "None" (reportOptionalMemberAccess)

def set_gguf_parameters(self):
# MobileNetV5 does not use normalisation at all
self.preprocessor_config["image_mean"] = [0.0 , 0.0 , 0.0 ]
self.preprocessor_config["image_std"] = [1.0 , 1.0 , 1.0 ]
self.hparams_vision["image_size"] = self.preprocessor_config.get(
"size", {"height": 768, "width": 768}
)["height"]

# Image sequence length (256 tokens = 16x16 for Gemma3n)
image_seq_length = self.preprocessor_config.get("image_seq_length", 256)
image_size = self.hparams_vision["image_size"]
self.hparams_vision["patch_size"] = image_size // image_seq_length

Check failure on line 6012 in convert_hf_to_gguf.py

View workflow job for this annotation

GitHub Actions / pyright type-check

Object of type "None" is not subscriptable (reportOptionalSubscript)
Comment on lines +6008 to +6012
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# Read the convert_hf_to_gguf.py file around lines 6008-6012
find . -name "convert_hf_to_gguf.py" -type f | head -1 | xargs -I {} sh -c 'wc -l {} && sed -n "6000,6020p" {} | cat -n'

Repository: ngxson/llama.cpp

Length of output: 1253


Patch size computation for Gemma3n MobileNetV5 is off by a factor of √image_seq_length

The code currently computes patch_size = 768 // 256 = 3, which results in n_per_side = 256 and tokens = 65,536—wildly inconsistent with the 16×16 grid (256 tokens) described in the comment.

The issue is treating image_seq_length (total token count) as if it were patch size. Gemma 3n uses MobileNet v5 with a default image resolution of 768×768 pixels, and image_seq_length defaults to 256.

The correct approach derives patch size from the per-side token count:

n_per_side = int(image_seq_length ** 0.5)  # 16 for 256 tokens
patch_size = image_size // n_per_side      # 768 // 16 = 48

This yields n_per_side = 16 and tokens = 256, matching the expected grid layout and HF processor config.

🧰 Tools
🪛 GitHub Actions: flake8 Lint

[error] 6010-6010: flake8: E202 whitespace before ']'. Command: /opt/hostedtoolcache/Python/3.11.14/x64/bin/flake8


[error] 6011-6011: flake8: E202 whitespace before ']'. Command: /opt/hostedtoolcache/Python/3.11.14/x64/bin/flake8

🤖 Prompt for AI Agents
In convert_hf_to_gguf.py around lines 6008 to 6012, the code treats
image_seq_length as a per-side patch count which yields an incorrect tiny
patch_size; instead compute n_per_side = int(image_seq_length ** 0.5) to get the
number of patches per side (e.g. 16 for 256 tokens) and then set patch_size =
image_size // n_per_side (e.g. 768 // 16 = 48), and assign that patch_size back
into self.hparams_vision so the tokens and grid (n_per_side²) match the HF
processor config.

# Now call parent which will use the corrected values
super().set_gguf_parameters()

# Set projector type to GEMMA3N
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.GEMMA3N)

Check failure on line 6018 in convert_hf_to_gguf.py

View workflow job for this annotation

GitHub Actions / pyright type-check

Object of type "None" is not subscriptable (reportOptionalSubscript)
# MobileNetV5 specific parameters

Check failure on line 6019 in convert_hf_to_gguf.py

View workflow job for this annotation

GitHub Actions / pyright type-check

Object of type "None" is not subscriptable (reportOptionalSubscript)
self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams.get("layer_norm_eps", 1e-6))

def tensor_force_quant(self, name, new_name, bid, n_dims):
# Force quantization settings for specific tensor types
if "input_projection" in name or "input_proj" in name:
return gguf.GGMLQuantizationType.F16
if ".embeddings." in name or "stem" in name:
return gguf.GGMLQuantizationType.F32
return super().tensor_force_quant(name, new_name, bid, n_dims)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused

# Gemma3n uses different prefixes than other models:
# - model.embed_vision.* for projection layers
# - model.vision_tower.* for vision encoder
# Skip non-vision tensors
if not (name.startswith("model.embed_vision.") or
name.startswith("model.vision_tower.")):
return []

# Strip "model." prefix to match expected llama.cpp format
if name.startswith("model."):
name = name[6:] # Remove "model." prefix

# Process MobileNetV5 and projection tensors
name = name.replace("_weight", ".weight")

# Rename embed_vision to match our C++ implementation expectations
name = name.replace("embed_vision.", "")

# Rename vision_tower.timm_model to vision_tower for cleaner naming
name = name.replace("vision_tower.timm_model.", "vision_tower.")

# Handle normalization layer naming
name = name.replace("hard_embedding_norm", "hard_emb_norm")
name = name.replace("soft_embedding_norm", "soft_emb_norm")

# Gemma3n uses Gemma3p5RMSNorm which has scale_shift=0, so no correction needed
# Unlike Gemma3 which uses Gemma3RMSNorm with scale_shift=1
if "soft_emb_norm.weight" in name:
# No correction needed for Gemma3n
pass

if name.startswith("vision_tower."):
tensor_suffix = name[13:]
return [(f"v.enc.{tensor_suffix}", data_torch)]
else:
return [(self.map_tensor_name(name), data_torch)]


@ModelBase.register("Gemma3nForConditionalGeneration")
@ModelBase.register("Gemma3nForCausalLM", "Gemma3nForConditionalGeneration")
class Gemma3NModel(Gemma3Model):
model_arch = gguf.MODEL_ARCH.GEMMA3N
norm_shift = 0.0 # same value with Gemma3p5RMSNorm scale_shift on python code
Expand All @@ -5983,8 +6091,25 @@
]

def set_vocab(self):
# For Gemma3n multimodal models, we need the FULL vocab_size (262400)
# which includes special tokens from 262144-262399 for vision/audio.
# The vocab_size_per_layer_input (262144) is only the embedding size per layer.
# Temporarily override the hparams lookup order to prioritize vocab_size.

# Store original vocab_size_per_layer_input if it exists
vocab_size_per_layer_input = self.hparams.get("vocab_size_per_layer_input")

# Temporarily remove vocab_size_per_layer_input to force using vocab_size
if vocab_size_per_layer_input is not None:
del self.hparams["vocab_size_per_layer_input"]

# Call parent set_vocab which will now use vocab_size (262400)
super().set_vocab()

# Restore vocab_size_per_layer_input for later use
if vocab_size_per_layer_input is not None:
self.hparams["vocab_size_per_layer_input"] = vocab_size_per_layer_input

def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_altup_active_idx(self.hparams["altup_active_idx"])
Expand Down Expand Up @@ -6020,8 +6145,32 @@
if "language_model." not in name:
return [] # skip non-language model tensors

# Pad token embeddings for vision/audio special tokens (262144-262399)
if "embed_tokens.weight" in name or "embed_tokens_per_layer" in name:
# Move to CPU to avoid meta device issues during padding
data_torch = data_torch.to(device="cpu")

vocab_size = self.hparams.get("vocab_size", 262400)
current_size = data_torch.shape[0] # First dimension is vocab_size

if current_size < vocab_size:
# Pad with zeros for vision/audio tokens (they get embeddings from vision tower)
padding_size = vocab_size - current_size
tensor_type = "per-layer embeddings" if "per_layer" in name else "token embeddings"
logger.info(f"Padding {tensor_type} shape {list(data_torch.shape)} from {current_size} to {vocab_size} (adding {padding_size} vision/audio token slots)")

# Create padding with zeros (vision tokens won't use these embeddings)
padding = torch.zeros((padding_size, data_torch.shape[1]), dtype=data_torch.dtype, device=data_torch.device)
data_torch = torch.cat([data_torch, padding], dim=0)

# Continue with normal processing
name = name.replace("language_model.", "")
return [(self.map_tensor_name(name), data_torch)]

if "altup_unembed_projections" in name:
data_torch = data_torch.to(device="cpu")
# altup_unembed matrices are [hidden_size, hidden_size], NOT vocab-based
# They should NOT be padded
if ".0." in name:
self._altup_unembd[0] = data_torch
elif ".1." in name:
Expand Down
11 changes: 11 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,7 @@ class VISION_PROJECTOR_TYPE(IntEnum):
RESAMPLER = auto()
GLM_EDGE = auto()
MERGER = auto()
GEMMA3N = auto()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Add GEMMA3N mapping to VISION_PROJECTOR_TYPE_NAMES.

The GEMMA3N entry was added to VISION_PROJECTOR_TYPE enum but is missing from the VISION_PROJECTOR_TYPE_NAMES dictionary at lines 850-858. This mapping is used to convert the enum value to its string representation.

🔎 Proposed fix

Add the mapping to the VISION_PROJECTOR_TYPE_NAMES dictionary:

 VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
     VISION_PROJECTOR_TYPE.MLP:       "mlp",
     VISION_PROJECTOR_TYPE.LDP:       "ldp",
     VISION_PROJECTOR_TYPE.LDPV2:     "ldpv2",
     VISION_PROJECTOR_TYPE.RESAMPLER: "resampler",
     VISION_PROJECTOR_TYPE.GLM_EDGE:  "adapter",
     VISION_PROJECTOR_TYPE.MERGER:    "qwen2vl_merger",
     VISION_PROJECTOR_TYPE.GEMMA3:    "gemma3",
+    VISION_PROJECTOR_TYPE.GEMMA3N:   "gemma3n",
+    VISION_PROJECTOR_TYPE.QWEN3VL:   "qwen3vl_merger",
+    VISION_PROJECTOR_TYPE.COGVLM:    "cogvlm",
 }

Note: QWEN3VL and COGVLM are also missing from this dictionary.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
GEMMA3N = auto()
VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
VISION_PROJECTOR_TYPE.MLP: "mlp",
VISION_PROJECTOR_TYPE.LDP: "ldp",
VISION_PROJECTOR_TYPE.LDPV2: "ldpv2",
VISION_PROJECTOR_TYPE.RESAMPLER: "resampler",
VISION_PROJECTOR_TYPE.GLM_EDGE: "adapter",
VISION_PROJECTOR_TYPE.MERGER: "qwen2vl_merger",
VISION_PROJECTOR_TYPE.GEMMA3: "gemma3",
VISION_PROJECTOR_TYPE.GEMMA3N: "gemma3n",
VISION_PROJECTOR_TYPE.QWEN3VL: "qwen3vl_merger",
VISION_PROJECTOR_TYPE.COGVLM: "cogvlm",
}
🤖 Prompt for AI Agents
In gguf-py/gguf/constants.py around line 459 (GEMMA3N added to
VISION_PROJECTOR_TYPE) and the VISION_PROJECTOR_TYPE_NAMES dictionary at lines
~850-858, the GEMMA3N enum value is not mapped to its string representation; add
an entry mapping VISION_PROJECTOR_TYPE.GEMMA3N to "GEMMA3N" in that dictionary.
Also add missing mappings for QWEN3VL and COGVLM (e.g.,
VISION_PROJECTOR_TYPE.QWEN3VL -> "QWEN3VL" and VISION_PROJECTOR_TYPE.COGVLM ->
"COGVLM") so all enum members have corresponding string names. Ensure
formatting/commas match the surrounding dictionary entries.

GEMMA3 = auto()
QWEN3VL = auto()
COGVLM = auto()
Expand Down Expand Up @@ -666,6 +667,9 @@ class MODEL_TENSOR(IntEnum):
V_MM_INP_NORM = auto()
V_MM_INP_PROJ = auto() # gemma3
V_MM_SOFT_EMB_NORM = auto() # gemma3
V_MM_EMBEDDING = auto() # gemma3n
V_MM_HARD_EMB_NORM = auto() # gemma3n
V_MM_POST_PROJ_NORM = auto() # gemma3n
V_RESMPL_POS_EMBD_K = auto() # minicpmv
V_RESMPL_ATTN_Q = auto() # minicpmv
V_RESMPL_ATTN_K = auto() # minicpmv
Expand Down Expand Up @@ -1058,6 +1062,9 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.V_MM_INP_PROJ: "mm.input_projection",
MODEL_TENSOR.V_MM_INP_NORM: "mm.input_norm",
MODEL_TENSOR.V_MM_SOFT_EMB_NORM: "mm.soft_emb_norm",
MODEL_TENSOR.V_MM_EMBEDDING: "mm.embedding",
MODEL_TENSOR.V_MM_HARD_EMB_NORM: "mm.hard_emb_norm",
MODEL_TENSOR.V_MM_POST_PROJ_NORM: "mm.post_proj_norm",
MODEL_TENSOR.V_RESMPL_POS_EMBD_K: "resampler.pos_embd_k",
MODEL_TENSOR.V_RESMPL_ATTN_Q: "resampler.attn.q",
MODEL_TENSOR.V_RESMPL_ATTN_K: "resampler.attn.k",
Expand Down Expand Up @@ -1156,6 +1163,9 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.V_MM_INP_PROJ,
MODEL_TENSOR.V_MM_INP_NORM,
MODEL_TENSOR.V_MM_SOFT_EMB_NORM,
MODEL_TENSOR.V_MM_EMBEDDING,
MODEL_TENSOR.V_MM_HARD_EMB_NORM,
MODEL_TENSOR.V_MM_POST_PROJ_NORM,
MODEL_TENSOR.V_RESMPL_POS_EMBD_K,
MODEL_TENSOR.V_RESMPL_ATTN_Q,
MODEL_TENSOR.V_RESMPL_ATTN_K,
Expand Down Expand Up @@ -3397,6 +3407,7 @@ def get_type(val: Any) -> GGUFValueType:

class VisionProjectorType:
GEMMA3 = "gemma3"
GEMMA3N = "gemma3n"
IDEFICS3 = "idefics3"
PIXTRAL = "pixtral"
LLAMA4 = "llama4"
Expand Down
21 changes: 21 additions & 0 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,27 @@ class TensorNameMap:
MODEL_TENSOR.CONV1D: (
"backbone.embed", # roberta
),

# Vision multimodal projector tensors (non-block) for gemma3n
MODEL_TENSOR.V_MM_INP_PROJ: (
"embedding_projection", # gemma3n
),

MODEL_TENSOR.V_MM_SOFT_EMB_NORM: (
"soft_emb_norm", # gemma3n
),

MODEL_TENSOR.V_MM_EMBEDDING: (
"embedding", # gemma3n
),

MODEL_TENSOR.V_MM_HARD_EMB_NORM: (
"hard_emb_norm", # gemma3n
),

MODEL_TENSOR.V_MM_POST_PROJ_NORM: (
"post_proj_norm", # gemma3n
),
}

block_mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = {
Expand Down
46 changes: 45 additions & 1 deletion src/models/gemma3n-iswa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,51 @@ ggml_tensor * llm_build_gemma3n_iswa::get_per_layer_inputs() {
inp_per_layer = ggml_scale(ctx0, inp_per_layer, sqrtf((float) n_embd_altup));
cb(inp_per_layer, "inp_per_layer_selected", -1);
} else {
GGML_ABORT("TODO: support embd input");
// For embedding inputs (e.g., from vision encoder)
// Vision tokens should use the padding token (ID=0) embedding
// from tok_embd_per_layer, NOT project the vision embeddings.
// The projection happens later in project_per_layer_inputs().
// This matches PyTorch behavior:
// per_layer_inputs_tokens = torch.where(mask, input_ids, torch.zeros_like(input_ids))
// per_layer_inputs = EmbedPerLayer(per_layer_inputs_tokens) # Uses padding (0) for vision

inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens);
ggml_set_input(inp->embd);

// tok_embd_per_layer shape: [embd_size, vocab_size] where embd_size = n_embd_altup * n_layer
const int64_t embd_size = model.tok_embd_per_layer->ne[0]; // n_embd_altup * n_layer

// Create zeros tensor [embd_size, n_tokens] by projecting vision embeddings and multiplying by 0
// First, project inp->embd [n_embd, n_tokens] to per-layer space [embd_size, n_tokens]
ggml_tensor * zeros_per_layer = ggml_mul_mat(ctx0, model.per_layer_model_proj, inp->embd);
zeros_per_layer = ggml_scale(ctx0, zeros_per_layer, 0.0f); // Multiply by 0 to get zeros
ggml_set_name(zeros_per_layer, "zeros_per_layer");

// Extract column 0 (padding token's embedding) as a vector: [embd_size]
// Note: tok_embd_per_layer is quantized (q8_0), so the view is also q8_0
ggml_tensor * padding_embd_vec_q = ggml_view_1d(ctx0, model.tok_embd_per_layer,
embd_size, // number of elements
0); // offset (column 0)
ggml_set_name(padding_embd_vec_q, "padding_token_emb_q8");

// Dequantize to f32 using ggml_cpy
ggml_tensor * padding_embd_vec_f32 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, embd_size);
ggml_tensor * padding_embd_vec = ggml_cpy(ctx0, padding_embd_vec_q, padding_embd_vec_f32);
ggml_set_name(padding_embd_vec, "padding_token_emb_f32");

// Reshape to [embd_size, 1] for broadcasting
ggml_tensor * padding_embd_col = ggml_reshape_2d(ctx0, padding_embd_vec, embd_size, 1);

// Add: zeros [embd_size, n_tokens] + padding [embd_size, 1] = broadcasted padding [embd_size, n_tokens]
ggml_tensor * inp_per_layer_flat = ggml_add(ctx0, zeros_per_layer, padding_embd_col);
ggml_set_name(inp_per_layer_flat, "inp_per_layer_broadcasted");

// Reshape to [n_embd_altup, n_layer, n_tokens] for per-layer processing
inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer_flat, n_embd_altup, n_layer, n_tokens);

// Apply same scaling as text tokens
// inp_per_layer = ggml_scale(ctx0, inp_per_layer, sqrtf((float) n_embd_altup));
cb(inp_per_layer, "inp_per_layer_vision", -1);
}
res->add_input(std::move(inp));
return inp_per_layer;
Expand Down
1 change: 1 addition & 0 deletions tools/mtmd/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ add_library(mtmd
models/qwen3vl.cpp
models/siglip.cpp
models/whisper-enc.cpp
models/mobilenetv5.cpp
)

set_target_properties(mtmd PROPERTIES
Expand Down
43 changes: 43 additions & 0 deletions tools/mtmd/clip-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,47 @@
#define TN_CONV_PW1 "%s.blk.%d.conv_pw1.%s"
#define TN_CONV_PW2 "%s.blk.%d.conv_pw2.%s"

// mobilenetv5 (gemma3n) definitions
#define TN_MNV5_STEM_CONV "v.enc.conv_stem.conv.weight"
#define TN_MNV5_STEM_BIAS "v.enc.conv_stem.conv.bias"
#define TN_MNV5_STEM_BN "v.enc.conv_stem.bn.weight"

// Stage 0 Block (Edge Residual)
#define TN_MNV5_BLK_S0_EXP_W "v.enc.blocks.%d.%d.conv_exp.weight"
#define TN_MNV5_BLK_S0_BN1_W "v.enc.blocks.%d.%d.bn1.weight"
#define TN_MNV5_BLK_S0_PWL_W "v.enc.blocks.%d.%d.conv_pwl.weight"
#define TN_MNV5_BLK_S0_BN2_W "v.enc.blocks.%d.%d.bn2.weight"

// Stage 1+ Block (Universal Inverted Residual)
#define TN_MNV5_BLK_DW_START_W "v.enc.blocks.%d.%d.dw_start.conv.weight"
#define TN_MNV5_BLK_DW_START_BN "v.enc.blocks.%d.%d.dw_start.bn.weight"
#define TN_MNV5_BLK_DW_MID_W "v.enc.blocks.%d.%d.dw_mid.conv.weight"
#define TN_MNV5_BLK_DW_MID_BN "v.enc.blocks.%d.%d.dw_mid.bn.weight"
#define TN_MNV5_BLK_PW_EXP_W "v.enc.blocks.%d.%d.pw_exp.conv.weight"
#define TN_MNV5_BLK_PW_EXP_BN "v.enc.blocks.%d.%d.pw_exp.bn.weight"
#define TN_MNV5_BLK_PW_PROJ_W "v.enc.blocks.%d.%d.pw_proj.conv.weight"
#define TN_MNV5_BLK_PW_PROJ_BN "v.enc.blocks.%d.%d.pw_proj.bn.weight"
#define TN_MNV5_BLK_LAYER_SCALE "v.enc.blocks.%d.%d.layer_scale.gamma"

// Attention Components
#define TN_MNV5_ATTN_Q_W "v.enc.blocks.%d.%d.attn.query.proj.weight"
#define TN_MNV5_ATTN_K_W "v.enc.blocks.%d.%d.attn.key.proj.weight"
#define TN_MNV5_ATTN_V_W "v.enc.blocks.%d.%d.attn.value.proj.weight"
#define TN_MNV5_ATTN_O_W "v.enc.blocks.%d.%d.attn.output.proj.weight"
#define TN_MNV5_ATTN_K_DW "v.enc.blocks.%d.%d.attn.key.down_conv.weight"
#define TN_MNV5_ATTN_K_NORM "v.enc.blocks.%d.%d.attn.key.norm.weight"
#define TN_MNV5_ATTN_V_DW "v.enc.blocks.%d.%d.attn.value.down_conv.weight"
#define TN_MNV5_ATTN_V_NORM "v.enc.blocks.%d.%d.attn.value.norm.weight"
#define TN_MNV5_ATTN_NORM "v.enc.blocks.%d.%d.norm.weight" // Block norm used in attn blocks

// MSFA
#define TN_MNV5_MSFA_FFN_EXP_W "v.enc.msfa.ffn.pw_exp.conv.weight"
#define TN_MNV5_MSFA_FFN_EXP_BN "v.enc.msfa.ffn.pw_exp.bn.weight"
#define TN_MNV5_MSFA_FFN_PROJ_W "v.enc.msfa.ffn.pw_proj.conv.weight"
#define TN_MNV5_MSFA_FFN_PROJ_BN "v.enc.msfa.ffn.pw_proj.bn.weight"
#define TN_MNV5_MSFA_NORM "v.enc.msfa.norm.weight"


// align x to upper multiple of n
#define CLIP_ALIGN(x, n) ((((x) + (n) - 1) / (n)) * (n))

Expand All @@ -170,6 +211,7 @@ enum projector_type {
PROJECTOR_TYPE_QWEN2VL,
PROJECTOR_TYPE_QWEN3VL,
PROJECTOR_TYPE_GEMMA3,
PROJECTOR_TYPE_GEMMA3N,
PROJECTOR_TYPE_IDEFICS3,
PROJECTOR_TYPE_PIXTRAL,
PROJECTOR_TYPE_QWEN25VL,
Expand Down Expand Up @@ -200,6 +242,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
{ PROJECTOR_TYPE_QWEN25VL, "qwen2.5vl_merger"},
{ PROJECTOR_TYPE_QWEN3VL, "qwen3vl_merger"},
{ PROJECTOR_TYPE_GEMMA3, "gemma3"},
{ PROJECTOR_TYPE_GEMMA3N, "gemma3n"},
{ PROJECTOR_TYPE_IDEFICS3, "idefics3"},
{ PROJECTOR_TYPE_PIXTRAL, "pixtral"},
{ PROJECTOR_TYPE_ULTRAVOX, "ultravox"},
Expand Down
Loading
Loading