Skip to content

Conversation

@NiklasAbraham
Copy link
Contributor

PyEED Embedding Architecture Report

Executive Summary

The PyEED embedding system has undergone a significant architectural transformation to provide a more robust, scalable, and user-friendly interface for protein embedding calculations. This report details the new structure, legacy support mechanisms, and the design philosophy behind these changes.

Table of Contents

  1. New Architecture Overview
  2. Complete Module Structure
  3. Legacy Support Strategy
  4. Design Philosophy and Benefits
  5. Usage Guide
  6. Migration Path
  7. Technical Implementation Details

New Architecture Overview

Core Components

The new embedding system is built around several key components:

1. EmbeddingProcessor (Main Interface)

  • Location: src/pyeed/embeddings/processor.py
  • Purpose: Central coordinator for all embedding operations
  • Key Features:
    • Automatic device management (CPU/GPU)
    • Model lifecycle management
    • Simplified high-level APIs
    • Multi-GPU support with automatic fallback

2. BaseEmbeddingModel (Abstract Base)

  • Location: src/pyeed/embeddings/base.py
  • Purpose: Common interface for all embedding models
  • Benefits: Ensures consistency across different model implementations

3. ModelFactory (Creation Pattern)

  • Location: src/pyeed/embeddings/factory.py
  • Purpose: Centralized model creation and configuration
  • Benefits: Encapsulates model-specific initialization logic

4. Specialized Model Classes

  • Location: src/pyeed/embeddings/models/
  • Models Supported:
    • ESM2EmbeddingModel (ESM-2 variants)
    • ESMCEmbeddingModel (ESM-C models)
    • ESM3EmbeddingModel (ESM-3 models)
    • ProtT5EmbeddingModel (ProtT5 variants)

Architectural Patterns

Factory Pattern

# Old way - manual model instantiation
model = load_model_manually(device, config)

# New way - factory-based creation
model = ModelFactory.create_model(model_name, device)

Singleton Pattern

# Global processor instance for efficiency
_processor = EmbeddingProcessor()

def get_processor() -> EmbeddingProcessor:
    return _processor

Complete Module Structure

File Organization

src/pyeed/embeddings/
├── __init__.py              # Public API and backward compatibility
├── processor.py             # Main EmbeddingProcessor class
├── base.py                  # Abstract base classes and interfaces
├── factory.py               # Model factory for creating instances
├── utils.py                 # Utility functions and helpers
├── database.py              # Database integration functions
└── models/                  # Model-specific implementations
    ├── __init__.py
    ├── esm2.py             # ESM-2 model implementation
    ├── esmc.py             # ESM-C model implementation
    ├── esm3.py             # ESM-3 model implementation
    └── prott5.py           # ProtT5 model implementation

Public API Structure

The module exports a clean public API through __init__.py:

Core Classes

from pyeed.embeddings import (
    # Main interfaces
    EmbeddingProcessor,      # Primary interface
    get_processor,           # Singleton accessor
    
    # Base classes
    BaseEmbeddingModel,      # For custom implementations
    ModelFactory,            # For programmatic model creation
    
    # Model implementations
    ESM2EmbeddingModel,
    ESMCEmbeddingModel,
    ESM3EmbeddingModel,
    ProtT5EmbeddingModel,
)

Utility Functions

from pyeed.embeddings import (
    # Utilities
    get_hf_token,                    # Hugging Face authentication
    preprocess_sequence_for_prott5,  # ProtT5-specific preprocessing
    free_memory,                     # Memory management
    determine_model_type,            # Automatic model type detection
    normalize_embedding,             # L2 normalization
    
    # Database integration
    update_protein_embeddings_in_db,
)

Backward Compatibility

from pyeed.embeddings import (
    # Legacy functions (still supported)
    load_model_and_tokenizer,
    process_batches_on_gpu,
    calculate_single_sequence_embedding_last_hidden_state,
    # ... other legacy functions
)

Base Class Interface

The BaseEmbeddingModel defines the contract that all model implementations must follow:

