Skip to content

PaliGemma is a family of vision-language models (VLMs), combining the SigLIP vision encoder with the Gemma 2B language model.

Notifications You must be signed in to change notification settings

kcharvi/Vision_Language_Model_From_Scratch

Repository files navigation

PaliGemma: Multimodal Vision-Language Model from Scratch

The main agenda of this project is to gain the knowledge of important concepts and methodologies behind a vision language model. Techniques that power the VLM, algorithm applied and to gain a deeper understanding of topics such as

transformer model(embeddings, positional encoding, multi-head attention, feed forward layers, logits, softmax), rotary positional encoding, multi-query vs grouped query attention, contrastive learning, vision transformer encoder, language model decoder, KV Cache, weight tying, top-p sampling and temperature, attention masks and much more.

Vision-Language Model Architecture

Project Overview

PaliGemma is a family of vision-language models (VLMs), combining the SigLIP vision encoder with the Gemma 2B language model. Inspired by PaLI-3, PaliGemma is based on open components like the SigLIP vision model and the Gemma language model. PaliGemma takes both images and text as inputs and can answer questions about images with detail and context, enabling tasks such as image captioning, visual question answering (VQA), object detection, and reading text embedded within images.

  • Parameter Sizes: Available in 3B, 10B, and 28B parameter versions.
  • SigLIP Vision Encoder: A "shape-optimized" contrastively pretrained ViT that converts an image into a sequence of tokens, prepended to an optional prompt.
  • Gemma 2B Decoder: Acts as the language model decoder.
  • Full Attention: Uses full attention on all image and text tokens to maximize capacity.

PaliGemma is not a conversational model and works best when fine-tuned for specific downstream tasks such as image captioning, VQA, object detection, and document understanding.


Learning Topics & Explanations

Transformer Model

  • Core Components: Embeddings, Positional Encoding, Multi-Head Attention, Feed Forward Layer, Logits, Softmax
  • Other Concepts: Numerical stability of Softmax and Cross Entropy Loss, Attention masks (causal and non-causal), Top-P Sampling and Temperature, Weight tying

Contrastive Learning (CLIP, SigLIP)

  • CLIP was transformed to SigLIP by Google
  • Properties of I1T1 dot product should produce a higher value, so corresponding one gives more, I1 and T1 are image embedding and text embedding, picture this in the form of matrix with I representing rows and T representing columns.
  • Contrastive Learning → Take I1 embeddings and T1 embeddings; their dot should be higher value.
  • How do we do this? Cross Entropy Loss
  • How do we tell the model we want one item in a row/column to be higher while minimizing all others? → Cross Entropy Loss

Coding:

  • image_encoder - ResNet or Vision Transformer
  • text_encoder - CBOW or text transformer
  • I[n, h, w, c] - minibatch of aligned images
  • T[n, l] - minibatch of aligned texts
  • W_i[d_i, d_e] - learned projection of image to embed
  • W_t[d_t, d_e] - learned projection of text to embed
  • t - learned temperature parameter

Steps:

  1. Extract feature representations of each modality
    • I_f = image_encoder(I) → [n, d_i]
    • T_f = text_encoder(T) → [n, d_t]
  2. Joint multimodal embedding [n, d_e]
    • I_e = l2_normalize(np.dot(I_f, W_i), axis=1)
    • T_e = l2_normalize(np.dot(T_f, W_t), axis=1)
  3. Scaled pairwise cosine similarities
    • logits = np.dot(I_e, T_e.T) * np.exp(t)
  4. Symmetric loss function
    • labels = np.arange(n)
    • loss_i = cross_entropy_loss(logits, labels, axis=0)
    • loss_t = cross_entropy_loss(logits, labels, axis=1)
    • loss = (loss_i + loss_t) / 2

Note: SigLIP proposes using a Sigmoid function instead of Softmax, making each computation an independent binary computation.

Why Contrastive?

  • Embeddings should be good representations of both image and text, not just individually.
  • Contrastive Vision Encoder is basically a Vision Transformer.

Vision Transformer Model

Normalization (Batch, Layer, and RMS)

  • Covariate Shift: Input distribution changes between training and testing
  • Handling Covariate Shift: Importance weighting, domain adaptation, more representative data, robust models
  • Batch Normalization: Normalizes with mean and std stats of each dimension across the batch
  • Layer Normalization: Normalizes along the dimension of each item
  • Root Mean Squared Norm (RMS): Computes only one RMS stat, not mean and std

Multi-Head Attention

  • Contextualizes tokens with each other
  • Vision Transformer: sequence of patches
  • Language Model: sequence of tokens
  • Parallel computation for efficiency
  • Steps: Project to Q, K, V; calculate attention; concatenate heads; apply weights

