From 3e4c8f8faf497bdf2e02d40381fbc1a673516992 Mon Sep 17 00:00:00 2001 From: Simranjeet Singh Date: Fri, 19 Dec 2025 20:07:14 +0000 Subject: [PATCH 1/8] Add Gemma3nVisionModel - MobileNetV5 vision encoder convertor to convert_hf_to_gguf.py. Add gemma3n to vision projectors in gguf-py/gguf/constants.py. --- convert_hf_to_gguf.py | 241 +++++++++++++++++++++++++++++++++++++- gguf-py/gguf/constants.py | 2 + 2 files changed, 241 insertions(+), 2 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 432be599469..36a7ed000af 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -520,7 +520,11 @@ def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: return () def prepare_tensors(self): - max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,") + # Handle empty tensor_map for models with block_count=0 (like MobileNetV5) + if self.tensor_map.mapping: + max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,") + else: + max_name_len = len("vision_encoder.weight,") # Default reasonable length for name, data_torch in chain(self.generate_extra_tensors(), self.get_tensors()): # we don't need these @@ -5959,8 +5963,182 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return [] # skip other tensors +@ModelBase.register("Gemma3nForConditionalGeneration", "Gemma3nVisionModel") +class Gemma3nVisionModel(MmprojModel): + """Vision encoder converter for Gemma3n using MobileNetV5 architecture""" + + # MobileNetV5 doesn't have transformer layers, so we don't need block count + # Set n_block_keys to empty list to skip the find_hparam check + n_block_keys = [] + + def find_hparam(self, keys: list[str], optional: bool = False) -> Any: + """Override to return 0 for block count since MobileNetV5 is CNN-based""" + if not keys: # If n_block_keys is empty (our case) + return 0 + # Otherwise use parent implementation + return super().find_hparam(keys, optional) + + def __init__(self, *args, **kwargs): + # Parent init will call find_hparam which now returns 0 for empty keys + super().__init__(*args, **kwargs) + + def find_vparam(self, keys: list[str], optional: bool = False) -> Any: + """Override to provide hardcoded MobileNetV5 parameters that aren't in config""" + # MobileNetV5 hardcodes these values in the architecture definition + # rather than storing them in config.json + + # Handle empty keys list (n_block_keys) - return 0 for CNN architecture + if not keys: + return 0 + + # Check if we're looking for image_size + if "image_size" in keys: + # MobileNetV5 300m_enc uses 768x768 input + return 768 + + # Check if we're looking for patch_size + if "patch_size" in keys: + # MobileNetV5 is CNN-based, doesn't use patches + # Set to 1 for compatibility + return 1 + + # Check if we're looking for intermediate_size + if "intermediate_size" in keys: + # MobileNetV5 uses expansion ratios in inverted residual blocks + # Typical expansion is 4x the embedding dimension + hidden_size = self.hparams_vision.get("hidden_size", 2048) + return hidden_size * 4 + + # Check if we're looking for num_attention_heads + if "num_attention_heads" in keys or "num_heads" in keys: + # MobileNetV5 uses Multi-Query Attention with 8 heads + return 8 + + # For other parameters, use parent implementation + return super().find_vparam(keys, optional) + + def set_gguf_parameters(self): + # MobileNetV5 requires ImageNet normalization values + # Override preprocessor_config to ensure correct values before calling super() + # IMAGENET_MEAN = [0.485, 0.456, 0.406] + # IMAGENET_STD = [0.229, 0.224, 0.225] + IMAGENET_MEAN = [0.5 , 0.5 , 0.5 ] + IMAGENET_STD = [0.5 , 0.5 , 0.5 ] + + print("test") + + # Check if preprocessor_config has incorrect normalization values + if "image_mean" in self.preprocessor_config: + current_mean = self.preprocessor_config["image_mean"] + if current_mean != IMAGENET_MEAN: + logger.warning(f"Overriding image_mean from {current_mean} to ImageNet standard {IMAGENET_MEAN}") + self.preprocessor_config["image_mean"] = IMAGENET_MEAN + print("test2") + else: + logger.info(f"Setting image_mean to ImageNet standard {IMAGENET_MEAN}") + self.preprocessor_config["image_mean"] = IMAGENET_MEAN + + if "image_std" in self.preprocessor_config: + current_std = self.preprocessor_config["image_std"] + if current_std != IMAGENET_STD: + logger.warning(f"Overriding image_std from {current_std} to ImageNet standard {IMAGENET_STD}") + self.preprocessor_config["image_std"] = IMAGENET_STD + else: + logger.info(f"Setting image_std to ImageNet standard {IMAGENET_STD}") + self.preprocessor_config["image_std"] = IMAGENET_STD + + # Now call parent which will use the corrected values + super().set_gguf_parameters() + hparams = self.hparams + + # Set projector type to GEMMA3N + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.GEMMA3N) + + # MobileNetV5 specific parameters + self.gguf_writer.add_vision_attention_layernorm_eps(hparams.get("layer_norm_eps", 1e-6)) + self.gguf_writer.add_vision_use_gelu(True) # MobileNetV5 uses approximate GELU + + # Image sequence length (256 tokens = 16x16 for Gemma3n) + image_seq_length = self.preprocessor_config.get("image_seq_length", 256) + # Note: Additional metadata can be added as needed + + def tensor_force_quant(self, name, new_name, bid, n_dims): + # Force quantization settings for specific tensor types + if "input_projection" in name or "input_proj" in name: + return gguf.GGMLQuantizationType.F16 + if ".embeddings." in name or "stem" in name: + return gguf.GGMLQuantizationType.F32 + return super().tensor_force_quant(name, new_name, bid, n_dims) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + # Gemma3n uses different prefixes than other models: + # - model.embed_vision.* for projection layers + # - model.vision_tower.* for vision encoder + # Skip non-vision tensors + if not (name.startswith("model.embed_vision.") or + name.startswith("model.vision_tower.")): + return [] + + # Strip "model." prefix to match expected llama.cpp format + if name.startswith("model."): + name = name[6:] # Remove "model." prefix + + # Process MobileNetV5 and projection tensors + name = name.replace("_weight", ".weight") + + # Rename embed_vision to match our C++ implementation expectations + name = name.replace("embed_vision.", "") + + # Rename vision_tower.timm_model to vision_tower for cleaner naming + name = name.replace("vision_tower.timm_model.", "vision_tower.") + + # Handle normalization layer naming + name = name.replace("hard_embedding_norm", "hard_emb_norm") + name = name.replace("soft_embedding_norm", "soft_emb_norm") + # name = name.replace("embedding_post_projection_norm", "post_proj_norm") + + # Gemma3n uses Gemma3p5RMSNorm which has scale_shift=0, so no correction needed + # Unlike Gemma3 which uses Gemma3RMSNorm with scale_shift=1 + if "soft_emb_norm.weight" in name: + # No correction needed for Gemma3n + pass + + return [(self.map_tensor_name(name), data_torch)] + + def map_tensor_name(self, name: str) -> str: + """Map Gemma3n tensor names to GGUF format""" + # Projector tensors (from embed_vision) - use mm. prefix like Gemma3 + # IMPORTANT: Keep the .weight suffix to match C++ expectations + if name == "embedding.weight": + return "mm.embedding.weight" + if name == "embedding_projection.weight": + return "mm.input_projection.weight" # Main projection used by C++ + if name == "hard_emb_norm.weight": + return "mm.hard_emb_norm.weight" # Hard embedding normalization + if name == "soft_emb_norm.weight": + return "mm.soft_emb_norm.weight" # Soft embedding normalization (used by C++) + if name == "post_proj_norm.weight": + return "mm.post_proj_norm.weight" # Post projection normalization (CRITICAL for Gemma3n) + + # Vision tower tensors - add v.enc. prefix for MobileNetV5 encoder + if name.startswith("vision_tower."): + # Remove vision_tower prefix and add v.enc. prefix + tensor_suffix = name[13:] # Remove "vision_tower." + return f"v.enc.{tensor_suffix}" + + # If no match, try parent implementation + try: + return super().map_tensor_name(name) + except ValueError: + # If parent also can't map it, provide a sensible default + # This shouldn't happen, but provides a fallback + logger.warning(f"Using fallback mapping for tensor: {name}") + return f"v.{name}" + -@ModelBase.register("Gemma3nForConditionalGeneration") +@ModelBase.register("Gemma3nForCausalLM", "Gemma3nForConditionalGeneration") class Gemma3NModel(Gemma3Model): model_arch = gguf.MODEL_ARCH.GEMMA3N norm_shift = 0.0 # same value with Gemma3p5RMSNorm scale_shift on python code @@ -5983,8 +6161,43 @@ def __init__(self, *args, **kwargs): ] def set_vocab(self): + # For Gemma3n multimodal models, we need the FULL vocab_size (262400) + # which includes special tokens from 262144-262399 for vision/audio. + # The vocab_size_per_layer_input (262144) is only the embedding size per layer. + # Temporarily override the hparams lookup order to prioritize vocab_size. + + # Store original vocab_size_per_layer_input if it exists + vocab_size_per_layer_input = self.hparams.get("vocab_size_per_layer_input") + + # Temporarily remove vocab_size_per_layer_input to force using vocab_size + if vocab_size_per_layer_input is not None: + del self.hparams["vocab_size_per_layer_input"] + + # Call parent set_vocab which will now use vocab_size (262400) super().set_vocab() + # Restore vocab_size_per_layer_input for later use + if vocab_size_per_layer_input is not None: + self.hparams["vocab_size_per_layer_input"] = vocab_size_per_layer_input + + # Fix chat template for Gemma3n multimodal: replace special token placeholders with mtmd markers + # The mtmd library uses <__media__> as the default marker for images/audio + # but Gemma3n's chat template uses and + chat_template_key = "tokenizer.chat_template" + for kv_dict in self.gguf_writer.kv_data: + if chat_template_key in kv_dict: + template_value = kv_dict[chat_template_key].value + + # Replace soft token placeholders with mtmd markers + if '' in template_value or '' in template_value: + logger.info("Fixing Gemma3n chat template: replacing soft token placeholders with mtmd markers") + template_value = template_value.replace('', '<__media__>') + template_value = template_value.replace('', '<__media__>') + + # Update the value in place + kv_dict[chat_template_key].value = template_value + break + def set_gguf_parameters(self): super().set_gguf_parameters() self.gguf_writer.add_altup_active_idx(self.hparams["altup_active_idx"]) @@ -6020,8 +6233,32 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if "language_model." not in name: return [] # skip non-language model tensors + # Pad token embeddings for vision/audio special tokens (262144-262399) + if "embed_tokens.weight" in name or "embed_tokens_per_layer" in name: + # Move to CPU to avoid meta device issues during padding + data_torch = data_torch.to(device="cpu") + + vocab_size = self.hparams.get("vocab_size", 262400) + current_size = data_torch.shape[0] # First dimension is vocab_size + + if current_size < vocab_size: + # Pad with zeros for vision/audio tokens (they get embeddings from vision tower) + padding_size = vocab_size - current_size + tensor_type = "per-layer embeddings" if "per_layer" in name else "token embeddings" + logger.info(f"Padding {tensor_type} shape {list(data_torch.shape)} from {current_size} to {vocab_size} (adding {padding_size} vision/audio token slots)") + + # Create padding with zeros (vision tokens won't use these embeddings) + padding = torch.zeros((padding_size, data_torch.shape[1]), dtype=data_torch.dtype, device=data_torch.device) + data_torch = torch.cat([data_torch, padding], dim=0) + + # Continue with normal processing + name = name.replace("language_model.", "") + return [(self.map_tensor_name(name), data_torch)] + if "altup_unembed_projections" in name: data_torch = data_torch.to(device="cpu") + # altup_unembed matrices are [hidden_size, hidden_size], NOT vocab-based + # They should NOT be padded if ".0." in name: self._altup_unembd[0] = data_torch elif ".1." in name: diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index cab8f2901ae..41654b22b5d 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -456,6 +456,7 @@ class VISION_PROJECTOR_TYPE(IntEnum): RESAMPLER = auto() GLM_EDGE = auto() MERGER = auto() + GEMMA3N = auto() GEMMA3 = auto() QWEN3VL = auto() COGVLM = auto() @@ -3397,6 +3398,7 @@ def get_type(val: Any) -> GGUFValueType: class VisionProjectorType: GEMMA3 = "gemma3" + GEMMA3N = "gemma3n" IDEFICS3 = "idefics3" PIXTRAL = "pixtral" LLAMA4 = "llama4" From ad5ed98d7068f50447867f238c4cb1a9e1e29f3c Mon Sep 17 00:00:00 2001 From: Simranjeet Singh Date: Sat, 20 Dec 2025 20:20:54 +0000 Subject: [PATCH 2/8] Add mobilenetv5 impl --- src/models/gemma3n-iswa.cpp | 55 +++- tools/mtmd/CMakeLists.txt | 1 + tools/mtmd/clip-graph.h | 32 ++ tools/mtmd/clip-impl.h | 43 +++ tools/mtmd/clip-model.h | 56 ++++ tools/mtmd/clip.cpp | 521 ++++++++++++++++++++++++++++++ tools/mtmd/clip.h | 1 + tools/mtmd/models/mobilenetv5.cpp | 247 ++++++++++++++ tools/mtmd/models/models.h | 5 + tools/mtmd/mtmd.cpp | 5 +- 10 files changed, 963 insertions(+), 3 deletions(-) create mode 100644 tools/mtmd/models/mobilenetv5.cpp diff --git a/src/models/gemma3n-iswa.cpp b/src/models/gemma3n-iswa.cpp index a0bdd6a15a1..7a6a446eb20 100644 --- a/src/models/gemma3n-iswa.cpp +++ b/src/models/gemma3n-iswa.cpp @@ -259,7 +259,60 @@ ggml_tensor * llm_build_gemma3n_iswa::get_per_layer_inputs() { inp_per_layer = ggml_scale(ctx0, inp_per_layer, sqrtf((float) n_embd_altup)); cb(inp_per_layer, "inp_per_layer_selected", -1); } else { - GGML_ABORT("TODO: support embd input"); + // For embedding inputs (e.g., from vision encoder) + // CRITICAL FIX: Vision tokens should use the padding token (ID=0) embedding + // from tok_embd_per_layer, NOT project the vision embeddings. + // The projection happens later in project_per_layer_inputs(). + // This matches PyTorch behavior: + // per_layer_inputs_tokens = torch.where(mask, input_ids, torch.zeros_like(input_ids)) + // per_layer_inputs = EmbedPerLayer(per_layer_inputs_tokens) # Uses padding (0) for vision + + inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens); + ggml_set_input(inp->embd); + + // For vision, we need per_layer_inputs from padding token (ID=0) + // We CANNOT use inp->tokens because batch allows EITHER tokens OR embeddings + // + // The challenge: We need to broadcast padding token embedding from [embd_size, 1] to [embd_size, n_tokens] + // but ggml_repeat+ggml_dup doesn't work in no_alloc mode (creates views without backing memory). + // + // Solution: Use ggml_add to broadcast! GGML automatically broadcasts along compatible dimensions. + // We create zeros of shape [embd_size, n_tokens], then add padding_emb [embd_size, 1] which broadcasts. + + // tok_embd_per_layer shape: [embd_size, vocab_size] where embd_size = n_embd_altup * n_layer + const int64_t embd_size = model.tok_embd_per_layer->ne[0]; // n_embd_altup * n_layer + + // Create zeros tensor [embd_size, n_tokens] by projecting vision embeddings and multiplying by 0 + // First, project inp->embd [n_embd, n_tokens] to per-layer space [embd_size, n_tokens] + ggml_tensor * zeros_per_layer = ggml_mul_mat(ctx0, model.per_layer_model_proj, inp->embd); + zeros_per_layer = ggml_scale(ctx0, zeros_per_layer, 0.0f); // Multiply by 0 to get zeros + ggml_set_name(zeros_per_layer, "zeros_per_layer"); + + // Extract column 0 (padding token's embedding) as a vector: [embd_size] + // Note: tok_embd_per_layer is quantized (q8_0), so the view is also q8_0 + ggml_tensor * padding_embd_vec_q = ggml_view_1d(ctx0, model.tok_embd_per_layer, + embd_size, // number of elements + 0); // offset (column 0) + ggml_set_name(padding_embd_vec_q, "padding_token_emb_q8"); + + // Dequantize to f32 using ggml_cpy + ggml_tensor * padding_embd_vec_f32 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, embd_size); + ggml_tensor * padding_embd_vec = ggml_cpy(ctx0, padding_embd_vec_q, padding_embd_vec_f32); + ggml_set_name(padding_embd_vec, "padding_token_emb_f32"); + + // Reshape to [embd_size, 1] for broadcasting + ggml_tensor * padding_embd_col = ggml_reshape_2d(ctx0, padding_embd_vec, embd_size, 1); + + // Add: zeros [embd_size, n_tokens] + padding [embd_size, 1] = broadcasted padding [embd_size, n_tokens] + ggml_tensor * inp_per_layer_flat = ggml_add(ctx0, zeros_per_layer, padding_embd_col); + ggml_set_name(inp_per_layer_flat, "inp_per_layer_broadcasted"); + + // Reshape to [n_embd_altup, n_layer, n_tokens] for per-layer processing + inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer_flat, n_embd_altup, n_layer, n_tokens); + + // Apply same scaling as text tokens + // inp_per_layer = ggml_scale(ctx0, inp_per_layer, sqrtf((float) n_embd_altup)); + cb(inp_per_layer, "inp_per_layer_vision", -1); } res->add_input(std::move(inp)); return inp_per_layer; diff --git a/tools/mtmd/CMakeLists.txt b/tools/mtmd/CMakeLists.txt index 317d5f19fd9..a74b4bc2154 100644 --- a/tools/mtmd/CMakeLists.txt +++ b/tools/mtmd/CMakeLists.txt @@ -27,6 +27,7 @@ add_library(mtmd models/qwen3vl.cpp models/siglip.cpp models/whisper-enc.cpp + models/mobilenetv5.cpp ) set_target_properties(mtmd PROPERTIES diff --git a/tools/mtmd/clip-graph.h b/tools/mtmd/clip-graph.h index 2b1915779f2..5d8c46862bd 100644 --- a/tools/mtmd/clip-graph.h +++ b/tools/mtmd/clip-graph.h @@ -70,6 +70,38 @@ struct clip_graph { ggml_tensor * build_inp_raw(int channels = 3); + ggml_tensor * rms_norm_2d( + ggml_tensor * inp, + ggml_tensor * weight, + float eps = 1e-6f, + int block_idx=-1); + + ggml_tensor* pad_same_2d( + ggml_tensor* inp, + int kernel_h, + int kernel_w, + int stride_h, + int stride_w, + int dilation_h = 1, + int dilation_w = 1); + + ggml_tensor * build_edge_residual( + ggml_tensor * inp, + const mobilenetv5_block & block, + int stride, + int block_idx = -1); + + ggml_tensor * build_inverted_residual( + ggml_tensor * inp, + const mobilenetv5_block & block, + int stride, + int block_idx = -1); + + ggml_tensor * build_mobilenet_attn( + ggml_tensor * inp, + const mobilenetv5_block & block, + int block_idx = -1); + ggml_tensor * build_norm( ggml_tensor * cur, ggml_tensor * mw, diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index a0939865e3f..24a1ef52d08 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -153,6 +153,47 @@ #define TN_CONV_PW1 "%s.blk.%d.conv_pw1.%s" #define TN_CONV_PW2 "%s.blk.%d.conv_pw2.%s" +// mobilenetv5 (gemma3n) definitions +#define TN_MNV5_STEM_CONV "v.enc.conv_stem.conv.weight" +#define TN_MNV5_STEM_BIAS "v.enc.conv_stem.conv.bias" +#define TN_MNV5_STEM_BN "v.enc.conv_stem.bn.weight" + +// Stage 0 Block (Edge Residual) +#define TN_MNV5_BLK_S0_EXP_W "v.enc.blocks.%d.%d.conv_exp.weight" +#define TN_MNV5_BLK_S0_BN1_W "v.enc.blocks.%d.%d.bn1.weight" +#define TN_MNV5_BLK_S0_PWL_W "v.enc.blocks.%d.%d.conv_pwl.weight" +#define TN_MNV5_BLK_S0_BN2_W "v.enc.blocks.%d.%d.bn2.weight" + +// Stage 1+ Block (Universal Inverted Residual) +#define TN_MNV5_BLK_DW_START_W "v.enc.blocks.%d.%d.dw_start.conv.weight" +#define TN_MNV5_BLK_DW_START_BN "v.enc.blocks.%d.%d.dw_start.bn.weight" +#define TN_MNV5_BLK_DW_MID_W "v.enc.blocks.%d.%d.dw_mid.conv.weight" +#define TN_MNV5_BLK_DW_MID_BN "v.enc.blocks.%d.%d.dw_mid.bn.weight" +#define TN_MNV5_BLK_PW_EXP_W "v.enc.blocks.%d.%d.pw_exp.conv.weight" +#define TN_MNV5_BLK_PW_EXP_BN "v.enc.blocks.%d.%d.pw_exp.bn.weight" +#define TN_MNV5_BLK_PW_PROJ_W "v.enc.blocks.%d.%d.pw_proj.conv.weight" +#define TN_MNV5_BLK_PW_PROJ_BN "v.enc.blocks.%d.%d.pw_proj.bn.weight" +#define TN_MNV5_BLK_LAYER_SCALE "v.enc.blocks.%d.%d.layer_scale.gamma" + +// Attention Components +#define TN_MNV5_ATTN_Q_W "v.enc.blocks.%d.%d.attn.query.proj.weight" +#define TN_MNV5_ATTN_K_W "v.enc.blocks.%d.%d.attn.key.proj.weight" +#define TN_MNV5_ATTN_V_W "v.enc.blocks.%d.%d.attn.value.proj.weight" +#define TN_MNV5_ATTN_O_W "v.enc.blocks.%d.%d.attn.output.proj.weight" +#define TN_MNV5_ATTN_K_DW "v.enc.blocks.%d.%d.attn.key.down_conv.weight" +#define TN_MNV5_ATTN_K_NORM "v.enc.blocks.%d.%d.attn.key.norm.weight" +#define TN_MNV5_ATTN_V_DW "v.enc.blocks.%d.%d.attn.value.down_conv.weight" +#define TN_MNV5_ATTN_V_NORM "v.enc.blocks.%d.%d.attn.value.norm.weight" +#define TN_MNV5_ATTN_NORM "v.enc.blocks.%d.%d.norm.weight" // Block norm used in attn blocks + +// MSFA +#define TN_MNV5_MSFA_FFN_EXP_W "v.enc.msfa.ffn.pw_exp.conv.weight" +#define TN_MNV5_MSFA_FFN_EXP_BN "v.enc.msfa.ffn.pw_exp.bn.weight" +#define TN_MNV5_MSFA_FFN_PROJ_W "v.enc.msfa.ffn.pw_proj.conv.weight" +#define TN_MNV5_MSFA_FFN_PROJ_BN "v.enc.msfa.ffn.pw_proj.bn.weight" +#define TN_MNV5_MSFA_NORM "v.enc.msfa.norm.weight" + + // align x to upper multiple of n #define CLIP_ALIGN(x, n) ((((x) + (n) - 1) / (n)) * (n)) @@ -170,6 +211,7 @@ enum projector_type { PROJECTOR_TYPE_QWEN2VL, PROJECTOR_TYPE_QWEN3VL, PROJECTOR_TYPE_GEMMA3, + PROJECTOR_TYPE_GEMMA3N, PROJECTOR_TYPE_IDEFICS3, PROJECTOR_TYPE_PIXTRAL, PROJECTOR_TYPE_QWEN25VL, @@ -200,6 +242,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_QWEN25VL, "qwen2.5vl_merger"}, { PROJECTOR_TYPE_QWEN3VL, "qwen3vl_merger"}, { PROJECTOR_TYPE_GEMMA3, "gemma3"}, + { PROJECTOR_TYPE_GEMMA3N, "gemma3n"}, { PROJECTOR_TYPE_IDEFICS3, "idefics3"}, { PROJECTOR_TYPE_PIXTRAL, "pixtral"}, { PROJECTOR_TYPE_ULTRAVOX, "ultravox"}, diff --git a/tools/mtmd/clip-model.h b/tools/mtmd/clip-model.h index b4c31cdde6b..e03f455b1b5 100644 --- a/tools/mtmd/clip-model.h +++ b/tools/mtmd/clip-model.h @@ -172,6 +172,45 @@ struct clip_layer { } }; +// Expanded MobileNetV5 block structure for Gemma3n vision encoder +struct mobilenetv5_block { + // Stage 0 (Edge Residual) + ggml_tensor * s0_conv_exp_w = nullptr; + ggml_tensor * s0_bn1_w = nullptr; + ggml_tensor * s0_conv_pwl_w = nullptr; + ggml_tensor * s0_bn2_w = nullptr; + + // Stage 1+ (Universal Inverted Residual) + ggml_tensor * dw_start_w = nullptr; + ggml_tensor * dw_start_bn_w = nullptr; + + ggml_tensor * pw_exp_w = nullptr; + ggml_tensor * pw_exp_bn_w = nullptr; + + ggml_tensor * dw_mid_w = nullptr; + ggml_tensor * dw_mid_bn_w = nullptr; + + ggml_tensor * pw_proj_w = nullptr; + ggml_tensor * pw_proj_bn_w = nullptr; + + ggml_tensor * layer_scale_w = nullptr; + + // Attention (MQA) components + ggml_tensor * attn_q_w = nullptr; + ggml_tensor * attn_k_w = nullptr; + ggml_tensor * attn_v_w = nullptr; + ggml_tensor * attn_o_w = nullptr; + + // Optional downsampling/norm in attention + ggml_tensor * attn_k_dw_w = nullptr; + ggml_tensor * attn_k_norm_w = nullptr; + ggml_tensor * attn_v_dw_w = nullptr; + ggml_tensor * attn_v_norm_w = nullptr; + + // Block norm (often present in attention blocks) + ggml_tensor * attn_norm_w = nullptr; +}; + struct clip_model { clip_modality modality = CLIP_MODALITY_VISION; projector_type proj_type = PROJECTOR_TYPE_MLP; @@ -288,6 +327,23 @@ struct clip_model { ggml_tensor * mm_input_proj_w = nullptr; ggml_tensor * mm_soft_emb_norm_w = nullptr; + // mobilenetv5 for gemma3n + std::vector mobilenet_blocks; + std::vector mobilenet_stage_ends; // NEW: Track end indices of stages + ggml_tensor * mobilenet_stem_conv_w = nullptr; + ggml_tensor * mobilenet_stem_conv_b = nullptr; + ggml_tensor * mobilenet_stem_norm_w = nullptr; + ggml_tensor * mm_post_proj_norm_w = nullptr; + + // Multi-Scale Fusion Adapter (MSFA) components + ggml_tensor * msfa_concat_conv_w = nullptr; // Concatenated feature processing + ggml_tensor * msfa_concat_norm_w = nullptr; + ggml_tensor * msfa_ffn_expand_w = nullptr; // FFN expansion + ggml_tensor * msfa_ffn_project_w = nullptr; // FFN projection + ggml_tensor * msfa_ffn_expand_bn = nullptr; // NEW: FFN expansion batch norm + ggml_tensor * msfa_ffn_project_bn = nullptr; // NEW: FFN projection batch norm + + // pixtral, glm4v ggml_tensor * token_embd_img_break = nullptr; ggml_tensor * mm_patch_merger_w = nullptr; diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 3ba0823defb..4c357aab19e 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -263,6 +263,378 @@ void clip_graph::cb(ggml_tensor * cur0, const char * name, int il) const { } } +// Helper: Normalize over the Channel dimension (dim 2 in [W, H, C, B]) +// RMS Norm 2D - normalizes over channels for each spatial position +// PyTorch: v = torch.mean(x.pow(2), dim=1) - mean over C for each (N,H,W) +// We need to normalize each spatial position across its C channels +ggml_tensor * clip_graph::rms_norm_2d(ggml_tensor * inp, ggml_tensor * weight, float eps, int block_idx) { + // inp: [W, H, C, B] + const int64_t W = inp->ne[0]; + const int64_t H = inp->ne[1]; + const int64_t C = inp->ne[2]; + const int64_t B = inp->ne[3]; + + // Step 1: Permute [W, H, C, B] -> [C, W, H, B] + // Puts Channels in ne[0] (contiguous) + ggml_tensor * cur = ggml_permute(ctx0, inp, 2, 1, 0, 3); + cur = ggml_cont(ctx0, cur); + + // Step 2: Reshape [C, W, H, B] -> [C, W*H*B] + // We now have a 2D matrix where columns are Channels (ne[0]) + // and rows are Spatial/Batch (ne[1]). + // cur = ggml_reshape_2d(ctx0, cur, C, W * H * B); + + // REMOVED Step 3 (Transpose). + // We WANT ne[0] to be C so rms_norm reduces over it. + + // Step 4: Apply RMS Norm + // Normalizes ne[0] (C) for every element in ne[1] (Spatial/Batch). + cur = ggml_rms_norm(ctx0, cur, eps); + + // Step 5: Apply weight if present + if (weight) { + // weight is [C] + // cur is [C, W*H*B] + // ggml_mul broadcasts automatically along higher dims. + // It multiplies element i of weight with element i of cur's ne[0]. + cur = ggml_mul(ctx0, cur, weight); + } + + // REMOVED Step 6 (Transpose back). We never transposed. + + // Step 7: Reshape back to [C, W, H, B] + // cur = ggml_reshape_4d(ctx0, cur, C, W, H, B); + + // Step 8: Permute back to [W, H, C, B] + // ne[0]=C, ne[1]=W, ne[2]=H, ne[3]=B + // We want new ne[0] to be old ne[1] (W) + // We want new ne[1] to be old ne[2] (H) + // We want new ne[2] to be old ne[0] (C) + // We want new ne[3] to be old ne[3] (B) + cur = ggml_permute(ctx0, cur, 2, 1, 0, 3); + + // cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); + + // Note: The second permute in your original code was likely redundant/incorrect + // after the first one. A single permute is sufficient to restore order. + cur = ggml_cont(ctx0, cur); + + return cur; +} + + +// ------------------------------------------------------------------------ +// Helper for Conv2dSame padding (asymmetric SAME padding like PyTorch/TF) +// ------------------------------------------------------------------------ +ggml_tensor* clip_graph::pad_same_2d(ggml_tensor* inp, int kernel_h, int kernel_w, int stride_h, int stride_w, int dilation_h, int dilation_w) { + const int64_t ih = inp->ne[1]; // height + const int64_t iw = inp->ne[0]; // width + + // Calculate output size (ceil division) + const int64_t oh = (ih + stride_h - 1) / stride_h; + const int64_t ow = (iw + stride_w - 1) / stride_w; + + // Calculate padding needed + const int64_t pad_h = std::max((int64_t)0, (oh - 1) * stride_h + (kernel_h - 1) * dilation_h + 1 - ih); + const int64_t pad_w = std::max((int64_t)0, (ow - 1) * stride_w + (kernel_w - 1) * dilation_w + 1 - iw); + + // Split padding asymmetrically + const int pad_h_top = pad_h / 2; + const int pad_h_bottom = pad_h - pad_h_top; + const int pad_w_left = pad_w / 2; + const int pad_w_right = pad_w - pad_w_left; + + // Apply padding if needed + // ggml_pad_ext: (ctx, tensor, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3) + // For [W, H, C, B]: p0=width, p1=height, p2=channels, p3=batch + if (pad_h > 0 || pad_w > 0) { + inp = ggml_pad_ext(ctx0, inp, + pad_w_left, pad_w_right, // width padding (dim 0) + pad_h_top, pad_h_bottom, // height padding (dim 1) + 0, 0, // no channel padding (dim 2) + 0, 0); // no batch padding (dim 3) + } + + return inp; +} + +// ------------------------------------------------------------------------ +// Edge Residual Block (Stage 0) - CORRECTED +// ------------------------------------------------------------------------ +ggml_tensor * clip_graph::build_edge_residual(ggml_tensor * inp, const mobilenetv5_block & block, int stride, int block_idx) { + ggml_tensor * cur = inp; + + // 1. Expansion Conv (3x3) + // -------------------------------------------------------------------- + // LOGIC FIX: + // Block 0 (stride=2): Uses "Conv2dSame". We must manually pad, then conv with pad=0. + // Block 1,2 (stride=1): Uses standard "Conv2d" with padding=(1,1). + // -------------------------------------------------------------------- + + if (stride == 2) { + // Case: Downsampling (Block 0) + // Replicates Conv2dSame(kernel=3, stride=2) + // We calculate asymmetric padding dynamically + cur = pad_same_2d(cur, 3, 3, stride, stride); + + // Perform conv with 0 padding because we just applied it manually + cur = ggml_conv_2d_direct(ctx0, block.s0_conv_exp_w, cur, stride, stride, 0, 0, 1, 1); + } else { + // Case: Normal 3x3 Block (Block 1, 2) + // Replicates Conv2d(kernel=3, stride=1, padding=1) + // Standard symmetric padding of 1 is sufficient for 3x3 s1 to keep dims same + cur = ggml_conv_2d_direct(ctx0, block.s0_conv_exp_w, cur, stride, stride, 1, 1, 1, 1); + } + + // BN + Activation + if (block.s0_bn1_w) cur = rms_norm_2d(cur, block.s0_bn1_w); + cur = ggml_gelu(ctx0, cur); + + // 2. Pointwise Linear Conv (1x1) + // 1x1 Convs usually have padding=0 and stride=1 + cur = ggml_conv_2d_direct(ctx0, block.s0_conv_pwl_w, cur, 1, 1, 0, 0, 1, 1); + if (block.s0_bn2_w) cur = rms_norm_2d(cur, block.s0_bn2_w); + + // 3. Residual Connection + // Only apply residual if spatial dimensions and channels match (stride 1) + if (stride == 1 && inp->ne[2] == cur->ne[2] && inp->ne[0] == cur->ne[0]) { + cur = ggml_add(ctx0, cur, inp); + } + + return cur; +} + +ggml_tensor * clip_graph::build_inverted_residual(ggml_tensor * inp, const mobilenetv5_block & block, int stride, int block_idx) { + ggml_tensor * cur = inp; + + // 1. Depthwise Start (Optional) + // NOTE: dw_start always has stride=1 (no downsampling here) + if (block.dw_start_w) { + int k = block.dw_start_w->ne[0]; // 3 or 5 + int p = k / 2; + // cur = ggml_conv_2d_dw_direct(ctx0, block.dw_start_w, cur, 1, 1, p, p, 1, 1); + cur = ggml_conv_2d_dw(ctx0, block.dw_start_w, cur, 1, 1, p, p, 1, 1); + if (block.dw_start_bn_w) cur = rms_norm_2d(cur, block.dw_start_bn_w); + } + + // 2. Pointwise Expansion (1x1) + if (block.pw_exp_w) { + // Standard 1x1 conv, pad=0, stride=1 + cur = ggml_conv_2d_direct(ctx0, block.pw_exp_w, cur, 1, 1, 0, 0, 1, 1); + if (block.pw_exp_bn_w) cur = rms_norm_2d(cur, block.pw_exp_bn_w); + cur = ggml_gelu(ctx0, cur); + } + + // 3. Depthwise Mid (Optional) + // NOTE: dw_mid is where downsampling happens (stride=2 for first block of stage) + if (block.dw_mid_w) { + int k = block.dw_mid_w->ne[0]; // 3 or 5 + + if (stride > 1) { + // Case: Stride 2 (Downsample) -> Use Asymmetric "Same" Padding + cur = pad_same_2d(cur, k, k, stride, stride); + cur = ggml_conv_2d_dw(ctx0, block.dw_mid_w, cur, stride, stride, 0, 0, 1, 1); // pad=0 + } else { + // Case: Stride 1 -> Use Standard Symmetric Padding + int p = k / 2; + cur = ggml_conv_2d_dw(ctx0, block.dw_mid_w, cur, stride, stride, p, p, 1, 1); + } + + if (block.dw_mid_bn_w) cur = rms_norm_2d(cur, block.dw_mid_bn_w); + cur = ggml_gelu(ctx0, cur); + } + + // 4. Pointwise Projection (1x1) + if (block.pw_proj_w) { + cur = ggml_conv_2d_direct(ctx0, block.pw_proj_w, cur, 1, 1, 0, 0, 1, 1); + if (block.pw_proj_bn_w) cur = rms_norm_2d(cur, block.pw_proj_bn_w); + } + + // Apply Layer Scaling if present + if (block.layer_scale_w) { + ggml_tensor * scale_w_reshaped = ggml_reshape_4d(ctx0, block.layer_scale_w, + 1, 1, block.layer_scale_w->ne[0], 1); + + cur = ggml_mul(ctx0, cur, scale_w_reshaped); + } + + // 5. Residual Connection + bool same_spatial = (inp->ne[0] == cur->ne[0]) && (inp->ne[1] == cur->ne[1]); + bool same_channel = (inp->ne[2] == cur->ne[2]); + if (same_spatial && same_channel) { + // --- FIXED LAYER SCALING --- + // --------------------------- + cur = ggml_add(ctx0, cur, inp); + } + + return cur; +} + +// MobileNetV5 Builder (Gemma 3n) - Attention Block +ggml_tensor * clip_graph::build_mobilenet_attn(ggml_tensor * inp, const mobilenetv5_block & block, int block_idx) { + + // ... [Debug Helpers kept same as original] ... + // auto DEBUG_SHAPE = [&](const char* label, ggml_tensor* t) { /* ... */ }; + // auto REGISTER_DEBUG = [&](const std::string& name, ggml_tensor* t) { /* ... */ }; + + // // Debug input + // if (block_idx == 33 || block_idx == 50 || block_idx == 52) { + // char debug_name[128]; + // snprintf(debug_name, sizeof(debug_name), "block%d_input", block_idx); + // REGISTER_DEBUG(debug_name, inp); + // } + + ggml_tensor * cur = inp; + + // --- Norm --- + if (block.attn_norm_w) { + cur = rms_norm_2d(cur, block.attn_norm_w, 1e-6f, block_idx); + } + + // --- 1. Q Calculation --- + ggml_tensor * q = ggml_conv_2d_direct(ctx0, block.attn_q_w, cur, 1, 1, 0, 0, 1, 1); + + // --- 2. K Calculation (Downsampled) --- + // Uses Conv2dSame(640, 640, kernel_size=(3, 3), stride=(2, 2), groups=640) + ggml_tensor * k_inp = cur; + if (block.attn_k_dw_w) { + int k_size = block.attn_k_dw_w->ne[0]; // Usually 3 + k_inp = pad_same_2d(cur, k_size, k_size, 2, 2); // Apply SAME padding + k_inp = ggml_conv_2d_dw(ctx0, block.attn_k_dw_w, k_inp, 2, 2, 0, 0, 1, 1); // padding=0 + if (block.attn_k_norm_w) { + k_inp = rms_norm_2d(k_inp, block.attn_k_norm_w, 1e-6f, block_idx); + } + } + ggml_tensor * k = ggml_conv_2d_direct(ctx0, block.attn_k_w, k_inp, 1, 1, 0, 0, 1, 1); + + // --- 3. V Calculation (Downsampled) --- + // Uses Conv2dSame(640, 640, kernel_size=(3, 3), stride=(2, 2), groups=640) + ggml_tensor * v_inp = cur; + if (block.attn_v_dw_w) { + int v_size = block.attn_v_dw_w->ne[0]; // Usually 3 + v_inp = pad_same_2d(cur, v_size, v_size, 2, 2); // Apply SAME padding + v_inp = ggml_conv_2d_dw(ctx0, block.attn_v_dw_w, v_inp, 2, 2, 0, 0, 1, 1); // padding=0 + if (block.attn_v_norm_w) { + v_inp = rms_norm_2d(v_inp, block.attn_v_norm_w, 1e-6f, block_idx); + } + } + ggml_tensor * v = ggml_conv_2d_direct(ctx0, block.attn_v_w, v_inp, 1, 1, 0, 0, 1, 1); + + // --- Reshape & Permute Logic --- + + const int W = cur->ne[0]; const int H = cur->ne[1]; const int B = cur->ne[3]; + const int D = k->ne[2]; // Head dimension + const int n_head = q->ne[2] / D; + const int N = W * H; + + // Process Q: [W, H, D*n_head, B] -> [D, N, n_head, B] + q = ggml_reshape_3d(ctx0, q, N, D*n_head, B); + q = ggml_reshape_4d(ctx0, q, N, D, n_head, B); + q = ggml_permute(ctx0, q, 1, 0, 2, 3); // [D, N, n_head, B] + q = ggml_cont(ctx0, q); + + const int Wk = k->ne[0]; const int Hk = k->ne[1]; + const int M = Wk * Hk; + + // Process K: [Wk, Hk, D, B] -> [D, M, 1, B] + k = ggml_reshape_3d(ctx0, k, M, D, B); + k = ggml_reshape_4d(ctx0, k, M, D, 1, B); + k = ggml_permute(ctx0, k, 1, 0, 2, 3); // [D, M, 1, B] + k = ggml_cont(ctx0, k); + + // Process V: [Wk, Hk, D, B] -> [M, D, 1, B] + // NOTE: We keep V as [M, D] because ggml_mul_mat expects src0^T * src1. + // To get output [D, N], we will need [M, D]^T * [M, N]. + v = ggml_reshape_3d(ctx0, v, M, D, B); + v = ggml_reshape_4d(ctx0, v, M, D, 1, B); + v = ggml_cont(ctx0, v); // [M, D, 1, B] + + // --- Multi-Query Attention --- + float scale = 1.0f / sqrtf((float)D); + + // Step 1: Compute Q @ K.T + // Q: [D, N, n_head, B] + // K: [D, M, 1, B] + // ggml_mul_mat computes K^T * Q -> [D, M]^T * [D, N] -> [M, D] * [D, N] -> [M, N] + // Implicit Broadcast: K has 1 head, Q has n_head. ggml handles this automatically. + ggml_tensor * scores = ggml_mul_mat(ctx0, k, q); // Result: [M, N, n_head, B] (in ggml layout) + + // // Debug scores + // if (block_idx == 33) { + // char debug_name[128]; + // snprintf(debug_name, sizeof(debug_name), "block%d_scores_raw", block_idx); + // REGISTER_DEBUG(debug_name, scores); + // } + + scores = ggml_scale(ctx0, scores, scale); + + // Step 2: Softmax + // scores is [M, N, n_head, B] (ne0=M, ne1=N) + // We need softmax over M (keys). + // ggml_soft_max applies to dim 0, which is M. Perfect - no permute needed! + scores = ggml_soft_max(ctx0, scores); + + // Step 3: Compute Attn @ V + // V: [M, D, 1, B] (ne0=M, ne1=D) + // Scores: [M, N, n_head, B] (ne0=M, ne1=N) + // + // ggml_mul_mat computes V^T * Scores -> [M, D]^T * [M, N] -> [D, M] * [M, N] -> [D, N] + // Implicit Broadcast: V has 1 head, Scores has n_head. ggml handles this automatically. + ggml_tensor * kqv = ggml_mul_mat(ctx0, v, scores); // Result: [N, D, n_head, B] + + // // Debug kqv + // if (block_idx == 33) { + // char debug_name[128]; + // snprintf(debug_name, sizeof(debug_name), "block%d_kqv_out", block_idx); + // REGISTER_DEBUG(debug_name, kqv); + // } + + // --- Reshape back to spatial layout --- + // kqv is [N, D, n_head, B]. We want [D, N, n_head, B] to merge heads. + kqv = ggml_permute(ctx0, kqv, 1, 0, 2, 3); // [D, N, n_head, B] + kqv = ggml_cont(ctx0, kqv); + + // Reshape to [N, D*n_head, B] then [W, H, C, B] + kqv = ggml_reshape_3d(ctx0, kqv, N, D * n_head, B); + kqv = ggml_reshape_4d(ctx0, kqv, W, H, D * n_head, B); + kqv = ggml_cont(ctx0, kqv); + +// Output projection + cur = ggml_conv_2d_direct(ctx0, block.attn_o_w, kqv, 1, 1, 0, 0, 1, 1); + + // --- Residual & Layer Scale (FIXED) --- + if (inp->ne[0] == cur->ne[0] && inp->ne[2] == cur->ne[2]) { + if (block.layer_scale_w) { + // FIX: Simplified Layer Scale. No permute needed. + // Tensor is [W, H, C, B]. Weight is [C]. + // We reshape Weight to [1, 1, C, 1]. + // GGML will broadcast W and H dimensions automatically. + + // Debug print shape of block.layer_scale_w + // fprintf(stderr, "DEBUG: block %d layer_scale_w shape: [%ld x %ld x %ld x %ld]\n", block_idx, block.layer_scale_w->ne[0], block.layer_scale_w->ne[1], block.layer_scale_w->ne[2], block.layer_scale_w->ne[3]); + + // Debug print shape of cur before scaling + // fprintf(stderr, "DEBUG: block %d cur shape before scaling: [%ld x %ld x %ld x %ld]\n", block_idx, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); + + + ggml_tensor * scale_w_reshaped = ggml_reshape_4d(ctx0, block.layer_scale_w, + 1, 1, block.layer_scale_w->ne[0], 1); + + // Debug print shape of scale_w_reshaped + // fprintf(stderr, "DEBUG: block %d scale_w_reshaped shape: [%ld x %ld x %ld x %ld]\n", block_idx, scale_w_reshaped->ne[0], scale_w_reshaped->ne[1], scale_w_reshaped->ne[2], scale_w_reshaped->ne[3]); + + cur = ggml_mul(ctx0, cur, scale_w_reshaped); + } + + // Residual Addition + // 'cur' is the pointer to the graph node of the attention output. + // 'inp' is the pointer to the graph node of the block input. + cur = ggml_add(ctx0, cur, inp); + } + + return cur; +} + // siglip2 naflex ggml_tensor * clip_graph::resize_position_embeddings(uint32_t interpolation_mode) { ggml_tensor * pos_embd = model.position_embeddings; @@ -788,6 +1160,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 { builder = std::make_unique(ctx, img); } break; + case PROJECTOR_TYPE_GEMMA3N: + { + builder = std::make_unique(ctx, img); + } break; case PROJECTOR_TYPE_PIXTRAL: case PROJECTOR_TYPE_LIGHTONOCR: { @@ -1141,6 +1517,14 @@ struct clip_model_loader { // test model (tinygemma3) has a different value, we optionally read it get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false); } break; + + case PROJECTOR_TYPE_GEMMA3N: + { + // Gemma3n uses MobileNetV5 which produces 256 tokens (16x16) + // Similar configuration to Gemma3 + hparams.n_merge = 1; // MobileNetV5 handles resizing internally + get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false); + } break; case PROJECTOR_TYPE_QWEN2VL: case PROJECTOR_TYPE_QWEN25VL: case PROJECTOR_TYPE_QWEN3VL: @@ -1381,6 +1765,7 @@ struct clip_model_loader { } } + switch (model.proj_type) { case PROJECTOR_TYPE_MLP: case PROJECTOR_TYPE_MLP_NORM: @@ -1512,6 +1897,106 @@ struct clip_model_loader { model.mm_input_proj_w = get_tensor(TN_MM_INP_PROJ); model.mm_soft_emb_norm_w = get_tensor(TN_MM_SOFT_EMB_N); } break; + case PROJECTOR_TYPE_GEMMA3N: + { + model.mobilenet_stem_conv_w = get_tensor(TN_MNV5_STEM_CONV, false); + model.mobilenet_stem_conv_b = get_tensor(TN_MNV5_STEM_BIAS, false); + model.mobilenet_stem_norm_w = get_tensor(TN_MNV5_STEM_BN, false); + + model.msfa_ffn_expand_w = get_tensor(TN_MNV5_MSFA_FFN_EXP_W, false); + model.msfa_ffn_expand_bn = get_tensor(TN_MNV5_MSFA_FFN_EXP_BN, false); // Consume BN if present but likely folded + model.msfa_ffn_project_w = get_tensor(TN_MNV5_MSFA_FFN_PROJ_W, false); + model.msfa_ffn_project_bn = get_tensor(TN_MNV5_MSFA_FFN_PROJ_BN, false); + + // IMPORTANT: Your GGUF log shows 'v.enc.msfa.norm.weight' -> shape {2048} + // Ensure TN_MNV5_MSFA_NORM matches this string + model.msfa_concat_norm_w = get_tensor(TN_MNV5_MSFA_NORM, false); + + // Dynamically load blocks stage by stage + for (int stage = 0; stage < 4; ++stage) { + int blocks_found_in_stage = 0; + + for (int blk_idx = 0; ; ++blk_idx) { + bool found_block = false; + mobilenetv5_block block; + + // 1. Check for Edge Residual (S0) + block.s0_conv_exp_w = get_tensor(string_format(TN_MNV5_BLK_S0_EXP_W, stage, blk_idx), false); + if (block.s0_conv_exp_w) { + found_block = true; + block.s0_bn1_w = get_tensor(string_format(TN_MNV5_BLK_S0_BN1_W, stage, blk_idx), false); + block.s0_conv_pwl_w = get_tensor(string_format(TN_MNV5_BLK_S0_PWL_W, stage, blk_idx), false); + block.s0_bn2_w = get_tensor(string_format(TN_MNV5_BLK_S0_BN2_W, stage, blk_idx), false); + } + // 2. Check for UIR (Universal Inverted Residual) + else { + // Check for dw_start OR pw_exp (some UIR blocks skip dw_start) + block.dw_start_w = get_tensor(string_format(TN_MNV5_BLK_DW_START_W, stage, blk_idx), false); + block.pw_exp_w = get_tensor(string_format(TN_MNV5_BLK_PW_EXP_W, stage, blk_idx), false); + + if (block.dw_start_w || block.pw_exp_w) { + found_block = true; + if (block.dw_start_w) { + block.dw_start_bn_w = get_tensor(string_format(TN_MNV5_BLK_DW_START_BN, stage, blk_idx), false); + } + if (block.pw_exp_w) { + block.pw_exp_bn_w = get_tensor(string_format(TN_MNV5_BLK_PW_EXP_BN, stage, blk_idx), false); + } + block.dw_mid_w = get_tensor(string_format(TN_MNV5_BLK_DW_MID_W, stage, blk_idx), false); + if (block.dw_mid_w) { + block.dw_mid_bn_w = get_tensor(string_format(TN_MNV5_BLK_DW_MID_BN, stage, blk_idx), false); + } + block.pw_proj_w = get_tensor(string_format(TN_MNV5_BLK_PW_PROJ_W, stage, blk_idx), false); + if (block.pw_proj_w) { + block.pw_proj_bn_w = get_tensor(string_format(TN_MNV5_BLK_PW_PROJ_BN, stage, blk_idx), false); + } + block.layer_scale_w = get_tensor(string_format(TN_MNV5_BLK_LAYER_SCALE, stage, blk_idx), false); + } + } + + // 3. Check for Attention (MQA) + // Even if UIR/Edge check failed, this might be a pure attention block + ggml_tensor* attn_q_check = get_tensor(string_format(TN_MNV5_ATTN_Q_W, stage, blk_idx), false); + if (attn_q_check) { + found_block = true; + block.attn_q_w = attn_q_check; + block.attn_k_w = get_tensor(string_format(TN_MNV5_ATTN_K_W, stage, blk_idx), false); + block.attn_v_w = get_tensor(string_format(TN_MNV5_ATTN_V_W, stage, blk_idx), false); + block.attn_o_w = get_tensor(string_format(TN_MNV5_ATTN_O_W, stage, blk_idx), false); + block.attn_k_dw_w = get_tensor(string_format(TN_MNV5_ATTN_K_DW, stage, blk_idx), false); + block.attn_k_norm_w = get_tensor(string_format(TN_MNV5_ATTN_K_NORM, stage, blk_idx), false); + block.attn_v_dw_w = get_tensor(string_format(TN_MNV5_ATTN_V_DW, stage, blk_idx), false); + block.attn_v_norm_w = get_tensor(string_format(TN_MNV5_ATTN_V_NORM, stage, blk_idx), false); + block.attn_norm_w = get_tensor(string_format(TN_MNV5_ATTN_NORM, stage, blk_idx), false); + // Note: Attention blocks also have layer_scale, load it if not already loaded by UIR check + if (!block.layer_scale_w) { + block.layer_scale_w = get_tensor(string_format(TN_MNV5_BLK_LAYER_SCALE, stage, blk_idx), false); + } + } + + if (found_block) { + model.mobilenet_blocks.push_back(block); + blocks_found_in_stage++; + } else { + // End of blocks for this stage + break; + } + } + + // Track where this stage ends in the flat vector + if (blocks_found_in_stage > 0) { + model.mobilenet_stage_ends.push_back(model.mobilenet_blocks.size() - 1); + LOG_INF("%s: Stage %d ended at global block index %zu\n", __func__, stage, model.mobilenet_blocks.size() - 1); + } + } + // Load projection weights (similar to Gemma3) + model.mm_input_proj_w = get_tensor(TN_MM_INP_PROJ); + model.mm_soft_emb_norm_w = get_tensor(TN_MM_SOFT_EMB_N); + // model.mm_post_proj_norm_w = get_tensor(TN_MM_POST_PROJ_N); // CRITICAL: Post projection norm + // Load additional Gemma3n projection tensors + model.mm_0_w = get_tensor("mm.embedding.weight", false); // Input embedding + model.mm_1_w = get_tensor("mm.hard_emb_norm.weight", false); // Hard embedding norm + } break; case PROJECTOR_TYPE_IDEFICS3: { model.projection = get_tensor(TN_MM_PROJECTOR); @@ -2052,6 +2537,18 @@ void clip_build_img_from_pixels(const unsigned char * rgb_pixels, int nx, int ny memcpy(img->buf.data(), rgb_pixels, img->buf.size()); } +// Rescale image from u8 to f32 without normalization (for models like GEMMA3N that use SiglipImageProcessorFast) +// This only converts from [0, 255] to [0.0, 1.0] range without applying mean/std normalization +static void rescale_image_u8_to_f32(const clip_image_u8 & src, clip_image_f32 & dst) { + dst.nx = src.nx; + dst.ny = src.ny; + dst.buf.resize(src.buf.size()); + + for (size_t i = 0; i < src.buf.size(); ++i) { + dst.buf[i] = static_cast(src.buf[i]) / 255.0f; + } +} + // Normalize image to float32 - careful with pytorch .to(model.device, dtype=torch.float16) - this sometimes reduces precision (32>16>32), sometimes not static void normalize_image_u8_to_f32(const clip_image_u8 & src, clip_image_f32 & dst, const float mean[3], const float std[3]) { dst.nx = src.nx; @@ -2747,6 +3244,18 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str res_imgs->entries.push_back(std::move(img_f32)); } break; + case PROJECTOR_TYPE_GEMMA3N: + { + // GEMMA3N uses SiglipImageProcessorFast which only rescales to [0.0, 1.0] without normalization + // Resize to 768x768 using bilinear interpolation, then rescale to f32 + clip_image_u8 resized_image; + int sz = params.image_size; + img_tool::resize(*img, resized_image, {sz, sz}, img_tool::RESIZE_ALGO_BILINEAR, false); + clip_image_f32_ptr img_f32(clip_image_f32_init()); + rescale_image_u8_to_f32(resized_image, *img_f32); + res_imgs->entries.push_back(std::move(img_f32)); + } break; + case PROJECTOR_TYPE_JANUS_PRO: { // Janus Pro preprocessing: pad to square with gray(127), resize to 384x384 @@ -3006,6 +3515,12 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im int scale_factor = ctx->model.hparams.n_merge; n_patches /= (scale_factor * scale_factor); } break; + case PROJECTOR_TYPE_GEMMA3N: + { + // MobileNetV5 MSFA adapter always outputs fixed 16x16 resolution + // regardless of input size (see architecture description) + n_patches = 16 * 16; // 256 tokens + } break; case PROJECTOR_TYPE_LFM2: case PROJECTOR_TYPE_KIMIVL: { @@ -3396,6 +3911,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima set_input_i32("patches", patches); } break; case PROJECTOR_TYPE_GEMMA3: + case PROJECTOR_TYPE_GEMMA3N: case PROJECTOR_TYPE_IDEFICS3: case PROJECTOR_TYPE_INTERNVL: case PROJECTOR_TYPE_QWEN2A: @@ -3521,6 +4037,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { // main path + deepstack paths return ctx->model.mm_1_b->ne[0] * (1 + ctx->model.n_deepstack_layers); case PROJECTOR_TYPE_GEMMA3: + case PROJECTOR_TYPE_GEMMA3N: return ctx->model.mm_input_proj_w->ne[0]; case PROJECTOR_TYPE_IDEFICS3: return ctx->model.projection->ne[1]; @@ -3575,6 +4092,10 @@ bool clip_is_gemma3(const struct clip_ctx * ctx) { return ctx->proj_type() == PROJECTOR_TYPE_GEMMA3; } +bool clip_is_gemma3n(const struct clip_ctx * ctx) { + return ctx->proj_type() == PROJECTOR_TYPE_GEMMA3N; +} + bool clip_has_vision_encoder(const struct clip_ctx * ctx) { return ctx->model.modality == CLIP_MODALITY_VISION; } diff --git a/tools/mtmd/clip.h b/tools/mtmd/clip.h index 68a0d6e857e..c244df2677f 100644 --- a/tools/mtmd/clip.h +++ b/tools/mtmd/clip.h @@ -107,6 +107,7 @@ bool clip_is_glm(const struct clip_ctx * ctx); bool clip_is_mrope(const struct clip_ctx * ctx); bool clip_is_llava(const struct clip_ctx * ctx); bool clip_is_gemma3(const struct clip_ctx * ctx); +bool clip_is_gemma3n(const struct clip_ctx * ctx); bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec); diff --git a/tools/mtmd/models/mobilenetv5.cpp b/tools/mtmd/models/mobilenetv5.cpp new file mode 100644 index 00000000000..9946ca6afa8 --- /dev/null +++ b/tools/mtmd/models/mobilenetv5.cpp @@ -0,0 +1,247 @@ +#include "models.h" + +ggml_cgraph * clip_graph_mobilenetv5::build() { + + fprintf(stderr, "\n--- START build_mobilenetv5 ---\n"); + + ggml_tensor * inp = build_inp_raw(); + + // 1. Stem - Conv2dSame(3, 64, kernel_size=(3, 3), stride=(2, 2)) + ggml_tensor * cur = pad_same_2d(inp, 3, 3, 2, 2); // Apply SAME padding + + // ggml_tensor * mobilenet_stem_conv_w_fixed = fix_1x1_weight(model.mobilenet_stem_conv_w); + + cur = ggml_conv_2d_direct(ctx0, model.mobilenet_stem_conv_w, cur, 2, 2, 0, 0, 1, 1); // padding=0 + if (model.mobilenet_stem_conv_b) { + // Bias is [C, 1, 1, 1], need to reshape to [1, 1, C, 1] for broadcasting to [W, H, C, B] + ggml_tensor * bias = ggml_reshape_4d(ctx0, model.mobilenet_stem_conv_b, 1, 1, cur->ne[2], 1); + cur = ggml_add(ctx0, cur, bias); + } + if (model.mobilenet_stem_norm_w) cur = rms_norm_2d(cur, model.mobilenet_stem_norm_w); + cur = ggml_gelu(ctx0, cur); + + + // 2. Blocks + std::vector intermediate_features; + const int total_blocks = model.mobilenet_blocks.size(); + + auto is_stage_start = [&](int i) { + if (i == 0) return true; + for (int end_idx : model.mobilenet_stage_ends) { + if (i == end_idx + 1) return true; + } + return false; + }; + + auto is_fusion_point = [&](int i) { + if (model.mobilenet_stage_ends.size() >= 4) { + if (i == model.mobilenet_stage_ends[2]) return true; // End of Stage 2 + if (i == model.mobilenet_stage_ends[3]) return true; // End of Stage 3 + } else { + if (i == total_blocks - 1) return true; + } + return false; + }; + + for (int i = 0; i < total_blocks; i++) { + const auto & block = model.mobilenet_blocks[i]; + int stride = is_stage_start(i) ? 2 : 1; + + // Debug block type + const char* block_type = block.s0_conv_exp_w ? "edge_residual" : + block.attn_q_w ? "attention" : "inverted_residual"; + + // // Debug input for problematic blocks + // if (i >= 50 && i <= 54) { + // fprintf(stderr, "DEBUG: Block %d (%s) input shape: [%ld, %ld, %ld, %ld], stride=%d\n", + // i, block_type, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3], stride); + // } + + if (block.s0_conv_exp_w) cur = build_edge_residual(cur, block, stride, i); + else if (block.attn_q_w) cur = build_mobilenet_attn(cur, block, i); + else cur = build_inverted_residual(cur, block, stride, i); + + // Register block output for debugging + char block_name[64]; + + if (is_fusion_point(i)) { + + intermediate_features.push_back(cur); + } + } + + // 3. Multi-Scale Fusion Adapter (MSFA) - REPLICATED & FIXED + if (!intermediate_features.empty()) { + + // A. Reference Resolution: PyTorch implementation uses inputs[0] + // We assume intermediate_features[0] is the "High Resolution" target. + // In MobileNet designs, this is typically the feature map with the smallest stride (e.g. 32x32). + ggml_tensor* target_feat = intermediate_features[0]; + int high_res_w = target_feat->ne[0]; + int high_res_h = target_feat->ne[1]; + + std::vector resized_feats; + + // B. Resize inputs to match inputs[0] (High Resolution) + for (auto feat : intermediate_features) { + int feat_w = feat->ne[0]; + int feat_h = feat->ne[1]; + + // PyTorch: if feat_size < high_resolution: interpolate + if (feat_w < high_res_w || feat_h < high_res_h) { + // Calculate scale factor. + // Note: PyTorch 'nearest' works on arbitrary float scales. + // ggml_upscale generally takes integer factors or target sizes depending on helper. + // Assuming standard power-of-2 scaling (e.g. 16 -> 32 means scale=2). + int scale_w = high_res_w / feat_w; + int scale_h = high_res_h / feat_h; + + // Safety check for non-integer scaling if strictly replicating + if (high_res_w % feat_w != 0) { + fprintf(stderr, "Warning: Non-integer scaling detected in MSFA\n"); + } + + // Upsample (Nearest Neighbor) + // 2 is the scale factor + feat = ggml_upscale(ctx0, feat, scale_w, ggml_scale_mode::GGML_SCALE_MODE_NEAREST); + } + resized_feats.push_back(feat); + } + + // C. Concatenate at High Resolution (Channel Dim = 2 in ggml) + cur = resized_feats[0]; + for (size_t k = 1; k < resized_feats.size(); ++k) { + cur = ggml_concat(ctx0, cur, resized_feats[k], 2); + } + + // D. FFN (UniversalInvertedResidual) + // Structure: Expand Conv -> Norm -> GELU -> Project Conv -> Norm + + // 1. Expansion + if (model.msfa_ffn_expand_w) { + // 1x1 Conv + cur = ggml_conv_2d_direct(ctx0, model.msfa_ffn_expand_w, cur, 1, 1, 0, 0, 1, 1); + + // MISSING IN YOUR CODE: Expansion Norm + if (model.msfa_ffn_expand_bn) { + cur = rms_norm_2d(cur, model.msfa_ffn_expand_bn); // Helper to apply RMSNorm + } + + cur = ggml_gelu(ctx0, cur); + + } + + // 2. Projection (No DW because kernel_size=0) + if (model.msfa_ffn_project_w) { + // 1x1 Conv + cur = ggml_conv_2d_direct(ctx0, model.msfa_ffn_project_w, cur, 1, 1, 0, 0, 1, 1); + + // MISSING IN YOUR CODE: Projection Norm + // UniversalInvertedResidual typically has a norm after projection + if (model.msfa_ffn_project_bn) { + cur = rms_norm_2d(cur, model.msfa_ffn_project_bn); + } + + } + + // E. Final Downsample to Target Resolution (Output Resolution) + // PyTorch: matches self.output_resolution (e.g. 16x16) + const int target_out_res = 16; + int current_w = cur->ne[0]; + + if (current_w > target_out_res) { + int s = current_w / target_out_res; + + // PyTorch Logic: + // if divisible: avg_pool + // if not divisible: bilinear interpolate (hard to do in pure ggml, usually assumed divisible here) + + if (current_w % target_out_res == 0) { + // Avg Pool: Kernel=s, Stride=s + cur = ggml_pool_2d(ctx0, cur, GGML_OP_POOL_AVG, s, s, s, s, 0, 0); + } else { + // Fallback or Error: ggml doesn't easily support bilinear downsampling + // without custom ops, but standard models usually stick to integer strides. + fprintf(stderr, "Error: Irregular downsampling stride required.\n"); + } + + } + + // F. Final Norm + if (model.msfa_concat_norm_w) { + cur = rms_norm_2d(cur, model.msfa_concat_norm_w); + + } + } + + // 4. Gemma 3n Multimodal Projection (Embedder) - FULL FIX + // Input: 'cur' is [Width, Height, Channels, Batch] + int W = cur->ne[0]; + int H = cur->ne[1]; + int C = cur->ne[2]; // Should be 2048 + int B = cur->ne[3]; + + // 1. Permute and Flatten to [Channels, Tokens, Batch] + // PyTorch expects (Batch, Seq, Hidden), GGML usually processes (Hidden, Seq, Batch) + cur = ggml_permute(ctx0, cur, 2, 0, 1, 3); // -> [C, W, H, B] + cur = ggml_cont(ctx0, cur); + cur = ggml_reshape_3d(ctx0, cur, C, W*H, B); + cur = ggml_cont(ctx0, cur); + + + // 2. FEATURE SCALING (Missing in your original code) + // PyTorch: vision_outputs *= self.config.vision_config.hidden_size**0.5 + // This prevents the signal from vanishing during the subsequent RMSNorm. + const float scale_factor = sqrtf((float)C); + cur = ggml_scale(ctx0, cur, scale_factor); + + + // 3. SOFT EMBEDDING NORM + // PyTorch: self._norm(x) * self.weight + // We must normalize regardless, then multiply if weight exists. + { + const float eps = 1e-6f; // Gemma3n uses 1e-6 + cur = ggml_rms_norm(ctx0, cur, eps); + + if (model.mm_soft_emb_norm_w) { + // Weight shape is (2048,) -> Element-wise broadcast multiply + cur = ggml_mul(ctx0, cur, model.mm_soft_emb_norm_w); + } + + } + + // 4. PROJECTION + // PyTorch: embedding_projection = nn.Linear(vision_hidden, text_hidden, bias=False) + // Weight stored as [out_features, in_features] = [text_hidden_size, vision_hidden_size] + // Need to transpose for ggml_mul_mat which computes A^T * B + // This matches Gemma3's projection at line ~1319 which also transposes + if (model.mm_input_proj_w) { + // cur = ggml_mul_mat(ctx0, model.mm_input_proj_w, cur); + cur = ggml_mul_mat(ctx0, + ggml_cont(ctx0, ggml_transpose(ctx0, model.mm_input_proj_w)), + cur); + + } + + // 5. POST PROJECTION NORM + // PyTorch: embedding_post_projection_norm = Gemma3nRMSNorm(..., with_scale=False) + // with_scale=False means weight is registered as buffer with value 1.0 + // So output = rms_norm(x) * 1.0 = rms_norm(x), magnitude ~1 + // NOTE: Vision embeddings intentionally have magnitude ~1, different from + // text embeddings at ~sqrt(n_embd). The model was trained with this mismatch. + { + const float eps = 1e-6f; + cur = ggml_rms_norm(ctx0, cur, eps); + + if (model.mm_post_proj_norm_w) { + // If weight is loaded, multiply (should be ~1.0 anyway) + cur = ggml_mul(ctx0, cur, model.mm_post_proj_norm_w); + } + } + + + // cur = ggml_scale(ctx0, cur, scale_factor); + + ggml_build_forward_expand(gf, cur); + return gf; +} \ No newline at end of file diff --git a/tools/mtmd/models/models.h b/tools/mtmd/models/models.h index 8d6d4ef67be..3875285fe92 100644 --- a/tools/mtmd/models/models.h +++ b/tools/mtmd/models/models.h @@ -66,3 +66,8 @@ struct clip_graph_glm4v : clip_graph { clip_graph_glm4v(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} ggml_cgraph * build() override; }; + +struct clip_graph_mobilenetv5 : clip_graph { + clip_graph_mobilenetv5(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} + ggml_cgraph * build() override; +}; diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index b9c4fa90980..2d970cf45c2 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -266,7 +266,7 @@ struct mtmd_context { } // set boi/eoi - if (proj == PROJECTOR_TYPE_GEMMA3) { + if (proj == PROJECTOR_TYPE_GEMMA3 || proj == PROJECTOR_TYPE_GEMMA3N) { // ... (image embeddings) ... img_beg = ""; img_end = ""; @@ -858,7 +858,8 @@ float * mtmd_get_output_embd(mtmd_context * ctx) { } bool mtmd_decode_use_non_causal(mtmd_context * ctx) { - if (ctx->ctx_v && clip_get_projector_type(ctx->ctx_v) == PROJECTOR_TYPE_GEMMA3) { + if (ctx->ctx_v && + (clip_get_projector_type(ctx->ctx_v) == PROJECTOR_TYPE_GEMMA3 || clip_get_projector_type(ctx->ctx_v) == PROJECTOR_TYPE_GEMMA3N)) { return true; } return false; From f57705478749497f561d793e8fb7b2e0a2712b8f Mon Sep 17 00:00:00 2001 From: Simranjeet Singh Date: Sat, 20 Dec 2025 20:46:21 +0000 Subject: [PATCH 3/8] Fix comments, remove unused vars --- src/models/gemma3n-iswa.cpp | 11 +-- tools/mtmd/clip-graph.h | 12 +-- tools/mtmd/clip-model.h | 12 +-- tools/mtmd/clip.cpp | 157 ++++-------------------------- tools/mtmd/models/mobilenetv5.cpp | 44 ++------- 5 files changed, 36 insertions(+), 200 deletions(-) diff --git a/src/models/gemma3n-iswa.cpp b/src/models/gemma3n-iswa.cpp index 7a6a446eb20..e172b9a79f8 100644 --- a/src/models/gemma3n-iswa.cpp +++ b/src/models/gemma3n-iswa.cpp @@ -260,7 +260,7 @@ ggml_tensor * llm_build_gemma3n_iswa::get_per_layer_inputs() { cb(inp_per_layer, "inp_per_layer_selected", -1); } else { // For embedding inputs (e.g., from vision encoder) - // CRITICAL FIX: Vision tokens should use the padding token (ID=0) embedding + // Vision tokens should use the padding token (ID=0) embedding // from tok_embd_per_layer, NOT project the vision embeddings. // The projection happens later in project_per_layer_inputs(). // This matches PyTorch behavior: @@ -270,15 +270,6 @@ ggml_tensor * llm_build_gemma3n_iswa::get_per_layer_inputs() { inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens); ggml_set_input(inp->embd); - // For vision, we need per_layer_inputs from padding token (ID=0) - // We CANNOT use inp->tokens because batch allows EITHER tokens OR embeddings - // - // The challenge: We need to broadcast padding token embedding from [embd_size, 1] to [embd_size, n_tokens] - // but ggml_repeat+ggml_dup doesn't work in no_alloc mode (creates views without backing memory). - // - // Solution: Use ggml_add to broadcast! GGML automatically broadcasts along compatible dimensions. - // We create zeros of shape [embd_size, n_tokens], then add padding_emb [embd_size, 1] which broadcasts. - // tok_embd_per_layer shape: [embd_size, vocab_size] where embd_size = n_embd_altup * n_layer const int64_t embd_size = model.tok_embd_per_layer->ne[0]; // n_embd_altup * n_layer diff --git a/tools/mtmd/clip-graph.h b/tools/mtmd/clip-graph.h index 5d8c46862bd..6a9efb933e5 100644 --- a/tools/mtmd/clip-graph.h +++ b/tools/mtmd/clip-graph.h @@ -73,8 +73,7 @@ struct clip_graph { ggml_tensor * rms_norm_2d( ggml_tensor * inp, ggml_tensor * weight, - float eps = 1e-6f, - int block_idx=-1); + float eps = 1e-6f); ggml_tensor* pad_same_2d( ggml_tensor* inp, @@ -88,19 +87,16 @@ struct clip_graph { ggml_tensor * build_edge_residual( ggml_tensor * inp, const mobilenetv5_block & block, - int stride, - int block_idx = -1); + int stride); ggml_tensor * build_inverted_residual( ggml_tensor * inp, const mobilenetv5_block & block, - int stride, - int block_idx = -1); + int stride); ggml_tensor * build_mobilenet_attn( ggml_tensor * inp, - const mobilenetv5_block & block, - int block_idx = -1); + const mobilenetv5_block & block); ggml_tensor * build_norm( ggml_tensor * cur, diff --git a/tools/mtmd/clip-model.h b/tools/mtmd/clip-model.h index e03f455b1b5..be168b97ef2 100644 --- a/tools/mtmd/clip-model.h +++ b/tools/mtmd/clip-model.h @@ -329,19 +329,19 @@ struct clip_model { // mobilenetv5 for gemma3n std::vector mobilenet_blocks; - std::vector mobilenet_stage_ends; // NEW: Track end indices of stages + std::vector mobilenet_stage_ends; ggml_tensor * mobilenet_stem_conv_w = nullptr; ggml_tensor * mobilenet_stem_conv_b = nullptr; ggml_tensor * mobilenet_stem_norm_w = nullptr; ggml_tensor * mm_post_proj_norm_w = nullptr; // Multi-Scale Fusion Adapter (MSFA) components - ggml_tensor * msfa_concat_conv_w = nullptr; // Concatenated feature processing + ggml_tensor * msfa_concat_conv_w = nullptr; ggml_tensor * msfa_concat_norm_w = nullptr; - ggml_tensor * msfa_ffn_expand_w = nullptr; // FFN expansion - ggml_tensor * msfa_ffn_project_w = nullptr; // FFN projection - ggml_tensor * msfa_ffn_expand_bn = nullptr; // NEW: FFN expansion batch norm - ggml_tensor * msfa_ffn_project_bn = nullptr; // NEW: FFN projection batch norm + ggml_tensor * msfa_ffn_expand_w = nullptr; + ggml_tensor * msfa_ffn_project_w = nullptr; + ggml_tensor * msfa_ffn_expand_bn = nullptr; + ggml_tensor * msfa_ffn_project_bn = nullptr; // pixtral, glm4v diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 4c357aab19e..9e4519c502b 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -263,69 +263,26 @@ void clip_graph::cb(ggml_tensor * cur0, const char * name, int il) const { } } -// Helper: Normalize over the Channel dimension (dim 2 in [W, H, C, B]) +// --- Helpers for MobileNetV5 Blocks --- // RMS Norm 2D - normalizes over channels for each spatial position -// PyTorch: v = torch.mean(x.pow(2), dim=1) - mean over C for each (N,H,W) -// We need to normalize each spatial position across its C channels -ggml_tensor * clip_graph::rms_norm_2d(ggml_tensor * inp, ggml_tensor * weight, float eps, int block_idx) { +ggml_tensor * clip_graph::rms_norm_2d(ggml_tensor * inp, ggml_tensor * weight, float eps) { // inp: [W, H, C, B] - const int64_t W = inp->ne[0]; - const int64_t H = inp->ne[1]; - const int64_t C = inp->ne[2]; - const int64_t B = inp->ne[3]; - // Step 1: Permute [W, H, C, B] -> [C, W, H, B] - // Puts Channels in ne[0] (contiguous) ggml_tensor * cur = ggml_permute(ctx0, inp, 2, 1, 0, 3); cur = ggml_cont(ctx0, cur); - - // Step 2: Reshape [C, W, H, B] -> [C, W*H*B] - // We now have a 2D matrix where columns are Channels (ne[0]) - // and rows are Spatial/Batch (ne[1]). - // cur = ggml_reshape_2d(ctx0, cur, C, W * H * B); - - // REMOVED Step 3 (Transpose). - // We WANT ne[0] to be C so rms_norm reduces over it. - - // Step 4: Apply RMS Norm - // Normalizes ne[0] (C) for every element in ne[1] (Spatial/Batch). cur = ggml_rms_norm(ctx0, cur, eps); - - // Step 5: Apply weight if present + if (weight) { - // weight is [C] - // cur is [C, W*H*B] - // ggml_mul broadcasts automatically along higher dims. - // It multiplies element i of weight with element i of cur's ne[0]. cur = ggml_mul(ctx0, cur, weight); } - // REMOVED Step 6 (Transpose back). We never transposed. - - // Step 7: Reshape back to [C, W, H, B] - // cur = ggml_reshape_4d(ctx0, cur, C, W, H, B); - - // Step 8: Permute back to [W, H, C, B] - // ne[0]=C, ne[1]=W, ne[2]=H, ne[3]=B - // We want new ne[0] to be old ne[1] (W) - // We want new ne[1] to be old ne[2] (H) - // We want new ne[2] to be old ne[0] (C) - // We want new ne[3] to be old ne[3] (B) cur = ggml_permute(ctx0, cur, 2, 1, 0, 3); - - // cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); - - // Note: The second permute in your original code was likely redundant/incorrect - // after the first one. A single permute is sufficient to restore order. cur = ggml_cont(ctx0, cur); return cur; } - -// ------------------------------------------------------------------------ // Helper for Conv2dSame padding (asymmetric SAME padding like PyTorch/TF) -// ------------------------------------------------------------------------ ggml_tensor* clip_graph::pad_same_2d(ggml_tensor* inp, int kernel_h, int kernel_w, int stride_h, int stride_w, int dilation_h, int dilation_w) { const int64_t ih = inp->ne[1]; // height const int64_t iw = inp->ne[0]; // width @@ -358,31 +315,20 @@ ggml_tensor* clip_graph::pad_same_2d(ggml_tensor* inp, int kernel_h, int kernel_ return inp; } -// ------------------------------------------------------------------------ -// Edge Residual Block (Stage 0) - CORRECTED -// ------------------------------------------------------------------------ -ggml_tensor * clip_graph::build_edge_residual(ggml_tensor * inp, const mobilenetv5_block & block, int stride, int block_idx) { + +// Edge Residual Block (Stage 0) +ggml_tensor * clip_graph::build_edge_residual(ggml_tensor * inp, const mobilenetv5_block & block, int stride) { ggml_tensor * cur = inp; // 1. Expansion Conv (3x3) - // -------------------------------------------------------------------- - // LOGIC FIX: - // Block 0 (stride=2): Uses "Conv2dSame". We must manually pad, then conv with pad=0. - // Block 1,2 (stride=1): Uses standard "Conv2d" with padding=(1,1). - // -------------------------------------------------------------------- - if (stride == 2) { // Case: Downsampling (Block 0) // Replicates Conv2dSame(kernel=3, stride=2) - // We calculate asymmetric padding dynamically cur = pad_same_2d(cur, 3, 3, stride, stride); - - // Perform conv with 0 padding because we just applied it manually cur = ggml_conv_2d_direct(ctx0, block.s0_conv_exp_w, cur, stride, stride, 0, 0, 1, 1); } else { // Case: Normal 3x3 Block (Block 1, 2) // Replicates Conv2d(kernel=3, stride=1, padding=1) - // Standard symmetric padding of 1 is sufficient for 3x3 s1 to keep dims same cur = ggml_conv_2d_direct(ctx0, block.s0_conv_exp_w, cur, stride, stride, 1, 1, 1, 1); } @@ -404,7 +350,7 @@ ggml_tensor * clip_graph::build_edge_residual(ggml_tensor * inp, const mobilenet return cur; } -ggml_tensor * clip_graph::build_inverted_residual(ggml_tensor * inp, const mobilenetv5_block & block, int stride, int block_idx) { +ggml_tensor * clip_graph::build_inverted_residual(ggml_tensor * inp, const mobilenetv5_block & block, int stride) { ggml_tensor * cur = inp; // 1. Depthwise Start (Optional) @@ -412,7 +358,6 @@ ggml_tensor * clip_graph::build_inverted_residual(ggml_tensor * inp, const mobil if (block.dw_start_w) { int k = block.dw_start_w->ne[0]; // 3 or 5 int p = k / 2; - // cur = ggml_conv_2d_dw_direct(ctx0, block.dw_start_w, cur, 1, 1, p, p, 1, 1); cur = ggml_conv_2d_dw(ctx0, block.dw_start_w, cur, 1, 1, p, p, 1, 1); if (block.dw_start_bn_w) cur = rms_norm_2d(cur, block.dw_start_bn_w); } @@ -462,8 +407,6 @@ ggml_tensor * clip_graph::build_inverted_residual(ggml_tensor * inp, const mobil bool same_spatial = (inp->ne[0] == cur->ne[0]) && (inp->ne[1] == cur->ne[1]); bool same_channel = (inp->ne[2] == cur->ne[2]); if (same_spatial && same_channel) { - // --- FIXED LAYER SCALING --- - // --------------------------- cur = ggml_add(ctx0, cur, inp); } @@ -471,24 +414,12 @@ ggml_tensor * clip_graph::build_inverted_residual(ggml_tensor * inp, const mobil } // MobileNetV5 Builder (Gemma 3n) - Attention Block -ggml_tensor * clip_graph::build_mobilenet_attn(ggml_tensor * inp, const mobilenetv5_block & block, int block_idx) { - - // ... [Debug Helpers kept same as original] ... - // auto DEBUG_SHAPE = [&](const char* label, ggml_tensor* t) { /* ... */ }; - // auto REGISTER_DEBUG = [&](const std::string& name, ggml_tensor* t) { /* ... */ }; - - // // Debug input - // if (block_idx == 33 || block_idx == 50 || block_idx == 52) { - // char debug_name[128]; - // snprintf(debug_name, sizeof(debug_name), "block%d_input", block_idx); - // REGISTER_DEBUG(debug_name, inp); - // } - +ggml_tensor * clip_graph::build_mobilenet_attn(ggml_tensor * inp, const mobilenetv5_block & block) { ggml_tensor * cur = inp; // --- Norm --- if (block.attn_norm_w) { - cur = rms_norm_2d(cur, block.attn_norm_w, 1e-6f, block_idx); + cur = rms_norm_2d(cur, block.attn_norm_w, 1e-6f); } // --- 1. Q Calculation --- @@ -502,7 +433,7 @@ ggml_tensor * clip_graph::build_mobilenet_attn(ggml_tensor * inp, const mobilene k_inp = pad_same_2d(cur, k_size, k_size, 2, 2); // Apply SAME padding k_inp = ggml_conv_2d_dw(ctx0, block.attn_k_dw_w, k_inp, 2, 2, 0, 0, 1, 1); // padding=0 if (block.attn_k_norm_w) { - k_inp = rms_norm_2d(k_inp, block.attn_k_norm_w, 1e-6f, block_idx); + k_inp = rms_norm_2d(k_inp, block.attn_k_norm_w, 1e-6f); } } ggml_tensor * k = ggml_conv_2d_direct(ctx0, block.attn_k_w, k_inp, 1, 1, 0, 0, 1, 1); @@ -515,13 +446,11 @@ ggml_tensor * clip_graph::build_mobilenet_attn(ggml_tensor * inp, const mobilene v_inp = pad_same_2d(cur, v_size, v_size, 2, 2); // Apply SAME padding v_inp = ggml_conv_2d_dw(ctx0, block.attn_v_dw_w, v_inp, 2, 2, 0, 0, 1, 1); // padding=0 if (block.attn_v_norm_w) { - v_inp = rms_norm_2d(v_inp, block.attn_v_norm_w, 1e-6f, block_idx); + v_inp = rms_norm_2d(v_inp, block.attn_v_norm_w, 1e-6f); } } ggml_tensor * v = ggml_conv_2d_direct(ctx0, block.attn_v_w, v_inp, 1, 1, 0, 0, 1, 1); - // --- Reshape & Permute Logic --- - const int W = cur->ne[0]; const int H = cur->ne[1]; const int B = cur->ne[3]; const int D = k->ne[2]; // Head dimension const int n_head = q->ne[2] / D; @@ -543,8 +472,6 @@ ggml_tensor * clip_graph::build_mobilenet_attn(ggml_tensor * inp, const mobilene k = ggml_cont(ctx0, k); // Process V: [Wk, Hk, D, B] -> [M, D, 1, B] - // NOTE: We keep V as [M, D] because ggml_mul_mat expects src0^T * src1. - // To get output [D, N], we will need [M, D]^T * [M, N]. v = ggml_reshape_3d(ctx0, v, M, D, B); v = ggml_reshape_4d(ctx0, v, M, D, 1, B); v = ggml_cont(ctx0, v); // [M, D, 1, B] @@ -553,82 +480,32 @@ ggml_tensor * clip_graph::build_mobilenet_attn(ggml_tensor * inp, const mobilene float scale = 1.0f / sqrtf((float)D); // Step 1: Compute Q @ K.T - // Q: [D, N, n_head, B] - // K: [D, M, 1, B] - // ggml_mul_mat computes K^T * Q -> [D, M]^T * [D, N] -> [M, D] * [D, N] -> [M, N] - // Implicit Broadcast: K has 1 head, Q has n_head. ggml handles this automatically. - ggml_tensor * scores = ggml_mul_mat(ctx0, k, q); // Result: [M, N, n_head, B] (in ggml layout) - - // // Debug scores - // if (block_idx == 33) { - // char debug_name[128]; - // snprintf(debug_name, sizeof(debug_name), "block%d_scores_raw", block_idx); - // REGISTER_DEBUG(debug_name, scores); - // } + ggml_tensor * scores = ggml_mul_mat(ctx0, k, q); scores = ggml_scale(ctx0, scores, scale); - // Step 2: Softmax - // scores is [M, N, n_head, B] (ne0=M, ne1=N) - // We need softmax over M (keys). - // ggml_soft_max applies to dim 0, which is M. Perfect - no permute needed! scores = ggml_soft_max(ctx0, scores); - // Step 3: Compute Attn @ V - // V: [M, D, 1, B] (ne0=M, ne1=D) - // Scores: [M, N, n_head, B] (ne0=M, ne1=N) - // - // ggml_mul_mat computes V^T * Scores -> [M, D]^T * [M, N] -> [D, M] * [M, N] -> [D, N] - // Implicit Broadcast: V has 1 head, Scores has n_head. ggml handles this automatically. - ggml_tensor * kqv = ggml_mul_mat(ctx0, v, scores); // Result: [N, D, n_head, B] - - // // Debug kqv - // if (block_idx == 33) { - // char debug_name[128]; - // snprintf(debug_name, sizeof(debug_name), "block%d_kqv_out", block_idx); - // REGISTER_DEBUG(debug_name, kqv); - // } - - // --- Reshape back to spatial layout --- - // kqv is [N, D, n_head, B]. We want [D, N, n_head, B] to merge heads. - kqv = ggml_permute(ctx0, kqv, 1, 0, 2, 3); // [D, N, n_head, B] + ggml_tensor * kqv = ggml_mul_mat(ctx0, v, scores); + + kqv = ggml_permute(ctx0, kqv, 1, 0, 2, 3); kqv = ggml_cont(ctx0, kqv); - // Reshape to [N, D*n_head, B] then [W, H, C, B] + kqv = ggml_reshape_3d(ctx0, kqv, N, D * n_head, B); kqv = ggml_reshape_4d(ctx0, kqv, W, H, D * n_head, B); kqv = ggml_cont(ctx0, kqv); -// Output projection + // Output projection cur = ggml_conv_2d_direct(ctx0, block.attn_o_w, kqv, 1, 1, 0, 0, 1, 1); // --- Residual & Layer Scale (FIXED) --- if (inp->ne[0] == cur->ne[0] && inp->ne[2] == cur->ne[2]) { if (block.layer_scale_w) { - // FIX: Simplified Layer Scale. No permute needed. - // Tensor is [W, H, C, B]. Weight is [C]. - // We reshape Weight to [1, 1, C, 1]. - // GGML will broadcast W and H dimensions automatically. - - // Debug print shape of block.layer_scale_w - // fprintf(stderr, "DEBUG: block %d layer_scale_w shape: [%ld x %ld x %ld x %ld]\n", block_idx, block.layer_scale_w->ne[0], block.layer_scale_w->ne[1], block.layer_scale_w->ne[2], block.layer_scale_w->ne[3]); - - // Debug print shape of cur before scaling - // fprintf(stderr, "DEBUG: block %d cur shape before scaling: [%ld x %ld x %ld x %ld]\n", block_idx, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); - - ggml_tensor * scale_w_reshaped = ggml_reshape_4d(ctx0, block.layer_scale_w, 1, 1, block.layer_scale_w->ne[0], 1); - - // Debug print shape of scale_w_reshaped - // fprintf(stderr, "DEBUG: block %d scale_w_reshaped shape: [%ld x %ld x %ld x %ld]\n", block_idx, scale_w_reshaped->ne[0], scale_w_reshaped->ne[1], scale_w_reshaped->ne[2], scale_w_reshaped->ne[3]); - cur = ggml_mul(ctx0, cur, scale_w_reshaped); } - - // Residual Addition - // 'cur' is the pointer to the graph node of the attention output. - // 'inp' is the pointer to the graph node of the block input. cur = ggml_add(ctx0, cur, inp); } diff --git a/tools/mtmd/models/mobilenetv5.cpp b/tools/mtmd/models/mobilenetv5.cpp index 9946ca6afa8..88bd1e6fcb9 100644 --- a/tools/mtmd/models/mobilenetv5.cpp +++ b/tools/mtmd/models/mobilenetv5.cpp @@ -9,8 +9,6 @@ ggml_cgraph * clip_graph_mobilenetv5::build() { // 1. Stem - Conv2dSame(3, 64, kernel_size=(3, 3), stride=(2, 2)) ggml_tensor * cur = pad_same_2d(inp, 3, 3, 2, 2); // Apply SAME padding - // ggml_tensor * mobilenet_stem_conv_w_fixed = fix_1x1_weight(model.mobilenet_stem_conv_w); - cur = ggml_conv_2d_direct(ctx0, model.mobilenet_stem_conv_w, cur, 2, 2, 0, 0, 1, 1); // padding=0 if (model.mobilenet_stem_conv_b) { // Bias is [C, 1, 1, 1], need to reshape to [1, 1, C, 1] for broadcasting to [W, H, C, B] @@ -47,22 +45,9 @@ ggml_cgraph * clip_graph_mobilenetv5::build() { const auto & block = model.mobilenet_blocks[i]; int stride = is_stage_start(i) ? 2 : 1; - // Debug block type - const char* block_type = block.s0_conv_exp_w ? "edge_residual" : - block.attn_q_w ? "attention" : "inverted_residual"; - - // // Debug input for problematic blocks - // if (i >= 50 && i <= 54) { - // fprintf(stderr, "DEBUG: Block %d (%s) input shape: [%ld, %ld, %ld, %ld], stride=%d\n", - // i, block_type, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3], stride); - // } - - if (block.s0_conv_exp_w) cur = build_edge_residual(cur, block, stride, i); - else if (block.attn_q_w) cur = build_mobilenet_attn(cur, block, i); - else cur = build_inverted_residual(cur, block, stride, i); - - // Register block output for debugging - char block_name[64]; + if (block.s0_conv_exp_w) cur = build_edge_residual(cur, block, stride); + else if (block.attn_q_w) cur = build_mobilenet_attn(cur, block); + else cur = build_inverted_residual(cur, block, stride); if (is_fusion_point(i)) { @@ -94,7 +79,7 @@ ggml_cgraph * clip_graph_mobilenetv5::build() { // ggml_upscale generally takes integer factors or target sizes depending on helper. // Assuming standard power-of-2 scaling (e.g. 16 -> 32 means scale=2). int scale_w = high_res_w / feat_w; - int scale_h = high_res_h / feat_h; + // int scale_h = high_res_h / feat_h; // Safety check for non-integer scaling if strictly replicating if (high_res_w % feat_w != 0) { @@ -122,9 +107,8 @@ ggml_cgraph * clip_graph_mobilenetv5::build() { // 1x1 Conv cur = ggml_conv_2d_direct(ctx0, model.msfa_ffn_expand_w, cur, 1, 1, 0, 0, 1, 1); - // MISSING IN YOUR CODE: Expansion Norm if (model.msfa_ffn_expand_bn) { - cur = rms_norm_2d(cur, model.msfa_ffn_expand_bn); // Helper to apply RMSNorm + cur = rms_norm_2d(cur, model.msfa_ffn_expand_bn); } cur = ggml_gelu(ctx0, cur); @@ -136,7 +120,6 @@ ggml_cgraph * clip_graph_mobilenetv5::build() { // 1x1 Conv cur = ggml_conv_2d_direct(ctx0, model.msfa_ffn_project_w, cur, 1, 1, 0, 0, 1, 1); - // MISSING IN YOUR CODE: Projection Norm // UniversalInvertedResidual typically has a norm after projection if (model.msfa_ffn_project_bn) { cur = rms_norm_2d(cur, model.msfa_ffn_project_bn); @@ -151,17 +134,11 @@ ggml_cgraph * clip_graph_mobilenetv5::build() { if (current_w > target_out_res) { int s = current_w / target_out_res; - - // PyTorch Logic: - // if divisible: avg_pool - // if not divisible: bilinear interpolate (hard to do in pure ggml, usually assumed divisible here) - + if (current_w % target_out_res == 0) { // Avg Pool: Kernel=s, Stride=s cur = ggml_pool_2d(ctx0, cur, GGML_OP_POOL_AVG, s, s, s, s, 0, 0); } else { - // Fallback or Error: ggml doesn't easily support bilinear downsampling - // without custom ops, but standard models usually stick to integer strides. fprintf(stderr, "Error: Irregular downsampling stride required.\n"); } @@ -174,7 +151,7 @@ ggml_cgraph * clip_graph_mobilenetv5::build() { } } - // 4. Gemma 3n Multimodal Projection (Embedder) - FULL FIX + // 4. Gemma 3n Multimodal Projection (Embedder) // Input: 'cur' is [Width, Height, Channels, Batch] int W = cur->ne[0]; int H = cur->ne[1]; @@ -189,7 +166,7 @@ ggml_cgraph * clip_graph_mobilenetv5::build() { cur = ggml_cont(ctx0, cur); - // 2. FEATURE SCALING (Missing in your original code) + // 2. FEATURE SCALING // PyTorch: vision_outputs *= self.config.vision_config.hidden_size**0.5 // This prevents the signal from vanishing during the subsequent RMSNorm. const float scale_factor = sqrtf((float)C); @@ -227,8 +204,6 @@ ggml_cgraph * clip_graph_mobilenetv5::build() { // PyTorch: embedding_post_projection_norm = Gemma3nRMSNorm(..., with_scale=False) // with_scale=False means weight is registered as buffer with value 1.0 // So output = rms_norm(x) * 1.0 = rms_norm(x), magnitude ~1 - // NOTE: Vision embeddings intentionally have magnitude ~1, different from - // text embeddings at ~sqrt(n_embd). The model was trained with this mismatch. { const float eps = 1e-6f; cur = ggml_rms_norm(ctx0, cur, eps); @@ -239,9 +214,6 @@ ggml_cgraph * clip_graph_mobilenetv5::build() { } } - - // cur = ggml_scale(ctx0, cur, scale_factor); - ggml_build_forward_expand(gf, cur); return gf; } \ No newline at end of file From 4589d3eb748c48a33446f6d1465cb8b9a65d3635 Mon Sep 17 00:00:00 2001 From: Simranjeet Singh Date: Sun, 21 Dec 2025 10:59:15 +0000 Subject: [PATCH 4/8] Fix permute and remove transpose of projection weights --- tools/mtmd/models/mobilenetv5.cpp | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/tools/mtmd/models/mobilenetv5.cpp b/tools/mtmd/models/mobilenetv5.cpp index 88bd1e6fcb9..6dd1a3d465d 100644 --- a/tools/mtmd/models/mobilenetv5.cpp +++ b/tools/mtmd/models/mobilenetv5.cpp @@ -160,7 +160,8 @@ ggml_cgraph * clip_graph_mobilenetv5::build() { // 1. Permute and Flatten to [Channels, Tokens, Batch] // PyTorch expects (Batch, Seq, Hidden), GGML usually processes (Hidden, Seq, Batch) - cur = ggml_permute(ctx0, cur, 2, 0, 1, 3); // -> [C, W, H, B] + cur = ggml_permute(ctx0, cur, 2, 1, 0, 3); // -> [C, H, W, B] + cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); // -> [C, W, H, B] cur = ggml_cont(ctx0, cur); cur = ggml_reshape_3d(ctx0, cur, C, W*H, B); cur = ggml_cont(ctx0, cur); @@ -193,11 +194,7 @@ ggml_cgraph * clip_graph_mobilenetv5::build() { // Need to transpose for ggml_mul_mat which computes A^T * B // This matches Gemma3's projection at line ~1319 which also transposes if (model.mm_input_proj_w) { - // cur = ggml_mul_mat(ctx0, model.mm_input_proj_w, cur); - cur = ggml_mul_mat(ctx0, - ggml_cont(ctx0, ggml_transpose(ctx0, model.mm_input_proj_w)), - cur); - + cur = ggml_mul_mat(ctx0, model.mm_input_proj_w, cur); } // 5. POST PROJECTION NORM From 47423a295ba1c272d38b85f98e6da89be995b7c0 Mon Sep 17 00:00:00 2001 From: Simranjeet Singh Date: Sun, 21 Dec 2025 11:50:10 +0000 Subject: [PATCH 5/8] Fix comments, remove debugging prints from hf_to_gguf --- convert_hf_to_gguf.py | 17 +++++------------ tools/mtmd/models/mobilenetv5.cpp | 4 +--- 2 files changed, 6 insertions(+), 15 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 36a7ed000af..dd94efe7ed0 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -6018,22 +6018,16 @@ def find_vparam(self, keys: list[str], optional: bool = False) -> Any: return super().find_vparam(keys, optional) def set_gguf_parameters(self): - # MobileNetV5 requires ImageNet normalization values - # Override preprocessor_config to ensure correct values before calling super() - # IMAGENET_MEAN = [0.485, 0.456, 0.406] - # IMAGENET_STD = [0.229, 0.224, 0.225] + # MobileNetV5 does not use normalisation at all IMAGENET_MEAN = [0.5 , 0.5 , 0.5 ] IMAGENET_STD = [0.5 , 0.5 , 0.5 ] - print("test") - # Check if preprocessor_config has incorrect normalization values if "image_mean" in self.preprocessor_config: current_mean = self.preprocessor_config["image_mean"] if current_mean != IMAGENET_MEAN: logger.warning(f"Overriding image_mean from {current_mean} to ImageNet standard {IMAGENET_MEAN}") self.preprocessor_config["image_mean"] = IMAGENET_MEAN - print("test2") else: logger.info(f"Setting image_mean to ImageNet standard {IMAGENET_MEAN}") self.preprocessor_config["image_mean"] = IMAGENET_MEAN @@ -6060,7 +6054,6 @@ def set_gguf_parameters(self): # Image sequence length (256 tokens = 16x16 for Gemma3n) image_seq_length = self.preprocessor_config.get("image_seq_length", 256) - # Note: Additional metadata can be added as needed def tensor_force_quant(self, name, new_name, bid, n_dims): # Force quantization settings for specific tensor types @@ -6110,17 +6103,17 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter def map_tensor_name(self, name: str) -> str: """Map Gemma3n tensor names to GGUF format""" # Projector tensors (from embed_vision) - use mm. prefix like Gemma3 - # IMPORTANT: Keep the .weight suffix to match C++ expectations + # IMPORTANT: Keep the .weight suffix to match ggml expectations if name == "embedding.weight": return "mm.embedding.weight" if name == "embedding_projection.weight": - return "mm.input_projection.weight" # Main projection used by C++ + return "mm.input_projection.weight" # Main projection if name == "hard_emb_norm.weight": return "mm.hard_emb_norm.weight" # Hard embedding normalization if name == "soft_emb_norm.weight": - return "mm.soft_emb_norm.weight" # Soft embedding normalization (used by C++) + return "mm.soft_emb_norm.weight" # Soft embedding normalization if name == "post_proj_norm.weight": - return "mm.post_proj_norm.weight" # Post projection normalization (CRITICAL for Gemma3n) + return "mm.post_proj_norm.weight" # Post projection normalization (if exists) # Vision tower tensors - add v.enc. prefix for MobileNetV5 encoder if name.startswith("vision_tower."): diff --git a/tools/mtmd/models/mobilenetv5.cpp b/tools/mtmd/models/mobilenetv5.cpp index 6dd1a3d465d..930da38e302 100644 --- a/tools/mtmd/models/mobilenetv5.cpp +++ b/tools/mtmd/models/mobilenetv5.cpp @@ -55,7 +55,7 @@ ggml_cgraph * clip_graph_mobilenetv5::build() { } } - // 3. Multi-Scale Fusion Adapter (MSFA) - REPLICATED & FIXED + // 3. Multi-Scale Fusion Adapter (MSFA) if (!intermediate_features.empty()) { // A. Reference Resolution: PyTorch implementation uses inputs[0] @@ -191,8 +191,6 @@ ggml_cgraph * clip_graph_mobilenetv5::build() { // 4. PROJECTION // PyTorch: embedding_projection = nn.Linear(vision_hidden, text_hidden, bias=False) // Weight stored as [out_features, in_features] = [text_hidden_size, vision_hidden_size] - // Need to transpose for ggml_mul_mat which computes A^T * B - // This matches Gemma3's projection at line ~1319 which also transposes if (model.mm_input_proj_w) { cur = ggml_mul_mat(ctx0, model.mm_input_proj_w, cur); } From 67801e5b62a68db509ef879a0c47b5bf096df785 Mon Sep 17 00:00:00 2001 From: Simranjeet Singh Date: Sun, 21 Dec 2025 19:13:47 +0000 Subject: [PATCH 6/8] 1. Hard-code image_mean = 0 and image_std = 1 2. Use available tensor mapping logic 3. Remove redundant chat template replacement of soft tokens placeholder with media placeholder --- convert_hf_to_gguf.py | 113 +++++---------------------------- gguf-py/gguf/constants.py | 9 +++ gguf-py/gguf/tensor_mapping.py | 21 ++++++ 3 files changed, 46 insertions(+), 97 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index dd94efe7ed0..55e82fe9128 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -5966,9 +5966,6 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter @ModelBase.register("Gemma3nForConditionalGeneration", "Gemma3nVisionModel") class Gemma3nVisionModel(MmprojModel): """Vision encoder converter for Gemma3n using MobileNetV5 architecture""" - - # MobileNetV5 doesn't have transformer layers, so we don't need block count - # Set n_block_keys to empty list to skip the find_hparam check n_block_keys = [] def find_hparam(self, keys: list[str], optional: bool = False) -> Any: @@ -5984,34 +5981,17 @@ def __init__(self, *args, **kwargs): def find_vparam(self, keys: list[str], optional: bool = False) -> Any: """Override to provide hardcoded MobileNetV5 parameters that aren't in config""" - # MobileNetV5 hardcodes these values in the architecture definition - # rather than storing them in config.json - # Handle empty keys list (n_block_keys) - return 0 for CNN architecture if not keys: return 0 - # Check if we're looking for image_size - if "image_size" in keys: - # MobileNetV5 300m_enc uses 768x768 input - return 768 - - # Check if we're looking for patch_size - if "patch_size" in keys: - # MobileNetV5 is CNN-based, doesn't use patches - # Set to 1 for compatibility - return 1 - - # Check if we're looking for intermediate_size if "intermediate_size" in keys: - # MobileNetV5 uses expansion ratios in inverted residual blocks # Typical expansion is 4x the embedding dimension hidden_size = self.hparams_vision.get("hidden_size", 2048) return hidden_size * 4 - # Check if we're looking for num_attention_heads if "num_attention_heads" in keys or "num_heads" in keys: - # MobileNetV5 uses Multi-Query Attention with 8 heads + # Multi-Query Attention with 8 heads return 8 # For other parameters, use parent implementation @@ -6019,41 +5999,25 @@ def find_vparam(self, keys: list[str], optional: bool = False) -> Any: def set_gguf_parameters(self): # MobileNetV5 does not use normalisation at all - IMAGENET_MEAN = [0.5 , 0.5 , 0.5 ] - IMAGENET_STD = [0.5 , 0.5 , 0.5 ] - - # Check if preprocessor_config has incorrect normalization values - if "image_mean" in self.preprocessor_config: - current_mean = self.preprocessor_config["image_mean"] - if current_mean != IMAGENET_MEAN: - logger.warning(f"Overriding image_mean from {current_mean} to ImageNet standard {IMAGENET_MEAN}") - self.preprocessor_config["image_mean"] = IMAGENET_MEAN - else: - logger.info(f"Setting image_mean to ImageNet standard {IMAGENET_MEAN}") - self.preprocessor_config["image_mean"] = IMAGENET_MEAN - - if "image_std" in self.preprocessor_config: - current_std = self.preprocessor_config["image_std"] - if current_std != IMAGENET_STD: - logger.warning(f"Overriding image_std from {current_std} to ImageNet standard {IMAGENET_STD}") - self.preprocessor_config["image_std"] = IMAGENET_STD - else: - logger.info(f"Setting image_std to ImageNet standard {IMAGENET_STD}") - self.preprocessor_config["image_std"] = IMAGENET_STD + self.preprocessor_config["image_mean"] = [0.0 , 0.0 , 0.0 ] + self.preprocessor_config["image_std"] = [1.0 , 1.0 , 1.0 ] + self.hparams_vision["image_size"] = self.preprocessor_config.get( + "size", {"height": 768, "width": 768} + )["height"] + + # Image sequence length (256 tokens = 16x16 for Gemma3n) + image_seq_length = self.preprocessor_config.get("image_seq_length", 256) + image_size = self.hparams_vision["image_size"] + self.hparams_vision["patch_size"] = image_size // image_seq_length # Now call parent which will use the corrected values super().set_gguf_parameters() - hparams = self.hparams # Set projector type to GEMMA3N self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.GEMMA3N) # MobileNetV5 specific parameters - self.gguf_writer.add_vision_attention_layernorm_eps(hparams.get("layer_norm_eps", 1e-6)) - self.gguf_writer.add_vision_use_gelu(True) # MobileNetV5 uses approximate GELU - - # Image sequence length (256 tokens = 16x16 for Gemma3n) - image_seq_length = self.preprocessor_config.get("image_seq_length", 256) + self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams.get("layer_norm_eps", 1e-6)) def tensor_force_quant(self, name, new_name, bid, n_dims): # Force quantization settings for specific tensor types @@ -6090,7 +6054,6 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter # Handle normalization layer naming name = name.replace("hard_embedding_norm", "hard_emb_norm") name = name.replace("soft_embedding_norm", "soft_emb_norm") - # name = name.replace("embedding_post_projection_norm", "post_proj_norm") # Gemma3n uses Gemma3p5RMSNorm which has scale_shift=0, so no correction needed # Unlike Gemma3 which uses Gemma3RMSNorm with scale_shift=1 @@ -6098,37 +6061,11 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter # No correction needed for Gemma3n pass - return [(self.map_tensor_name(name), data_torch)] - - def map_tensor_name(self, name: str) -> str: - """Map Gemma3n tensor names to GGUF format""" - # Projector tensors (from embed_vision) - use mm. prefix like Gemma3 - # IMPORTANT: Keep the .weight suffix to match ggml expectations - if name == "embedding.weight": - return "mm.embedding.weight" - if name == "embedding_projection.weight": - return "mm.input_projection.weight" # Main projection - if name == "hard_emb_norm.weight": - return "mm.hard_emb_norm.weight" # Hard embedding normalization - if name == "soft_emb_norm.weight": - return "mm.soft_emb_norm.weight" # Soft embedding normalization - if name == "post_proj_norm.weight": - return "mm.post_proj_norm.weight" # Post projection normalization (if exists) - - # Vision tower tensors - add v.enc. prefix for MobileNetV5 encoder if name.startswith("vision_tower."): - # Remove vision_tower prefix and add v.enc. prefix - tensor_suffix = name[13:] # Remove "vision_tower." - return f"v.enc.{tensor_suffix}" - - # If no match, try parent implementation - try: - return super().map_tensor_name(name) - except ValueError: - # If parent also can't map it, provide a sensible default - # This shouldn't happen, but provides a fallback - logger.warning(f"Using fallback mapping for tensor: {name}") - return f"v.{name}" + tensor_suffix = name[13:] + return [(f"v.enc.{tensor_suffix}", data_torch)] + else: + return [(self.map_tensor_name(name), data_torch)] @ModelBase.register("Gemma3nForCausalLM", "Gemma3nForConditionalGeneration") @@ -6173,24 +6110,6 @@ def set_vocab(self): if vocab_size_per_layer_input is not None: self.hparams["vocab_size_per_layer_input"] = vocab_size_per_layer_input - # Fix chat template for Gemma3n multimodal: replace special token placeholders with mtmd markers - # The mtmd library uses <__media__> as the default marker for images/audio - # but Gemma3n's chat template uses and - chat_template_key = "tokenizer.chat_template" - for kv_dict in self.gguf_writer.kv_data: - if chat_template_key in kv_dict: - template_value = kv_dict[chat_template_key].value - - # Replace soft token placeholders with mtmd markers - if '' in template_value or '' in template_value: - logger.info("Fixing Gemma3n chat template: replacing soft token placeholders with mtmd markers") - template_value = template_value.replace('', '<__media__>') - template_value = template_value.replace('', '<__media__>') - - # Update the value in place - kv_dict[chat_template_key].value = template_value - break - def set_gguf_parameters(self): super().set_gguf_parameters() self.gguf_writer.add_altup_active_idx(self.hparams["altup_active_idx"]) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 41654b22b5d..869a8582b12 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -667,6 +667,9 @@ class MODEL_TENSOR(IntEnum): V_MM_INP_NORM = auto() V_MM_INP_PROJ = auto() # gemma3 V_MM_SOFT_EMB_NORM = auto() # gemma3 + V_MM_EMBEDDING = auto() # gemma3n + V_MM_HARD_EMB_NORM = auto() # gemma3n + V_MM_POST_PROJ_NORM = auto() # gemma3n V_RESMPL_POS_EMBD_K = auto() # minicpmv V_RESMPL_ATTN_Q = auto() # minicpmv V_RESMPL_ATTN_K = auto() # minicpmv @@ -1059,6 +1062,9 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.V_MM_INP_PROJ: "mm.input_projection", MODEL_TENSOR.V_MM_INP_NORM: "mm.input_norm", MODEL_TENSOR.V_MM_SOFT_EMB_NORM: "mm.soft_emb_norm", + MODEL_TENSOR.V_MM_EMBEDDING: "mm.embedding", + MODEL_TENSOR.V_MM_HARD_EMB_NORM: "mm.hard_emb_norm", + MODEL_TENSOR.V_MM_POST_PROJ_NORM: "mm.post_proj_norm", MODEL_TENSOR.V_RESMPL_POS_EMBD_K: "resampler.pos_embd_k", MODEL_TENSOR.V_RESMPL_ATTN_Q: "resampler.attn.q", MODEL_TENSOR.V_RESMPL_ATTN_K: "resampler.attn.k", @@ -1157,6 +1163,9 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.V_MM_INP_PROJ, MODEL_TENSOR.V_MM_INP_NORM, MODEL_TENSOR.V_MM_SOFT_EMB_NORM, + MODEL_TENSOR.V_MM_EMBEDDING, + MODEL_TENSOR.V_MM_HARD_EMB_NORM, + MODEL_TENSOR.V_MM_POST_PROJ_NORM, MODEL_TENSOR.V_RESMPL_POS_EMBD_K, MODEL_TENSOR.V_RESMPL_ATTN_Q, MODEL_TENSOR.V_RESMPL_ATTN_K, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 301aafa9102..3e1cf8a136f 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -119,6 +119,27 @@ class TensorNameMap: MODEL_TENSOR.CONV1D: ( "backbone.embed", # roberta ), + + # Vision multimodal projector tensors (non-block) for gemma3n + MODEL_TENSOR.V_MM_INP_PROJ: ( + "embedding_projection", # gemma3n + ), + + MODEL_TENSOR.V_MM_SOFT_EMB_NORM: ( + "soft_emb_norm", # gemma3n + ), + + MODEL_TENSOR.V_MM_EMBEDDING: ( + "embedding", # gemma3n + ), + + MODEL_TENSOR.V_MM_HARD_EMB_NORM: ( + "hard_emb_norm", # gemma3n + ), + + MODEL_TENSOR.V_MM_POST_PROJ_NORM: ( + "post_proj_norm", # gemma3n + ), } block_mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = { From 04947c7f9e355914e7e4d6cdf79c28193218d988 Mon Sep 17 00:00:00 2001 From: Simranjeet Singh Date: Sun, 21 Dec 2025 19:19:47 +0000 Subject: [PATCH 7/8] 1. Move mobilenetv5 helpers declarations to `clip_graph_mobilenetv5` struct and definitions to mobilenetv5.cpp 2.Remove unused `clip_is_gemma3n` func declarations and definitions 3. Remove redundant `rescale_image_u8_to_f32` func and use `normalize_image_u8_to_f32` with zero mean and unit std 4. Calculate n_patches using image_size / patch_size --- tools/mtmd/clip-graph.h | 28 --- tools/mtmd/clip.cpp | 271 +----------------------------- tools/mtmd/clip.h | 1 - tools/mtmd/models/mobilenetv5.cpp | 249 +++++++++++++++++++++++++++ tools/mtmd/models/models.h | 28 +++ 5 files changed, 279 insertions(+), 298 deletions(-) diff --git a/tools/mtmd/clip-graph.h b/tools/mtmd/clip-graph.h index 6a9efb933e5..2b1915779f2 100644 --- a/tools/mtmd/clip-graph.h +++ b/tools/mtmd/clip-graph.h @@ -70,34 +70,6 @@ struct clip_graph { ggml_tensor * build_inp_raw(int channels = 3); - ggml_tensor * rms_norm_2d( - ggml_tensor * inp, - ggml_tensor * weight, - float eps = 1e-6f); - - ggml_tensor* pad_same_2d( - ggml_tensor* inp, - int kernel_h, - int kernel_w, - int stride_h, - int stride_w, - int dilation_h = 1, - int dilation_w = 1); - - ggml_tensor * build_edge_residual( - ggml_tensor * inp, - const mobilenetv5_block & block, - int stride); - - ggml_tensor * build_inverted_residual( - ggml_tensor * inp, - const mobilenetv5_block & block, - int stride); - - ggml_tensor * build_mobilenet_attn( - ggml_tensor * inp, - const mobilenetv5_block & block); - ggml_tensor * build_norm( ggml_tensor * cur, ggml_tensor * mw, diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 9e4519c502b..e86a09bb5c1 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -263,255 +263,6 @@ void clip_graph::cb(ggml_tensor * cur0, const char * name, int il) const { } } -// --- Helpers for MobileNetV5 Blocks --- -// RMS Norm 2D - normalizes over channels for each spatial position -ggml_tensor * clip_graph::rms_norm_2d(ggml_tensor * inp, ggml_tensor * weight, float eps) { - // inp: [W, H, C, B] - - ggml_tensor * cur = ggml_permute(ctx0, inp, 2, 1, 0, 3); - cur = ggml_cont(ctx0, cur); - cur = ggml_rms_norm(ctx0, cur, eps); - - if (weight) { - cur = ggml_mul(ctx0, cur, weight); - } - - cur = ggml_permute(ctx0, cur, 2, 1, 0, 3); - cur = ggml_cont(ctx0, cur); - - return cur; -} - -// Helper for Conv2dSame padding (asymmetric SAME padding like PyTorch/TF) -ggml_tensor* clip_graph::pad_same_2d(ggml_tensor* inp, int kernel_h, int kernel_w, int stride_h, int stride_w, int dilation_h, int dilation_w) { - const int64_t ih = inp->ne[1]; // height - const int64_t iw = inp->ne[0]; // width - - // Calculate output size (ceil division) - const int64_t oh = (ih + stride_h - 1) / stride_h; - const int64_t ow = (iw + stride_w - 1) / stride_w; - - // Calculate padding needed - const int64_t pad_h = std::max((int64_t)0, (oh - 1) * stride_h + (kernel_h - 1) * dilation_h + 1 - ih); - const int64_t pad_w = std::max((int64_t)0, (ow - 1) * stride_w + (kernel_w - 1) * dilation_w + 1 - iw); - - // Split padding asymmetrically - const int pad_h_top = pad_h / 2; - const int pad_h_bottom = pad_h - pad_h_top; - const int pad_w_left = pad_w / 2; - const int pad_w_right = pad_w - pad_w_left; - - // Apply padding if needed - // ggml_pad_ext: (ctx, tensor, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3) - // For [W, H, C, B]: p0=width, p1=height, p2=channels, p3=batch - if (pad_h > 0 || pad_w > 0) { - inp = ggml_pad_ext(ctx0, inp, - pad_w_left, pad_w_right, // width padding (dim 0) - pad_h_top, pad_h_bottom, // height padding (dim 1) - 0, 0, // no channel padding (dim 2) - 0, 0); // no batch padding (dim 3) - } - - return inp; -} - - -// Edge Residual Block (Stage 0) -ggml_tensor * clip_graph::build_edge_residual(ggml_tensor * inp, const mobilenetv5_block & block, int stride) { - ggml_tensor * cur = inp; - - // 1. Expansion Conv (3x3) - if (stride == 2) { - // Case: Downsampling (Block 0) - // Replicates Conv2dSame(kernel=3, stride=2) - cur = pad_same_2d(cur, 3, 3, stride, stride); - cur = ggml_conv_2d_direct(ctx0, block.s0_conv_exp_w, cur, stride, stride, 0, 0, 1, 1); - } else { - // Case: Normal 3x3 Block (Block 1, 2) - // Replicates Conv2d(kernel=3, stride=1, padding=1) - cur = ggml_conv_2d_direct(ctx0, block.s0_conv_exp_w, cur, stride, stride, 1, 1, 1, 1); - } - - // BN + Activation - if (block.s0_bn1_w) cur = rms_norm_2d(cur, block.s0_bn1_w); - cur = ggml_gelu(ctx0, cur); - - // 2. Pointwise Linear Conv (1x1) - // 1x1 Convs usually have padding=0 and stride=1 - cur = ggml_conv_2d_direct(ctx0, block.s0_conv_pwl_w, cur, 1, 1, 0, 0, 1, 1); - if (block.s0_bn2_w) cur = rms_norm_2d(cur, block.s0_bn2_w); - - // 3. Residual Connection - // Only apply residual if spatial dimensions and channels match (stride 1) - if (stride == 1 && inp->ne[2] == cur->ne[2] && inp->ne[0] == cur->ne[0]) { - cur = ggml_add(ctx0, cur, inp); - } - - return cur; -} - -ggml_tensor * clip_graph::build_inverted_residual(ggml_tensor * inp, const mobilenetv5_block & block, int stride) { - ggml_tensor * cur = inp; - - // 1. Depthwise Start (Optional) - // NOTE: dw_start always has stride=1 (no downsampling here) - if (block.dw_start_w) { - int k = block.dw_start_w->ne[0]; // 3 or 5 - int p = k / 2; - cur = ggml_conv_2d_dw(ctx0, block.dw_start_w, cur, 1, 1, p, p, 1, 1); - if (block.dw_start_bn_w) cur = rms_norm_2d(cur, block.dw_start_bn_w); - } - - // 2. Pointwise Expansion (1x1) - if (block.pw_exp_w) { - // Standard 1x1 conv, pad=0, stride=1 - cur = ggml_conv_2d_direct(ctx0, block.pw_exp_w, cur, 1, 1, 0, 0, 1, 1); - if (block.pw_exp_bn_w) cur = rms_norm_2d(cur, block.pw_exp_bn_w); - cur = ggml_gelu(ctx0, cur); - } - - // 3. Depthwise Mid (Optional) - // NOTE: dw_mid is where downsampling happens (stride=2 for first block of stage) - if (block.dw_mid_w) { - int k = block.dw_mid_w->ne[0]; // 3 or 5 - - if (stride > 1) { - // Case: Stride 2 (Downsample) -> Use Asymmetric "Same" Padding - cur = pad_same_2d(cur, k, k, stride, stride); - cur = ggml_conv_2d_dw(ctx0, block.dw_mid_w, cur, stride, stride, 0, 0, 1, 1); // pad=0 - } else { - // Case: Stride 1 -> Use Standard Symmetric Padding - int p = k / 2; - cur = ggml_conv_2d_dw(ctx0, block.dw_mid_w, cur, stride, stride, p, p, 1, 1); - } - - if (block.dw_mid_bn_w) cur = rms_norm_2d(cur, block.dw_mid_bn_w); - cur = ggml_gelu(ctx0, cur); - } - - // 4. Pointwise Projection (1x1) - if (block.pw_proj_w) { - cur = ggml_conv_2d_direct(ctx0, block.pw_proj_w, cur, 1, 1, 0, 0, 1, 1); - if (block.pw_proj_bn_w) cur = rms_norm_2d(cur, block.pw_proj_bn_w); - } - - // Apply Layer Scaling if present - if (block.layer_scale_w) { - ggml_tensor * scale_w_reshaped = ggml_reshape_4d(ctx0, block.layer_scale_w, - 1, 1, block.layer_scale_w->ne[0], 1); - - cur = ggml_mul(ctx0, cur, scale_w_reshaped); - } - - // 5. Residual Connection - bool same_spatial = (inp->ne[0] == cur->ne[0]) && (inp->ne[1] == cur->ne[1]); - bool same_channel = (inp->ne[2] == cur->ne[2]); - if (same_spatial && same_channel) { - cur = ggml_add(ctx0, cur, inp); - } - - return cur; -} - -// MobileNetV5 Builder (Gemma 3n) - Attention Block -ggml_tensor * clip_graph::build_mobilenet_attn(ggml_tensor * inp, const mobilenetv5_block & block) { - ggml_tensor * cur = inp; - - // --- Norm --- - if (block.attn_norm_w) { - cur = rms_norm_2d(cur, block.attn_norm_w, 1e-6f); - } - - // --- 1. Q Calculation --- - ggml_tensor * q = ggml_conv_2d_direct(ctx0, block.attn_q_w, cur, 1, 1, 0, 0, 1, 1); - - // --- 2. K Calculation (Downsampled) --- - // Uses Conv2dSame(640, 640, kernel_size=(3, 3), stride=(2, 2), groups=640) - ggml_tensor * k_inp = cur; - if (block.attn_k_dw_w) { - int k_size = block.attn_k_dw_w->ne[0]; // Usually 3 - k_inp = pad_same_2d(cur, k_size, k_size, 2, 2); // Apply SAME padding - k_inp = ggml_conv_2d_dw(ctx0, block.attn_k_dw_w, k_inp, 2, 2, 0, 0, 1, 1); // padding=0 - if (block.attn_k_norm_w) { - k_inp = rms_norm_2d(k_inp, block.attn_k_norm_w, 1e-6f); - } - } - ggml_tensor * k = ggml_conv_2d_direct(ctx0, block.attn_k_w, k_inp, 1, 1, 0, 0, 1, 1); - - // --- 3. V Calculation (Downsampled) --- - // Uses Conv2dSame(640, 640, kernel_size=(3, 3), stride=(2, 2), groups=640) - ggml_tensor * v_inp = cur; - if (block.attn_v_dw_w) { - int v_size = block.attn_v_dw_w->ne[0]; // Usually 3 - v_inp = pad_same_2d(cur, v_size, v_size, 2, 2); // Apply SAME padding - v_inp = ggml_conv_2d_dw(ctx0, block.attn_v_dw_w, v_inp, 2, 2, 0, 0, 1, 1); // padding=0 - if (block.attn_v_norm_w) { - v_inp = rms_norm_2d(v_inp, block.attn_v_norm_w, 1e-6f); - } - } - ggml_tensor * v = ggml_conv_2d_direct(ctx0, block.attn_v_w, v_inp, 1, 1, 0, 0, 1, 1); - - const int W = cur->ne[0]; const int H = cur->ne[1]; const int B = cur->ne[3]; - const int D = k->ne[2]; // Head dimension - const int n_head = q->ne[2] / D; - const int N = W * H; - - // Process Q: [W, H, D*n_head, B] -> [D, N, n_head, B] - q = ggml_reshape_3d(ctx0, q, N, D*n_head, B); - q = ggml_reshape_4d(ctx0, q, N, D, n_head, B); - q = ggml_permute(ctx0, q, 1, 0, 2, 3); // [D, N, n_head, B] - q = ggml_cont(ctx0, q); - - const int Wk = k->ne[0]; const int Hk = k->ne[1]; - const int M = Wk * Hk; - - // Process K: [Wk, Hk, D, B] -> [D, M, 1, B] - k = ggml_reshape_3d(ctx0, k, M, D, B); - k = ggml_reshape_4d(ctx0, k, M, D, 1, B); - k = ggml_permute(ctx0, k, 1, 0, 2, 3); // [D, M, 1, B] - k = ggml_cont(ctx0, k); - - // Process V: [Wk, Hk, D, B] -> [M, D, 1, B] - v = ggml_reshape_3d(ctx0, v, M, D, B); - v = ggml_reshape_4d(ctx0, v, M, D, 1, B); - v = ggml_cont(ctx0, v); // [M, D, 1, B] - - // --- Multi-Query Attention --- - float scale = 1.0f / sqrtf((float)D); - - // Step 1: Compute Q @ K.T - ggml_tensor * scores = ggml_mul_mat(ctx0, k, q); - - scores = ggml_scale(ctx0, scores, scale); - - scores = ggml_soft_max(ctx0, scores); - - ggml_tensor * kqv = ggml_mul_mat(ctx0, v, scores); - - kqv = ggml_permute(ctx0, kqv, 1, 0, 2, 3); - kqv = ggml_cont(ctx0, kqv); - - - kqv = ggml_reshape_3d(ctx0, kqv, N, D * n_head, B); - kqv = ggml_reshape_4d(ctx0, kqv, W, H, D * n_head, B); - kqv = ggml_cont(ctx0, kqv); - - // Output projection - cur = ggml_conv_2d_direct(ctx0, block.attn_o_w, kqv, 1, 1, 0, 0, 1, 1); - - // --- Residual & Layer Scale (FIXED) --- - if (inp->ne[0] == cur->ne[0] && inp->ne[2] == cur->ne[2]) { - if (block.layer_scale_w) { - ggml_tensor * scale_w_reshaped = ggml_reshape_4d(ctx0, block.layer_scale_w, - 1, 1, block.layer_scale_w->ne[0], 1); - cur = ggml_mul(ctx0, cur, scale_w_reshaped); - } - cur = ggml_add(ctx0, cur, inp); - } - - return cur; -} - // siglip2 naflex ggml_tensor * clip_graph::resize_position_embeddings(uint32_t interpolation_mode) { ggml_tensor * pos_embd = model.position_embeddings; @@ -2414,18 +2165,6 @@ void clip_build_img_from_pixels(const unsigned char * rgb_pixels, int nx, int ny memcpy(img->buf.data(), rgb_pixels, img->buf.size()); } -// Rescale image from u8 to f32 without normalization (for models like GEMMA3N that use SiglipImageProcessorFast) -// This only converts from [0, 255] to [0.0, 1.0] range without applying mean/std normalization -static void rescale_image_u8_to_f32(const clip_image_u8 & src, clip_image_f32 & dst) { - dst.nx = src.nx; - dst.ny = src.ny; - dst.buf.resize(src.buf.size()); - - for (size_t i = 0; i < src.buf.size(); ++i) { - dst.buf[i] = static_cast(src.buf[i]) / 255.0f; - } -} - // Normalize image to float32 - careful with pytorch .to(model.device, dtype=torch.float16) - this sometimes reduces precision (32>16>32), sometimes not static void normalize_image_u8_to_f32(const clip_image_u8 & src, clip_image_f32 & dst, const float mean[3], const float std[3]) { dst.nx = src.nx; @@ -3123,13 +2862,11 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str case PROJECTOR_TYPE_GEMMA3N: { - // GEMMA3N uses SiglipImageProcessorFast which only rescales to [0.0, 1.0] without normalization - // Resize to 768x768 using bilinear interpolation, then rescale to f32 clip_image_u8 resized_image; int sz = params.image_size; img_tool::resize(*img, resized_image, {sz, sz}, img_tool::RESIZE_ALGO_BILINEAR, false); clip_image_f32_ptr img_f32(clip_image_f32_init()); - rescale_image_u8_to_f32(resized_image, *img_f32); + normalize_image_u8_to_f32(resized_image, *img_f32, params.image_mean, params.image_std); res_imgs->entries.push_back(std::move(img_f32)); } break; @@ -3396,7 +3133,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im { // MobileNetV5 MSFA adapter always outputs fixed 16x16 resolution // regardless of input size (see architecture description) - n_patches = 16 * 16; // 256 tokens + n_patches = ctx->model.hparams.image_size / ctx->model.hparams.patch_size; } break; case PROJECTOR_TYPE_LFM2: case PROJECTOR_TYPE_KIMIVL: @@ -3969,10 +3706,6 @@ bool clip_is_gemma3(const struct clip_ctx * ctx) { return ctx->proj_type() == PROJECTOR_TYPE_GEMMA3; } -bool clip_is_gemma3n(const struct clip_ctx * ctx) { - return ctx->proj_type() == PROJECTOR_TYPE_GEMMA3N; -} - bool clip_has_vision_encoder(const struct clip_ctx * ctx) { return ctx->model.modality == CLIP_MODALITY_VISION; } diff --git a/tools/mtmd/clip.h b/tools/mtmd/clip.h index c244df2677f..68a0d6e857e 100644 --- a/tools/mtmd/clip.h +++ b/tools/mtmd/clip.h @@ -107,7 +107,6 @@ bool clip_is_glm(const struct clip_ctx * ctx); bool clip_is_mrope(const struct clip_ctx * ctx); bool clip_is_llava(const struct clip_ctx * ctx); bool clip_is_gemma3(const struct clip_ctx * ctx); -bool clip_is_gemma3n(const struct clip_ctx * ctx); bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec); diff --git a/tools/mtmd/models/mobilenetv5.cpp b/tools/mtmd/models/mobilenetv5.cpp index 930da38e302..bc1185c10eb 100644 --- a/tools/mtmd/models/mobilenetv5.cpp +++ b/tools/mtmd/models/mobilenetv5.cpp @@ -1,5 +1,254 @@ #include "models.h" +// --- Helpers for MobileNetV5 Blocks --- +// RMS Norm 2D - normalizes over channels for each spatial position +ggml_tensor * clip_graph_mobilenetv5::rms_norm_2d(ggml_tensor * inp, ggml_tensor * weight, float eps) { + // inp: [W, H, C, B] + + ggml_tensor * cur = ggml_permute(ctx0, inp, 2, 1, 0, 3); + cur = ggml_cont(ctx0, cur); + cur = ggml_rms_norm(ctx0, cur, eps); + + if (weight) { + cur = ggml_mul(ctx0, cur, weight); + } + + cur = ggml_permute(ctx0, cur, 2, 1, 0, 3); + cur = ggml_cont(ctx0, cur); + + return cur; +} + +// Helper for Conv2dSame padding (asymmetric SAME padding like PyTorch/TF) +ggml_tensor* clip_graph_mobilenetv5::pad_same_2d(ggml_tensor* inp, int kernel_h, int kernel_w, int stride_h, int stride_w, int dilation_h, int dilation_w) { + const int64_t ih = inp->ne[1]; // height + const int64_t iw = inp->ne[0]; // width + + // Calculate output size (ceil division) + const int64_t oh = (ih + stride_h - 1) / stride_h; + const int64_t ow = (iw + stride_w - 1) / stride_w; + + // Calculate padding needed + const int64_t pad_h = std::max((int64_t)0, (oh - 1) * stride_h + (kernel_h - 1) * dilation_h + 1 - ih); + const int64_t pad_w = std::max((int64_t)0, (ow - 1) * stride_w + (kernel_w - 1) * dilation_w + 1 - iw); + + // Split padding asymmetrically + const int pad_h_top = pad_h / 2; + const int pad_h_bottom = pad_h - pad_h_top; + const int pad_w_left = pad_w / 2; + const int pad_w_right = pad_w - pad_w_left; + + // Apply padding if needed + // ggml_pad_ext: (ctx, tensor, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3) + // For [W, H, C, B]: p0=width, p1=height, p2=channels, p3=batch + if (pad_h > 0 || pad_w > 0) { + inp = ggml_pad_ext(ctx0, inp, + pad_w_left, pad_w_right, // width padding (dim 0) + pad_h_top, pad_h_bottom, // height padding (dim 1) + 0, 0, // no channel padding (dim 2) + 0, 0); // no batch padding (dim 3) + } + + return inp; +} + + +// Edge Residual Block (Stage 0) +ggml_tensor * clip_graph_mobilenetv5::build_edge_residual(ggml_tensor * inp, const mobilenetv5_block & block, int stride) { + ggml_tensor * cur = inp; + + // 1. Expansion Conv (3x3) + if (stride == 2) { + // Case: Downsampling (Block 0) + // Replicates Conv2dSame(kernel=3, stride=2) + cur = pad_same_2d(cur, 3, 3, stride, stride); + cur = ggml_conv_2d_direct(ctx0, block.s0_conv_exp_w, cur, stride, stride, 0, 0, 1, 1); + } else { + // Case: Normal 3x3 Block (Block 1, 2) + // Replicates Conv2d(kernel=3, stride=1, padding=1) + cur = ggml_conv_2d_direct(ctx0, block.s0_conv_exp_w, cur, stride, stride, 1, 1, 1, 1); + } + + // BN + Activation + if (block.s0_bn1_w) cur = rms_norm_2d(cur, block.s0_bn1_w); + cur = ggml_gelu(ctx0, cur); + + // 2. Pointwise Linear Conv (1x1) + // 1x1 Convs usually have padding=0 and stride=1 + cur = ggml_conv_2d_direct(ctx0, block.s0_conv_pwl_w, cur, 1, 1, 0, 0, 1, 1); + if (block.s0_bn2_w) cur = rms_norm_2d(cur, block.s0_bn2_w); + + // 3. Residual Connection + // Only apply residual if spatial dimensions and channels match (stride 1) + if (stride == 1 && inp->ne[2] == cur->ne[2] && inp->ne[0] == cur->ne[0]) { + cur = ggml_add(ctx0, cur, inp); + } + + return cur; +} + +ggml_tensor * clip_graph_mobilenetv5::build_inverted_residual(ggml_tensor * inp, const mobilenetv5_block & block, int stride) { + ggml_tensor * cur = inp; + + // 1. Depthwise Start (Optional) + // NOTE: dw_start always has stride=1 (no downsampling here) + if (block.dw_start_w) { + int k = block.dw_start_w->ne[0]; // 3 or 5 + int p = k / 2; + cur = ggml_conv_2d_dw(ctx0, block.dw_start_w, cur, 1, 1, p, p, 1, 1); + if (block.dw_start_bn_w) cur = rms_norm_2d(cur, block.dw_start_bn_w); + } + + // 2. Pointwise Expansion (1x1) + if (block.pw_exp_w) { + // Standard 1x1 conv, pad=0, stride=1 + cur = ggml_conv_2d_direct(ctx0, block.pw_exp_w, cur, 1, 1, 0, 0, 1, 1); + if (block.pw_exp_bn_w) cur = rms_norm_2d(cur, block.pw_exp_bn_w); + cur = ggml_gelu(ctx0, cur); + } + + // 3. Depthwise Mid (Optional) + // NOTE: dw_mid is where downsampling happens (stride=2 for first block of stage) + if (block.dw_mid_w) { + int k = block.dw_mid_w->ne[0]; // 3 or 5 + + if (stride > 1) { + // Case: Stride 2 (Downsample) -> Use Asymmetric "Same" Padding + cur = pad_same_2d(cur, k, k, stride, stride); + cur = ggml_conv_2d_dw(ctx0, block.dw_mid_w, cur, stride, stride, 0, 0, 1, 1); // pad=0 + } else { + // Case: Stride 1 -> Use Standard Symmetric Padding + int p = k / 2; + cur = ggml_conv_2d_dw(ctx0, block.dw_mid_w, cur, stride, stride, p, p, 1, 1); + } + + if (block.dw_mid_bn_w) cur = rms_norm_2d(cur, block.dw_mid_bn_w); + cur = ggml_gelu(ctx0, cur); + } + + // 4. Pointwise Projection (1x1) + if (block.pw_proj_w) { + cur = ggml_conv_2d_direct(ctx0, block.pw_proj_w, cur, 1, 1, 0, 0, 1, 1); + if (block.pw_proj_bn_w) cur = rms_norm_2d(cur, block.pw_proj_bn_w); + } + + // Apply Layer Scaling if present + if (block.layer_scale_w) { + ggml_tensor * scale_w_reshaped = ggml_reshape_4d(ctx0, block.layer_scale_w, + 1, 1, block.layer_scale_w->ne[0], 1); + + cur = ggml_mul(ctx0, cur, scale_w_reshaped); + } + + // 5. Residual Connection + bool same_spatial = (inp->ne[0] == cur->ne[0]) && (inp->ne[1] == cur->ne[1]); + bool same_channel = (inp->ne[2] == cur->ne[2]); + if (same_spatial && same_channel) { + cur = ggml_add(ctx0, cur, inp); + } + + return cur; +} + +// MobileNetV5 Builder (Gemma 3n) - Attention Block +ggml_tensor * clip_graph_mobilenetv5::build_mobilenet_attn(ggml_tensor * inp, const mobilenetv5_block & block) { + ggml_tensor * cur = inp; + + // --- Norm --- + if (block.attn_norm_w) { + cur = rms_norm_2d(cur, block.attn_norm_w, 1e-6f); + } + + // --- 1. Q Calculation --- + ggml_tensor * q = ggml_conv_2d_direct(ctx0, block.attn_q_w, cur, 1, 1, 0, 0, 1, 1); + + // --- 2. K Calculation (Downsampled) --- + // Uses Conv2dSame(640, 640, kernel_size=(3, 3), stride=(2, 2), groups=640) + ggml_tensor * k_inp = cur; + if (block.attn_k_dw_w) { + int k_size = block.attn_k_dw_w->ne[0]; // Usually 3 + k_inp = pad_same_2d(cur, k_size, k_size, 2, 2); // Apply SAME padding + k_inp = ggml_conv_2d_dw(ctx0, block.attn_k_dw_w, k_inp, 2, 2, 0, 0, 1, 1); // padding=0 + if (block.attn_k_norm_w) { + k_inp = rms_norm_2d(k_inp, block.attn_k_norm_w, 1e-6f); + } + } + ggml_tensor * k = ggml_conv_2d_direct(ctx0, block.attn_k_w, k_inp, 1, 1, 0, 0, 1, 1); + + // --- 3. V Calculation (Downsampled) --- + // Uses Conv2dSame(640, 640, kernel_size=(3, 3), stride=(2, 2), groups=640) + ggml_tensor * v_inp = cur; + if (block.attn_v_dw_w) { + int v_size = block.attn_v_dw_w->ne[0]; // Usually 3 + v_inp = pad_same_2d(cur, v_size, v_size, 2, 2); // Apply SAME padding + v_inp = ggml_conv_2d_dw(ctx0, block.attn_v_dw_w, v_inp, 2, 2, 0, 0, 1, 1); // padding=0 + if (block.attn_v_norm_w) { + v_inp = rms_norm_2d(v_inp, block.attn_v_norm_w, 1e-6f); + } + } + ggml_tensor * v = ggml_conv_2d_direct(ctx0, block.attn_v_w, v_inp, 1, 1, 0, 0, 1, 1); + + const int W = cur->ne[0]; const int H = cur->ne[1]; const int B = cur->ne[3]; + const int D = k->ne[2]; // Head dimension + const int n_head = q->ne[2] / D; + const int N = W * H; + + // Process Q: [W, H, D*n_head, B] -> [D, N, n_head, B] + q = ggml_reshape_3d(ctx0, q, N, D*n_head, B); + q = ggml_reshape_4d(ctx0, q, N, D, n_head, B); + q = ggml_permute(ctx0, q, 1, 0, 2, 3); // [D, N, n_head, B] + q = ggml_cont(ctx0, q); + + const int Wk = k->ne[0]; const int Hk = k->ne[1]; + const int M = Wk * Hk; + + // Process K: [Wk, Hk, D, B] -> [D, M, 1, B] + k = ggml_reshape_3d(ctx0, k, M, D, B); + k = ggml_reshape_4d(ctx0, k, M, D, 1, B); + k = ggml_permute(ctx0, k, 1, 0, 2, 3); // [D, M, 1, B] + k = ggml_cont(ctx0, k); + + // Process V: [Wk, Hk, D, B] -> [M, D, 1, B] + v = ggml_reshape_3d(ctx0, v, M, D, B); + v = ggml_reshape_4d(ctx0, v, M, D, 1, B); + v = ggml_cont(ctx0, v); // [M, D, 1, B] + + // --- Multi-Query Attention --- + float scale = 1.0f / sqrtf((float)D); + + // Step 1: Compute Q @ K.T + ggml_tensor * scores = ggml_mul_mat(ctx0, k, q); + + scores = ggml_scale(ctx0, scores, scale); + + scores = ggml_soft_max(ctx0, scores); + + ggml_tensor * kqv = ggml_mul_mat(ctx0, v, scores); + + kqv = ggml_permute(ctx0, kqv, 1, 0, 2, 3); + kqv = ggml_cont(ctx0, kqv); + + + kqv = ggml_reshape_3d(ctx0, kqv, N, D * n_head, B); + kqv = ggml_reshape_4d(ctx0, kqv, W, H, D * n_head, B); + kqv = ggml_cont(ctx0, kqv); + + // Output projection + cur = ggml_conv_2d_direct(ctx0, block.attn_o_w, kqv, 1, 1, 0, 0, 1, 1); + + // --- Residual & Layer Scale (FIXED) --- + if (inp->ne[0] == cur->ne[0] && inp->ne[2] == cur->ne[2]) { + if (block.layer_scale_w) { + ggml_tensor * scale_w_reshaped = ggml_reshape_4d(ctx0, block.layer_scale_w, + 1, 1, block.layer_scale_w->ne[0], 1); + cur = ggml_mul(ctx0, cur, scale_w_reshaped); + } + cur = ggml_add(ctx0, cur, inp); + } + + return cur; +} + ggml_cgraph * clip_graph_mobilenetv5::build() { fprintf(stderr, "\n--- START build_mobilenetv5 ---\n"); diff --git a/tools/mtmd/models/models.h b/tools/mtmd/models/models.h index 3875285fe92..54664d10ce3 100644 --- a/tools/mtmd/models/models.h +++ b/tools/mtmd/models/models.h @@ -70,4 +70,32 @@ struct clip_graph_glm4v : clip_graph { struct clip_graph_mobilenetv5 : clip_graph { clip_graph_mobilenetv5(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} ggml_cgraph * build() override; + + ggml_tensor * rms_norm_2d( + ggml_tensor * inp, + ggml_tensor * weight, + float eps = 1e-6f); + + ggml_tensor* pad_same_2d( + ggml_tensor* inp, + int kernel_h, + int kernel_w, + int stride_h, + int stride_w, + int dilation_h = 1, + int dilation_w = 1); + + ggml_tensor * build_edge_residual( + ggml_tensor * inp, + const mobilenetv5_block & block, + int stride); + + ggml_tensor * build_inverted_residual( + ggml_tensor * inp, + const mobilenetv5_block & block, + int stride); + + ggml_tensor * build_mobilenet_attn( + ggml_tensor * inp, + const mobilenetv5_block & block); }; From 86618c7c0a0a3aff2aa12294fb17b2ad15610c29 Mon Sep 17 00:00:00 2001 From: Simranjeet Singh Date: Mon, 22 Dec 2025 13:45:24 +0000 Subject: [PATCH 8/8] Remove obsolete comments --- tools/mtmd/clip.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index e86a09bb5c1..dd778ea3c96 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -1536,8 +1536,6 @@ struct clip_model_loader { model.msfa_ffn_project_w = get_tensor(TN_MNV5_MSFA_FFN_PROJ_W, false); model.msfa_ffn_project_bn = get_tensor(TN_MNV5_MSFA_FFN_PROJ_BN, false); - // IMPORTANT: Your GGUF log shows 'v.enc.msfa.norm.weight' -> shape {2048} - // Ensure TN_MNV5_MSFA_NORM matches this string model.msfa_concat_norm_w = get_tensor(TN_MNV5_MSFA_NORM, false); // Dynamically load blocks stage by stage @@ -1620,8 +1618,6 @@ struct clip_model_loader { // Load projection weights (similar to Gemma3) model.mm_input_proj_w = get_tensor(TN_MM_INP_PROJ); model.mm_soft_emb_norm_w = get_tensor(TN_MM_SOFT_EMB_N); - // model.mm_post_proj_norm_w = get_tensor(TN_MM_POST_PROJ_N); // CRITICAL: Post projection norm - // Load additional Gemma3n projection tensors model.mm_0_w = get_tensor("mm.embedding.weight", false); // Input embedding model.mm_1_w = get_tensor("mm.hard_emb_norm.weight", false); // Hard embedding norm } break;