class BaseEmbeddingModel(ABC):
    """Abstract base class for protein embedding models."""
    
    # Required methods for all implementations
    @abstractmethod
    def load_model(self) -> Tuple[Any, Optional[Any]]:
        """Load and return the model and tokenizer."""
    
    @abstractmethod
    def get_batch_embeddings(self, sequences: List[str], pool_embeddings: bool = True) -> List[NDArray[np.float64]]:
        """Get embeddings for a batch of sequences."""
    
    @abstractmethod
    def get_single_embedding_last_hidden_state(self, sequence: str) -> NDArray[np.float64]:
        """Get embedding from the last hidden state."""
    
    @abstractmethod
    def get_single_embedding_all_layers(self, sequence: str) -> NDArray[np.float64]:
        """Get embeddings from all layers."""
    
    @abstractmethod
    def get_single_embedding_first_layer(self, sequence: str) -> NDArray[np.float64]:
        """Get embedding from the first layer."""
    
    # Default implementation available
    def get_final_embeddings(self, sequence: str) -> NDArray[np.float64]:
        """Robust embedding method that works across all models."""
        return self.get_single_embedding_last_hidden_state(sequence)

Model Type Detection

The system automatically detects model types based on naming conventions:

def determine_model_type(model_name: str) -> str:
    """Determine the model type based on model name."""
    model_name_lower = model_name.lower()
    
    if "esmc" in model_name_lower:
        return "esmc"
    elif "esm3" in model_name_lower:
        return "esm3"
    elif "prot_t5" in model_name_lower or "prott5" in model_name_lower:
        return "prott5"
    else:
        return "esm2"  # Default to ESM-2

Legacy Support Strategy

Philosophy

The new architecture maintains 100% backward compatibility while encouraging migration to the improved interfaces. This is achieved through:

  1. Method Wrapping: Legacy methods are wrapped to use new implementations
  2. Parameter Mapping: Old parameter formats are mapped to new ones
  3. Deprecation Warnings: Gentle nudging toward new methods
  4. Feature Parity: All legacy functionality is preserved

Legacy Method Categories

1. Batch Processing Legacy Methods

# Legacy method (still works)
def process_batches_on_gpu(
    self,
    data: List[tuple[str, str]],
    batch_size: int,
    model: Union[Any, DataParallel[Module]],
    tokenizer: Union[Any, None],
    db: DatabaseConnector,
    device: torch.device,
) -> None:
    """Legacy method for backward compatibility."""
    # Internally calls calculate_batch_embeddings()

Modern Equivalent:

# New method (recommended)
processor.calculate_batch_embeddings(
    data=data,
    model_name="facebook/esm2_t33_650M_UR50D",
    batch_size=batch_size,
    db=db
)

2. Single Embedding Legacy Methods

# Legacy methods (still work)
calculate_single_sequence_embedding_last_hidden_state()
calculate_single_sequence_embedding_all_layers()
calculate_single_sequence_embedding_first_layer()

# Modern equivalent
calculate_single_embedding(
    sequence=sequence,
    embedding_type="last_hidden_state" | "all_layers" | "first_layer"
)

3. Model-Specific Legacy Methods

# Legacy method (still works)
def get_batch_embeddings_unified(
    self,
    batch_sequences: List[str],
    model: Union[Any, DataParallel[Module]],
    tokenizer: Union[Any, None],
    device: torch.device = torch.device("cuda:0"),
    pool_embeddings: bool = True,
) -> List[NDArray[np.float64]]:
    # Automatically detects model type and uses appropriate handler

Migration Safety Net

The legacy support includes:

  • Automatic Model Type Detection: Legacy methods detect the model type and route to appropriate handlers
  • Parameter Translation: Old parameter formats are automatically converted
  • Warning System: Deprecation warnings guide users toward new methods
  • Error Handling: Robust fallback mechanisms for edge cases

Backward Compatibility Implementation

The __init__.py file provides comprehensive backward compatibility:

# Try to import from old embedding.py
try:
    from ..embedding import (
        load_model_and_tokenizer,
        process_batches_on_gpu,
        # ... other legacy functions
    )
except ImportError:
    # Fallback implementations using new processor
    _processor = get_processor()
    
    def load_model_and_tokenizer(model_name: str, device=None):
        """Backward compatibility function."""
        return None, None, device
    
    def process_batches_on_gpu(data, batch_size, model, tokenizer, db, device):
        """Backward compatibility function."""
        return _processor.process_batches_on_gpu(data, batch_size, model, tokenizer, db, device)

Design Philosophy and Benefits

Core Principles

1. Separation of Concerns

  • Model Management: Isolated in specialized classes
  • Device Management: Centralized in the processor
  • Embedding Logic: Separated by model type
  • Database Operations: Abstracted from core embedding logic

2. Unified Interface

# One method handles all embedding types
embedding = processor.calculate_single_embedding(
    sequence="MKLLVLSL...",
    model_name="facebook/esm2_t33_650M_UR50D",
    embedding_type="final_embeddings"  # Works across all models
)

