The main agenda of this project is to gain the knowledge of important concepts and methodologies behind a vision language model. Techniques that power the VLM, algorithm applied and to gain a deeper understanding of topics such as
transformer model(embeddings, positional encoding, multi-head attention, feed forward layers, logits, softmax), rotary positional encoding, multi-query vs grouped query attention, contrastive learning, vision transformer encoder, language model decoder, KV Cache, weight tying, top-p sampling and temperature, attention masks and much more.
PaliGemma is a family of vision-language models (VLMs), combining the SigLIP vision encoder with the Gemma 2B language model. Inspired by PaLI-3, PaliGemma is based on open components like the SigLIP vision model and the Gemma language model. PaliGemma takes both images and text as inputs and can answer questions about images with detail and context, enabling tasks such as image captioning, visual question answering (VQA), object detection, and reading text embedded within images.
- Parameter Sizes: Available in 3B, 10B, and 28B parameter versions.
- SigLIP Vision Encoder: A "shape-optimized" contrastively pretrained ViT that converts an image into a sequence of tokens, prepended to an optional prompt.
- Gemma 2B Decoder: Acts as the language model decoder.
- Full Attention: Uses full attention on all image and text tokens to maximize capacity.
PaliGemma is not a conversational model and works best when fine-tuned for specific downstream tasks such as image captioning, VQA, object detection, and document understanding.
- Core Components: Embeddings, Positional Encoding, Multi-Head Attention, Feed Forward Layer, Logits, Softmax
- Other Concepts: Numerical stability of Softmax and Cross Entropy Loss, Attention masks (causal and non-causal), Top-P Sampling and Temperature, Weight tying
- CLIP was transformed to SigLIP by Google
- Properties of I1T1 dot product should produce a higher value, so corresponding one gives more, I1 and T1 are image embedding and text embedding, picture this in the form of matrix with I representing rows and T representing columns.
- Contrastive Learning → Take I1 embeddings and T1 embeddings; their dot should be higher value.
- How do we do this? Cross Entropy Loss
- How do we tell the model we want one item in a row/column to be higher while minimizing all others? → Cross Entropy Loss
Coding:
image_encoder- ResNet or Vision Transformertext_encoder- CBOW or text transformerI[n, h, w, c]- minibatch of aligned imagesT[n, l]- minibatch of aligned textsW_i[d_i, d_e]- learned projection of image to embedW_t[d_t, d_e]- learned projection of text to embedt- learned temperature parameter
Steps:
- Extract feature representations of each modality
I_f = image_encoder(I)→ [n, d_i]T_f = text_encoder(T)→ [n, d_t]
- Joint multimodal embedding [n, d_e]
I_e = l2_normalize(np.dot(I_f, W_i), axis=1)T_e = l2_normalize(np.dot(T_f, W_t), axis=1)
- Scaled pairwise cosine similarities
logits = np.dot(I_e, T_e.T) * np.exp(t)
- Symmetric loss function
labels = np.arange(n)loss_i = cross_entropy_loss(logits, labels, axis=0)loss_t = cross_entropy_loss(logits, labels, axis=1)loss = (loss_i + loss_t) / 2
Note: SigLIP proposes using a Sigmoid function instead of Softmax, making each computation an independent binary computation.
Why Contrastive?
- Embeddings should be good representations of both image and text, not just individually.
- Contrastive Vision Encoder is basically a Vision Transformer.
- An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
- Sequence-to-sequence model
- Vision encoder to extract information from the image
- Trained with contrastive learning
- Image is split into patches, embedded, and passed through a transformer
- Not auto-regressive; any patch can depend on any other patch
- Covariate Shift: Input distribution changes between training and testing
- Handling Covariate Shift: Importance weighting, domain adaptation, more representative data, robust models
- Batch Normalization: Normalizes with mean and std stats of each dimension across the batch
- Layer Normalization: Normalizes along the dimension of each item
- Root Mean Squared Norm (RMS): Computes only one RMS stat, not mean and std
- Contextualizes tokens with each other
- Vision Transformer: sequence of patches
- Language Model: sequence of tokens
- Parallel computation for efficiency
- Steps: Project to Q, K, V; calculate attention; concatenate heads; apply weights
Grouped Query Attention (GQA):
- Use many Query heads, but share Keys and Values across groups
- Reduces memory and compute, scales better to long sequences, decouples Q and KV capacity
- Avoids recomputing all keys & values for past tokens
- Caches previously computed K and V matrices
- Pre-Filling: Populates the KV cache with K and V values from an input prompt before token-by-token generation
- Injects positional information by rotating Q and K vectors based on their position
- Encodes both content similarity and relative position
- Benefits: relative position awareness, long context generalization, no position embedding table, improved performance
- Top-P (nucleus) sampling and temperature control randomness and diversity in generation
- Embeddings → [Normalization → Self-Attention → Skip Connection (+) → Normalization → Feed Forward Network → Skip Connection (+)] → Norm → Linear Logits → Softmax
Weight Tying:
- Sharing the same weight matrix between the input embedding layer and the output projection layer
- Reduces parameters and improves generalization
- In practice:
W = E^T, i.e., reuse the transpose of the embedding matrix
Download the model weights from: HuggingFace - paligemma-3b-pt-224
- Purpose: Main script for running inference with the PaliGemma model. Loads the model and tokenizer, processes the input image and prompt, and generates a response.
- Key Components:
- Loads the model and tokenizer using
utils.py. - Uses the processor to convert images and text into model-ready tensors.
- Handles device selection (CPU/GPU).
- Implements the inference loop, including token generation, sampling, and decoding.
- Can be run from the command line with various arguments (model path, prompt, image, etc.).
- Loads the model and tokenizer using
- Purpose: Shell script to automate running
inference.pywith preset arguments. - Key Components:
- Sets environment variables for model path, prompt, image, and generation parameters.
- Calls
python inference.pywith these arguments for easy reproducibility.
- Purpose: Implements the SigLIP Vision Transformer (ViT) model, which encodes images into embeddings.
- Key Components:
SigLipConfig: Configuration class for the vision transformer.SigLipVisionEmbeddings: Converts images into patch embeddings and adds positional encodings.SigLipAttention,SigLipMLP,SigLipVisionEncoderLayer,SigLipVisionEncoder: Core transformer blocks for processing image patches.SigLipVisionTransformer: Assembles the full vision transformer.SigLipVisionModel: Top-level class for the vision encoder.
- Purpose: Implements the Gemma language model and the overall multimodal architecture (PaliGemma).
- Key Components:
GemmaConfig,PaliGemmaConfig: Configuration classes for the language and multimodal models.GemmaAttention,GemmaMLP,GemmaDecoderLayer,GemmaModel: Core transformer blocks for the language model.GemmaForCausalLM: Language model head for text generation.PaliGemmaMultiModalProjector: Projects vision features into the language model's embedding space.PaliGemmaForConditionalGeneration: The main multimodal model that combines vision and language, merges embeddings, and handles attention masks and KV cache.
- Purpose: Handles preprocessing of images and text for the PaliGemma model.
- Key Components:
- Image resizing, normalization, and conversion to tensor.
- Adds special image tokens to the prompt.
- Tokenizes text using the HuggingFace tokenizer.
- Returns a dictionary of tensors ready for model input.
- Purpose: Utility functions for loading models and tokenizers.
- Key Components:
load_hf_model: Loads the model weights, configuration, and tokenizer from a HuggingFace directory.- Handles loading of safetensors and config files, and ties weights between embedding and output layers.
- Purpose: Documentation for the project.
- Key Components:
- Project overview, learning topics, detailed explanations of core concepts, and usage instructions.
- Purpose: Visual diagram of the overall architecture, showing how the vision and language components interact.
- Purpose: Example image used for testing the model's inference capabilities.
This project is purely for educational purposes. Please refer to the respective model and dataset licenses for usage restrictions. Thanks to @hkproj Umar Jamil for a wonderful explanation of each of the topics in his video