Grouped Query Attention (GQA):

  • Use many Query heads, but share Keys and Values across groups
  • Reduces memory and compute, scales better to long sequences, decouples Q and KV capacity

KV-Cache (Prefilling and Token Generation)

  • Avoids recomputing all keys & values for past tokens
  • Caches previously computed K and V matrices
  • Pre-Filling: Populates the KV cache with K and V values from an input prompt before token-by-token generation

Rotary Positional Encoding

  • Injects positional information by rotating Q and K vectors based on their position
  • Encodes both content similarity and relative position
  • Benefits: relative position awareness, long context generalization, no position embedding table, improved performance

Top-P Sampling and Temperature

  • Top-P (nucleus) sampling and temperature control randomness and diversity in generation

Language Model - Gemma

  • Embeddings → [Normalization → Self-Attention → Skip Connection (+) → Normalization → Feed Forward Network → Skip Connection (+)] → Norm → Linear Logits → Softmax

Weight Tying:

  • Sharing the same weight matrix between the input embedding layer and the output projection layer
  • Reduces parameters and improves generalization
  • In practice: W = E^T, i.e., reuse the transpose of the embedding matrix

Model Weights

Download the model weights from: HuggingFace - paligemma-3b-pt-224


Code Structure and Component Explanations

1. inference.py

  • Purpose: Main script for running inference with the PaliGemma model. Loads the model and tokenizer, processes the input image and prompt, and generates a response.
  • Key Components:
    • Loads the model and tokenizer using utils.py.
    • Uses the processor to convert images and text into model-ready tensors.
    • Handles device selection (CPU/GPU).
    • Implements the inference loop, including token generation, sampling, and decoding.
    • Can be run from the command line with various arguments (model path, prompt, image, etc.).

2. launch_inference.sh

  • Purpose: Shell script to automate running inference.py with preset arguments.
  • Key Components:
    • Sets environment variables for model path, prompt, image, and generation parameters.
    • Calls python inference.py with these arguments for easy reproducibility.

3. modeling_siglip.py

  • Purpose: Implements the SigLIP Vision Transformer (ViT) model, which encodes images into embeddings.
  • Key Components:
    • SigLipConfig: Configuration class for the vision transformer.
    • SigLipVisionEmbeddings: Converts images into patch embeddings and adds positional encodings.
    • SigLipAttention, SigLipMLP, SigLipVisionEncoderLayer, SigLipVisionEncoder: Core transformer blocks for processing image patches.
    • SigLipVisionTransformer: Assembles the full vision transformer.
    • SigLipVisionModel: Top-level class for the vision encoder.

4. modelling_gemma.py

  • Purpose: Implements the Gemma language model and the overall multimodal architecture (PaliGemma).
  • Key Components:
    • GemmaConfig, PaliGemmaConfig: Configuration classes for the language and multimodal models.
    • GemmaAttention, GemmaMLP, GemmaDecoderLayer, GemmaModel: Core transformer blocks for the language model.
    • GemmaForCausalLM: Language model head for text generation.
    • PaliGemmaMultiModalProjector: Projects vision features into the language model's embedding space.
    • PaliGemmaForConditionalGeneration: The main multimodal model that combines vision and language, merges embeddings, and handles attention masks and KV cache.

5. processing_paligemma.py

  • Purpose: Handles preprocessing of images and text for the PaliGemma model.
  • Key Components:
    • Image resizing, normalization, and conversion to tensor.
    • Adds special image tokens to the prompt.
    • Tokenizes text using the HuggingFace tokenizer.
    • Returns a dictionary of tensors ready for model input.

6. utils.py

  • Purpose: Utility functions for loading models and tokenizers.
  • Key Components:
    • load_hf_model: Loads the model weights, configuration, and tokenizer from a HuggingFace directory.
    • Handles loading of safetensors and config files, and ties weights between embedding and output layers.

7. README.md

  • Purpose: Documentation for the project.
  • Key Components:
    • Project overview, learning topics, detailed explanations of core concepts, and usage instructions.

8. Vision_Language_Model.png

  • Purpose: Visual diagram of the overall architecture, showing how the vision and language components interact.

9. squirrel_example.jpg

  • Purpose: Example image used for testing the model's inference capabilities.

References

This project is purely for educational purposes. Please refer to the respective model and dataset licenses for usage restrictions. Thanks to @hkproj Umar Jamil for a wonderful explanation of each of the topics in his video

About

PaliGemma is a family of vision-language models (VLMs), combining the SigLIP vision encoder with the Gemma 2B language model.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published