3. Automatic Resource Management

  • Device Selection: Automatically chooses optimal device
  • Memory Management: Automatic cleanup and garbage collection
  • Model Caching: Intelligent model reuse across operations
  • Error Recovery: Automatic fallback for OOM errors

4. Extensibility

  • New Models: Easy to add via base class inheritance
  • New Features: Minimal impact on existing code
  • Custom Embedding Types: Pluggable embedding strategies

Key Benefits

Performance Improvements

  • Multi-GPU Support: Automatic parallelization across available GPUs
  • Smart Batching: Adaptive batch sizing for memory optimization
  • Model Reuse: Cached models reduce initialization overhead
  • Memory Efficiency: Automatic cleanup and garbage collection

Developer Experience

  • Simplified API: Single method for most use cases
  • Type Safety: Full type hints throughout
  • Better Error Messages: Clear, actionable error information
  • Documentation: Comprehensive docstrings and examples

Reliability

  • Error Handling: Graceful degradation and recovery
  • Resource Management: Automatic cleanup prevents memory leaks
  • Thread Safety: Safe for concurrent operations
  • Fallback Mechanisms: CPU fallback when GPU unavailable

Usage Guide

Basic Usage

Single Embedding Calculation

from pyeed.embeddings import get_processor

processor = get_processor()

# Simple case - uses defaults
embedding = processor.calculate_single_embedding("MKLLVLSL...")

# Full configuration
embedding = processor.calculate_single_embedding(
    sequence="MKLLVLSL...",
    model_name="facebook/esm2_t33_650M_UR50D",
    embedding_type="final_embeddings",
    device=torch.device("cuda:0")
)

Batch Processing

# Prepare data
data = [
    ("protein1", "MKLLVLSL..."),
    ("protein2", "AKLMVLSL..."),
    # ... more sequences
]

# Process batch
embeddings = processor.calculate_batch_embeddings(
    data=data,
    model_name="facebook/esm2_t33_650M_UR50D",
    batch_size=16,
    embedding_type="final_embeddings"
)

Database Integration

from pyeed.dbconnect import DatabaseConnector

db = DatabaseConnector()

# Process all sequences in database
processor.calculate_database_embeddings(
    db=db,
    model_name="facebook/esm2_t33_650M_UR50D",
    batch_size=32,
    embedding_type="final_embeddings"
)

Advanced Configuration

Multi-GPU Processing

# Use all available GPUs
processor.calculate_batch_embeddings(
    data=data,
    num_gpus=None  # Use all available
)

# Limit to specific number of GPUs
processor.calculate_batch_embeddings(
    data=data,
    num_gpus=2  # Use only 2 GPUs
)

Embedding Type Selection

# Available embedding types:
# - "last_hidden_state": Most common, model's final layer
# - "all_layers": Average across all transformer layers
# - "first_layer": First layer embeddings
# - "final_embeddings": Robust cross-model option (recommended)

embedding = processor.calculate_single_embedding(
    sequence="MKLLVLSL...",
    embedding_type="final_embeddings"  # Recommended for compatibility
)

Device Management

# Automatic device selection (recommended)
processor = get_processor()
devices = processor.get_available_devices()

# Manual device specification
embedding = processor.calculate_single_embedding(
    sequence="MKLLVLSL...",
    device=torch.device("cuda:1")  # Specific GPU
)

Custom Model Implementations

To add a new model type, inherit from BaseEmbeddingModel:

from pyeed.embeddings import BaseEmbeddingModel

class CustomEmbeddingModel(BaseEmbeddingModel):
    def load_model(self):
        # Implement model loading logic
        pass
    
    def get_single_embedding_last_hidden_state(self, sequence: str):
        # Implement embedding calculation
        pass
    
    # ... implement other required methods

Migration Path

Phase 1: Immediate (No Code Changes Required)

  • Existing code continues to work unchanged
  • Legacy methods route to new implementations
  • Deprecation warnings provide guidance

Phase 2: Gradual Migration (Recommended)

  • Replace legacy method calls with new equivalents
  • Benefit from improved error handling and performance
  • Maintain existing workflows

Phase 3: Full Adoption (Optional)

  • Adopt new patterns and best practices
  • Leverage advanced features like multi-GPU processing
  • Implement new embedding types

Migration Examples

Batch Processing Migration

# Old way (still works)
processor.process_batches_on_gpu(
    data=data,
    batch_size=16,
    model=model,
    tokenizer=tokenizer,
    db=db,
    device=device
)

# New way (recommended)
processor.calculate_batch_embeddings(
    data=data,
    model_name="facebook/esm2_t33_650M_UR50D",
    batch_size=16,
    db=db
)

