diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 432be599469..55e82fe9128 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -520,7 +520,11 @@ def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: return () def prepare_tensors(self): - max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,") + # Handle empty tensor_map for models with block_count=0 (like MobileNetV5) + if self.tensor_map.mapping: + max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,") + else: + max_name_len = len("vision_encoder.weight,") # Default reasonable length for name, data_torch in chain(self.generate_extra_tensors(), self.get_tensors()): # we don't need these @@ -5959,8 +5963,112 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return [] # skip other tensors +@ModelBase.register("Gemma3nForConditionalGeneration", "Gemma3nVisionModel") +class Gemma3nVisionModel(MmprojModel): + """Vision encoder converter for Gemma3n using MobileNetV5 architecture""" + n_block_keys = [] + + def find_hparam(self, keys: list[str], optional: bool = False) -> Any: + """Override to return 0 for block count since MobileNetV5 is CNN-based""" + if not keys: # If n_block_keys is empty (our case) + return 0 + # Otherwise use parent implementation + return super().find_hparam(keys, optional) + + def __init__(self, *args, **kwargs): + # Parent init will call find_hparam which now returns 0 for empty keys + super().__init__(*args, **kwargs) + + def find_vparam(self, keys: list[str], optional: bool = False) -> Any: + """Override to provide hardcoded MobileNetV5 parameters that aren't in config""" + # Handle empty keys list (n_block_keys) - return 0 for CNN architecture + if not keys: + return 0 + + if "intermediate_size" in keys: + # Typical expansion is 4x the embedding dimension + hidden_size = self.hparams_vision.get("hidden_size", 2048) + return hidden_size * 4 + + if "num_attention_heads" in keys or "num_heads" in keys: + # Multi-Query Attention with 8 heads + return 8 + + # For other parameters, use parent implementation + return super().find_vparam(keys, optional) + + def set_gguf_parameters(self): + # MobileNetV5 does not use normalisation at all + self.preprocessor_config["image_mean"] = [0.0 , 0.0 , 0.0 ] + self.preprocessor_config["image_std"] = [1.0 , 1.0 , 1.0 ] + self.hparams_vision["image_size"] = self.preprocessor_config.get( + "size", {"height": 768, "width": 768} + )["height"] + + # Image sequence length (256 tokens = 16x16 for Gemma3n) + image_seq_length = self.preprocessor_config.get("image_seq_length", 256) + image_size = self.hparams_vision["image_size"] + self.hparams_vision["patch_size"] = image_size // image_seq_length + + # Now call parent which will use the corrected values + super().set_gguf_parameters() + + # Set projector type to GEMMA3N + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.GEMMA3N) + + # MobileNetV5 specific parameters + self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams.get("layer_norm_eps", 1e-6)) + + def tensor_force_quant(self, name, new_name, bid, n_dims): + # Force quantization settings for specific tensor types + if "input_projection" in name or "input_proj" in name: + return gguf.GGMLQuantizationType.F16 + if ".embeddings." in name or "stem" in name: + return gguf.GGMLQuantizationType.F32 + return super().tensor_force_quant(name, new_name, bid, n_dims) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + # Gemma3n uses different prefixes than other models: + # - model.embed_vision.* for projection layers + # - model.vision_tower.* for vision encoder + # Skip non-vision tensors + if not (name.startswith("model.embed_vision.") or + name.startswith("model.vision_tower.")): + return [] + + # Strip "model." prefix to match expected llama.cpp format + if name.startswith("model."): + name = name[6:] # Remove "model." prefix + + # Process MobileNetV5 and projection tensors + name = name.replace("_weight", ".weight") + + # Rename embed_vision to match our C++ implementation expectations + name = name.replace("embed_vision.", "") + + # Rename vision_tower.timm_model to vision_tower for cleaner naming + name = name.replace("vision_tower.timm_model.", "vision_tower.") + + # Handle normalization layer naming + name = name.replace("hard_embedding_norm", "hard_emb_norm") + name = name.replace("soft_embedding_norm", "soft_emb_norm") + + # Gemma3n uses Gemma3p5RMSNorm which has scale_shift=0, so no correction needed + # Unlike Gemma3 which uses Gemma3RMSNorm with scale_shift=1 + if "soft_emb_norm.weight" in name: + # No correction needed for Gemma3n + pass + + if name.startswith("vision_tower."): + tensor_suffix = name[13:] + return [(f"v.enc.{tensor_suffix}", data_torch)] + else: + return [(self.map_tensor_name(name), data_torch)] + -@ModelBase.register("Gemma3nForConditionalGeneration") +@ModelBase.register("Gemma3nForCausalLM", "Gemma3nForConditionalGeneration") class Gemma3NModel(Gemma3Model): model_arch = gguf.MODEL_ARCH.GEMMA3N norm_shift = 0.0 # same value with Gemma3p5RMSNorm scale_shift on python code @@ -5983,8 +6091,25 @@ def __init__(self, *args, **kwargs): ] def set_vocab(self): + # For Gemma3n multimodal models, we need the FULL vocab_size (262400) + # which includes special tokens from 262144-262399 for vision/audio. + # The vocab_size_per_layer_input (262144) is only the embedding size per layer. + # Temporarily override the hparams lookup order to prioritize vocab_size. + + # Store original vocab_size_per_layer_input if it exists + vocab_size_per_layer_input = self.hparams.get("vocab_size_per_layer_input") + + # Temporarily remove vocab_size_per_layer_input to force using vocab_size + if vocab_size_per_layer_input is not None: + del self.hparams["vocab_size_per_layer_input"] + + # Call parent set_vocab which will now use vocab_size (262400) super().set_vocab() + # Restore vocab_size_per_layer_input for later use + if vocab_size_per_layer_input is not None: + self.hparams["vocab_size_per_layer_input"] = vocab_size_per_layer_input + def set_gguf_parameters(self): super().set_gguf_parameters() self.gguf_writer.add_altup_active_idx(self.hparams["altup_active_idx"]) @@ -6020,8 +6145,32 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if "language_model." not in name: return [] # skip non-language model tensors + # Pad token embeddings for vision/audio special tokens (262144-262399) + if "embed_tokens.weight" in name or "embed_tokens_per_layer" in name: + # Move to CPU to avoid meta device issues during padding + data_torch = data_torch.to(device="cpu") + + vocab_size = self.hparams.get("vocab_size", 262400) + current_size = data_torch.shape[0] # First dimension is vocab_size + + if current_size < vocab_size: + # Pad with zeros for vision/audio tokens (they get embeddings from vision tower) + padding_size = vocab_size - current_size + tensor_type = "per-layer embeddings" if "per_layer" in name else "token embeddings" + logger.info(f"Padding {tensor_type} shape {list(data_torch.shape)} from {current_size} to {vocab_size} (adding {padding_size} vision/audio token slots)") + + # Create padding with zeros (vision tokens won't use these embeddings) + padding = torch.zeros((padding_size, data_torch.shape[1]), dtype=data_torch.dtype, device=data_torch.device) + data_torch = torch.cat([data_torch, padding], dim=0) + + # Continue with normal processing + name = name.replace("language_model.", "") + return [(self.map_tensor_name(name), data_torch)] + if "altup_unembed_projections" in name: data_torch = data_torch.to(device="cpu") + # altup_unembed matrices are [hidden_size, hidden_size], NOT vocab-based + # They should NOT be padded if ".0." in name: self._altup_unembd[0] = data_torch elif ".1." in name: diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index cab8f2901ae..869a8582b12 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -456,6 +456,7 @@ class VISION_PROJECTOR_TYPE(IntEnum): RESAMPLER = auto() GLM_EDGE = auto() MERGER = auto() + GEMMA3N = auto() GEMMA3 = auto() QWEN3VL = auto() COGVLM = auto() @@ -666,6 +667,9 @@ class MODEL_TENSOR(IntEnum): V_MM_INP_NORM = auto() V_MM_INP_PROJ = auto() # gemma3 V_MM_SOFT_EMB_NORM = auto() # gemma3 + V_MM_EMBEDDING = auto() # gemma3n + V_MM_HARD_EMB_NORM = auto() # gemma3n + V_MM_POST_PROJ_NORM = auto() # gemma3n V_RESMPL_POS_EMBD_K = auto() # minicpmv V_RESMPL_ATTN_Q = auto() # minicpmv V_RESMPL_ATTN_K = auto() # minicpmv @@ -1058,6 +1062,9 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.V_MM_INP_PROJ: "mm.input_projection", MODEL_TENSOR.V_MM_INP_NORM: "mm.input_norm", MODEL_TENSOR.V_MM_SOFT_EMB_NORM: "mm.soft_emb_norm", + MODEL_TENSOR.V_MM_EMBEDDING: "mm.embedding", + MODEL_TENSOR.V_MM_HARD_EMB_NORM: "mm.hard_emb_norm", + MODEL_TENSOR.V_MM_POST_PROJ_NORM: "mm.post_proj_norm", MODEL_TENSOR.V_RESMPL_POS_EMBD_K: "resampler.pos_embd_k", MODEL_TENSOR.V_RESMPL_ATTN_Q: "resampler.attn.q", MODEL_TENSOR.V_RESMPL_ATTN_K: "resampler.attn.k", @@ -1156,6 +1163,9 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.V_MM_INP_PROJ, MODEL_TENSOR.V_MM_INP_NORM, MODEL_TENSOR.V_MM_SOFT_EMB_NORM, + MODEL_TENSOR.V_MM_EMBEDDING, + MODEL_TENSOR.V_MM_HARD_EMB_NORM, + MODEL_TENSOR.V_MM_POST_PROJ_NORM, MODEL_TENSOR.V_RESMPL_POS_EMBD_K, MODEL_TENSOR.V_RESMPL_ATTN_Q, MODEL_TENSOR.V_RESMPL_ATTN_K, @@ -3397,6 +3407,7 @@ def get_type(val: Any) -> GGUFValueType: class VisionProjectorType: GEMMA3 = "gemma3" + GEMMA3N = "gemma3n" IDEFICS3 = "idefics3" PIXTRAL = "pixtral" LLAMA4 = "llama4" diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 301aafa9102..3e1cf8a136f 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -119,6 +119,27 @@ class TensorNameMap: MODEL_TENSOR.CONV1D: ( "backbone.embed", # roberta ), + + # Vision multimodal projector tensors (non-block) for gemma3n + MODEL_TENSOR.V_MM_INP_PROJ: ( + "embedding_projection", # gemma3n + ), + + MODEL_TENSOR.V_MM_SOFT_EMB_NORM: ( + "soft_emb_norm", # gemma3n + ), + + MODEL_TENSOR.V_MM_EMBEDDING: ( + "embedding", # gemma3n + ), + + MODEL_TENSOR.V_MM_HARD_EMB_NORM: ( + "hard_emb_norm", # gemma3n + ), + + MODEL_TENSOR.V_MM_POST_PROJ_NORM: ( + "post_proj_norm", # gemma3n + ), } block_mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = { diff --git a/src/models/gemma3n-iswa.cpp b/src/models/gemma3n-iswa.cpp index a0bdd6a15a1..e172b9a79f8 100644 --- a/src/models/gemma3n-iswa.cpp +++ b/src/models/gemma3n-iswa.cpp @@ -259,7 +259,51 @@ ggml_tensor * llm_build_gemma3n_iswa::get_per_layer_inputs() { inp_per_layer = ggml_scale(ctx0, inp_per_layer, sqrtf((float) n_embd_altup)); cb(inp_per_layer, "inp_per_layer_selected", -1); } else { - GGML_ABORT("TODO: support embd input"); + // For embedding inputs (e.g., from vision encoder) + // Vision tokens should use the padding token (ID=0) embedding + // from tok_embd_per_layer, NOT project the vision embeddings. + // The projection happens later in project_per_layer_inputs(). + // This matches PyTorch behavior: + // per_layer_inputs_tokens = torch.where(mask, input_ids, torch.zeros_like(input_ids)) + // per_layer_inputs = EmbedPerLayer(per_layer_inputs_tokens) # Uses padding (0) for vision + + inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens); + ggml_set_input(inp->embd); + + // tok_embd_per_layer shape: [embd_size, vocab_size] where embd_size = n_embd_altup * n_layer + const int64_t embd_size = model.tok_embd_per_layer->ne[0]; // n_embd_altup * n_layer + + // Create zeros tensor [embd_size, n_tokens] by projecting vision embeddings and multiplying by 0 + // First, project inp->embd [n_embd, n_tokens] to per-layer space [embd_size, n_tokens] + ggml_tensor * zeros_per_layer = ggml_mul_mat(ctx0, model.per_layer_model_proj, inp->embd); + zeros_per_layer = ggml_scale(ctx0, zeros_per_layer, 0.0f); // Multiply by 0 to get zeros + ggml_set_name(zeros_per_layer, "zeros_per_layer"); + + // Extract column 0 (padding token's embedding) as a vector: [embd_size] + // Note: tok_embd_per_layer is quantized (q8_0), so the view is also q8_0 + ggml_tensor * padding_embd_vec_q = ggml_view_1d(ctx0, model.tok_embd_per_layer, + embd_size, // number of elements + 0); // offset (column 0) + ggml_set_name(padding_embd_vec_q, "padding_token_emb_q8"); + + // Dequantize to f32 using ggml_cpy + ggml_tensor * padding_embd_vec_f32 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, embd_size); + ggml_tensor * padding_embd_vec = ggml_cpy(ctx0, padding_embd_vec_q, padding_embd_vec_f32); + ggml_set_name(padding_embd_vec, "padding_token_emb_f32"); + + // Reshape to [embd_size, 1] for broadcasting + ggml_tensor * padding_embd_col = ggml_reshape_2d(ctx0, padding_embd_vec, embd_size, 1); + + // Add: zeros [embd_size, n_tokens] + padding [embd_size, 1] = broadcasted padding [embd_size, n_tokens] + ggml_tensor * inp_per_layer_flat = ggml_add(ctx0, zeros_per_layer, padding_embd_col); + ggml_set_name(inp_per_layer_flat, "inp_per_layer_broadcasted"); + + // Reshape to [n_embd_altup, n_layer, n_tokens] for per-layer processing + inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer_flat, n_embd_altup, n_layer, n_tokens); + + // Apply same scaling as text tokens + // inp_per_layer = ggml_scale(ctx0, inp_per_layer, sqrtf((float) n_embd_altup)); + cb(inp_per_layer, "inp_per_layer_vision", -1); } res->add_input(std::move(inp)); return inp_per_layer; diff --git a/tools/mtmd/CMakeLists.txt b/tools/mtmd/CMakeLists.txt index 317d5f19fd9..a74b4bc2154 100644 --- a/tools/mtmd/CMakeLists.txt +++ b/tools/mtmd/CMakeLists.txt @@ -27,6 +27,7 @@ add_library(mtmd models/qwen3vl.cpp models/siglip.cpp models/whisper-enc.cpp + models/mobilenetv5.cpp ) set_target_properties(mtmd PROPERTIES diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index a0939865e3f..24a1ef52d08 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -153,6 +153,47 @@ #define TN_CONV_PW1 "%s.blk.%d.conv_pw1.%s" #define TN_CONV_PW2 "%s.blk.%d.conv_pw2.%s" +// mobilenetv5 (gemma3n) definitions +#define TN_MNV5_STEM_CONV "v.enc.conv_stem.conv.weight" +#define TN_MNV5_STEM_BIAS "v.enc.conv_stem.conv.bias" +#define TN_MNV5_STEM_BN "v.enc.conv_stem.bn.weight" + +// Stage 0 Block (Edge Residual) +#define TN_MNV5_BLK_S0_EXP_W "v.enc.blocks.%d.%d.conv_exp.weight" +#define TN_MNV5_BLK_S0_BN1_W "v.enc.blocks.%d.%d.bn1.weight" +#define TN_MNV5_BLK_S0_PWL_W "v.enc.blocks.%d.%d.conv_pwl.weight" +#define TN_MNV5_BLK_S0_BN2_W "v.enc.blocks.%d.%d.bn2.weight" + +// Stage 1+ Block (Universal Inverted Residual) +#define TN_MNV5_BLK_DW_START_W "v.enc.blocks.%d.%d.dw_start.conv.weight" +#define TN_MNV5_BLK_DW_START_BN "v.enc.blocks.%d.%d.dw_start.bn.weight" +#define TN_MNV5_BLK_DW_MID_W "v.enc.blocks.%d.%d.dw_mid.conv.weight" +#define TN_MNV5_BLK_DW_MID_BN "v.enc.blocks.%d.%d.dw_mid.bn.weight" +#define TN_MNV5_BLK_PW_EXP_W "v.enc.blocks.%d.%d.pw_exp.conv.weight" +#define TN_MNV5_BLK_PW_EXP_BN "v.enc.blocks.%d.%d.pw_exp.bn.weight" +#define TN_MNV5_BLK_PW_PROJ_W "v.enc.blocks.%d.%d.pw_proj.conv.weight" +#define TN_MNV5_BLK_PW_PROJ_BN "v.enc.blocks.%d.%d.pw_proj.bn.weight" +#define TN_MNV5_BLK_LAYER_SCALE "v.enc.blocks.%d.%d.layer_scale.gamma" + +// Attention Components +#define TN_MNV5_ATTN_Q_W "v.enc.blocks.%d.%d.attn.query.proj.weight" +#define TN_MNV5_ATTN_K_W "v.enc.blocks.%d.%d.attn.key.proj.weight" +#define TN_MNV5_ATTN_V_W "v.enc.blocks.%d.%d.attn.value.proj.weight" +#define TN_MNV5_ATTN_O_W "v.enc.blocks.%d.%d.attn.output.proj.weight" +#define TN_MNV5_ATTN_K_DW "v.enc.blocks.%d.%d.attn.key.down_conv.weight" +#define TN_MNV5_ATTN_K_NORM "v.enc.blocks.%d.%d.attn.key.norm.weight" +#define TN_MNV5_ATTN_V_DW "v.enc.blocks.%d.%d.attn.value.down_conv.weight" +#define TN_MNV5_ATTN_V_NORM "v.enc.blocks.%d.%d.attn.value.norm.weight" +#define TN_MNV5_ATTN_NORM "v.enc.blocks.%d.%d.norm.weight" // Block norm used in attn blocks + +// MSFA +#define TN_MNV5_MSFA_FFN_EXP_W "v.enc.msfa.ffn.pw_exp.conv.weight" +#define TN_MNV5_MSFA_FFN_EXP_BN "v.enc.msfa.ffn.pw_exp.bn.weight" +#define TN_MNV5_MSFA_FFN_PROJ_W "v.enc.msfa.ffn.pw_proj.conv.weight" +#define TN_MNV5_MSFA_FFN_PROJ_BN "v.enc.msfa.ffn.pw_proj.bn.weight" +#define TN_MNV5_MSFA_NORM "v.enc.msfa.norm.weight" + + // align x to upper multiple of n #define CLIP_ALIGN(x, n) ((((x) + (n) - 1) / (n)) * (n)) @@ -170,6 +211,7 @@ enum projector_type { PROJECTOR_TYPE_QWEN2VL, PROJECTOR_TYPE_QWEN3VL, PROJECTOR_TYPE_GEMMA3, + PROJECTOR_TYPE_GEMMA3N, PROJECTOR_TYPE_IDEFICS3, PROJECTOR_TYPE_PIXTRAL, PROJECTOR_TYPE_QWEN25VL, @@ -200,6 +242,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_QWEN25VL, "qwen2.5vl_merger"}, { PROJECTOR_TYPE_QWEN3VL, "qwen3vl_merger"}, { PROJECTOR_TYPE_GEMMA3, "gemma3"}, + { PROJECTOR_TYPE_GEMMA3N, "gemma3n"}, { PROJECTOR_TYPE_IDEFICS3, "idefics3"}, { PROJECTOR_TYPE_PIXTRAL, "pixtral"}, { PROJECTOR_TYPE_ULTRAVOX, "ultravox"}, diff --git a/tools/mtmd/clip-model.h b/tools/mtmd/clip-model.h index b4c31cdde6b..be168b97ef2 100644 --- a/tools/mtmd/clip-model.h +++ b/tools/mtmd/clip-model.h @@ -172,6 +172,45 @@ struct clip_layer { } }; +// Expanded MobileNetV5 block structure for Gemma3n vision encoder +struct mobilenetv5_block { + // Stage 0 (Edge Residual) + ggml_tensor * s0_conv_exp_w = nullptr; + ggml_tensor * s0_bn1_w = nullptr; + ggml_tensor * s0_conv_pwl_w = nullptr; + ggml_tensor * s0_bn2_w = nullptr; + + // Stage 1+ (Universal Inverted Residual) + ggml_tensor * dw_start_w = nullptr; + ggml_tensor * dw_start_bn_w = nullptr; + + ggml_tensor * pw_exp_w = nullptr; + ggml_tensor * pw_exp_bn_w = nullptr; + + ggml_tensor * dw_mid_w = nullptr; + ggml_tensor * dw_mid_bn_w = nullptr; + + ggml_tensor * pw_proj_w = nullptr; + ggml_tensor * pw_proj_bn_w = nullptr; + + ggml_tensor * layer_scale_w = nullptr; + + // Attention (MQA) components + ggml_tensor * attn_q_w = nullptr; + ggml_tensor * attn_k_w = nullptr; + ggml_tensor * attn_v_w = nullptr; + ggml_tensor * attn_o_w = nullptr; + + // Optional downsampling/norm in attention + ggml_tensor * attn_k_dw_w = nullptr; + ggml_tensor * attn_k_norm_w = nullptr; + ggml_tensor * attn_v_dw_w = nullptr; + ggml_tensor * attn_v_norm_w = nullptr; + + // Block norm (often present in attention blocks) + ggml_tensor * attn_norm_w = nullptr; +}; + struct clip_model { clip_modality modality = CLIP_MODALITY_VISION; projector_type proj_type = PROJECTOR_TYPE_MLP; @@ -288,6 +327,23 @@ struct clip_model { ggml_tensor * mm_input_proj_w = nullptr; ggml_tensor * mm_soft_emb_norm_w = nullptr; + // mobilenetv5 for gemma3n + std::vector mobilenet_blocks; + std::vector mobilenet_stage_ends; + ggml_tensor * mobilenet_stem_conv_w = nullptr; + ggml_tensor * mobilenet_stem_conv_b = nullptr; + ggml_tensor * mobilenet_stem_norm_w = nullptr; + ggml_tensor * mm_post_proj_norm_w = nullptr; + + // Multi-Scale Fusion Adapter (MSFA) components + ggml_tensor * msfa_concat_conv_w = nullptr; + ggml_tensor * msfa_concat_norm_w = nullptr; + ggml_tensor * msfa_ffn_expand_w = nullptr; + ggml_tensor * msfa_ffn_project_w = nullptr; + ggml_tensor * msfa_ffn_expand_bn = nullptr; + ggml_tensor * msfa_ffn_project_bn = nullptr; + + // pixtral, glm4v ggml_tensor * token_embd_img_break = nullptr; ggml_tensor * mm_patch_merger_w = nullptr; diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 3ba0823defb..dd778ea3c96 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -788,6 +788,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 { builder = std::make_unique(ctx, img); } break; + case PROJECTOR_TYPE_GEMMA3N: + { + builder = std::make_unique(ctx, img); + } break; case PROJECTOR_TYPE_PIXTRAL: case PROJECTOR_TYPE_LIGHTONOCR: { @@ -1141,6 +1145,14 @@ struct clip_model_loader { // test model (tinygemma3) has a different value, we optionally read it get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false); } break; + + case PROJECTOR_TYPE_GEMMA3N: + { + // Gemma3n uses MobileNetV5 which produces 256 tokens (16x16) + // Similar configuration to Gemma3 + hparams.n_merge = 1; // MobileNetV5 handles resizing internally + get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false); + } break; case PROJECTOR_TYPE_QWEN2VL: case PROJECTOR_TYPE_QWEN25VL: case PROJECTOR_TYPE_QWEN3VL: @@ -1381,6 +1393,7 @@ struct clip_model_loader { } } + switch (model.proj_type) { case PROJECTOR_TYPE_MLP: case PROJECTOR_TYPE_MLP_NORM: @@ -1512,6 +1525,102 @@ struct clip_model_loader { model.mm_input_proj_w = get_tensor(TN_MM_INP_PROJ); model.mm_soft_emb_norm_w = get_tensor(TN_MM_SOFT_EMB_N); } break; + case PROJECTOR_TYPE_GEMMA3N: + { + model.mobilenet_stem_conv_w = get_tensor(TN_MNV5_STEM_CONV, false); + model.mobilenet_stem_conv_b = get_tensor(TN_MNV5_STEM_BIAS, false); + model.mobilenet_stem_norm_w = get_tensor(TN_MNV5_STEM_BN, false); + + model.msfa_ffn_expand_w = get_tensor(TN_MNV5_MSFA_FFN_EXP_W, false); + model.msfa_ffn_expand_bn = get_tensor(TN_MNV5_MSFA_FFN_EXP_BN, false); // Consume BN if present but likely folded + model.msfa_ffn_project_w = get_tensor(TN_MNV5_MSFA_FFN_PROJ_W, false); + model.msfa_ffn_project_bn = get_tensor(TN_MNV5_MSFA_FFN_PROJ_BN, false); + + model.msfa_concat_norm_w = get_tensor(TN_MNV5_MSFA_NORM, false); + + // Dynamically load blocks stage by stage + for (int stage = 0; stage < 4; ++stage) { + int blocks_found_in_stage = 0; + + for (int blk_idx = 0; ; ++blk_idx) { + bool found_block = false; + mobilenetv5_block block; + + // 1. Check for Edge Residual (S0) + block.s0_conv_exp_w = get_tensor(string_format(TN_MNV5_BLK_S0_EXP_W, stage, blk_idx), false); + if (block.s0_conv_exp_w) { + found_block = true; + block.s0_bn1_w = get_tensor(string_format(TN_MNV5_BLK_S0_BN1_W, stage, blk_idx), false); + block.s0_conv_pwl_w = get_tensor(string_format(TN_MNV5_BLK_S0_PWL_W, stage, blk_idx), false); + block.s0_bn2_w = get_tensor(string_format(TN_MNV5_BLK_S0_BN2_W, stage, blk_idx), false); + } + // 2. Check for UIR (Universal Inverted Residual) + else { + // Check for dw_start OR pw_exp (some UIR blocks skip dw_start) + block.dw_start_w = get_tensor(string_format(TN_MNV5_BLK_DW_START_W, stage, blk_idx), false); + block.pw_exp_w = get_tensor(string_format(TN_MNV5_BLK_PW_EXP_W, stage, blk_idx), false); + + if (block.dw_start_w || block.pw_exp_w) { + found_block = true; + if (block.dw_start_w) { + block.dw_start_bn_w = get_tensor(string_format(TN_MNV5_BLK_DW_START_BN, stage, blk_idx), false); + } + if (block.pw_exp_w) { + block.pw_exp_bn_w = get_tensor(string_format(TN_MNV5_BLK_PW_EXP_BN, stage, blk_idx), false); + } + block.dw_mid_w = get_tensor(string_format(TN_MNV5_BLK_DW_MID_W, stage, blk_idx), false); + if (block.dw_mid_w) { + block.dw_mid_bn_w = get_tensor(string_format(TN_MNV5_BLK_DW_MID_BN, stage, blk_idx), false); + } + block.pw_proj_w = get_tensor(string_format(TN_MNV5_BLK_PW_PROJ_W, stage, blk_idx), false); + if (block.pw_proj_w) { + block.pw_proj_bn_w = get_tensor(string_format(TN_MNV5_BLK_PW_PROJ_BN, stage, blk_idx), false); + } + block.layer_scale_w = get_tensor(string_format(TN_MNV5_BLK_LAYER_SCALE, stage, blk_idx), false); + } + } + + // 3. Check for Attention (MQA) + // Even if UIR/Edge check failed, this might be a pure attention block + ggml_tensor* attn_q_check = get_tensor(string_format(TN_MNV5_ATTN_Q_W, stage, blk_idx), false); + if (attn_q_check) { + found_block = true; + block.attn_q_w = attn_q_check; + block.attn_k_w = get_tensor(string_format(TN_MNV5_ATTN_K_W, stage, blk_idx), false); + block.attn_v_w = get_tensor(string_format(TN_MNV5_ATTN_V_W, stage, blk_idx), false); + block.attn_o_w = get_tensor(string_format(TN_MNV5_ATTN_O_W, stage, blk_idx), false); + block.attn_k_dw_w = get_tensor(string_format(TN_MNV5_ATTN_K_DW, stage, blk_idx), false); + block.attn_k_norm_w = get_tensor(string_format(TN_MNV5_ATTN_K_NORM, stage, blk_idx), false); + block.attn_v_dw_w = get_tensor(string_format(TN_MNV5_ATTN_V_DW, stage, blk_idx), false); + block.attn_v_norm_w = get_tensor(string_format(TN_MNV5_ATTN_V_NORM, stage, blk_idx), false); + block.attn_norm_w = get_tensor(string_format(TN_MNV5_ATTN_NORM, stage, blk_idx), false); + // Note: Attention blocks also have layer_scale, load it if not already loaded by UIR check + if (!block.layer_scale_w) { + block.layer_scale_w = get_tensor(string_format(TN_MNV5_BLK_LAYER_SCALE, stage, blk_idx), false); + } + } + + if (found_block) { + model.mobilenet_blocks.push_back(block); + blocks_found_in_stage++; + } else { + // End of blocks for this stage + break; + } + } + + // Track where this stage ends in the flat vector + if (blocks_found_in_stage > 0) { + model.mobilenet_stage_ends.push_back(model.mobilenet_blocks.size() - 1); + LOG_INF("%s: Stage %d ended at global block index %zu\n", __func__, stage, model.mobilenet_blocks.size() - 1); + } + } + // Load projection weights (similar to Gemma3) + model.mm_input_proj_w = get_tensor(TN_MM_INP_PROJ); + model.mm_soft_emb_norm_w = get_tensor(TN_MM_SOFT_EMB_N); + model.mm_0_w = get_tensor("mm.embedding.weight", false); // Input embedding + model.mm_1_w = get_tensor("mm.hard_emb_norm.weight", false); // Hard embedding norm + } break; case PROJECTOR_TYPE_IDEFICS3: { model.projection = get_tensor(TN_MM_PROJECTOR); @@ -2747,6 +2856,16 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str res_imgs->entries.push_back(std::move(img_f32)); } break; + case PROJECTOR_TYPE_GEMMA3N: + { + clip_image_u8 resized_image; + int sz = params.image_size; + img_tool::resize(*img, resized_image, {sz, sz}, img_tool::RESIZE_ALGO_BILINEAR, false); + clip_image_f32_ptr img_f32(clip_image_f32_init()); + normalize_image_u8_to_f32(resized_image, *img_f32, params.image_mean, params.image_std); + res_imgs->entries.push_back(std::move(img_f32)); + } break; + case PROJECTOR_TYPE_JANUS_PRO: { // Janus Pro preprocessing: pad to square with gray(127), resize to 384x384 @@ -3006,6 +3125,12 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im int scale_factor = ctx->model.hparams.n_merge; n_patches /= (scale_factor * scale_factor); } break; + case PROJECTOR_TYPE_GEMMA3N: + { + // MobileNetV5 MSFA adapter always outputs fixed 16x16 resolution + // regardless of input size (see architecture description) + n_patches = ctx->model.hparams.image_size / ctx->model.hparams.patch_size; + } break; case PROJECTOR_TYPE_LFM2: case PROJECTOR_TYPE_KIMIVL: { @@ -3396,6 +3521,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima set_input_i32("patches", patches); } break; case PROJECTOR_TYPE_GEMMA3: + case PROJECTOR_TYPE_GEMMA3N: case PROJECTOR_TYPE_IDEFICS3: case PROJECTOR_TYPE_INTERNVL: case PROJECTOR_TYPE_QWEN2A: @@ -3521,6 +3647,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { // main path + deepstack paths return ctx->model.mm_1_b->ne[0] * (1 + ctx->model.n_deepstack_layers); case PROJECTOR_TYPE_GEMMA3: + case PROJECTOR_TYPE_GEMMA3N: return ctx->model.mm_input_proj_w->ne[0]; case PROJECTOR_TYPE_IDEFICS3: return ctx->model.projection->ne[1]; diff --git a/tools/mtmd/models/mobilenetv5.cpp b/tools/mtmd/models/mobilenetv5.cpp new file mode 100644 index 00000000000..bc1185c10eb --- /dev/null +++ b/tools/mtmd/models/mobilenetv5.cpp @@ -0,0 +1,463 @@ +#include "models.h" + +// --- Helpers for MobileNetV5 Blocks --- +// RMS Norm 2D - normalizes over channels for each spatial position +ggml_tensor * clip_graph_mobilenetv5::rms_norm_2d(ggml_tensor * inp, ggml_tensor * weight, float eps) { + // inp: [W, H, C, B] + + ggml_tensor * cur = ggml_permute(ctx0, inp, 2, 1, 0, 3); + cur = ggml_cont(ctx0, cur); + cur = ggml_rms_norm(ctx0, cur, eps); + + if (weight) { + cur = ggml_mul(ctx0, cur, weight); + } + + cur = ggml_permute(ctx0, cur, 2, 1, 0, 3); + cur = ggml_cont(ctx0, cur); + + return cur; +} + +// Helper for Conv2dSame padding (asymmetric SAME padding like PyTorch/TF) +ggml_tensor* clip_graph_mobilenetv5::pad_same_2d(ggml_tensor* inp, int kernel_h, int kernel_w, int stride_h, int stride_w, int dilation_h, int dilation_w) { + const int64_t ih = inp->ne[1]; // height + const int64_t iw = inp->ne[0]; // width + + // Calculate output size (ceil division) + const int64_t oh = (ih + stride_h - 1) / stride_h; + const int64_t ow = (iw + stride_w - 1) / stride_w; + + // Calculate padding needed + const int64_t pad_h = std::max((int64_t)0, (oh - 1) * stride_h + (kernel_h - 1) * dilation_h + 1 - ih); + const int64_t pad_w = std::max((int64_t)0, (ow - 1) * stride_w + (kernel_w - 1) * dilation_w + 1 - iw); + + // Split padding asymmetrically + const int pad_h_top = pad_h / 2; + const int pad_h_bottom = pad_h - pad_h_top; + const int pad_w_left = pad_w / 2; + const int pad_w_right = pad_w - pad_w_left; + + // Apply padding if needed + // ggml_pad_ext: (ctx, tensor, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3) + // For [W, H, C, B]: p0=width, p1=height, p2=channels, p3=batch + if (pad_h > 0 || pad_w > 0) { + inp = ggml_pad_ext(ctx0, inp, + pad_w_left, pad_w_right, // width padding (dim 0) + pad_h_top, pad_h_bottom, // height padding (dim 1) + 0, 0, // no channel padding (dim 2) + 0, 0); // no batch padding (dim 3) + } + + return inp; +} + + +// Edge Residual Block (Stage 0) +ggml_tensor * clip_graph_mobilenetv5::build_edge_residual(ggml_tensor * inp, const mobilenetv5_block & block, int stride) { + ggml_tensor * cur = inp; + + // 1. Expansion Conv (3x3) + if (stride == 2) { + // Case: Downsampling (Block 0) + // Replicates Conv2dSame(kernel=3, stride=2) + cur = pad_same_2d(cur, 3, 3, stride, stride); + cur = ggml_conv_2d_direct(ctx0, block.s0_conv_exp_w, cur, stride, stride, 0, 0, 1, 1); + } else { + // Case: Normal 3x3 Block (Block 1, 2) + // Replicates Conv2d(kernel=3, stride=1, padding=1) + cur = ggml_conv_2d_direct(ctx0, block.s0_conv_exp_w, cur, stride, stride, 1, 1, 1, 1); + } + + // BN + Activation + if (block.s0_bn1_w) cur = rms_norm_2d(cur, block.s0_bn1_w); + cur = ggml_gelu(ctx0, cur); + + // 2. Pointwise Linear Conv (1x1) + // 1x1 Convs usually have padding=0 and stride=1 + cur = ggml_conv_2d_direct(ctx0, block.s0_conv_pwl_w, cur, 1, 1, 0, 0, 1, 1); + if (block.s0_bn2_w) cur = rms_norm_2d(cur, block.s0_bn2_w); + + // 3. Residual Connection + // Only apply residual if spatial dimensions and channels match (stride 1) + if (stride == 1 && inp->ne[2] == cur->ne[2] && inp->ne[0] == cur->ne[0]) { + cur = ggml_add(ctx0, cur, inp); + } + + return cur; +} + +ggml_tensor * clip_graph_mobilenetv5::build_inverted_residual(ggml_tensor * inp, const mobilenetv5_block & block, int stride) { + ggml_tensor * cur = inp; + + // 1. Depthwise Start (Optional) + // NOTE: dw_start always has stride=1 (no downsampling here) + if (block.dw_start_w) { + int k = block.dw_start_w->ne[0]; // 3 or 5 + int p = k / 2; + cur = ggml_conv_2d_dw(ctx0, block.dw_start_w, cur, 1, 1, p, p, 1, 1); + if (block.dw_start_bn_w) cur = rms_norm_2d(cur, block.dw_start_bn_w); + } + + // 2. Pointwise Expansion (1x1) + if (block.pw_exp_w) { + // Standard 1x1 conv, pad=0, stride=1 + cur = ggml_conv_2d_direct(ctx0, block.pw_exp_w, cur, 1, 1, 0, 0, 1, 1); + if (block.pw_exp_bn_w) cur = rms_norm_2d(cur, block.pw_exp_bn_w); + cur = ggml_gelu(ctx0, cur); + } + + // 3. Depthwise Mid (Optional) + // NOTE: dw_mid is where downsampling happens (stride=2 for first block of stage) + if (block.dw_mid_w) { + int k = block.dw_mid_w->ne[0]; // 3 or 5 + + if (stride > 1) { + // Case: Stride 2 (Downsample) -> Use Asymmetric "Same" Padding + cur = pad_same_2d(cur, k, k, stride, stride); + cur = ggml_conv_2d_dw(ctx0, block.dw_mid_w, cur, stride, stride, 0, 0, 1, 1); // pad=0 + } else { + // Case: Stride 1 -> Use Standard Symmetric Padding + int p = k / 2; + cur = ggml_conv_2d_dw(ctx0, block.dw_mid_w, cur, stride, stride, p, p, 1, 1); + } + + if (block.dw_mid_bn_w) cur = rms_norm_2d(cur, block.dw_mid_bn_w); + cur = ggml_gelu(ctx0, cur); + } + + // 4. Pointwise Projection (1x1) + if (block.pw_proj_w) { + cur = ggml_conv_2d_direct(ctx0, block.pw_proj_w, cur, 1, 1, 0, 0, 1, 1); + if (block.pw_proj_bn_w) cur = rms_norm_2d(cur, block.pw_proj_bn_w); + } + + // Apply Layer Scaling if present + if (block.layer_scale_w) { + ggml_tensor * scale_w_reshaped = ggml_reshape_4d(ctx0, block.layer_scale_w, + 1, 1, block.layer_scale_w->ne[0], 1); + + cur = ggml_mul(ctx0, cur, scale_w_reshaped); + } + + // 5. Residual Connection + bool same_spatial = (inp->ne[0] == cur->ne[0]) && (inp->ne[1] == cur->ne[1]); + bool same_channel = (inp->ne[2] == cur->ne[2]); + if (same_spatial && same_channel) { + cur = ggml_add(ctx0, cur, inp); + } + + return cur; +} + +// MobileNetV5 Builder (Gemma 3n) - Attention Block +ggml_tensor * clip_graph_mobilenetv5::build_mobilenet_attn(ggml_tensor * inp, const mobilenetv5_block & block) { + ggml_tensor * cur = inp; + + // --- Norm --- + if (block.attn_norm_w) { + cur = rms_norm_2d(cur, block.attn_norm_w, 1e-6f); + } + + // --- 1. Q Calculation --- + ggml_tensor * q = ggml_conv_2d_direct(ctx0, block.attn_q_w, cur, 1, 1, 0, 0, 1, 1); + + // --- 2. K Calculation (Downsampled) --- + // Uses Conv2dSame(640, 640, kernel_size=(3, 3), stride=(2, 2), groups=640) + ggml_tensor * k_inp = cur; + if (block.attn_k_dw_w) { + int k_size = block.attn_k_dw_w->ne[0]; // Usually 3 + k_inp = pad_same_2d(cur, k_size, k_size, 2, 2); // Apply SAME padding + k_inp = ggml_conv_2d_dw(ctx0, block.attn_k_dw_w, k_inp, 2, 2, 0, 0, 1, 1); // padding=0 + if (block.attn_k_norm_w) { + k_inp = rms_norm_2d(k_inp, block.attn_k_norm_w, 1e-6f); + } + } + ggml_tensor * k = ggml_conv_2d_direct(ctx0, block.attn_k_w, k_inp, 1, 1, 0, 0, 1, 1); + + // --- 3. V Calculation (Downsampled) --- + // Uses Conv2dSame(640, 640, kernel_size=(3, 3), stride=(2, 2), groups=640) + ggml_tensor * v_inp = cur; + if (block.attn_v_dw_w) { + int v_size = block.attn_v_dw_w->ne[0]; // Usually 3 + v_inp = pad_same_2d(cur, v_size, v_size, 2, 2); // Apply SAME padding + v_inp = ggml_conv_2d_dw(ctx0, block.attn_v_dw_w, v_inp, 2, 2, 0, 0, 1, 1); // padding=0 + if (block.attn_v_norm_w) { + v_inp = rms_norm_2d(v_inp, block.attn_v_norm_w, 1e-6f); + } + } + ggml_tensor * v = ggml_conv_2d_direct(ctx0, block.attn_v_w, v_inp, 1, 1, 0, 0, 1, 1); + + const int W = cur->ne[0]; const int H = cur->ne[1]; const int B = cur->ne[3]; + const int D = k->ne[2]; // Head dimension + const int n_head = q->ne[2] / D; + const int N = W * H; + + // Process Q: [W, H, D*n_head, B] -> [D, N, n_head, B] + q = ggml_reshape_3d(ctx0, q, N, D*n_head, B); + q = ggml_reshape_4d(ctx0, q, N, D, n_head, B); + q = ggml_permute(ctx0, q, 1, 0, 2, 3); // [D, N, n_head, B] + q = ggml_cont(ctx0, q); + + const int Wk = k->ne[0]; const int Hk = k->ne[1]; + const int M = Wk * Hk; + + // Process K: [Wk, Hk, D, B] -> [D, M, 1, B] + k = ggml_reshape_3d(ctx0, k, M, D, B); + k = ggml_reshape_4d(ctx0, k, M, D, 1, B); + k = ggml_permute(ctx0, k, 1, 0, 2, 3); // [D, M, 1, B] + k = ggml_cont(ctx0, k); + + // Process V: [Wk, Hk, D, B] -> [M, D, 1, B] + v = ggml_reshape_3d(ctx0, v, M, D, B); + v = ggml_reshape_4d(ctx0, v, M, D, 1, B); + v = ggml_cont(ctx0, v); // [M, D, 1, B] + + // --- Multi-Query Attention --- + float scale = 1.0f / sqrtf((float)D); + + // Step 1: Compute Q @ K.T + ggml_tensor * scores = ggml_mul_mat(ctx0, k, q); + + scores = ggml_scale(ctx0, scores, scale); + + scores = ggml_soft_max(ctx0, scores); + + ggml_tensor * kqv = ggml_mul_mat(ctx0, v, scores); + + kqv = ggml_permute(ctx0, kqv, 1, 0, 2, 3); + kqv = ggml_cont(ctx0, kqv); + + + kqv = ggml_reshape_3d(ctx0, kqv, N, D * n_head, B); + kqv = ggml_reshape_4d(ctx0, kqv, W, H, D * n_head, B); + kqv = ggml_cont(ctx0, kqv); + + // Output projection + cur = ggml_conv_2d_direct(ctx0, block.attn_o_w, kqv, 1, 1, 0, 0, 1, 1); + + // --- Residual & Layer Scale (FIXED) --- + if (inp->ne[0] == cur->ne[0] && inp->ne[2] == cur->ne[2]) { + if (block.layer_scale_w) { + ggml_tensor * scale_w_reshaped = ggml_reshape_4d(ctx0, block.layer_scale_w, + 1, 1, block.layer_scale_w->ne[0], 1); + cur = ggml_mul(ctx0, cur, scale_w_reshaped); + } + cur = ggml_add(ctx0, cur, inp); + } + + return cur; +} + +ggml_cgraph * clip_graph_mobilenetv5::build() { + + fprintf(stderr, "\n--- START build_mobilenetv5 ---\n"); + + ggml_tensor * inp = build_inp_raw(); + + // 1. Stem - Conv2dSame(3, 64, kernel_size=(3, 3), stride=(2, 2)) + ggml_tensor * cur = pad_same_2d(inp, 3, 3, 2, 2); // Apply SAME padding + + cur = ggml_conv_2d_direct(ctx0, model.mobilenet_stem_conv_w, cur, 2, 2, 0, 0, 1, 1); // padding=0 + if (model.mobilenet_stem_conv_b) { + // Bias is [C, 1, 1, 1], need to reshape to [1, 1, C, 1] for broadcasting to [W, H, C, B] + ggml_tensor * bias = ggml_reshape_4d(ctx0, model.mobilenet_stem_conv_b, 1, 1, cur->ne[2], 1); + cur = ggml_add(ctx0, cur, bias); + } + if (model.mobilenet_stem_norm_w) cur = rms_norm_2d(cur, model.mobilenet_stem_norm_w); + cur = ggml_gelu(ctx0, cur); + + + // 2. Blocks + std::vector intermediate_features; + const int total_blocks = model.mobilenet_blocks.size(); + + auto is_stage_start = [&](int i) { + if (i == 0) return true; + for (int end_idx : model.mobilenet_stage_ends) { + if (i == end_idx + 1) return true; + } + return false; + }; + + auto is_fusion_point = [&](int i) { + if (model.mobilenet_stage_ends.size() >= 4) { + if (i == model.mobilenet_stage_ends[2]) return true; // End of Stage 2 + if (i == model.mobilenet_stage_ends[3]) return true; // End of Stage 3 + } else { + if (i == total_blocks - 1) return true; + } + return false; + }; + + for (int i = 0; i < total_blocks; i++) { + const auto & block = model.mobilenet_blocks[i]; + int stride = is_stage_start(i) ? 2 : 1; + + if (block.s0_conv_exp_w) cur = build_edge_residual(cur, block, stride); + else if (block.attn_q_w) cur = build_mobilenet_attn(cur, block); + else cur = build_inverted_residual(cur, block, stride); + + if (is_fusion_point(i)) { + + intermediate_features.push_back(cur); + } + } + + // 3. Multi-Scale Fusion Adapter (MSFA) + if (!intermediate_features.empty()) { + + // A. Reference Resolution: PyTorch implementation uses inputs[0] + // We assume intermediate_features[0] is the "High Resolution" target. + // In MobileNet designs, this is typically the feature map with the smallest stride (e.g. 32x32). + ggml_tensor* target_feat = intermediate_features[0]; + int high_res_w = target_feat->ne[0]; + int high_res_h = target_feat->ne[1]; + + std::vector resized_feats; + + // B. Resize inputs to match inputs[0] (High Resolution) + for (auto feat : intermediate_features) { + int feat_w = feat->ne[0]; + int feat_h = feat->ne[1]; + + // PyTorch: if feat_size < high_resolution: interpolate + if (feat_w < high_res_w || feat_h < high_res_h) { + // Calculate scale factor. + // Note: PyTorch 'nearest' works on arbitrary float scales. + // ggml_upscale generally takes integer factors or target sizes depending on helper. + // Assuming standard power-of-2 scaling (e.g. 16 -> 32 means scale=2). + int scale_w = high_res_w / feat_w; + // int scale_h = high_res_h / feat_h; + + // Safety check for non-integer scaling if strictly replicating + if (high_res_w % feat_w != 0) { + fprintf(stderr, "Warning: Non-integer scaling detected in MSFA\n"); + } + + // Upsample (Nearest Neighbor) + // 2 is the scale factor + feat = ggml_upscale(ctx0, feat, scale_w, ggml_scale_mode::GGML_SCALE_MODE_NEAREST); + } + resized_feats.push_back(feat); + } + + // C. Concatenate at High Resolution (Channel Dim = 2 in ggml) + cur = resized_feats[0]; + for (size_t k = 1; k < resized_feats.size(); ++k) { + cur = ggml_concat(ctx0, cur, resized_feats[k], 2); + } + + // D. FFN (UniversalInvertedResidual) + // Structure: Expand Conv -> Norm -> GELU -> Project Conv -> Norm + + // 1. Expansion + if (model.msfa_ffn_expand_w) { + // 1x1 Conv + cur = ggml_conv_2d_direct(ctx0, model.msfa_ffn_expand_w, cur, 1, 1, 0, 0, 1, 1); + + if (model.msfa_ffn_expand_bn) { + cur = rms_norm_2d(cur, model.msfa_ffn_expand_bn); + } + + cur = ggml_gelu(ctx0, cur); + + } + + // 2. Projection (No DW because kernel_size=0) + if (model.msfa_ffn_project_w) { + // 1x1 Conv + cur = ggml_conv_2d_direct(ctx0, model.msfa_ffn_project_w, cur, 1, 1, 0, 0, 1, 1); + + // UniversalInvertedResidual typically has a norm after projection + if (model.msfa_ffn_project_bn) { + cur = rms_norm_2d(cur, model.msfa_ffn_project_bn); + } + + } + + // E. Final Downsample to Target Resolution (Output Resolution) + // PyTorch: matches self.output_resolution (e.g. 16x16) + const int target_out_res = 16; + int current_w = cur->ne[0]; + + if (current_w > target_out_res) { + int s = current_w / target_out_res; + + if (current_w % target_out_res == 0) { + // Avg Pool: Kernel=s, Stride=s + cur = ggml_pool_2d(ctx0, cur, GGML_OP_POOL_AVG, s, s, s, s, 0, 0); + } else { + fprintf(stderr, "Error: Irregular downsampling stride required.\n"); + } + + } + + // F. Final Norm + if (model.msfa_concat_norm_w) { + cur = rms_norm_2d(cur, model.msfa_concat_norm_w); + + } + } + + // 4. Gemma 3n Multimodal Projection (Embedder) + // Input: 'cur' is [Width, Height, Channels, Batch] + int W = cur->ne[0]; + int H = cur->ne[1]; + int C = cur->ne[2]; // Should be 2048 + int B = cur->ne[3]; + + // 1. Permute and Flatten to [Channels, Tokens, Batch] + // PyTorch expects (Batch, Seq, Hidden), GGML usually processes (Hidden, Seq, Batch) + cur = ggml_permute(ctx0, cur, 2, 1, 0, 3); // -> [C, H, W, B] + cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); // -> [C, W, H, B] + cur = ggml_cont(ctx0, cur); + cur = ggml_reshape_3d(ctx0, cur, C, W*H, B); + cur = ggml_cont(ctx0, cur); + + + // 2. FEATURE SCALING + // PyTorch: vision_outputs *= self.config.vision_config.hidden_size**0.5 + // This prevents the signal from vanishing during the subsequent RMSNorm. + const float scale_factor = sqrtf((float)C); + cur = ggml_scale(ctx0, cur, scale_factor); + + + // 3. SOFT EMBEDDING NORM + // PyTorch: self._norm(x) * self.weight + // We must normalize regardless, then multiply if weight exists. + { + const float eps = 1e-6f; // Gemma3n uses 1e-6 + cur = ggml_rms_norm(ctx0, cur, eps); + + if (model.mm_soft_emb_norm_w) { + // Weight shape is (2048,) -> Element-wise broadcast multiply + cur = ggml_mul(ctx0, cur, model.mm_soft_emb_norm_w); + } + + } + + // 4. PROJECTION + // PyTorch: embedding_projection = nn.Linear(vision_hidden, text_hidden, bias=False) + // Weight stored as [out_features, in_features] = [text_hidden_size, vision_hidden_size] + if (model.mm_input_proj_w) { + cur = ggml_mul_mat(ctx0, model.mm_input_proj_w, cur); + } + + // 5. POST PROJECTION NORM + // PyTorch: embedding_post_projection_norm = Gemma3nRMSNorm(..., with_scale=False) + // with_scale=False means weight is registered as buffer with value 1.0 + // So output = rms_norm(x) * 1.0 = rms_norm(x), magnitude ~1 + { + const float eps = 1e-6f; + cur = ggml_rms_norm(ctx0, cur, eps); + + if (model.mm_post_proj_norm_w) { + // If weight is loaded, multiply (should be ~1.0 anyway) + cur = ggml_mul(ctx0, cur, model.mm_post_proj_norm_w); + } + } + + ggml_build_forward_expand(gf, cur); + return gf; +} \ No newline at end of file diff --git a/tools/mtmd/models/models.h b/tools/mtmd/models/models.h index 8d6d4ef67be..54664d10ce3 100644 --- a/tools/mtmd/models/models.h +++ b/tools/mtmd/models/models.h @@ -66,3 +66,36 @@ struct clip_graph_glm4v : clip_graph { clip_graph_glm4v(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} ggml_cgraph * build() override; }; + +struct clip_graph_mobilenetv5 : clip_graph { + clip_graph_mobilenetv5(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} + ggml_cgraph * build() override; + + ggml_tensor * rms_norm_2d( + ggml_tensor * inp, + ggml_tensor * weight, + float eps = 1e-6f); + + ggml_tensor* pad_same_2d( + ggml_tensor* inp, + int kernel_h, + int kernel_w, + int stride_h, + int stride_w, + int dilation_h = 1, + int dilation_w = 1); + + ggml_tensor * build_edge_residual( + ggml_tensor * inp, + const mobilenetv5_block & block, + int stride); + + ggml_tensor * build_inverted_residual( + ggml_tensor * inp, + const mobilenetv5_block & block, + int stride); + + ggml_tensor * build_mobilenet_attn( + ggml_tensor * inp, + const mobilenetv5_block & block); +}; diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index b9c4fa90980..2d970cf45c2 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -266,7 +266,7 @@ struct mtmd_context { } // set boi/eoi - if (proj == PROJECTOR_TYPE_GEMMA3) { + if (proj == PROJECTOR_TYPE_GEMMA3 || proj == PROJECTOR_TYPE_GEMMA3N) { // ... (image embeddings) ... img_beg = ""; img_end = ""; @@ -858,7 +858,8 @@ float * mtmd_get_output_embd(mtmd_context * ctx) { } bool mtmd_decode_use_non_causal(mtmd_context * ctx) { - if (ctx->ctx_v && clip_get_projector_type(ctx->ctx_v) == PROJECTOR_TYPE_GEMMA3) { + if (ctx->ctx_v && + (clip_get_projector_type(ctx->ctx_v) == PROJECTOR_TYPE_GEMMA3 || clip_get_projector_type(ctx->ctx_v) == PROJECTOR_TYPE_GEMMA3N)) { return true; } return false;