Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,27 @@ Install `pyeed` by running
```bash
pip install git+https://github.com/PyEED/pyeed.git
```

## Features

- Supports various protein language models including ESM2, ESMC, ESM3, ProtT5,
and the newly integrated [SaProt](https://huggingface.co/westlake-repl/SaProt_650M_AF2).

## SaProt mutation prediction

PyEED exposes utilities from SaProt for estimating mutation effects. Example:

```python
from pyeed.embeddings.models import SaProtFoldseekMutationModel

model = SaProtFoldseekMutationModel()
model.to("cuda")
model.eval()

seq = "M#EvVpQpL#VyQdYaKv" # '#' marks low-confidence regions
score = model.predict_mut(seq, "V3A")
print(score)
```

By default, the weights are downloaded from Hugging Face. Set
``config_path`` if you wish to load a local model directory.
3 changes: 2 additions & 1 deletion src/pyeed/analysis/embedding_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
import scipy.spatial as sp
from matplotlib.figure import Figure
from numpy.typing import NDArray
from pyeed.dbconnect import DatabaseConnector
from scipy.spatial.distance import cosine

from pyeed.dbconnect import DatabaseConnector

logger = logging.getLogger(__name__)


Expand Down
1 change: 1 addition & 0 deletions src/pyeed/analysis/mutation_detection.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Optional

from loguru import logger

from pyeed.dbconnect import DatabaseConnector


Expand Down
1 change: 1 addition & 0 deletions src/pyeed/analysis/network_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import networkx as nx
from loguru import logger

from pyeed.dbconnect import DatabaseConnector


Expand Down
3 changes: 2 additions & 1 deletion src/pyeed/analysis/ontology_loading.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Dict

from pyeed.dbconnect import DatabaseConnector
from rdflib import OWL, RDF, RDFS, Graph, Namespace, URIRef

from pyeed.dbconnect import DatabaseConnector


class OntologyAdapter:
"""
Expand Down
3 changes: 2 additions & 1 deletion src/pyeed/analysis/sequence_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from Bio.Align import PairwiseAligner as BioPairwiseAligner
from Bio.Align.substitution_matrices import Array as BioSubstitutionMatrix
from joblib import Parallel, cpu_count, delayed
from rich.progress import Progress

from pyeed.dbconnect import DatabaseConnector
from pyeed.tools.utility import chunks
from rich.progress import Progress


class PairwiseAligner:
Expand Down
1 change: 1 addition & 0 deletions src/pyeed/analysis/standard_numbering.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import Any, Dict, List, Optional, Tuple

from loguru import logger

from pyeed.analysis.sequence_alignment import PairwiseAligner
from pyeed.dbconnect import DatabaseConnector
from pyeed.model import StandardNumbering
Expand Down
4 changes: 4 additions & 0 deletions src/pyeed/embeddings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
ESM3EmbeddingModel,
ESMCEmbeddingModel,
ProtT5EmbeddingModel,
SaProtEmbeddingModel,
SaProtFoldseekMutationModel,
)
from .processor import EmbeddingProcessor, get_processor
from .utils import (
Expand Down Expand Up @@ -209,6 +211,8 @@ def get_single_embedding_first_layer(
"ESMCEmbeddingModel",
"ESM3EmbeddingModel",
"ProtT5EmbeddingModel",
"SaProtEmbeddingModel",
"SaProtFoldseekMutationModel",
# Backward compatibility functions
"load_model_and_tokenizer",
"process_batches_on_gpu",
Expand Down
1 change: 1 addition & 0 deletions src/pyeed/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ class ModelType:
ESMC = "esmc"
ESM3 = "esm3"
PROTT5 = "prott5"
SAPROT = "saprot"


def normalize_embedding(embedding: NDArray[np.float64]) -> NDArray[np.float64]:
Expand Down
3 changes: 3 additions & 0 deletions src/pyeed/embeddings/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
ESM3EmbeddingModel,
ESMCEmbeddingModel,
ProtT5EmbeddingModel,
SaProtEmbeddingModel,
)
from .utils import determine_model_type

Expand Down Expand Up @@ -43,6 +44,8 @@ def create_model(
return ESMCEmbeddingModel(model_name, device)
elif model_type == "esm3":
return ESM3EmbeddingModel(model_name, device)
elif model_type == "saprot":
return SaProtEmbeddingModel(model_name, device)
elif model_type == "prott5":
return ProtT5EmbeddingModel(model_name, device)
else: # Default to ESM-2
Expand Down
4 changes: 4 additions & 0 deletions src/pyeed/embeddings/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,14 @@
from .esm3 import ESM3EmbeddingModel
from .esmc import ESMCEmbeddingModel
from .prott5 import ProtT5EmbeddingModel
from .saprot import SaProtEmbeddingModel
from .saprot_foldseek_mutation import SaProtFoldseekMutationModel

__all__ = [
"ESM2EmbeddingModel",
"ESMCEmbeddingModel",
"ESM3EmbeddingModel",
"ProtT5EmbeddingModel",
"SaProtEmbeddingModel",
"SaProtFoldseekMutationModel",
]
136 changes: 136 additions & 0 deletions src/pyeed/embeddings/models/saprot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
"""SaProt model implementation for protein embeddings."""

from typing import List, Tuple, cast

import numpy as np
import torch
from numpy.typing import NDArray
from transformers import EsmForMaskedLM, EsmTokenizer

from ..base import BaseEmbeddingModel, normalize_embedding
from ..utils import get_hf_token


class SaProtEmbeddingModel(BaseEmbeddingModel):
"""SaProt model implementation."""

def __init__(self, model_name: str, device: torch.device):
super().__init__(model_name, device)

def load_model(self) -> Tuple[EsmForMaskedLM, EsmTokenizer]:
"""Load SaProt model and tokenizer."""
token = get_hf_token()

model = EsmForMaskedLM.from_pretrained(
self.model_name, use_auth_token=token
)
tokenizer = EsmTokenizer.from_pretrained(
self.model_name, use_auth_token=token
)

model = model.to(self.device)

self.model = model
self.tokenizer = tokenizer

return model, tokenizer

def preprocess_sequence(self, sequence: str) -> str:
"""SaProt doesn't need special preprocessing."""
return sequence

def get_batch_embeddings(
self, sequences: List[str], pool_embeddings: bool = True
) -> List[NDArray[np.float64]]:
"""Get embeddings for a batch of sequences using SaProt."""
if self.model is None or self.tokenizer is None:
self.load_model()

model = cast(EsmForMaskedLM, self.model)
tokenizer = cast(EsmTokenizer, self.tokenizer)

embeddings = []

for sequence in sequences:
inputs = tokenizer(
sequence, padding=True, truncation=True, return_tensors="pt"
).to(self.device)

with torch.no_grad():
outputs = model.esm(**inputs, output_hidden_states=True)

hidden_states = outputs.last_hidden_state.cpu().numpy()

if pool_embeddings:
embeddings.append(hidden_states.mean(axis=1)[0])
else:
embeddings.append(hidden_states)
return embeddings

def get_single_embedding_last_hidden_state(
self, sequence: str
) -> NDArray[np.float64]:
"""Get last hidden state embedding for a single sequence."""
if self.model is None or self.tokenizer is None:
self.load_model()

model = cast(EsmForMaskedLM, self.model)
tokenizer = cast(EsmTokenizer, self.tokenizer)

inputs = tokenizer(sequence, return_tensors="pt").to(self.device)

with torch.no_grad():
outputs = model.esm(**inputs)

embedding = outputs.last_hidden_state[0, 1:-1, :].detach().cpu().numpy()
return np.asarray(embedding, dtype=np.float64)

def get_single_embedding_all_layers(self, sequence: str) -> NDArray[np.float64]:
"""Get embeddings from all layers for a single sequence."""
if self.model is None or self.tokenizer is None:
self.load_model()

model = cast(EsmForMaskedLM, self.model)
tokenizer = cast(EsmTokenizer, self.tokenizer)

inputs = tokenizer(sequence, return_tensors="pt").to(self.device)

with torch.no_grad():
outputs = model.esm(**inputs, output_hidden_states=True)

embeddings_list = []
hidden_states = outputs.hidden_states

for layer_tensor in hidden_states:
emb = layer_tensor[0, 1:-1, :].detach().cpu().numpy()
emb = normalize_embedding(emb)
embeddings_list.append(emb)

return np.array(embeddings_list)

def get_single_embedding_first_layer(self, sequence: str) -> NDArray[np.float64]:
"""Get first layer embedding for a single sequence."""
if self.model is None or self.tokenizer is None:
self.load_model()

model = cast(EsmForMaskedLM, self.model)
tokenizer = cast(EsmTokenizer, self.tokenizer)

inputs = tokenizer(sequence, return_tensors="pt").to(self.device)

with torch.no_grad():
outputs = model.esm(**inputs, output_hidden_states=True)

embedding = outputs.hidden_states[0][0, 1:-1, :].detach().cpu().numpy()
embedding = normalize_embedding(embedding)
return embedding

def get_final_embeddings(self, sequence: str) -> NDArray[np.float64]:
"""Get final embeddings for SaProt with robust fallback."""
try:
embeddings = self.get_batch_embeddings([sequence], pool_embeddings=True)
if embeddings and len(embeddings) > 0:
return np.asarray(embeddings[0], dtype=np.float64)
raise ValueError("Batch embeddings method returned empty results")
except Exception as e:
raise ValueError(f"SaProt embedding extraction failed: {e}")
Loading
Loading