Single Embedding Migration

# Old way (still works)
embedding = processor.calculate_single_sequence_embedding_last_hidden_state(
    sequence="MKLLVLSL...",
    device=device,
    model_name="facebook/esm2_t33_650M_UR50D"
)

# New way (recommended)
embedding = processor.calculate_single_embedding(
    sequence="MKLLVLSL...",
    model_name="facebook/esm2_t33_650M_UR50D",
    embedding_type="last_hidden_state"
)

Technical Implementation Details

Model Factory Pattern

The factory pattern centralizes model creation logic:

class ModelFactory:
    @staticmethod
    def create_model(model_name: str, device: torch.device) -> BaseEmbeddingModel:
        model_type = determine_model_type(model_name)
        
        if model_type == "esmc":
            return ESMCEmbeddingModel(model_name, device)
        elif model_type == "esm3":
            return ESM3EmbeddingModel(model_name, device)
        elif model_type == "prott5":
            return ProtT5EmbeddingModel(model_name, device)
        else:  # Default to ESM-2
            return ESM2EmbeddingModel(model_name, device)

Device Management

Automatic device detection and management:

def _initialize_devices(self) -> None:
    if torch.cuda.is_available():
        device_count = torch.cuda.device_count()
        self._devices = [torch.device(f"cuda:{i}") for i in range(device_count)]
    else:
        self._devices = [torch.device("cpu")]

Memory Management

Automatic cleanup and error recovery:

except torch.cuda.OutOfMemoryError:
    torch.cuda.empty_cache()
    current_batch_size = max(1, current_batch_size // 2)
    logger.warning(f"Reduced batch size to {current_batch_size}")

Legacy Compatibility

Method wrapping preserves old interfaces:

def calculate_single_sequence_embedding_last_hidden_state(self, ...):
    """Legacy method for backward compatibility."""
    return self.calculate_single_embedding(sequence, model_name, "last_hidden_state", device)

Utility Functions

The system includes comprehensive utility functions:

def free_memory() -> None:
    """Frees up memory by invoking garbage collection and clearing GPU caches."""
    gc.collect()
    if torch.backends.mps.is_available():
        torch.mps.empty_cache()
    elif torch.cuda.is_available():
        torch.cuda.empty_cache()

def preprocess_sequence_for_prott5(sequence: str) -> str:
    """Preprocesses a protein sequence for ProtT5 models."""
    sequence = re.sub(r"[UZOB]", "X", sequence.upper())
    return " ".join(list(sequence))

def normalize_embedding(embedding: NDArray[np.float64]) -> NDArray[np.float64]:
    """Normalize embeddings using L2 normalization."""
    return embedding / np.linalg.norm(embedding, axis=1, keepdims=True)

Error Handling and Resilience

The system includes robust error handling:

  1. Out-of-Memory Recovery: Automatic batch size reduction
  2. Threading Issues: Fallback to single-GPU mode
  3. Model Loading Failures: Graceful degradation
  4. Device Unavailability: CPU fallback
try:
    model = self.get_or_create_model(model_name, device)
    models.append(model)
except Exception as e:
    if "tqdm" in str(e).lower() or "_lock" in str(e).lower():
        logger.warning(f"Model loading failed on {device} due to threading issue. Reducing to single GPU mode.")
        # Fall back to single GPU mode
        devices_to_use = [devices_to_use[0]]
        num_gpus = 1
        models = [self.get_or_create_model(model_name, devices_to_use[0])]
        break

Conclusion

The new embedding architecture represents a significant improvement in:

  • Usability: Simplified, unified interface
  • Performance: Multi-GPU support and optimizations
  • Reliability: Better error handling and resource management
  • Maintainability: Clean separation of concerns and extensibility

The legacy support ensures a smooth transition path, allowing teams to migrate at their own pace while immediately benefiting from the improved underlying implementation.

Recommendations

  1. For New Projects: Use the new interfaces exclusively
  2. For Existing Projects: Migrate gradually, starting with new features
  3. For Production Systems: Test thoroughly but expect seamless operation
  4. For Performance-Critical Applications: Leverage multi-GPU capabilities

The architecture is designed to grow with your needs while maintaining stability and backward compatibility.

@haeussma haeussma self-requested a review June 4, 2025 14:06
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

umap cannot be dependency, since no cross-platform support

Copy link
Member

@haeussma haeussma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

umap cannot be a dependency, since it is not cross platform compatible

@haeussma haeussma merged commit 22f6bda into main Jun 6, 2025
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants