From f485019a6f6db83f71c27c843aedd853f017695e Mon Sep 17 00:00:00 2001 From: Niklas Abraham Date: Sat, 5 Jul 2025 11:39:20 +0200 Subject: [PATCH] Update SaProt usage example --- README.md | 24 ++++ src/pyeed/analysis/embedding_analysis.py | 3 +- src/pyeed/analysis/mutation_detection.py | 1 + src/pyeed/analysis/network_analysis.py | 1 + src/pyeed/analysis/ontology_loading.py | 3 +- src/pyeed/analysis/sequence_alignment.py | 3 +- src/pyeed/analysis/standard_numbering.py | 1 + src/pyeed/embeddings/__init__.py | 4 + src/pyeed/embeddings/base.py | 1 + src/pyeed/embeddings/factory.py | 3 + src/pyeed/embeddings/models/__init__.py | 4 + src/pyeed/embeddings/models/saprot.py | 136 ++++++++++++++++++ .../models/saprot_foldseek_mutation.py | 133 +++++++++++++++++ src/pyeed/embeddings/utils.py | 2 + 14 files changed, 316 insertions(+), 3 deletions(-) create mode 100644 src/pyeed/embeddings/models/saprot.py create mode 100644 src/pyeed/embeddings/models/saprot_foldseek_mutation.py diff --git a/README.md b/README.md index 2232c6f3..ba5186d1 100644 --- a/README.md +++ b/README.md @@ -19,3 +19,27 @@ Install `pyeed` by running ```bash pip install git+https://github.com/PyEED/pyeed.git ``` + +## Features + +- Supports various protein language models including ESM2, ESMC, ESM3, ProtT5, + and the newly integrated [SaProt](https://huggingface.co/westlake-repl/SaProt_650M_AF2). + +## SaProt mutation prediction + +PyEED exposes utilities from SaProt for estimating mutation effects. Example: + +```python +from pyeed.embeddings.models import SaProtFoldseekMutationModel + +model = SaProtFoldseekMutationModel() +model.to("cuda") +model.eval() + +seq = "M#EvVpQpL#VyQdYaKv" # '#' marks low-confidence regions +score = model.predict_mut(seq, "V3A") +print(score) +``` + +By default, the weights are downloaded from Hugging Face. Set +``config_path`` if you wish to load a local model directory. diff --git a/src/pyeed/analysis/embedding_analysis.py b/src/pyeed/analysis/embedding_analysis.py index 73aaefb3..d8d104c0 100644 --- a/src/pyeed/analysis/embedding_analysis.py +++ b/src/pyeed/analysis/embedding_analysis.py @@ -6,9 +6,10 @@ import scipy.spatial as sp from matplotlib.figure import Figure from numpy.typing import NDArray -from pyeed.dbconnect import DatabaseConnector from scipy.spatial.distance import cosine +from pyeed.dbconnect import DatabaseConnector + logger = logging.getLogger(__name__) diff --git a/src/pyeed/analysis/mutation_detection.py b/src/pyeed/analysis/mutation_detection.py index 274e168b..a9881810 100644 --- a/src/pyeed/analysis/mutation_detection.py +++ b/src/pyeed/analysis/mutation_detection.py @@ -1,6 +1,7 @@ from typing import Any, Optional from loguru import logger + from pyeed.dbconnect import DatabaseConnector diff --git a/src/pyeed/analysis/network_analysis.py b/src/pyeed/analysis/network_analysis.py index dd66b45c..3ab9aeaa 100644 --- a/src/pyeed/analysis/network_analysis.py +++ b/src/pyeed/analysis/network_analysis.py @@ -2,6 +2,7 @@ import networkx as nx from loguru import logger + from pyeed.dbconnect import DatabaseConnector diff --git a/src/pyeed/analysis/ontology_loading.py b/src/pyeed/analysis/ontology_loading.py index ee909636..5b6341f5 100644 --- a/src/pyeed/analysis/ontology_loading.py +++ b/src/pyeed/analysis/ontology_loading.py @@ -1,8 +1,9 @@ from typing import Dict -from pyeed.dbconnect import DatabaseConnector from rdflib import OWL, RDF, RDFS, Graph, Namespace, URIRef +from pyeed.dbconnect import DatabaseConnector + class OntologyAdapter: """ diff --git a/src/pyeed/analysis/sequence_alignment.py b/src/pyeed/analysis/sequence_alignment.py index 0ca43d02..5634bde1 100644 --- a/src/pyeed/analysis/sequence_alignment.py +++ b/src/pyeed/analysis/sequence_alignment.py @@ -5,9 +5,10 @@ from Bio.Align import PairwiseAligner as BioPairwiseAligner from Bio.Align.substitution_matrices import Array as BioSubstitutionMatrix from joblib import Parallel, cpu_count, delayed +from rich.progress import Progress + from pyeed.dbconnect import DatabaseConnector from pyeed.tools.utility import chunks -from rich.progress import Progress class PairwiseAligner: diff --git a/src/pyeed/analysis/standard_numbering.py b/src/pyeed/analysis/standard_numbering.py index 4bf9a8e8..fef383e7 100644 --- a/src/pyeed/analysis/standard_numbering.py +++ b/src/pyeed/analysis/standard_numbering.py @@ -13,6 +13,7 @@ from typing import Any, Dict, List, Optional, Tuple from loguru import logger + from pyeed.analysis.sequence_alignment import PairwiseAligner from pyeed.dbconnect import DatabaseConnector from pyeed.model import StandardNumbering diff --git a/src/pyeed/embeddings/__init__.py b/src/pyeed/embeddings/__init__.py index 729ec422..9296e2ea 100644 --- a/src/pyeed/embeddings/__init__.py +++ b/src/pyeed/embeddings/__init__.py @@ -26,6 +26,8 @@ ESM3EmbeddingModel, ESMCEmbeddingModel, ProtT5EmbeddingModel, + SaProtEmbeddingModel, + SaProtFoldseekMutationModel, ) from .processor import EmbeddingProcessor, get_processor from .utils import ( @@ -209,6 +211,8 @@ def get_single_embedding_first_layer( "ESMCEmbeddingModel", "ESM3EmbeddingModel", "ProtT5EmbeddingModel", + "SaProtEmbeddingModel", + "SaProtFoldseekMutationModel", # Backward compatibility functions "load_model_and_tokenizer", "process_batches_on_gpu", diff --git a/src/pyeed/embeddings/base.py b/src/pyeed/embeddings/base.py index c436937d..ec2c16d8 100644 --- a/src/pyeed/embeddings/base.py +++ b/src/pyeed/embeddings/base.py @@ -108,6 +108,7 @@ class ModelType: ESMC = "esmc" ESM3 = "esm3" PROTT5 = "prott5" + SAPROT = "saprot" def normalize_embedding(embedding: NDArray[np.float64]) -> NDArray[np.float64]: diff --git a/src/pyeed/embeddings/factory.py b/src/pyeed/embeddings/factory.py index 5f23b2c6..0aa62e15 100644 --- a/src/pyeed/embeddings/factory.py +++ b/src/pyeed/embeddings/factory.py @@ -16,6 +16,7 @@ ESM3EmbeddingModel, ESMCEmbeddingModel, ProtT5EmbeddingModel, + SaProtEmbeddingModel, ) from .utils import determine_model_type @@ -43,6 +44,8 @@ def create_model( return ESMCEmbeddingModel(model_name, device) elif model_type == "esm3": return ESM3EmbeddingModel(model_name, device) + elif model_type == "saprot": + return SaProtEmbeddingModel(model_name, device) elif model_type == "prott5": return ProtT5EmbeddingModel(model_name, device) else: # Default to ESM-2 diff --git a/src/pyeed/embeddings/models/__init__.py b/src/pyeed/embeddings/models/__init__.py index fa7b5006..3f649a5a 100644 --- a/src/pyeed/embeddings/models/__init__.py +++ b/src/pyeed/embeddings/models/__init__.py @@ -8,10 +8,14 @@ from .esm3 import ESM3EmbeddingModel from .esmc import ESMCEmbeddingModel from .prott5 import ProtT5EmbeddingModel +from .saprot import SaProtEmbeddingModel +from .saprot_foldseek_mutation import SaProtFoldseekMutationModel __all__ = [ "ESM2EmbeddingModel", "ESMCEmbeddingModel", "ESM3EmbeddingModel", "ProtT5EmbeddingModel", + "SaProtEmbeddingModel", + "SaProtFoldseekMutationModel", ] diff --git a/src/pyeed/embeddings/models/saprot.py b/src/pyeed/embeddings/models/saprot.py new file mode 100644 index 00000000..64de3386 --- /dev/null +++ b/src/pyeed/embeddings/models/saprot.py @@ -0,0 +1,136 @@ +"""SaProt model implementation for protein embeddings.""" + +from typing import List, Tuple, cast + +import numpy as np +import torch +from numpy.typing import NDArray +from transformers import EsmForMaskedLM, EsmTokenizer + +from ..base import BaseEmbeddingModel, normalize_embedding +from ..utils import get_hf_token + + +class SaProtEmbeddingModel(BaseEmbeddingModel): + """SaProt model implementation.""" + + def __init__(self, model_name: str, device: torch.device): + super().__init__(model_name, device) + + def load_model(self) -> Tuple[EsmForMaskedLM, EsmTokenizer]: + """Load SaProt model and tokenizer.""" + token = get_hf_token() + + model = EsmForMaskedLM.from_pretrained( + self.model_name, use_auth_token=token + ) + tokenizer = EsmTokenizer.from_pretrained( + self.model_name, use_auth_token=token + ) + + model = model.to(self.device) + + self.model = model + self.tokenizer = tokenizer + + return model, tokenizer + + def preprocess_sequence(self, sequence: str) -> str: + """SaProt doesn't need special preprocessing.""" + return sequence + + def get_batch_embeddings( + self, sequences: List[str], pool_embeddings: bool = True + ) -> List[NDArray[np.float64]]: + """Get embeddings for a batch of sequences using SaProt.""" + if self.model is None or self.tokenizer is None: + self.load_model() + + model = cast(EsmForMaskedLM, self.model) + tokenizer = cast(EsmTokenizer, self.tokenizer) + + embeddings = [] + + for sequence in sequences: + inputs = tokenizer( + sequence, padding=True, truncation=True, return_tensors="pt" + ).to(self.device) + + with torch.no_grad(): + outputs = model.esm(**inputs, output_hidden_states=True) + + hidden_states = outputs.last_hidden_state.cpu().numpy() + + if pool_embeddings: + embeddings.append(hidden_states.mean(axis=1)[0]) + else: + embeddings.append(hidden_states) + return embeddings + + def get_single_embedding_last_hidden_state( + self, sequence: str + ) -> NDArray[np.float64]: + """Get last hidden state embedding for a single sequence.""" + if self.model is None or self.tokenizer is None: + self.load_model() + + model = cast(EsmForMaskedLM, self.model) + tokenizer = cast(EsmTokenizer, self.tokenizer) + + inputs = tokenizer(sequence, return_tensors="pt").to(self.device) + + with torch.no_grad(): + outputs = model.esm(**inputs) + + embedding = outputs.last_hidden_state[0, 1:-1, :].detach().cpu().numpy() + return np.asarray(embedding, dtype=np.float64) + + def get_single_embedding_all_layers(self, sequence: str) -> NDArray[np.float64]: + """Get embeddings from all layers for a single sequence.""" + if self.model is None or self.tokenizer is None: + self.load_model() + + model = cast(EsmForMaskedLM, self.model) + tokenizer = cast(EsmTokenizer, self.tokenizer) + + inputs = tokenizer(sequence, return_tensors="pt").to(self.device) + + with torch.no_grad(): + outputs = model.esm(**inputs, output_hidden_states=True) + + embeddings_list = [] + hidden_states = outputs.hidden_states + + for layer_tensor in hidden_states: + emb = layer_tensor[0, 1:-1, :].detach().cpu().numpy() + emb = normalize_embedding(emb) + embeddings_list.append(emb) + + return np.array(embeddings_list) + + def get_single_embedding_first_layer(self, sequence: str) -> NDArray[np.float64]: + """Get first layer embedding for a single sequence.""" + if self.model is None or self.tokenizer is None: + self.load_model() + + model = cast(EsmForMaskedLM, self.model) + tokenizer = cast(EsmTokenizer, self.tokenizer) + + inputs = tokenizer(sequence, return_tensors="pt").to(self.device) + + with torch.no_grad(): + outputs = model.esm(**inputs, output_hidden_states=True) + + embedding = outputs.hidden_states[0][0, 1:-1, :].detach().cpu().numpy() + embedding = normalize_embedding(embedding) + return embedding + + def get_final_embeddings(self, sequence: str) -> NDArray[np.float64]: + """Get final embeddings for SaProt with robust fallback.""" + try: + embeddings = self.get_batch_embeddings([sequence], pool_embeddings=True) + if embeddings and len(embeddings) > 0: + return np.asarray(embeddings[0], dtype=np.float64) + raise ValueError("Batch embeddings method returned empty results") + except Exception as e: + raise ValueError(f"SaProt embedding extraction failed: {e}") diff --git a/src/pyeed/embeddings/models/saprot_foldseek_mutation.py b/src/pyeed/embeddings/models/saprot_foldseek_mutation.py new file mode 100644 index 00000000..baf2f10d --- /dev/null +++ b/src/pyeed/embeddings/models/saprot_foldseek_mutation.py @@ -0,0 +1,133 @@ +"""SaProt model for mutation effect prediction using Foldseek tokens.""" + +from __future__ import annotations + +from typing import Dict + +import torch +from transformers import EsmConfig, EsmForMaskedLM, EsmTokenizer + +from ..utils import get_hf_token + +AA_LIST = [ + "A", + "C", + "D", + "E", + "F", + "G", + "H", + "I", + "K", + "L", + "M", + "N", + "P", + "Q", + "R", + "S", + "T", + "V", + "W", + "Y", +] +FOLDSEEK_STRUC_VOCAB = "pynwrqhgdlvtmfsaeikc#" + + +class SaProtFoldseekMutationModel: + """Lightweight wrapper for mutation effect prediction with SaProt.""" + + def __init__( + self, + foldseek_path: str | None = None, + config_path: str | None = None, + load_pretrained: bool = True, + ) -> None: + self.foldseek_path = foldseek_path + self.config_path = config_path or "westlake-repl/SaProt_650M_AF2" + self.load_pretrained = load_pretrained + + token = get_hf_token() + if load_pretrained: + self.model = EsmForMaskedLM.from_pretrained(self.config_path, use_auth_token=token) + else: + cfg = EsmConfig.from_pretrained(self.config_path) + self.model = EsmForMaskedLM(cfg) + + self.tokenizer = EsmTokenizer.from_pretrained(self.config_path, use_auth_token=token) + self.device = torch.device("cpu") + self.model.eval() + + # ------------------------------------------------------------------ + # Basic helpers + # ------------------------------------------------------------------ + def to(self, device: str | torch.device) -> "SaProtFoldseekMutationModel": + self.device = torch.device(device) + self.model.to(self.device) + return self + + def eval(self) -> None: # pragma: no cover - wrapper + self.model.eval() + + # ------------------------------------------------------------------ + # Mutation effect prediction utilities + # ------------------------------------------------------------------ + def _mask_sequence(self, seq: str, mut_info: str) -> str: + tokens = self.tokenizer.tokenize(seq) + for single in mut_info.split(":"): + pos = int(single[1:-1]) + tokens[pos - 1] = "#" + tokens[pos - 1][-1] + return " ".join(tokens) + + def predict_mut(self, seq: str, mut_info: str) -> float: + """Predict effect of one or more mutations.""" + mask_seq = self._mask_sequence(seq, mut_info) + inputs = self.tokenizer(mask_seq, return_tensors="pt").to(self.device) + with torch.no_grad(): + probs = self.model(**inputs).logits.softmax(dim=-1) + + score = 0.0 + for single in mut_info.split(":"): + ori_aa, pos, mut_aa = single[0], int(single[1:-1]), single[-1] + ori_st = self.tokenizer.get_vocab()[ori_aa + FOLDSEEK_STRUC_VOCAB[0]] + mut_st = self.tokenizer.get_vocab()[mut_aa + FOLDSEEK_STRUC_VOCAB[0]] + ori_prob = probs[0, pos, ori_st : ori_st + len(FOLDSEEK_STRUC_VOCAB)].sum() + mut_prob = probs[0, pos, mut_st : mut_st + len(FOLDSEEK_STRUC_VOCAB)].sum() + score += torch.log(mut_prob / ori_prob) + return float(score.item()) + + def predict_pos_mut(self, seq: str, pos: int) -> Dict[str, float]: + """Predict mutation effect for all amino acids at a position.""" + tokens = self.tokenizer.tokenize(seq) + ori_aa = tokens[pos - 1][0] + tokens[pos - 1] = "#" + tokens[pos - 1][-1] + mask_seq = " ".join(tokens) + inputs = self.tokenizer(mask_seq, return_tensors="pt").to(self.device) + with torch.no_grad(): + probs = self.model(**inputs).logits.softmax(dim=-1)[0, pos] + + ori_st = self.tokenizer.get_vocab()[ori_aa + FOLDSEEK_STRUC_VOCAB[0]] + ori_prob = probs[ori_st : ori_st + len(FOLDSEEK_STRUC_VOCAB)].sum() + scores = {} + for mut_aa in AA_LIST: + mut_st = self.tokenizer.get_vocab()[mut_aa + FOLDSEEK_STRUC_VOCAB[0]] + mut_prob = probs[mut_st : mut_st + len(FOLDSEEK_STRUC_VOCAB)].sum() + scores[f"{ori_aa}{pos}{mut_aa}"] = float(torch.log(mut_prob / ori_prob).item()) + return scores + + def predict_pos_prob(self, seq: str, pos: int) -> Dict[str, float]: + """Return probabilities for all amino acids at a position.""" + tokens = self.tokenizer.tokenize(seq) + tokens[pos - 1] = "#" + tokens[pos - 1][-1] + mask_seq = " ".join(tokens) + inputs = self.tokenizer(mask_seq, return_tensors="pt").to(self.device) + with torch.no_grad(): + probs = self.model(**inputs).logits.softmax(dim=-1)[0, pos] + + scores = {} + for aa in AA_LIST: + st = self.tokenizer.get_vocab()[aa + FOLDSEEK_STRUC_VOCAB[0]] + prob = probs[st : st + len(FOLDSEEK_STRUC_VOCAB)].sum() + scores[aa] = float(prob.item()) + return scores + diff --git a/src/pyeed/embeddings/utils.py b/src/pyeed/embeddings/utils.py index da5e69cd..8784f1d1 100644 --- a/src/pyeed/embeddings/utils.py +++ b/src/pyeed/embeddings/utils.py @@ -88,6 +88,8 @@ def determine_model_type(model_name: str) -> str: return "esmc" elif "esm3" in model_name_lower: return "esm3" + elif "saprot" in model_name_lower: + return "saprot" elif "prot_t5" in model_name_lower or "prott5" in model_name_lower: return "prott5" else: