Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ dependencies:
- biopython==1.81
- biotite==0.37.0
- chex==0.1.7
- deepspeed==0.9.2
- deepspeed==0.12.2
- dgllife==0.3.2
- dm-haiku==0.0.8
- dm-tree==0.1.7
Expand Down Expand Up @@ -88,8 +88,6 @@ dependencies:
- pytest
- requests
- subword-nmt
- tensorflow
- tensorflow_addons
- threadpoolctl
- typing-extensions
- urllib3
Expand Down
3 changes: 2 additions & 1 deletion prtm/common_modules/rosetta/contigs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
import torch.nn as nn
from prtm import protein
from prtm.constants import residue_constants
from prtm.common_modules.rosetta.kinematics import *
from prtm.common_modules.rosetta.util import *

Expand Down Expand Up @@ -559,7 +560,7 @@ def get_idx0(self):
hal_idx0_receptor = []
pdb_idx = list(
zip(
[protein.PDB_CHAIN_IDS[i] for i in self.structure.chain_index],
[residue_constants.PDB_CHAIN_IDS[i] for i in self.structure.chain_index],
self.structure.residue_index,
)
)
Expand Down
26 changes: 21 additions & 5 deletions prtm/constants/residue_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,7 @@ def make_bond_key(atom1_name, atom2_name):
"NZ",
"OXT",
]
backbone4_atoms = ["N", "CA", "C", "O"]
atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)}
atom_type_num = len(atom_types) # := 37.

Expand Down Expand Up @@ -984,10 +985,8 @@ def make_bond_key(atom1_name, atom2_name):
restype_num = len(restypes) # := 20.
unk_restype_index = restype_num # Catch-all index for unknown restypes.

proteinmppn_restypes = sorted(restypes) + ["X"]
proteinmpnn_restype_order = {
restype: i for i, restype in enumerate(proteinmppn_restypes)
}
alphabetical_restypes = sorted(restypes)
alphabetical_restypes_x = sorted(restypes) + ["X"]
restypes_with_x = restypes + ["X"]
restypes_with_x_dash = restypes_with_x + ["-"]
restype_order_with_x = {restype: i for i, restype in enumerate(restypes_with_x)}
Expand Down Expand Up @@ -1064,13 +1063,16 @@ def sequence_to_onehot(
"X": "UNK",
}


# NB: restype_3to1 differs from Bio.PDB.protein_letters_3to1 by being a simple
# 1-to-1 mapping of 3 letter names to one letter names. The latter contains
# many more, and less common, three letter names as keys and maps many of these
# to the same one letter name (including 'X' and 'U' which we don't use here).
restype_3to1 = {v: k for k, v in restype_1to3.items()}

# Add handling of non-standard residues for Chroma
restype_3to1.update({res: "H" for res in ["HSD", "HSE", "HSC", "HSP"]})
restype_3to1.update({"MSE": "M", "SEC": "C"})

restype2atom14_mask = np.zeros([len(restypes_with_x), 14])
for k, v in restype_name_to_atom14_names.items():
for i, atom in enumerate(v):
Expand Down Expand Up @@ -1486,8 +1488,15 @@ def make_atom14_dists_bounds(overlap_tolerance=1.5, bond_length_tolerance_factor
mask_ambiguous[restype, atom_idx2] = 1

restype_3 = [restype_1to3[res] for res in restypes]
alphabetical_restype_3 = [restype_1to3[res] for res in alphabetical_restypes]
restype_3 += ["UNK"]

num_atoms_per_alphabetical_aa = [
len(list(filter(None, restype_name_to_atom14_names[c])))
for c in alphabetical_restype_3
]
num_chi_per_alphabetical_aa = [len(chi_angles_atoms[c]) for c in alphabetical_restype_3]

all_matrices = {res: np.eye(14, dtype=np.float32) for res in restype_3}
for resname, swap in residue_atom_renaming_swaps.items():
correspondences = np.arange(14)
Expand Down Expand Up @@ -1542,6 +1551,13 @@ def substitute(res: str):
return None


PICO_TO_ANGSTROM = 0.01

PDB_CHAIN_IDS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
PDB_MAX_CHAINS = len(PDB_CHAIN_IDS)
assert PDB_MAX_CHAINS == 62


esm_proteinseq_toks = {
"toks": [
"L",
Expand Down
13 changes: 13 additions & 0 deletions prtm/models/chroma/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright Generate Biomedicines, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
124 changes: 124 additions & 0 deletions prtm/models/chroma/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Copyright Generate Biomedicines, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import hashlib
import json
import os
import tempfile

import requests

from prtm.models import chroma

ROOT_DIR = os.path.dirname(os.path.dirname(chroma.__file__))


def register_key(key: str, key_directory=ROOT_DIR) -> None:
"""
Registers the provided key by saving it to a JSON file.

Args:
key (str): The access token to be registered.
key_directory (str, optional): The directory where the access key is registered.

Returns:
None
"""
config_path = os.path.join(key_directory, "config.json")
with open(config_path, "w") as f:
json.dump({"access_token": key}, f)


def read_key(key_directory=ROOT_DIR) -> str:
"""
Reads the registered key from the JSON file. If no key has been registered,
it informs the user and raises a FileNotFoundError.

Args:
key_directory (str, optional): The directory where the access key is registered.

Returns:
str: The registered access token.

Raises:
FileNotFoundError: If no key has been registered.
"""
config_path = os.path.join(key_directory, "config.json")

if not os.path.exists(config_path):
print("No access token has been registered.")
print(
"To obtain an access token, go to https://chroma-weights.generatebiomedicines.com/ and agree to the license."
)
raise FileNotFoundError("No token has been registered.")

with open(config_path, "r") as f:
config = json.load(f)

return config["access_token"]


def download_from_generate(
base_url: str,
weights_name: str,
force: bool = False,
exist_ok: bool = False,
key_directory=ROOT_DIR,
) -> str:
"""
Downloads data from the provided URL using the registered access token.
Provides caching behavior based on force and exist_ok flags.

Args:
base_url (str): The base URL from which data should be fetched.
force (bool): If True, always fetches data from the URL regardless of cache existence.
exist_ok (bool): If True and cache exists (and force is False), uses the cached data.
key_directory (str, optional): The directory where the access key is registered.

Returns:
str: Path to the downloaded (or cached) file.
"""

# Create a hash of the URL + weight name to determine the path for the cached/temporary file
url_hash = hashlib.md5((base_url + weights_name).encode()).hexdigest()
temp_dir = os.path.join(tempfile.gettempdir(), "chroma_weights", url_hash)
destination = os.path.join(temp_dir, "weights.pt")

# Ensure the directory exists
os.makedirs(temp_dir, exist_ok=True)

# Check if cache exists
cache_exists = os.path.exists(destination)

# Determine if we should use the cache or not
use_cache = cache_exists and exist_ok and not force

if use_cache:
print(f"Using cached data from {destination}")
return destination

# If not using cache, proceed with download

# Define the query parameters
params = {"token": read_key(key_directory), "weights": weights_name}

# Perform the GET request with the token as a query parameter
response = requests.get(base_url, params=params)
response.raise_for_status() # Raise an error for HTTP errors

with open(destination, "wb") as file:
file.write(response.content)

print(f"Data saved to {destination}")
return destination
Loading