diff --git a/environment.yaml b/environment.yaml index 896dfa1e..5c8b724f 100644 --- a/environment.yaml +++ b/environment.yaml @@ -27,7 +27,7 @@ dependencies: - biopython==1.81 - biotite==0.37.0 - chex==0.1.7 - - deepspeed==0.9.2 + - deepspeed==0.12.2 - dgllife==0.3.2 - dm-haiku==0.0.8 - dm-tree==0.1.7 @@ -88,8 +88,6 @@ dependencies: - pytest - requests - subword-nmt - - tensorflow - - tensorflow_addons - threadpoolctl - typing-extensions - urllib3 diff --git a/prtm/common_modules/rosetta/contigs.py b/prtm/common_modules/rosetta/contigs.py index 52ac0280..d36f1545 100644 --- a/prtm/common_modules/rosetta/contigs.py +++ b/prtm/common_modules/rosetta/contigs.py @@ -6,6 +6,7 @@ import torch import torch.nn as nn from prtm import protein +from prtm.constants import residue_constants from prtm.common_modules.rosetta.kinematics import * from prtm.common_modules.rosetta.util import * @@ -559,7 +560,7 @@ def get_idx0(self): hal_idx0_receptor = [] pdb_idx = list( zip( - [protein.PDB_CHAIN_IDS[i] for i in self.structure.chain_index], + [residue_constants.PDB_CHAIN_IDS[i] for i in self.structure.chain_index], self.structure.residue_index, ) ) diff --git a/prtm/constants/residue_constants.py b/prtm/constants/residue_constants.py index 707c55b4..24d32729 100644 --- a/prtm/constants/residue_constants.py +++ b/prtm/constants/residue_constants.py @@ -586,6 +586,7 @@ def make_bond_key(atom1_name, atom2_name): "NZ", "OXT", ] +backbone4_atoms = ["N", "CA", "C", "O"] atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)} atom_type_num = len(atom_types) # := 37. @@ -984,10 +985,8 @@ def make_bond_key(atom1_name, atom2_name): restype_num = len(restypes) # := 20. unk_restype_index = restype_num # Catch-all index for unknown restypes. -proteinmppn_restypes = sorted(restypes) + ["X"] -proteinmpnn_restype_order = { - restype: i for i, restype in enumerate(proteinmppn_restypes) -} +alphabetical_restypes = sorted(restypes) +alphabetical_restypes_x = sorted(restypes) + ["X"] restypes_with_x = restypes + ["X"] restypes_with_x_dash = restypes_with_x + ["-"] restype_order_with_x = {restype: i for i, restype in enumerate(restypes_with_x)} @@ -1064,13 +1063,16 @@ def sequence_to_onehot( "X": "UNK", } - # NB: restype_3to1 differs from Bio.PDB.protein_letters_3to1 by being a simple # 1-to-1 mapping of 3 letter names to one letter names. The latter contains # many more, and less common, three letter names as keys and maps many of these # to the same one letter name (including 'X' and 'U' which we don't use here). restype_3to1 = {v: k for k, v in restype_1to3.items()} +# Add handling of non-standard residues for Chroma +restype_3to1.update({res: "H" for res in ["HSD", "HSE", "HSC", "HSP"]}) +restype_3to1.update({"MSE": "M", "SEC": "C"}) + restype2atom14_mask = np.zeros([len(restypes_with_x), 14]) for k, v in restype_name_to_atom14_names.items(): for i, atom in enumerate(v): @@ -1486,8 +1488,15 @@ def make_atom14_dists_bounds(overlap_tolerance=1.5, bond_length_tolerance_factor mask_ambiguous[restype, atom_idx2] = 1 restype_3 = [restype_1to3[res] for res in restypes] +alphabetical_restype_3 = [restype_1to3[res] for res in alphabetical_restypes] restype_3 += ["UNK"] +num_atoms_per_alphabetical_aa = [ + len(list(filter(None, restype_name_to_atom14_names[c]))) + for c in alphabetical_restype_3 +] +num_chi_per_alphabetical_aa = [len(chi_angles_atoms[c]) for c in alphabetical_restype_3] + all_matrices = {res: np.eye(14, dtype=np.float32) for res in restype_3} for resname, swap in residue_atom_renaming_swaps.items(): correspondences = np.arange(14) @@ -1542,6 +1551,13 @@ def substitute(res: str): return None +PICO_TO_ANGSTROM = 0.01 + +PDB_CHAIN_IDS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" +PDB_MAX_CHAINS = len(PDB_CHAIN_IDS) +assert PDB_MAX_CHAINS == 62 + + esm_proteinseq_toks = { "toks": [ "L", diff --git a/prtm/models/chroma/__init__.py b/prtm/models/chroma/__init__.py new file mode 100644 index 00000000..af56d0ac --- /dev/null +++ b/prtm/models/chroma/__init__.py @@ -0,0 +1,13 @@ +# Copyright Generate Biomedicines, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/prtm/models/chroma/api.py b/prtm/models/chroma/api.py new file mode 100644 index 00000000..bbd0d74b --- /dev/null +++ b/prtm/models/chroma/api.py @@ -0,0 +1,124 @@ +# Copyright Generate Biomedicines, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import hashlib +import json +import os +import tempfile + +import requests + +from prtm.models import chroma + +ROOT_DIR = os.path.dirname(os.path.dirname(chroma.__file__)) + + +def register_key(key: str, key_directory=ROOT_DIR) -> None: + """ + Registers the provided key by saving it to a JSON file. + + Args: + key (str): The access token to be registered. + key_directory (str, optional): The directory where the access key is registered. + + Returns: + None + """ + config_path = os.path.join(key_directory, "config.json") + with open(config_path, "w") as f: + json.dump({"access_token": key}, f) + + +def read_key(key_directory=ROOT_DIR) -> str: + """ + Reads the registered key from the JSON file. If no key has been registered, + it informs the user and raises a FileNotFoundError. + + Args: + key_directory (str, optional): The directory where the access key is registered. + + Returns: + str: The registered access token. + + Raises: + FileNotFoundError: If no key has been registered. + """ + config_path = os.path.join(key_directory, "config.json") + + if not os.path.exists(config_path): + print("No access token has been registered.") + print( + "To obtain an access token, go to https://chroma-weights.generatebiomedicines.com/ and agree to the license." + ) + raise FileNotFoundError("No token has been registered.") + + with open(config_path, "r") as f: + config = json.load(f) + + return config["access_token"] + + +def download_from_generate( + base_url: str, + weights_name: str, + force: bool = False, + exist_ok: bool = False, + key_directory=ROOT_DIR, +) -> str: + """ + Downloads data from the provided URL using the registered access token. + Provides caching behavior based on force and exist_ok flags. + + Args: + base_url (str): The base URL from which data should be fetched. + force (bool): If True, always fetches data from the URL regardless of cache existence. + exist_ok (bool): If True and cache exists (and force is False), uses the cached data. + key_directory (str, optional): The directory where the access key is registered. + + Returns: + str: Path to the downloaded (or cached) file. + """ + + # Create a hash of the URL + weight name to determine the path for the cached/temporary file + url_hash = hashlib.md5((base_url + weights_name).encode()).hexdigest() + temp_dir = os.path.join(tempfile.gettempdir(), "chroma_weights", url_hash) + destination = os.path.join(temp_dir, "weights.pt") + + # Ensure the directory exists + os.makedirs(temp_dir, exist_ok=True) + + # Check if cache exists + cache_exists = os.path.exists(destination) + + # Determine if we should use the cache or not + use_cache = cache_exists and exist_ok and not force + + if use_cache: + print(f"Using cached data from {destination}") + return destination + + # If not using cache, proceed with download + + # Define the query parameters + params = {"token": read_key(key_directory), "weights": weights_name} + + # Perform the GET request with the token as a query parameter + response = requests.get(base_url, params=params) + response.raise_for_status() # Raise an error for HTTP errors + + with open(destination, "wb") as file: + file.write(response.content) + + print(f"Data saved to {destination}") + return destination diff --git a/prtm/models/chroma/attention.py b/prtm/models/chroma/attention.py new file mode 100644 index 00000000..b6fcc96d --- /dev/null +++ b/prtm/models/chroma/attention.py @@ -0,0 +1,345 @@ +# Copyright Generate Biomedicines, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn + + +class ScaledDotProductAttention(nn.Module): + """Scaled dot product attention as described in Eqn 1 of Vaswani et al. 2017 [https://arxiv.org/abs/1706.03762]. + + Attention(Q, K, V) = softmax(QK^T / sqrt(d_k))V + + Note that the dimension of the query has to match the dimension of the keys (here specified as ```d_k```) and the length of keys has to match + the length of the values. See for instance 'The Illustrated Transformer' [http://jalammar.github.io/illustrated-transformer/] + for pictorial depiction of attention. + + Inputs: + Q (torch.tensor): of shape (batch_size, sequence_length_q, d_k) + K (torch.tensor): of shape (batch_size, sequence_length_k, d_k) + V (torch.tensor): of shape (batch_size, sequence_length_k, d_v) + mask (torch.tensor): of dtype (bool) or (byte) and shape (batch_size, 1, sequence_length_k), optional + zeroes (or False) indicate positions that cannot contribute to attention + Outputs: + output (torch.tensor) of shape (batch_size, sequence_length_q, d_v). The [i-j]-entry output[i,j,:] is formed as a convex combination of values: + \sum_k a_k V[i,k,:] and \sum_k a_k = 1. + attentions (torch.tensor) of shape (batch_size, sequence_length_q, sequence_length_k)) where the [b,i,j]-element + corresponds to the attention value (e.g relative contribution) of position j in the key-tensor to position i in the query tensor in element b of the batch. + """ + + def __init__(self): + super(ScaledDotProductAttention, self).__init__() + self.softmax = nn.Softmax(dim=-1) + + def forward(self, Q, K, V, mask=None): + _, _, d = K.size() + attn = torch.bmm(Q, K.transpose(1, 2)) / d**0.5 + if mask is not None: + attn = attn.float().masked_fill(mask == 0, -1e9) + + attn = self.softmax(attn) + if mask is not None: + attn = attn.float().masked_fill(mask == 0, 0) + + if V.dtype == torch.float16: + attn = attn.half() + output = torch.bmm(attn, V) + return output, attn + + +class MultiHeadAttention(nn.Module): + """Multi-head attention with scaled dot product attention. See 'The Annotated Transformer' + http://nlp.seas.harvard.edu/2018/04/03/attention.html or 'The Illustrated Transformer' http://jalammar.github.io/illustrated-transformer/ + for details and intuition. + + Args: + n_head (int): number of attention heads + d_k (int): dimension of the keys and queries in each attention head + d_v (int): dimension of the values in each attention head + d_model (int): input and output dimension for the layer + dropout (float): dropout rate, default is 0.1 + + Inputs: + Q (torch.tensor): query tensor of shape ```(batch_size, sequence_length_q, d_model)``` + K (torch.tensor): key tensor of shape ```(batch_size, sequence_length_k, d_model)``` + V (torch.tensor): value tensor of shape ```(batch_size, sequence_length_k, d_model)``` + mask (torch.tensor): (optional) of dtype ```bool`` or ```byte``` and size (batch_size, 1, sequence_length_k), + zeroes (or False) indicate positions that cannot contribute to attention + + Outputs: + output (torch.tensor) : of shape ```(batch_size, sequence_length_q, d_model)``` + attentions (torch.tensor): of shape ```(batch_size * n_head, sequence_length_q, sequence_length_k) where + ```attentions[batch_size*(i):batch_size*(i+1),:,:]``` corresponds to the batch of attention blocks for i'th head. See + ```chroma.layers.attention.ScaledDotProductAttention``` for more details + """ + + def __init__(self, n_head, d_k, d_v, d_model, dropout=0.1): + super(MultiHeadAttention, self).__init__() + self.n_head = n_head + self.d_k = d_k + self.d_v = d_v + self.d_model = d_model + self.Wq = nn.Parameter(torch.Tensor(n_head, d_model, d_k)) + self.Wk = nn.Parameter(torch.Tensor(n_head, d_model, d_k)) + self.Wv = nn.Parameter(torch.Tensor(n_head, d_model, d_v)) + self.Wo = nn.Parameter(torch.Tensor(n_head * d_v, d_model)) + self.attention = ScaledDotProductAttention() + self.dropout = nn.Dropout(p=dropout) + self.reset_parameters() + + def reset_parameters(self): + nn.init.xavier_normal_(self.Wq) + nn.init.xavier_normal_(self.Wk) + nn.init.xavier_normal_(self.Wv) + nn.init.kaiming_uniform_(self.Wo) + + def forward(self, Q, K, V, bias=None, mask=None): + mb_size, len_q, d_q_in = Q.size() + mb_size, len_k, d_k_in = K.size() + mb_size, len_v, d_v_in = V.size() + d_model = self.d_model + if d_q_in != d_model: + raise ValueError("Dimension of Q does not match d_model.") + + if d_k_in != d_model: + raise ValueError("Dimension of K does not match d_model.") + + if d_v_in != d_model: + raise ValueError("Dimension of V does not match d_model.") + + # treat as a (n_head) size batch and project to d_k and d_v + q_s = torch.cat([Q @ W for W in self.Wq]) # (n_head*mb_size) x len_q x d_k + k_s = torch.cat([K @ W for W in self.Wk]) # (n_head*mb_size) x len_k x d_k + v_s = torch.cat([V @ W for W in self.Wv]) # (n_head*mb_size) x len_v x d_v + + # Attention + if mask is not None: + mask = mask.repeat(self.n_head, 1, 1) + outputs, attns = self.attention(q_s, k_s, v_s, mask=mask) + + # Back to original mb_size batch, result size = mb_size x len_q x (n_head*d_v) + outputs = torch.cat(torch.split(outputs, mb_size, dim=0), dim=-1) + + # Project back to residual size + outputs = outputs @ self.Wo + outputs = self.dropout(outputs) + return outputs, attns + + +class AttentionChainPool(nn.Module): + """Pools residue-based representations to chain-based representations using a chain mask and attention. + Args: + n_head (int): number of attention heads + d_model (int): dimension of embeddings to be pooled + + Inputs: + h (torch.tensor): of size (batch_size, sequence_length, d_model) + C (torch.tensor): of size (batch_size, sequence_length) + + Outputs: + output (torch.tensor): of size (batch_size, n_chains, d_model) + chain_mask (torch.tensor): of size (batch_size, n_chains) + """ + + def __init__(self, n_head, d_model): + super().__init__() + self.attention = MultiHeadAttention( + n_head, d_model, d_model, d_model, dropout=0.0 + ) + + def get_query(self, x): + return torch.ones(x.size(0), 1, x.size(2)).type(x.dtype).to(x.device) + + def forward(self, h, C): + bs, num_res = C.size() + chains = C.abs().unique() + chains = ( + chains[chains > 0].unsqueeze(-1).repeat(1, bs).reshape(-1).unsqueeze(-1) + ) + num_chains = len(chains.unique()) + + h_repeat = h.repeat(num_chains, 1, 1) + C_repeat = C.repeat(num_chains, 1) + mask = (C_repeat == chains).unsqueeze(-2) + + output, _ = self.attention( + self.get_query(h_repeat), h_repeat, h_repeat, mask=mask + ) + output = torch.cat(output.split(bs), 1) + chain_mask = torch.stack(mask.squeeze(1).any(dim=-1).split(bs), -1) + return output, chain_mask + + +class Attention(nn.Module): + """ + A multi-head attention layer with optional gating and bias as implemented in Jumper et al. (2021) + Args: + n_head (int): Number of heads of attention + d_model (int): Dimension of input and outputs + d_k (int): Dimension of keys/queries + d_v (int): Dimension of values + gate (bool): Whether to include a gate connection (as in Jumper et al. (2021)) + + Inputs: + Q (torch.tensor): of size (batch_size, num_queries, d_model) + K (torch.tensor): of size (batch_size, num_keys, d_model) + V (torch.tensor): of size (batch_size, num_keys, d_model) + bias (torch.tensor): (optional) of size (batch_size, n_head, num_queries, num_keys) + mask (torch.tensor): (optional) of size (batch_size, n_head, num_queries, num_keys) + + Outputs: + output (torch.tensor): of size (batch_size, num_queries, d_model) + """ + + def __init__(self, n_head, d_model, d_k=None, d_v=None, gate=False): + super().__init__() + self.n_head = n_head + self.d_model = d_model + self.d_k = d_model // n_head if d_k is None else d_k + self.d_v = d_model // n_head if d_v is None else d_v + self.gate = gate + self.q_weights = nn.Parameter(torch.Tensor(d_model, n_head, self.d_k)) + self.k_weights = nn.Parameter(torch.Tensor(d_model, n_head, self.d_k)) + self.v_weights = nn.Parameter(torch.Tensor(d_model, n_head, self.d_v)) + self.o_weights = nn.Parameter(torch.Tensor(n_head, self.d_v, d_model)) + self.o_bias = nn.Parameter(torch.Tensor(d_model)) + if self.gate: + self.g_weights = nn.Parameter(torch.Tensor(d_model, n_head, self.d_v)) + self.g_bias = nn.Parameter(torch.Tensor(n_head, self.d_v)) + self.softmax = nn.Softmax(dim=-1) + self.reset_parameters() + + def reset_parameters(self): + nn.init.xavier_uniform_(self.q_weights) + nn.init.xavier_uniform_(self.k_weights) + nn.init.xavier_uniform_(self.v_weights) + nn.init.xavier_uniform_(self.o_weights) + nn.init.zeros_(self.o_bias) + if self.gate: + nn.init.zeros_(self.g_weights) + nn.init.ones_(self.g_bias) + + def forward(self, Q, K, V, bias=None, mask=None): + self._check_inputs(Q, K, V, bias, mask) + q = torch.einsum("bqa,ahc->bqhc", Q, self.q_weights) * self.d_k ** (-0.5) + k = torch.einsum("bka,ahc->bkhc", K, self.k_weights) + v = torch.einsum("bka,ahc->bkhc", V, self.v_weights) + logits = torch.einsum("bqhc,bkhc->bhqk", q, k) + + if bias is not None: + logits = logits + bias + + weights = torch.nn.functional.softmax(logits, dim=-1) + + if mask is not None: + weights = weights.masked_fill(~mask, 0.0) + + weighted_avg = torch.einsum("bhqk,bkhc->bqhc", weights, v) + + if self.gate: + gate_values = torch.einsum("bqa,ahc->bqhc", Q, self.g_weights) + self.g_bias + gate_values = torch.sigmoid(gate_values, dim=-1) + weighted_avg = weighted_avg * gate_values + + output = ( + torch.einsum("bqhc,hco->bqo", weighted_avg, self.o_weights) + self.o_bias + ) + return output + + def _check_inputs(self, Q, K, V, bias, mask): + batch_size_q, num_queries, d_q_in = Q.size() + batch_size_k, num_keys, d_k_in = K.size() + batch_size_v, num_values, d_v_in = V.size() + + if d_q_in != self.d_model: + raise ValueError( + f"Dimension of Q tensor needs to be (batch_size, number_queries, d_model)" + ) + + if d_k_in != self.d_model: + raise ValueError( + f"Dimension of K tensor needs to be (batch_size, number_keys, d_model)" + ) + + if d_v_in != self.d_model: + raise ValueError( + f"Dimension of V tensor needs to be (batch_size, number_values, d_model)" + ) + + if num_keys != num_values: + raise ValueError(f"Number of keys needs to match number of values passed") + + if (batch_size_q != batch_size_k) or (batch_size_k != batch_size_v): + raise ValueError( + f"Found batch size mismatch among inputs, all tensors must agree in size of dimension 0" + ) + + if bias is not None: + if (bias.dim() != 3) and (bias.dim() != 4): + raise ValueError( + f"Bias specified but dimension mismatched: passed {bias.dim()}-dimensional tensor but should be 3-dimensional" + f"of shape (n_head, num_queries, num_keys) or 4-dimensional of shape (batch_size, n_head, num_queries, num_keys)" + ) + if bias.dim() == 3: + n_head_b, num_queries_b, num_keys_b = bias.size() + if n_head_b != self.n_head: + raise ValueError( + f"Bias specified but number of heads (dim of axis=0) does not match number of heads: {self.n_head}" + ) + if num_queries_b != num_queries: + raise ValueError( + f"Bias specified but number of queries (dim of axis=1) does not match number of queries given in Q tensor" + ) + if num_keys_b != num_keys: + raise ValueError( + f"Bias specified but number of keys (dim of axis=2) does not match number of queries given in K tensor " + f"(dimenson of axis=1)" + ) + elif bias.dim() == 4: + if bias.dim() == 3: + n_batch_b, n_head_b, num_queries_b, num_keys_b = bias.size() + if n_head_b != self.n_head: + raise ValueError( + f"Bias specified but number of heads (dim of axis=0) does not match number of heads: {self.n_head}" + ) + if num_queries_b != num_queries: + raise ValueError( + f"Bias specified but number of queries (dim of axis=1) does not match number of queries given in Q tensor" + ) + if num_keys_b != num_keys: + raise ValueError( + f"Bias specified but number of keys (dim of axis=2) does not match number of queries given in K tensor " + f"(dimenson of axis=1)" + ) + + if mask is not None: + if mask.dtype != torch.bool: + raise ValueError( + f"Mask specified but not given by correct dtype, should be torch.bool but found {mask.dtype}" + ) + if mask.dim() != 4: + raise ValueError( + f"Mask specified but dimension mismatched: passed {mask.dim()}-dimensional tensor but should be 4-dimensional" + f"of shape (batch_size, n_head, num_queries, num_keys)" + ) + batch_size_b, _, num_queries_b, num_keys_b = mask.size() + if (num_queries_b != num_queries) and (num_queries_b != 1): + raise ValueError( + f"Bias specified but number of queries (dim of axis=2) does not match number of queries given in Q tensor" + ) + if (num_keys_b != num_keys) and (num_keys_b != 1): + raise ValueError( + f"Bias specified but number of keys (dim of axis=3) does not match number of queries given in K tensor " + f"(dimenson of axis=1)" + ) diff --git a/prtm/models/chroma/basic.py b/prtm/models/chroma/basic.py new file mode 100644 index 00000000..76f5392a --- /dev/null +++ b/prtm/models/chroma/basic.py @@ -0,0 +1,453 @@ +# Copyright Generate Biomedicines, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from prtm.models.chroma.norm import MaskedBatchNorm1d + + +class NoOp(nn.Module): + """A dummy nn.Module wrapping an identity operation. + + Inputs: + x (any) + + Outputs: + x (any) + """ + + def __init__(self): + super().__init__() + + def forward(self, x, **kwargs): + return x + + +class Transpose(nn.Module): + """An nn.Module wrapping ```torch.transpose```. + + Args: + d1 (int): the first (of two) dimensions to swap + d2 (int): the second (of two) dimensions to swap + + Inputs: + x (torch.tensor) + + Outputs: + y (torch.tensor): ```y = x.transpose(d1,d2)``` + """ + + def __init__(self, d1=1, d2=2): + super().__init__() + self.d1 = d1 + self.d2 = d2 + + def forward(self, x): + return x.transpose(self.d1, self.d2) + + +class Unsqueeze(nn.Module): + """An nn.Module wrapping ```torch.unsqueeze```. + + Args: + dim (int): the dimension to unsqueeze input tensors + + Inputs: + x (torch.tensor): + + Outputs: + y (torch.tensor): where ```y=x.unsqueeze(dim)``` + """ + + def __init__(self, dim=1): + super().__init__() + self.dim = dim + + def forward(self, x): + return x.unsqueeze(self.dim) + + +class OneHot(nn.Module): + """An nn.Module that wraps F.one_hot```. + + Args: + n_tokens (int): the number of tokens comprising input sequences + + Inputs: + x (torch.LongTensor): of size ```(batch_size, *)``` + + Outputs: + y (torch.ByteTensor): of size (batch_size, *, n_tokens) cast to input.device + """ + + def __init__(self, n_tokens): + super().__init__() + self.n_tokens = n_tokens + + def forward(self, x): + return F.one_hot(x, self.n_tokens) + + +class MeanEmbedding(nn.Module): + """A wrapper around ```nn.Embedding``` that allows for one-hot-like representation inputs (as well as standard tokenized representation), + optionally applying a softmax to the last dimension if the input corresponds to a log-PMF. + Args: + embedding (nn.Embedding): Embedding to wrap + use_softmax (bool): Whether to apply a softmax to the last dimension if input is one-hot-like. + + Inputs: + x (torch.tensor): of size (batch_size, sequence_length) (standard tokenized representation) -OR- (batch_size, sequence_length, number_tokens) (one-hot representation) + + Outputs: + y (torch.tensor): of size (batch_size, sequence_length, embedding_dimension) obtained via. lookup into ```self.embedding.weight``` if + input is in standard tokenized form or by matrix multiplication of input with ```self.embedding.weight``` if input is one-hot-like. Note + that if the input is a one-hot matrix the output is the same regardless of representation. + """ + + def __init__(self, embedding, use_softmax=True): + super(MeanEmbedding, self).__init__() + self.embedding = embedding + self.use_softmax = use_softmax + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x): + if len(x.shape) == 2: + return self.embedding(x) + elif len(x.shape) == 3: + if self.use_softmax: + return self.softmax(x) @ self.embedding.weight + else: + return x @ self.embedding.weight + else: + raise (NotImplementedError) + + +class PeriodicPositionalEncoding(nn.Module): + """Positional encoding, adapted from 'The Annotated Transformer' + http://nlp.seas.harvard.edu/2018/04/03/attention.html + + Args: + d_model (int): input and output dimension for the layer + max_seq_len (int): maximum allowed sequence length + dropout (float): Dropout rate + + Inputs: + x (torch.tensor): of size (batch_size, sequence_length, d_model) + + Outputs: + y (torch.tensor): of size (batch_size, sequence_length, d_model) + """ + + def __init__(self, d_model, max_seq_len=4000, dropout=0.0): + super(PeriodicPositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) + + pe = torch.zeros(max_seq_len, d_model) + position = torch.arange(0.0, max_seq_len).unsqueeze(1) + div_term = torch.exp( + torch.arange(0.0, d_model, 2) * -(math.log(10000.0) / d_model) + ) + + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.register_buffer("pe", pe) + + def forward(self, x): + x = x + self.pe[:, : x.size(1)] + return self.dropout(x) + + +class PositionWiseFeedForward(nn.Module): + """Position-wise feed-forward using 1x1 convolutions, a building block of legacy Transformer code (not code optimized). + + Args: + d_model (int): input and output dimension for the layer + d_inner_hid (int): size of the hidden layer in the position-wise feed-forward sublayer + + Inputs: + x (torch.tensor): of size (batch_size, sequence_length, d_model) + Outputs: + y (torch.tensor): of size (batch_size, sequence_length, d_model) + """ + + def __init__(self, d_model, d_hidden, dropout=0.1): + super(PositionWiseFeedForward, self).__init__() + self.activation = nn.ReLU() + self.linear1 = nn.Conv1d(d_model, d_hidden, 1) + self.linear2 = nn.Conv1d(d_hidden, d_model, 1) + self.dropout = nn.Dropout(p=dropout) + + def reset_parameters(self): + self.linear1.reset_parameters() + self.linear2.reset_parameters() + + def forward(self, x): + output = self.activation(self.linear1(x.transpose(1, 2))) + output = self.linear2(output).transpose(1, 2) + return self.dropout(output) + + +class DropNormLin(nn.Module): + """nn.Module applying a linear layer, normalization, dropout, and activation + Args: + in_features (int): input dimension + out_features (int): output dimension + norm_type (str): ```'ln'``` for layer normalization or ```'bn'``` for batch normalization else skip normalization + dropout (float): dropout to apply + actn (nn.Module): activation function to apply + + Input: + x (torch.tensor): of size (batch_size, sequence_length, in_features) + input_mask (torch.tensor): of size (batch_size, 1, sequence_length) (optional) + + Output: + y (torch.tensor): of size (batch_size, sequence_length, out_features) + """ + + def __init__( + self, in_features, out_features, norm_type="ln", dropout=0.0, actn=nn.ReLU() + ): + super(DropNormLin, self).__init__() + self.linear = nn.Linear(in_features, out_features) + if norm_type == "ln": + self.norm_layer = nn.LayerNorm(out_features) + elif norm_type == "bn": + self.norm_layer = MaskedBatchNorm1d(out_features) + else: + self.norm_layer = NoOp() + self.dropout = nn.Dropout(p=dropout) + self.actn = actn + + def forward(self, x, input_mask=None): + h = self.linear(x) + if isinstance(self.norm_layer, MaskedBatchNorm1d): + h = self.norm_layer(h.transpose(1, 2), input_mask=input_mask).transpose( + 1, 2 + ) + else: + h = self.norm_layer(h) + return self.dropout(self.actn(h)) + + +class ResidualLinearLayer(nn.Module): + """A Simple Residual Layer using a linear layer a relu and an optional layer norm. + + Args: + d_model (int): Model Dimension + use_norm (bool, *optional*): Optionally Use a Layer Norm. Default `True`. + """ + + def __init__(self, d_model, use_norm=True): + super(ResidualLinearLayer, self).__init__() + self.linear = nn.Linear(d_model, d_model) + self.ReLU = nn.ReLU() + self.use_norm = use_norm + self.norm = nn.LayerNorm(d_model) + + def forward(self, x): + z = self.linear(x) + z = self.ReLU(z) + if self.use_norm: + z = self.norm(z) + return x + z + + +class TriangleMultiplication(nn.Module): + def __init__(self, d_model=512, mode="outgoing"): + """ + Triangle multiplication as defined in Jumper et al. (2021) + Args: + d_model (int): dimension of the embedding at each position + mode (str): Must be 'outgoing' (algorithm 11) or 'incoming' (algorithm 12). + + Inputs: + X (torch.tensor): Pair representations of size (batch, nres, nres, channels) + mask (torch.tensor): of dtype `torch.bool` and size (batch, nres, nres, channels) (or broadcastable to this size) + + Outputs: + Y (torch.tensor): Pair representations of size (batch, nres, nres, channels) + """ + super().__init__() + self.mode = mode + assert self.mode in ["outgoing", "incoming"] + self.equation = ( + "bikc,bjkc->bijc" if self.mode == "outgoing" else "bkjc,bkic->bijc" + ) + self.layer_norm = nn.LayerNorm(d_model) + self.left_edge_mlp = nn.Sequential( + nn.Linear(d_model, d_model), nn.Sigmoid(), nn.Linear(d_model, d_model) + ) + self.right_edge_mlp = nn.Sequential( + nn.Linear(d_model, d_model), nn.Sigmoid(), nn.Linear(d_model, d_model) + ) + self.skip = nn.Sequential(nn.Linear(d_model, d_model), nn.Sigmoid()) + self.combine = nn.Sequential(nn.LayerNorm(d_model), nn.Linear(d_model, d_model)) + + def forward(self, X, mask=None): + h = self.layer_norm(X) + + A = self.left_edge_mlp(h) + B = self.right_edge_mlp(h) + G = self.skip(h) + + if mask is not None: + A = A.masked_fill(~mask, 0.0) + B = B.masked_fill(~mask, 0.0) + + h = torch.einsum(self.equation, A, B) + h = self.combine(h) * G + return h + + +class NodeProduct(nn.Module): + """Like Alg. 10 in Jumper et al. (2021) but instead of computing a mean over MSA dimension, + process for single-sequence inputs. + Args: + d_in (int): dimension of node embeddings (inputs) + d_out (int): dimension of edge embeddings (outputs) + + Inputs: + node_features (torch.tensor): of size (batch_size, nres, d_model) + node_mask (torch.tensor): of size (batch_size, nres) + edge_mask (torch.tensor): of size (batch_size, nres, nres) + + Outputs: + edge_features (torch.tensor): of size (batch_size, nres, nres, d_model) + """ + + def __init__(self, d_in, d_out): + super().__init__() + self.layer_norm = nn.LayerNorm(d_in) + self.left_lin = nn.Linear(d_in, d_in) + self.right_lin = nn.Linear(d_in, d_in) + self.edge_lin = nn.Linear(2 * d_in, d_out) + + def forward(self, node_features, node_mask=None, edge_mask=None): + _, nres, _ = node_features.size() + + node_features = self.layer_norm(node_features) + left_embs = self.left_lin(node_features) + right_embs = self.right_lin(node_features) + + if node_mask is not None: + mask = node_mask[:, :, None] + left_embs = left_embs.masked_fill(~mask, 0.0) + right_embs = right_embs.masked_fill(~mask, 0.0) + + left_embs = left_embs[:, None, :, :].repeat(1, nres, 1, 1) + right_embs = right_embs[:, :, None, :].repeat(1, 1, nres, 1) + edge_features = torch.cat([left_embs, right_embs], dim=-1) + edge_features = self.edge_lin(edge_features) + + if edge_mask is not None: + mask = edge_mask[:, :, :, None] + edge_features = edge_features.masked_fill(~mask, 0.0) + + return edge_features + + +class FourierFeaturization(nn.Module): + """Applies fourier featurization of low-dimensional (usually spatial) input data as described in [https://arxiv.org/abs/2006.10739] , + optionally trainable as described in [https://arxiv.org/abs/2106.02795]. + + Args: + d_input (int): dimension of inputs + d_model (int): dimension of outputs + trainable (bool): whether to learn the frequency of fourier features + scale (float): if not trainable, controls the scale of fourier feature periods (see reference for description, this parameter matters and should be tuned!) + + Inputs: + input (torch.tensor): of size (batch_size, *, d_input) + + Outputs: + output (torch.tensor): of size (batch_size, *, d_output) + """ + + def __init__(self, d_input, d_model, trainable=False, scale=1.0): + super().__init__() + self.scale = scale + + if d_model % 2 != 0: + raise ValueError( + "d_model needs to be even for this featurization, try again!" + ) + + B = 2 * math.pi * scale * torch.randn(d_input, d_model // 2) + self.trainable = trainable + if not trainable: + self.register_buffer("B", B) + else: + self.register_parameter("B", torch.nn.Parameter(B)) + + def forward(self, inputs): + h = inputs @ self.B + return torch.cat([h.cos(), h.sin()], -1) + + +class PositionalEncoding(nn.Module): + """Axis-aligned positional encodings with log-linear spacing. + + Args: + d_input (int): dimension of inputs + d_model (int): dimension of outputs + period_range (tuple of floats): Min and maximum periods for the + frequency components. Fourier features will be log-linearly spaced + between these values (inclusive). + + Inputs: + input (torch.tensor): of size (..., d_input) + + Outputs: + output (torch.tensor): of size (..., d_model) + """ + + def __init__(self, d_model, d_input=1, period_range=(1.0, 1000.0)): + super().__init__() + + if d_model % (2 * d_input) != 0: + raise ValueError( + "d_model needs to be divisible by 2*d_input for this featurization, " + f"but got {d_model} versus {d_input}" + ) + + num_frequencies = d_model // (2 * d_input) + log_bounds = np.log10(period_range) + p = torch.logspace(log_bounds[0], log_bounds[1], num_frequencies, base=10.0) + w = 2 * math.pi / p + self.register_buffer("w", w) + + def forward(self, inputs): + batch_dims = list(inputs.shape)[:-1] + # (..., 1, num_out) * (..., num_in, 1) + w = self.w.reshape(len(batch_dims) * [1] + [1, -1]) + h = w * inputs[..., None] + h = torch.cat([h.cos(), h.sin()], -1).reshape(batch_dims + [-1]) + return h + + +class MaybeOnehotEmbedding(nn.Embedding): + """Wrapper around :class:`torch.nn.Embedding` to support either int-encoded + LongTensors or one-hot encoded FloatTensors. + """ + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if x.dtype.is_floating_point: # onehot + return x @ self.weight + return super().forward(x) diff --git a/prtm/models/chroma/chroma.py b/prtm/models/chroma/chroma.py new file mode 100644 index 00000000..4690cf56 --- /dev/null +++ b/prtm/models/chroma/chroma.py @@ -0,0 +1,766 @@ +# Copyright Generate Biomedicines, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Joint model for protein complexes with applications to unconditional and conditional +protein design in a programmable manner. +""" + +import copy +import inspect +from collections import defaultdict, namedtuple +from typing import List, Literal, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from prtm.protein import ProteinBase +from prtm.constants.residue_constants import alphabetical_restype_3 +from prtm.models.chroma import graph_backbone, graph_design +from prtm.models.chroma.structure.backbone import ProteinBackbone + + +class Chroma(nn.Module): + """Chroma: A generative model for protein design. + + Chroma is a generative model for proteins and protein complexes. It combines + a diffusion model for generating protein backbones together with discrete + generative models for sequence and sidechain conformations given structure. + It enables programmatic design of proteins through a conditioning + framework. This class provides an interface to: + * Load model weights + * Sample protein complexes, both unconditionally and conditionally + * Perform sequence design of sampled backbones + + Args: + weights_backbone (str, optional): The name of the pre-trained weights + to use for the backbone network. + + weights_design (str, optional): The name of the pre-trained weights + to use for the autoregressive design network. + + device (Optional[str]): The device on which to load the networks. If + not specified, will automatically use a CUDA device if available, + otherwise CPU. + + strict (bool): Whether to strictly enforce that all keys in `weights` + match the keys in the model's state_dict. + + verbose (bool, optional): Show outputs from download and loading. + Default False. + """ + + def __init__( + self, + weights_backbone: str = "named:public", + weights_design: str = "named:public", + device: Optional[str] = None, + strict: bool = False, + verbose: bool = False, + ) -> None: + super().__init__() + + import warnings + + warnings.filterwarnings("ignore") + + # If no device is explicity specified automatically set device + if device is None: + if torch.cuda.is_available(): + device = "cuda" + else: + device = "cpu" + + self.backbone_network = graph_backbone.load_model( + weights_backbone, device=device, strict=strict, verbose=verbose + ).eval() + + self.design_network = graph_design.load_model( + weights_design, + device=device, + strict=strict, + verbose=False, + ).eval() + + def sample( + self, + # Backbone Args + samples: int = 1, + steps: int = 500, + chain_lengths: List[int] = [100], + tspan: List[float] = (1.0, 0.001), + protein_init: Optional[ProteinBase] = None, + conditioner: Optional[nn.Module] = None, + langevin_factor: float = 2, + langevin_isothermal: bool = False, + inverse_temperature: float = 10, + initialize_noise: bool = True, + integrate_func: Literal["euler_maruyama", "heun"] = "euler_maruyama", + sde_func: Literal["langevin", "reverse_sde", "ode"] = "reverse_sde", + trajectory_length: int = 200, + full_output: bool = False, + # Sidechain Args + design_ban_S: Optional[List[str]] = None, + design_method: Literal["potts", "autoregressive"] = "potts", + design_selection: Optional[Union[str, torch.Tensor]] = None, + design_t: Optional[float] = 0.5, + temperature_S: float = 0.01, + temperature_chi: float = 1e-3, + top_p_S: Optional[float] = None, + regularization: Optional[str] = "LCP", + potts_mcmc_depth: int = 500, + potts_proposal: Literal["dlmc", "chromatic"] = "dlmc", + potts_symmetry_order: int = None, + verbose: bool = False, + ) -> Union[ + Union[ProteinBase, List[ProteinBase]], Tuple[Union[ProteinBase, List[ProteinBase]], dict] + ]: + """ + Performs Backbone Sampling and Sequence Design and returns a Protein or list + of Proteins. Optionally this method can return additional arguments to show + details of the sampling procedure. + + Args: + Backbone sampling: + samples (int, optional): The number of proteins to sample. + Default is 1. + steps (int, optional): The number of integration steps for the SDE. + Default is 500. + chain_lengths (List[int], optional): The lengths of the protein chains. + Default is [100]. + conditioner (Conditioner, optional): The conditioner object that + provides the conditioning information. Default is None. + langevin_isothermal (bool, optional): Whether to use the isothermal + version of the Langevin SDE. Default is False. + integrate_func (str, optional): The name of the integration function to + use. Default is “euler_maruyama”. + sde_func (str, optional): The name of the SDE function to use. Defaults + to “reverse_sde”. + langevin_factor (float, optional): The factor that controls the strength + of the Langevin noise. Default is 2. + inverse_temperature (float, optional): The inverse temperature parameter + for the SDE. Default is 10. + protein_init (Protein, optional): The initial protein state. Defaults + to None. + full_output (bool, optional): Whether to return the full outputs of the + SDE integration, including the protein sample trajectory, the + Xhat trajectory (the trajectory of the preceived denoising target) + and the Xunc trajectory (the trajectory of the unconditional sample + path). Default is False. + initialize_noise (bool, optional): Whether to initialize the noise for + the SDE integration. Default is True. + tspan (List[float], optional): The time span for the SDE integration. + Default is (1.0, 0.001). + trajectory_length (int, optional): The number of sampled steps in the + trajectory output. Maximum is `steps`. Default 200. + **kwargs: Additional keyword arguments for the integration function. + + Sequence and sidechain sampling: + design_ban_S (list of str, optional): List of amino acid single-letter + codes to ban, e.g. `["C"]` to ban cysteines. + design_method (str, optional): Specifies which method to use for design. + Can be `potts` and `autoregressive`. Default is `potts`. + design_selection (str, optional): Clamp selection for + conditioning on a subsequence during sequence sampling. Can be + either a selection string or a binary design mask indicating + positions to be sampled with shape `(num_batch, num_residues)` or + position-specific valid amino acid choices with shape + `(num_batch, num_residues, num_alphabet)`. + design_mask_sample (Tensor, optional): Binary design mask indicating + which positions can be sampled with shape `(num_batch, num_residues)`. + or which amino acids can be sampled at which position with + shape `(num_batch, num_residues, num_alphabet)`. + design_t (float or torch.Tensor, optional): Diffusion time for models + trained with diffusion augmentation of input structures. Setting `t=0` + or `t=None` will condition the model to treat the structure as + exact coordinates, while values of `t > 0` will condition + the model to treat structures as though they were drawn from + noise-augmented ensembles with that noise level. For robust design + (default) we recommend `t=0.5`, or for literal design we recommend + `t=0.0`. May be a float or a tensor of shape `(num_batch)`. + temperature_S (float, optional): Temperature for sequence sampling. + Default 0.01. + temperature_chi (float, optional): Temperature for chi angle sampling. + Default 1e-3. + top_p_S (float, optional): Top-p sampling cutoff for autoregressive + sampling. + regularization (str, optional): Complexity regularization for + sampling. + potts_mcmc_depth (int, optional): Depth of sampling (number of steps per + alphabet letter times number of sites) per cycle. + potts_proposal (str): MCMC proposal for Potts sampling. Currently implemented + proposals are `dlmc` (default) for Discrete Langevin Monte Carlo [1] or + `chromatic` for graph-colored block Gibbs sampling. + [1] Sun et al. Discrete Langevin Sampler via Wasserstein Gradient Flow (2023). + potts_symmetry_order (int, optional): Symmetric design. + The first `(num_nodes // symmetry_order)` residues in the protein + system will be variable, and all consecutively tiled sets of residues + will be locked to these during decoding. Internally this is accomplished by + summing the parameters Potts model under a symmetry constraint + into this reduced sized system and then back imputing at the end. + Currently only implemented for Potts models. + + Returns: + proteins: Sampled `Protein` object or list of sampled `Protein` objects in + the case of multiple outputs. + full_output_dictionary (dict, optional): Additional outputs if + `full_output=True`. + """ + + # Get KWARGS + input_args = locals() + + # Dynamically get acceptable kwargs for each method + backbone_keys = set(inspect.signature(self._sample).parameters) + design_keys = set(inspect.signature(self.design).parameters) + + # Filter kwargs for each method using dictionary comprehension + backbone_kwargs = {k: input_args[k] for k in input_args if k in backbone_keys} + design_kwargs = {k: input_args[k] for k in input_args if k in design_keys} + + # Perform Sampling + sample_output = self._sample(**backbone_kwargs) + + if full_output: + protein_sample, output_dictionary = sample_output + else: + protein_sample = sample_output + output_dictionary = None + + # Perform Design + if design_method is None: + proteins = protein_sample + else: + if isinstance(protein_sample, list): + proteins = [ + self.design(protein, **design_kwargs) for protein in protein_sample + ] + else: + proteins = self.design(protein_sample, **design_kwargs) + + # Perform conditioner postprocessing + if (conditioner is not None) and hasattr(conditioner, "_postprocessing_"): + proteins, output_dictionary = self._postprocess( + conditioner, proteins, output_dictionary + ) + + if full_output: + return proteins, output_dictionary + else: + return proteins + + def _postprocess(self, conditioner, proteins, output_dictionary): + if output_dictionary is None: + if isinstance(proteins, list): + proteins = [ + conditioner._postprocessing_(protein) for protein in proteins + ] + else: + proteins = conditioner._postprocessing_(proteins) + else: + if isinstance(proteins, list): + p_dicts = [] + proteins = [] + for i, protein in enumerate(proteins): + p_dict = {} + for key, value in output_dictionary.items(): + p_dict[key] = value[i] + + protein, p_dict = conditioner._postprocessing_(protein, p_dict) + p_dicts.append(p_dict) + + # Merge Output Dictionaries + output_dictionary = defaultdict(list) + for p_dict in p_dicts: + for k, v in p_dict.items(): + output_dictionary[k].append(v) + else: + proteins, output_dictionary = conditioner._postprocessing_( + proteins, output_dictionary + ) + return proteins, output_dictionary + + def _sample( + self, + samples: int = 1, + steps: int = 500, + chain_lengths: List[int] = [100], + tspan: List[float] = (1.0, 0.001), + protein_init: Optional[ProteinBase] = None, + conditioner: Optional[nn.Module] = None, + langevin_factor: float = 2, + langevin_isothermal: bool = False, + inverse_temperature: float = 10, + initialize_noise: bool = True, + integrate_func: Literal["euler_maruyama", "heun"] = "euler_maruyama", + sde_func: Literal["langevin", "reverse_sde", "ode"] = "reverse_sde", + trajectory_length: int = 200, + full_output: bool = False, + **kwargs, + ) -> Union[ + Tuple[List[ProteinBase], List[ProteinBase]], + Tuple[List[ProteinBase], List[ProteinBase], List[ProteinBase], List[ProteinBase]], + ]: + """Samples backbones given chain lengths by integrating SDEs. + + Args: + samples (int, optional): The number of proteins to sample. Default is 1. + steps (int, optional): The number of integration steps for the SDE. + Default is 500. + chain_lengths (List[int], optional): The lengths of the protein chains. + Default is [100]. + conditioner (Conditioner, optional): The conditioner object that provides + the conditioning information. Default is None. + langevin_isothermal (bool, optional): Whether to use the isothermal version + of the Langevin SDE. Default is False. + integrate_func (str, optional): The name of the integration function to use. + Default is `euler_maruyama`. + sde_func (str, optional): The name of the SDE function to use. Default is + “reverse_sde”. + langevin_factor (float, optional): The factor that controls the strength of + the Langevin noise. Default is 2. + inverse_temperature (float, optional): The inverse temperature parameter + for the SDE. Default is 10. + protein_init (Protein, optional): The initial protein state. Default is + None. + full_output (bool, optional): Whether to return the full outputs of the SDE + integration, including Xhat and Xunc. Default is False. + initialize_noise (bool, optional): Whether to initialize the noise for the + SDE integration. Default is True. + tspan (List[float], optional): The time span for the SDE integration. + Default is (1.0, 0.001). + trajectory_length (int, optional): The number of sampled steps in the + trajectory output. Maximum is `steps`. Default 200. + **kwargs: Additional keyword arguments for the integration function. + + Returns: + proteins: Sampled `Protein` object or list of sampled `Protein` objects in + the case of multiple outputs. + full_output_dictionary (dict, optional): Additional outputs if + `full_output=True`. + """ + + if protein_init is not None: + X_unc, C_unc, S_unc = protein_init.to_XCS() + else: + X_unc, C_unc, S_unc = self._init_backbones(samples, chain_lengths) + + outs = self.backbone_network.sample_sde( + C_unc, + X_init=X_unc, + conditioner=conditioner, + tspan=tspan, + langevin_isothermal=langevin_isothermal, + integrate_func=integrate_func, + sde_func=sde_func, + langevin_factor=langevin_factor, + inverse_temperature=inverse_temperature, + N=steps, + initialize_noise=initialize_noise, + **kwargs, + ) + + if S_unc.shape != outs["C"].shape: + S = torch.zeros_like(outs["C"]).long() + else: + S = S_unc + + assert S.shape == outs["C"].shape + + proteins = [ + Protein.from_XCS(outs_X[None, ...], outs_C[None, ...], outs_S[None, ...]) + for outs_X, outs_C, outs_S in zip(outs["X_sample"], outs["C"], S) + ] + if samples == 1: + proteins = proteins[0] + + if not full_output: + return proteins + else: + outs["S"] = S + trajectories = self._format_trajectory( + outs, "X_trajectory", trajectory_length + ) + + trajectories_Xhat = self._format_trajectory( + outs, "Xhat_trajectory", trajectory_length + ) + + # use unconstrained C and S for Xunc_trajectory + outs["S"] = S_unc + outs["C"] = C_unc + trajectories_Xunc = self._format_trajectory( + outs, "Xunc_trajectory", trajectory_length + ) + + if samples == 1: + full_output_dictionary = { + "trajectory": trajectories[0], + "Xhat_trajectory": trajectories_Xhat[0], + "Xunc_trajectory": trajectories_Xunc[0], + } + else: + full_output_dictionary = { + "trajectory": trajectories, + "Xhat_trajectory": trajectories_Xhat, + "Xunc_trajectory": trajectories_Xunc, + } + + return proteins, full_output_dictionary + + def _format_trajectory(self, outs, key, trajectory_length): + trajectories = [ + Protein.from_XCS_trajectory( + [ + outs_X[i][None, ...] + for outs_X in self._resample_trajectory( + trajectory_length, outs[key] + ) + ], + outs_C[None, ...], + outs_S[None, ...], + ) + for i, (outs_C, outs_S) in enumerate(zip(outs["C"], outs["S"])) + ] + return trajectories + + def _resample_trajectory(self, trajectory_length, trajectory): + if trajectory_length < 0: + raise ValueError( + "The trajectory length must fall on the interval [0, sample_steps]." + ) + n = len(trajectory) + trajectory_length = min(n, trajectory_length) + idx = torch.linspace(0, n - 1, trajectory_length).long() + return [trajectory[i] for i in idx] + + def design( + self, + protein: Protein, + design_ban_S: Optional[List[str]] = None, + design_method: Literal["potts", "autoregressive"] = "potts", + design_selection: Optional[Union[str, torch.Tensor]] = None, + design_t: Optional[float] = 0.5, + temperature_S: float = 0.01, + temperature_chi: float = 1e-3, + top_p_S: Optional[float] = None, + regularization: Optional[str] = "LCP", + potts_mcmc_depth: int = 500, + potts_proposal: Literal["dlmc", "chromatic"] = "dlmc", + potts_symmetry_order: Optional[int] = None, + verbose: bool = False, + ) -> Protein: + """Performs sequence design and repacking on the specified Protein object + and returns an updated copy. + + Args: + protein (Protein): The protein to design. + design_ban_S (list of str, optional): List of amino acid single-letter + codes to ban, e.g. `["C"]` to ban cysteines. + design_method (str, optional): Specifies which method to use for design. valid + methods are potts and autoregressive. Default is potts. + design_selection (str or torch.Tensor, optional): Clamp selection for + conditioning on a subsequence during sequence sampling. Can be + either a selection string or a binary design mask indicating + positions to be sampled with shape `(num_batch, num_residues)` or + position-specific valid amino acid choices with shape + `(num_batch, num_residues, num_alphabet)`. + design_t (float or torch.Tensor, optional): Diffusion time for models + trained with diffusion augmentation of input structures. Setting `t=0` + or `t=None` will condition the model to treat the structure as + exact coordinates, while values of `t > 0` will condition + the model to treat structures as though they were drawn from + noise-augmented ensembles with that noise level. For robust design + (default) we recommend `t=0.5`, or for literal design we recommend + `t=0.0`. May be a float or a tensor of shape `(num_batch)`. + temperature_S (float, optional): Temperature for sequence sampling. + Default 0.01. + temperature_chi (float, optional): Temperature for chi angle sampling. + Default 1e-3. + top_p_S (float, optional): Top-p sampling cutoff for autoregressive + sampling. + regularization (str, optional): Complexity regularization for + sampling. + potts_mcmc_depth (int, optional): Depth of sampling (number of steps per + alphabet letter times number of sites) per cycle. + potts_proposal (str): MCMC proposal for Potts sampling. Currently implemented + proposals are `dlmc` (default) for Discrete Langevin Monte Carlo [1] or + `chromatic` for graph-colored block Gibbs sampling. + [1] Sun et al. Discrete Langevin Sampler via Wasserstein Gradient Flow (2023). + potts_symmetry_order (int, optional): Symmetric design. + The first `(num_nodes // symmetry_order)` residues in the protein + system will be variable, and all consecutively tiled sets of residues + will be locked to these during decoding. Internally this is accomplished by + summing the parameters Potts model under a symmetry constraint + into this reduced sized system and then back imputing at the end. + Currently only implemented for Potts models. + + Returns: + A new Protein object with updated sequence and, optionally, side-chains. + """ + protein = copy.deepcopy(protein) + protein.canonicalize() + + X, C, S = protein.to_XCS() + if design_method not in set(["potts", "autoregressive"]): + raise NotImplementedError( + "Valid design methods are potts and autoregressive, recieved" + f" {design_method}" + ) + + # Optional sequence clamping + mask_sample = None + if design_selection is not None: + if isinstance(design_selection, str): + design_selection = protein.get_mask(design_selection) + mask_sample = design_selection + + X_sample, S_sample, _ = self.design_network.sample( + X, + C, + S, + t=design_t, + mask_sample=mask_sample, + temperature_S=temperature_S, + temperature_chi=temperature_chi, + ban_S=design_ban_S, + sampling_method=design_method, + regularization=regularization, + potts_sweeps=potts_mcmc_depth, + potts_proposal=potts_proposal, + verbose=verbose, + symmetry_order=potts_symmetry_order, + ) + protein.sys.update_with_XCS(X_sample, C=None, S=S_sample) + return protein + + def _design_ar(self, protein, alphabet=None, temp_S=0.1, temp_chi=1e-3): + X, C, S = protein.to_XCS() + ban_S = None + if alphabet is not None: + ban_S = set(alphabetical_restype_3).difference(alphabet) + + X_sample, S_sample, _, _ = self.design_network_ar.sample( + X, + C, + S, + temperature_S=temp_S, + temperature_chi=temp_chi, + return_scores=True, + ban_S=ban_S, + ) + + protein.sys.update_with_XCS(X_sample, C=None, S=S_sample) + + return protein + + def pack( + self, protein: Protein, temperature_chi: float = 1e-3, clamped: bool = False + ) -> Protein: + """Packs Sidechains of a Protein using the design network + + Args: + protein (Protein): The Protein to repack. + temperature_chi (float): Temperature parameter for sampling chi + angles. Even if a high temperature sequence is sampled, this is + recommended to always be low. Default is `1E-3`. + clamped (bool): If `True`, no sampling is done and the likelihood + values will be calculated for the input sequence and structure. + Used for validating the sequential versus parallel decoding + modes. Default is `False` + + Returns: + Protein: The Repacked Protein + """ + X, C, S = protein.to_XCS(all_atom=False) + + X_repack, _, _ = self.design_network.pack( + X, + C, + S, + temperature_chi=temperature_chi, + clamped=clamped, + return_scores=True, + ) + # Convert S_repack to seq + protein.sys.update_with_XCS(X_repack, C=None, S=S) + + return protein + + def score_backbone( + self, + proteins: Union[List[Protein], Protein], + num_samples: int = 50, + tspan: List[float] = [1e-4, 1.0], + ) -> Union[List[dict], dict]: + """ + Score Proteins with the following chroma scores: + elbo: + elbo_X: + rmsd_ratio: + fragment_mse: + neighborhood_mse: + distance_mse: + hb_local: + hb_nonlocal: + + Args: + proteins (list of Protein or Protein): The Proteins to be scored. + num_samples (int, optional): The number of time points to calculate the metrics. Default 50. + tspan (list of float, optional): A list of two times [t_initial, t_final] which represent + the range of times to draw samples. Default [1e-4, 1.0]. + + Returns: + List of dict or dict: A dictionary containing all of the score data. + Scores are returned as named tuples. + """ + + # Extract XCS for scoring + device = next(self.parameters()).device + if isinstance(proteins, list): + X, C, S = self._protein_list_to_XCS(proteins, device=device) + else: + X, C, S = proteins.to_XCS(device=device) + + # Generate Scores + metrics, metrics_samples = self.backbone_network.estimate_metrics( + X, C, return_samples=True, num_samples=num_samples, tspan=tspan + ) + + if isinstance(proteins, list): + metric_dictionary = [ + self._make_metric_dictionary(metrics, metrics_samples, idx=i) + for i in range(len(proteins)) + ] + else: + metric_dictionary = self._make_metric_dictionary(metrics, metrics_samples) + + return metric_dictionary + + def score_sequence( + self, + proteins: Union[List[Protein], Protein], + t: Optional[torch.Tensor] = None, + ) -> dict: + """ + Scores designed Proteins with the following Chroma scores: + - -log(p) for sequences and chi angles + - average RMSD and number of clashes per side-chain + For further details on the scores computed, see + chroma.models.graph_design.GraphDesign.loss. + + Args: + proteins (list of Protein or Protein): The Proteins to be scored. + t (torch.Tensor, optional): Diffusion timesteps corresponding to + noisy input backbones, of shape `(num_batch)`. Default is no + noise. + + Returns: + List of dict or dict: A dictionary containing all of the score data. + Scores are returned as named tuples. + """ + + # Extract XCS for scoring + device = next(self.parameters()).device + if isinstance(proteins, list): + X, C, S = self._protein_list_to_XCS(proteins, all_atom=True, device=device) + output_scores = [{} for _ in range(len(proteins))] + else: + X, C, S = proteins.to_XCS(all_atom=True, device=device) + output_scores = {} + losses = self.design_network.loss(X, C, S, t=t, batched=False) + # each value in the losses dictionary contains the results for all proteins + for name, loss_tensor in losses.items(): + loss_list = [_t.squeeze() for _t in loss_tensor.split(1)] + if isinstance(proteins, list): + for i, loss in enumerate(loss_list): + output_scores[i][name] = loss + else: + output_scores[name] = loss_list[0] + return output_scores + + def _protein_list_to_XCS(self, list_of_proteins, all_atom=False, device=None): + """Package up proteins with padding""" + + # get all the XCS stuff + Xs, Cs, Ss = zip( + *[protein.to_XCS(all_atom=all_atom) for protein in list_of_proteins] + ) + + # Get Max Dims for Xs, Cs, Ss + Dmax = max([C.shape[1] for C in Cs]) + device = Xs[0].device + + # Augment each with zeros + with torch.no_grad(): + X = torch.cat( + [nn.functional.pad(X, (0, 0, 0, 0, 0, Dmax - X.shape[1])) for X in Xs] + ) + C = torch.cat([nn.functional.pad(C, (0, Dmax - C.shape[1])) for C in Cs]) + S = torch.cat([nn.functional.pad(S, (0, Dmax - S.shape[1])) for S in Ss]) + return X, C, S + + def score( + self, + proteins: Union[List[Protein], Protein], + num_samples: int = 50, + tspan: List[float] = [1e-4, 1.0], + ) -> Tuple[Union[List[dict], dict], dict]: + backbone_scores = self.score_backbone(proteins, num_samples, tspan) + sequence_scores = self.score_sequence(proteins) + if isinstance(proteins, list): + for ss in sequence_scores: + ss["t_seq"] = ss.pop("t") + return [bs | ss for bs, ss in zip(backbone_scores, sequence_scores)] + else: + sequence_scores["t_seq"] = sequence_scores.pop("t") + return backbone_scores | sequence_scores + + def _make_metric_dictionary(self, metrics, metrics_samples, idx=None): + # Process Metrics into a Single Dictionary + metric_dictionary = {} + for k, vs in metrics_samples.items(): + if k == "t": + metric_dictionary["t"] = vs + elif k in ["X", "X0_pred"]: + if idx is None: + v = metrics[k] + else: + vs = vs[idx] + v = metrics[k][idx] + score = namedtuple(k, ["value", "samples"]) + metric_dictionary[k] = score(value=v, samples=vs) + else: + if idx is None: + v = metrics[k].item() + else: + vs = vs[idx] + v = metrics[k][idx].item() + vs = [i.item() for i in vs] + score = namedtuple(k, ["score", "subcomponents"]) + metric_dictionary[k] = score(score=v, subcomponents=vs) + + return metric_dictionary + + def _init_backbones(self, num_backbones, length_backbones): + # Start with purely alpha backbones + X = ProteinBackbone( + num_batch=num_backbones, + num_residues=sum(length_backbones), + init_state="alpha", + )() + C = torch.cat( + [torch.full([rep], i + 1) for i, rep in enumerate(length_backbones)] + ).expand(X.shape[0], -1) + S = torch.zeros_like(C) + return [i.to(next(self.parameters()).device) for i in [X, C, S]] diff --git a/prtm/models/chroma/chroma_utils.py b/prtm/models/chroma/chroma_utils.py new file mode 100644 index 00000000..c8e47ccc --- /dev/null +++ b/prtm/models/chroma/chroma_utils.py @@ -0,0 +1,174 @@ +# Copyright Generate Biomedicines, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import numpy as np +import torch +from PIL import Image, ImageDraw, ImageFont +from sklearn.decomposition import PCA + +from prtm.models import chroma +from prtm.models.chroma.modeling import protein_to_xcs +from prtm.models.chroma.structure import backbone +from prtm.protein import ProteinBase + + +def letter_to_point_cloud( + letter="G", + width_pixels=35, + font=os.path.join( + os.path.dirname(chroma.__path__[0]), "assets/LiberationSans-Regular.ttf" + ), + depth_ratio=0.15, + fontsize_ratio=1.2, + stroke_width=1, + margin=0.5, + max_points=2000, +): + """Build a point cloud from a letter""" + depth = int(depth_ratio * width_pixels) + fontsize = int(fontsize_ratio * width_pixels) + + font = ImageFont.truetype(font, fontsize) + ascent, descent = font.getmetrics() + text_width = font.getmask(letter).getbbox()[2] + text_height = font.getmask(letter).getbbox()[3] + descent + + margin_width = int(text_width * margin) + margin_height = int(text_height * margin) + image_size = [text_width + margin_width, text_height + margin_height] + + image = Image.new("RGBA", image_size, (255, 255, 255)) + draw = ImageDraw.Draw(image) + draw.text( + (margin_width // 2, margin_height // 2), + letter, + (0, 0, 0), + font=font, + stroke_width=stroke_width, + stroke_fill="black", + ) + + A = np.asarray(image).mean(-1) + A = A < 100.0 + V = np.ones(list(A.shape[:2]) + [depth]) * A[:, :, None] + X_point_cloud = np.stack(np.nonzero(V), 1) + # Uniform dequantization + X_point_cloud = X_point_cloud + np.random.rand(*X_point_cloud.shape) + + if max_points is not None and X_point_cloud.shape[0] > max_points: + np.random.shuffle(X_point_cloud) + X_point_cloud = X_point_cloud[:max_points, :] + + return X_point_cloud + + +def point_cloud_rescale( + X, num_residues, neighbor_k=8, volume_per_residue=128.57, scale_ratio=0.4 +): + """Rescale target coordinates to occupy protein-sized volume""" + + # Use heuristic for radius value from the average m-th nearest neighbor + # This was tuned empirically for target problems (could be optimized on the fly as LB estimate) + D = np.sqrt(np.square(X[None, :] - X[:, None]).sum(-1)) + radius = 0.5 * np.sort(D, axis=1)[:, neighbor_k].mean() + D.max() + + # Estimate initial volume with 2nd order inclusion exclusion + V = point_cloud_volume(X, radius) + + # Compute target volume, which scales linearly with number of residues + V_target = num_residues * volume_per_residue + scale_factor = (scale_ratio * V_target / V) ** (1.0 / 3.0) + X_rescale = scale_factor * X + cutoff_D = scale_factor * radius + return X_rescale, cutoff_D + + +def point_cloud_volume(X, radius): + """Estimate the volume of a point cloud given sphere radii""" + N = X.shape[0] + + # Volume estimation - One body volumes + V_1 = N * (4.0 / 3.0) * np.pi * radius**3 + + # Volume estimation - 2nd order overlaps + D = np.sqrt(np.square(X[None, :] - X[:, None]).sum(-1)) + overlap_ij = ( + (D < 2.0 * radius) + * (np.pi / 12.0) + * (4.0 * radius + D) + * (2.0 * radius - D) ** 2 + ) + V_2 = np.tril(overlap_ij, k=-1).sum() + + # Inclusion-Exclusion Principle + volume = V_1 - V_2 + return volume + + +def plane_split_protein(protein: ProteinBase, mask_percent=0.5): + """Split a protein by plane, masking mask_percent of residues.""" + X, C, _ = protein_to_xcs(protein, all_atom=False) + + X = backbone.center_X(X, C) + points = X[C > 0].reshape(-1, 3) + pca = PCA(n_components=1) + normal = torch.from_numpy( + pca.fit_transform(points.detach().cpu().numpy().transpose(1, 0)) + ).to(X.device) + c_alphas = X[:, :, 1, :] + + c = 0 + tries = 0 + + def percent_masked(c): + C_mask = ((c_alphas @ normal) > c).squeeze(-1) & (C > 0) + return (~C_mask).float().sum().item() / (C > 0).sum().item() + + # In the first stage we find the minimum C such that all of the residues + # lie on one side of the plane (c_alphas @ normal = c) + while (percent_masked(c) < 1.0) and (tries < 300000): + tries += 1 + c += 100 + + # Now we drag the plane back until percent_masked - masked_percent is small. + size = X.size(1) + threshold = 0.1 if size < 100 else 0.05 if size < 500 else 0.01 + tries = 0 + while (np.abs(percent_masked(c) - mask_percent) > threshold) and (tries < 300000): + c -= 100 + tries += 1 + + if tries >= 300000: + print( + "Tried and failed to split protein by plane to grab" + f" {mask_percent} residues." + ) + c = 0 + C_mask = ((c_alphas @ normal) > c).squeeze(-1) & (C > 0) + print( + f"Returning {100 * percent_masked(0.0):.2f} percent residues masked" + " instead." + ) + + else: + C_mask = ((c_alphas @ normal) > c).squeeze(-1) & (C > 0) + print( + f"Split protein by plane, masking {100 * percent_masked(c):.2f} percent of" + " residues." + ) + + return C_mask diff --git a/prtm/models/chroma/complexity.py b/prtm/models/chroma/complexity.py new file mode 100644 index 00000000..e7457214 --- /dev/null +++ b/prtm/models/chroma/complexity.py @@ -0,0 +1,201 @@ +# Copyright Generate Biomedicines, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Layers for computing sequence complexities. +""" + +import numpy as np +import torch +import torch.nn.functional as F + +from prtm.constants.residue_constants import alphabetical_restypes +from prtm.models.chroma.graph import collect_neighbors + + +def compositions(S: torch.Tensor, C: torch.LongTensor, w: int = 30): + """Compute local compositions per residue. + + Args: + S (torch.Tensor): Sequence tensor with shape `(num_batch, num_residues)` + (long) or `(num_batch, num_residues, num_alphabet)` (float). + C (torch.LongTensor): Chain map with shape `(num_batch, num_residues)`. + w (int, optional): Window size. + + Returns: + P (torch.Tensor): Local compositions with shape + `(num_batch, num_residues - w + 1, num_alphabet)`. + N (torch.Tensor): Local counts with shape + `(num_batch, num_residues - w + 1, num_alphabet)`. + mask_P (torch.Tensor): Mask with shape + `(num_batch, num_residues - w + 1)`. + """ + S.device + Q = len(alphabetical_restypes) + mask_i = (C > 0).float() + if len(S.shape) == 2: + S = F.one_hot(S, Q) + + # Build neighborhoods and masks + S_onehot = mask_i[..., None] * S + kx = torch.arange(w, device=S.device) - w // 2 + edge_idx = ( + torch.arange(S.shape[1], device=S.device)[None, :, None] + kx[None, None, :] + ) + mask_ij = (edge_idx > 0) & (edge_idx < S.shape[1]) + edge_idx = edge_idx.clamp(min=0, max=S.shape[1] - 1) + C_i = C[..., None] + C_j = collect_neighbors(C_i, edge_idx)[..., 0] + mask_ij = (mask_ij & C_j.eq(C_i) & (C_i > 0) & (C_j > 0)).float() + + # Sum neighborhood composition + S_j = mask_ij[..., None] * collect_neighbors(S_onehot, edge_idx) + N = S_j.sum(2) + + num_N = N.sum(-1, keepdims=True) + P = N / (num_N + 1e-5) + mask_i = ((num_N[..., 0] > 0) & (C > 0)).float() + mask_ij = mask_i[..., None] * mask_ij + return P, N, edge_idx, mask_i, mask_ij + + +def complexity_lcp( + S: torch.LongTensor, + C: torch.LongTensor, + w: int = 30, + entropy_min: float = 2.32, + method: str = "naive", + differentiable=True, + eps: float = 1e-5, + min_coverage=0.9, + # entropy_min: float = 2.52, + # method = "chao-shen" +) -> torch.Tensor: + """Compute the Local Composition Perplexity metric. + + Args: + S (torch.Tensor): Sequence tensor with shape `(num_batch, num_residues)` + (index tensor) or `(num_batch, num_residues, num_alphabet)`. + C (torch.LongTensor): Chain map with shape `(num_batch, num_residues)`. + w (int): Window size. + grad_pseudocount (float): Pseudocount for stabilizing entropy gradients + on backwards pass. + eps (float): Small number for numerical stability in division and logarithms. + + Returns: + U (torch.Tensor): Complexities with shape `(num_batch)`. + """ + + # adjust window size based on sequence length + if S.shape[1] < w: + w = S.shape[1] + + P, N, edge_idx, mask_i, mask_ij = compositions(S, C, w) + + # Only count windows with `min_coverage` + int(min_coverage * w) + mask_coverage = N.sum(-1) > int(min_coverage * w) + + H = estimate_entropy(N, method=method) + U = mask_coverage * (torch.exp(H) - np.exp(entropy_min)).clamp(max=0).square() + + # Compute entropy as a function of perturbed counts + if differentiable and len(S.shape) == 3: + # Compute how a mutation changes entropy for each neighbor + N_neighbors = collect_neighbors(N, edge_idx) + mask_coverage_j = collect_neighbors(mask_coverage[..., None], edge_idx) + N_ij = (N_neighbors - S[:, :, None, :])[..., None, :] + torch.eye( + N.shape[-1], device=N.device + )[None, None, None, ...] + N_ij = N_ij.clamp(min=0) + H_ij = estimate_entropy(N_ij, method=method) + U_ij = (torch.exp(H_ij) - np.exp(entropy_min)).clamp(max=0).square() + U_ij = mask_ij[..., None] * mask_coverage_j * U_ij + U_differentiable = (U_ij.detach() * S[:, :, None, :]).sum([-1, -2]) + U = U.detach() + U_differentiable - U_differentiable.detach() + + U = (mask_i * U).sum(1) + return U + + +def complexity_scores_lcp_t( + t, + S: torch.LongTensor, + C: torch.LongTensor, + idx: torch.LongTensor, + edge_idx_t: torch.LongTensor, + mask_ij_t: torch.Tensor, + w: int = 30, + entropy_min: float = 2.515, + eps: float = 1e-5, + method: str = "chao-shen", +) -> torch.Tensor: + """Compute local LCP scores for autoregressive decoding.""" + Q = len(alphabetical_restypes) + O = F.one_hot(S, Q) + O_j = collect_neighbors(O, edge_idx_t) + idx_i = idx[:, t, None] + C_i = C[:, t, None] + idx_j = collect_neighbors(idx[..., None], edge_idx_t)[..., 0] + C_j = collect_neighbors(C[..., None], edge_idx_t)[..., 0] + + # Sum valid neighbor counts + is_near = (idx_i - idx_j).abs() <= w / 2 + same_chain = C_i == C_j + valid_ij_t = (is_near * same_chain * (mask_ij_t > 0)).float()[..., None] + N_k = (valid_ij_t * O_j).sum(-2) + + # Compute counts under all possible extensions + N_k = N_k[:, :, None, :] + torch.eye(Q, device=N_k.device)[None, None, ...] + + H = estimate_entropy(N_k, method=method) + U = -(torch.exp(H) - np.exp(entropy_min)).clamp(max=0).square() + return U + + +def estimate_entropy( + N: torch.Tensor, method: str = "chao-shen", eps: float = 1e-11 +) -> torch.Tensor: + """Estimate entropy from counts. + + See Chao, A., & Shen, T. J. (2003) for more details. + + Args: + N (torch.Tensor): Tensor of counts with shape `(..., num_bins)`. + + Returns: + H (torch.Tensor): Estimated entropy with shape `(...)`. + """ + N = N.float() + N_total = N.sum(-1, keepdims=True) + P = N / (N_total + eps) + + if method == "chao-shen": + # Estimate coverage and adjusted frequencies + singletons = N.long().eq(1).sum(-1, keepdims=True).float() + C = 1.0 - singletons / (N_total + eps) + P_adjust = C * P + P_inclusion = (1.0 - (1.0 - P_adjust) ** N_total).clamp(min=eps) + H = -(P_adjust * torch.log(P_adjust.clamp(min=eps)) / P_inclusion).sum(-1) + elif method == "miller-maddow": + bins = (N > 0).float().sum(-1) + bias = (bins - 1) / (2 * N_total[..., 0] + eps) + H = -(P * torch.log(P + eps)).sum(-1) + bias + elif method == "laplace": + N = N.float() + 1 / N.shape[-1] + N_total = N.sum(-1, keepdims=True) + P = N / (N_total + eps) + H = -(P * torch.log(P)).sum(-1) + else: + H = -(P * torch.log(P + eps)).sum(-1) + return H diff --git a/prtm/models/chroma/config.py b/prtm/models/chroma/config.py new file mode 100644 index 00000000..e69de29b diff --git a/prtm/models/chroma/conv.py b/prtm/models/chroma/conv.py new file mode 100644 index 00000000..b2a63dff --- /dev/null +++ b/prtm/models/chroma/conv.py @@ -0,0 +1,56 @@ +# Copyright Generate Biomedicines, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import platform + +import torch +import torch.nn.functional as F + +MACHINE = platform.machine() + + +def filter1D_linear_decay(Z, B): + """Apply a low-pass filter with batch-heterogeneous coefficients. + + Computes `x_i = z_i + b * x_{i-1}` where `b` varies per batch member. + + Args: + Z (torch.Tensor): Batch of one-dimensional signals with shape `(N, W)`. + B (torch.Tensor): Batch of coefficients with shape `(N)`. + + Returns: + X (torch.Tensor): Result of applying linear recurrence with shape `(N, W)`. + """ + + # Build filter coefficients as powers of B + N, W = Z.shape + k = (W - 1) - torch.arange(W, device=Z.device) + kernel = B[:, None, None] ** k[None, None, :] + + # Pad on left to convolve from backwards in time + Z_pad = F.pad(Z, (W - 1, 0))[None, ...] + + # Group convolution can effectively do one filter per batch + while True: + X = F.conv1d(Z_pad, kernel, stride=1, padding=0, groups=N)[0, :, :] + # on arm64 (M1 Mac) this convolution erroneously sometimes produces NaNs + if ( + (MACHINE == "arm64") + and torch.isnan(X).any() + and (not torch.isnan(Z_pad).any()) + and (not torch.isnan(kernel).any()) + ): + continue + break + return X diff --git a/prtm/models/chroma/fetchdb.py b/prtm/models/chroma/fetchdb.py new file mode 100644 index 00000000..fa3b610f --- /dev/null +++ b/prtm/models/chroma/fetchdb.py @@ -0,0 +1,47 @@ +# Copyright Generate Biomedicines, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Functions to retrieve information from external databases via their API; Uniprot and RCSB are the primary databases included here. + +""" + + +import requests + + +def _download_file(url, out_file): + try: + with requests.get(url, stream=True) as r: + r.raise_for_status() + with open(out_file, "wb") as f: + for chunk in r.iter_content(chunk_size=8192): + if chunk: + f.write(chunk) + return True + except requests.HTTPError: + return False + + +def RCSB_file_download(pdb_id, ext, local_filename): + """Downloads a file from the RCSB files section. + + Args: + pdb_id (str) : 4-letter pdb id, case-insensitive + ext (str) : Extension of file. E.g. ".pdb" or ".pdb1" + local_filename (str) : Name for downloaded file. + Returns: + None + """ + url = f"https://files.rcsb.org/view/{pdb_id.upper()}{ext}" + return _download_file(url, local_filename) diff --git a/prtm/models/chroma/geometry.py b/prtm/models/chroma/geometry.py new file mode 100644 index 00000000..7a3c946e --- /dev/null +++ b/prtm/models/chroma/geometry.py @@ -0,0 +1,567 @@ +# Copyright Generate Biomedicines, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Dictionary containing ideal internal coordinates and chi angle assignments + for building amino acid 3D coordinates""" +from typing import Dict + +ATOM_SYMMETRIES = { + "ARG": [("NH1", "NH2")], # Correct handling of NH1 and NH2 is relabeling + "ASP": [("OD1", "OD2")], + "GLU": [("OE1", "OE2")], + "PHE": [("CD1", "CD2"), ("CE1", "CE2")], + "TYR": [("CD1", "CD2"), ("CE1", "CE2")], +} + + +AA_GEOMETRY: Dict[str, dict] = { + "ALA": { + "atoms": ["CB"], + "chi_indices": [], + "parents": [["N", "C", "CA"]], + "types": {"C": "C", "CA": "CT1", "CB": "CT3", "N": "NH1", "O": "O"}, + "z-angles": [111.09], + "z-dihedrals": [123.23], + "z-lengths": [1.55], + }, + "ARG": { + "atoms": ["CB", "CG", "CD", "NE", "CZ", "NH1", "NH2"], + "chi_indices": [1, 2, 3, 4], + "parents": [ + ["N", "C", "CA"], + ["N", "CA", "CB"], + ["CA", "CB", "CG"], + ["CB", "CG", "CD"], + ["CG", "CD", "NE"], + ["CD", "NE", "CZ"], + ["NH1", "NE", "CZ"], + ], + "types": { + "C": "C", + "CA": "CT1", + "CB": "CT2", + "CD": "CT2", + "CG": "CT2", + "CZ": "C", + "N": "NH1", + "NE": "NC2", + "NH1": "NC2", + "NH2": "NC2", + "O": "O", + }, + "z-angles": [112.26, 115.95, 114.01, 107.09, 123.05, 118.06, 122.14], + "z-dihedrals": [123.64, 180.0, 180.0, 180.0, 180.0, 180.0, 178.64], + "z-lengths": [1.56, 1.55, 1.54, 1.5, 1.34, 1.33, 1.33], + }, + "ASN": { + "atoms": ["CB", "CG", "OD1", "ND2"], + "chi_indices": [1, 2], + "parents": [ + ["N", "C", "CA"], + ["N", "CA", "CB"], + ["CA", "CB", "CG"], + ["OD1", "CB", "CG"], + ], + "types": { + "C": "C", + "CA": "CT1", + "CB": "CT2", + "CG": "CC", + "N": "NH1", + "ND2": "NH2", + "O": "O", + "OD1": "O", + }, + "z-angles": [113.04, 114.3, 122.56, 116.15], + "z-dihedrals": [121.18, 180.0, 180.0, -179.19], + "z-lengths": [1.56, 1.53, 1.23, 1.35], + }, + "ASP": { + "atoms": ["CB", "CG", "OD1", "OD2"], + "chi_indices": [1, 2], + "parents": [ + ["N", "C", "CA"], + ["N", "CA", "CB"], + ["CA", "CB", "CG"], + ["OD1", "CB", "CG"], + ], + "types": { + "C": "C", + "CA": "CT1", + "CB": "CT2A", + "CG": "CC", + "N": "NH1", + "O": "O", + "OD1": "OC", + "OD2": "OC", + }, + "z-angles": [114.1, 112.6, 117.99, 117.7], + "z-dihedrals": [122.33, 180.0, 180.0, -170.23], + "z-lengths": [1.56, 1.52, 1.26, 1.25], + }, + "CYS": { + "atoms": ["CB", "SG"], + "chi_indices": [1], + "parents": [["N", "C", "CA"], ["N", "CA", "CB"]], + "types": {"C": "C", "CA": "CT1", "CB": "CT2", "N": "NH1", "O": "O", "SG": "S"}, + "z-angles": [111.98, 113.87], + "z-dihedrals": [121.79, 180.0], + "z-lengths": [1.56, 1.84], + }, + "GLN": { + "atoms": ["CB", "CG", "CD", "OE1", "NE2"], + "chi_indices": [1, 2, 3], + "parents": [ + ["N", "C", "CA"], + ["N", "CA", "CB"], + ["CA", "CB", "CG"], + ["CB", "CG", "CD"], + ["OE1", "CG", "CD"], + ], + "types": { + "C": "C", + "CA": "CT1", + "CB": "CT2", + "CD": "CC", + "CG": "CT2", + "N": "NH1", + "NE2": "NH2", + "O": "O", + "OE1": "O", + }, + "z-angles": [111.68, 115.52, 112.5, 121.52, 116.84], + "z-dihedrals": [121.91, 180.0, 180.0, 180.0, 179.57], + "z-lengths": [1.55, 1.55, 1.53, 1.23, 1.35], + }, + "GLU": { + "atoms": ["CB", "CG", "CD", "OE1", "OE2"], + "chi_indices": [1, 2, 3], + "parents": [ + ["N", "C", "CA"], + ["N", "CA", "CB"], + ["CA", "CB", "CG"], + ["CB", "CG", "CD"], + ["OE1", "CG", "CD"], + ], + "types": { + "C": "C", + "CA": "CT1", + "CB": "CT2A", + "CD": "CC", + "CG": "CT2", + "N": "NH1", + "O": "O", + "OE1": "OC", + "OE2": "OC", + }, + "z-angles": [111.71, 115.69, 115.73, 114.99, 120.08], + "z-dihedrals": [121.9, 180.0, 180.0, 180.0, -179.1], + "z-lengths": [1.55, 1.56, 1.53, 1.26, 1.25], + }, + "GLY": { + "atoms": [], + "chi_indices": [], + "parents": [], + "types": {"C": "C", "CA": "CT2", "N": "NH1", "O": "O"}, + "z-angles": [], + "z-dihedrals": [], + "z-lengths": [], + }, + "HIS": { + "atoms": ["CB", "CG", "ND1", "CD2", "CE1", "NE2"], + "chi_indices": [1, 2], + "parents": [ + ["N", "C", "CA"], + ["N", "CA", "CB"], + ["CA", "CB", "CG"], + ["ND1", "CB", "CG"], + ["CB", "CG", "ND1"], + ["CB", "CG", "CD2"], + ], + "types": { + "C": "C", + "CA": "CT1", + "CB": "CT2", + "CD2": "CPH1", + "CE1": "CPH2", + "CG": "CPH1", + "N": "NH1", + "ND1": "NR1", + "NE2": "NR2", + "O": "O", + }, + "z-angles": [109.99, 114.05, 124.1, 129.6, 107.03, 110.03], + "z-dihedrals": [122.46, 180.0, 90.0, -171.29, -173.21, 171.99], + "z-lengths": [1.55, 1.5, 1.38, 1.36, 1.35, 1.38], + }, + "HSD": { + "atoms": ["CB", "CG", "ND1", "CD2", "CE1", "NE2"], + "chi_indices": [1, 2], + "parents": [ + ["N", "C", "CA"], + ["N", "CA", "CB"], + ["CA", "CB", "CG"], + ["ND1", "CB", "CG"], + ["CB", "CG", "ND1"], + ["CB", "CG", "CD2"], + ], + "types": { + "C": "C", + "CA": "CT1", + "CB": "CT2", + "CD2": "CPH1", + "CE1": "CPH2", + "CG": "CPH1", + "N": "NH1", + "ND1": "NR1", + "NE2": "NR2", + "O": "O", + }, + "z-angles": [109.99, 114.05, 124.1, 129.6, 107.03, 110.03], + "z-dihedrals": [122.46, 180.0, 90.0, -171.29, -173.21, 171.99], + "z-lengths": [1.55, 1.5, 1.38, 1.36, 1.35, 1.38], + }, + "HSE": { + "atoms": ["CB", "CG", "ND1", "CD2", "CE1", "NE2"], + "chi_indices": [], + "parents": [ + ["N", "C", "CA"], + ["N", "CA", "CB"], + ["CA", "CB", "CG"], + ["ND1", "CB", "CG"], + ["CB", "CG", "ND1"], + ["CB", "CG", "CD2"], + ], + "types": { + "C": "C", + "CA": "CT1", + "CB": "CT2", + "CD2": "CPH1", + "CE1": "CPH2", + "CG": "CPH1", + "N": "NH1", + "ND1": "NR2", + "NE2": "NR1", + "O": "O", + }, + "z-angles": [111.67, 116.94, 120.17, 129.71, 105.2, 105.8], + "z-dihedrals": [123.52, 180.0, 90.0, -178.26, -179.2, 178.66], + "z-lengths": [1.56, 1.51, 1.39, 1.36, 1.32, 1.38], + }, + "HSP": { + "atoms": ["CB", "CG", "ND1", "CD2", "CE1", "NE2"], + "chi_indices": [], + "parents": [ + ["N", "C", "CA"], + ["N", "CA", "CB"], + ["CA", "CB", "CG"], + ["ND1", "CB", "CG"], + ["CB", "CG", "ND1"], + ["CB", "CG", "CD2"], + ], + "types": { + "C": "C", + "CA": "CT1", + "CB": "CT2A", + "CD2": "CPH1", + "CE1": "CPH2", + "CG": "CPH1", + "N": "NH1", + "ND1": "NR3", + "NE2": "NR3", + "O": "O", + }, + "z-angles": [109.38, 114.18, 122.94, 128.93, 108.9, 106.93], + "z-dihedrals": [125.13, 180.0, 90.0, -165.26, -167.62, 167.13], + "z-lengths": [1.55, 1.52, 1.37, 1.35, 1.33, 1.37], + }, + "ILE": { + "atoms": ["CB", "CG1", "CG2", "CD1"], + "chi_indices": [1, 3], + "parents": [ + ["N", "C", "CA"], + ["N", "CA", "CB"], + ["CG1", "CA", "CB"], + ["CA", "CB", "CG1"], + ], + "types": { + "C": "C", + "CA": "CT1", + "CB": "CT1", + "CD": "CT3", + "CG1": "CT2", + "CG2": "CT3", + "N": "NH1", + "O": "O", + }, + "z-angles": [112.93, 113.63, 113.93, 114.09], + "z-dihedrals": [124.22, 180.0, -130.04, 180.0], + "z-lengths": [1.57, 1.55, 1.55, 1.54], + }, + "LEU": { + "atoms": ["CB", "CG", "CD1", "CD2"], + "chi_indices": [1, 2], + "parents": [ + ["N", "C", "CA"], + ["N", "CA", "CB"], + ["CA", "CB", "CG"], + ["CD1", "CB", "CG"], + ], + "types": { + "C": "C", + "CA": "CT1", + "CB": "CT2", + "CD1": "CT3", + "CD2": "CT3", + "CG": "CT1", + "N": "NH1", + "O": "O", + }, + "z-angles": [112.12, 117.46, 110.48, 112.57], + "z-dihedrals": [121.52, 180.0, 180.0, 120.0], + "z-lengths": [1.55, 1.55, 1.54, 1.54], + }, + "LYS": { + "atoms": ["CB", "CG", "CD", "CE", "NZ"], + "chi_indices": [1, 2, 3, 4], + "parents": [ + ["N", "C", "CA"], + ["N", "CA", "CB"], + ["CA", "CB", "CG"], + ["CB", "CG", "CD"], + ["CG", "CD", "CE"], + ], + "types": { + "C": "C", + "CA": "CT1", + "CB": "CT2", + "CD": "CT2", + "CE": "CT2", + "CG": "CT2", + "N": "NH1", + "NZ": "NH3", + "O": "O", + }, + "z-angles": [111.36, 115.76, 113.28, 112.33, 110.46], + "z-dihedrals": [122.23, 180.0, 180.0, 180.0, 180.0], + "z-lengths": [1.56, 1.54, 1.54, 1.53, 1.46], + }, + "MET": { + "atoms": ["CB", "CG", "SD", "CE"], + "chi_indices": [1, 2, 3], + "parents": [ + ["N", "C", "CA"], + ["N", "CA", "CB"], + ["CA", "CB", "CG"], + ["CB", "CG", "SD"], + ], + "types": { + "C": "C", + "CA": "CT1", + "CB": "CT2", + "CE": "CT3", + "CG": "CT2", + "N": "NH1", + "O": "O", + "SD": "S", + }, + "z-angles": [111.88, 115.92, 110.28, 98.94], + "z-dihedrals": [121.62, 180.0, 180.0, 180.0], + "z-lengths": [1.55, 1.55, 1.82, 1.82], + }, + "PHE": { + "atoms": ["CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ"], + "chi_indices": [1, 2], + "parents": [ + ["N", "C", "CA"], + ["N", "CA", "CB"], + ["CA", "CB", "CG"], + ["CD1", "CB", "CG"], + ["CB", "CG", "CD1"], + ["CB", "CG", "CD2"], + ["CG", "CD1", "CE1"], + ], + "types": { + "C": "C", + "CA": "CT1", + "CB": "CT2", + "CD1": "CA", + "CD2": "CA", + "CE1": "CA", + "CE2": "CA", + "CG": "CA", + "CZ": "CA", + "N": "NH1", + "O": "O", + }, + "z-angles": [112.45, 112.76, 120.32, 120.76, 120.63, 120.62, 119.93], + "z-dihedrals": [122.49, 180.0, 90.0, -177.96, -177.37, 177.2, -0.12], + "z-lengths": [1.56, 1.51, 1.41, 1.41, 1.4, 1.4, 1.4], + }, + "PRO": { + "atoms": ["CB", "CG", "CD"], + "chi_indices": [1, 2], + "parents": [["N", "C", "CA"], ["N", "CA", "CB"], ["CA", "CB", "CG"]], + "types": { + "C": "C", + "CA": "CP1", + "CB": "CP2", + "CD": "CP3", + "CG": "CP2", + "N": "N", + "O": "O", + }, + "z-angles": [111.74, 104.39, 103.21], + "z-dihedrals": [113.74, 31.61, -34.59], + "z-lengths": [1.54, 1.53, 1.53], + }, + "SER": { + "atoms": ["CB", "OG"], + "chi_indices": [1], + "parents": [["N", "C", "CA"], ["N", "CA", "CB"]], + "types": { + "C": "C", + "CA": "CT1", + "CB": "CT2", + "N": "NH1", + "O": "O", + "OG": "OH1", + }, + "z-angles": [111.4, 112.45], + "z-dihedrals": [124.75, 180.0], + "z-lengths": [1.56, 1.43], + }, + "THR": { + "atoms": ["CB", "OG1", "CG2"], + "chi_indices": [1], + "parents": [["N", "C", "CA"], ["N", "CA", "CB"], ["OG1", "CA", "CB"]], + "types": { + "C": "C", + "CA": "CT1", + "CB": "CT1", + "CG2": "CT3", + "N": "NH1", + "O": "O", + "OG1": "OH1", + }, + "z-angles": [112.74, 112.16, 115.91], + "z-dihedrals": [126.46, 180.0, -124.13], + "z-lengths": [1.57, 1.43, 1.53], + }, + "TRP": { + "atoms": ["CB", "CG", "CD2", "CD1", "CE2", "NE1", "CE3", "CZ3", "CH2", "CZ2"], + "chi_indices": [1, 2], + "parents": [ + ["N", "C", "CA"], + ["N", "CA", "CB"], + ["CA", "CB", "CG"], + ["CD2", "CB", "CG"], + ["CD1", "CG", "CD2"], + ["CG", "CD2", "CE2"], + ["CE2", "CG", "CD2"], + ["CE2", "CD2", "CE3"], + ["CD2", "CE3", "CZ3"], + ["CE3", "CZ3", "CH2"], + ], + "types": { + "C": "C", + "CA": "CT1", + "CB": "CT2", + "CD1": "CA", + "CD2": "CPT", + "CE2": "CPT", + "CE3": "CAI", + "CG": "CY", + "CH2": "CA", + "CZ2": "CAI", + "CZ3": "CA", + "N": "NH1", + "NE1": "NY", + "O": "O", + }, + "z-angles": [ + 111.23, + 115.14, + 123.95, + 129.18, + 106.65, + 107.87, + 132.54, + 118.16, + 120.97, + 120.87, + ], + "z-dihedrals": [ + 122.68, + 180.0, + 90.0, + -172.81, + -0.08, + 0.14, + 179.21, + -0.2, + 0.1, + 0.01, + ], + "z-lengths": [1.56, 1.52, 1.44, 1.37, 1.41, 1.37, 1.4, 1.4, 1.4, 1.4], + }, + "TYR": { + "atoms": ["CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "OH"], + "chi_indices": [1, 2], + "parents": [ + ["N", "C", "CA"], + ["N", "CA", "CB"], + ["CA", "CB", "CG"], + ["CD1", "CB", "CG"], + ["CB", "CG", "CD1"], + ["CB", "CG", "CD2"], + ["CG", "CD1", "CE1"], + ["CE1", "CE2", "CZ"], + ], + "types": { + "C": "C", + "CA": "CT1", + "CB": "CT2", + "CD1": "CA", + "CD2": "CA", + "CE1": "CA", + "CE2": "CA", + "CG": "CA", + "CZ": "CA", + "N": "NH1", + "O": "O", + "OH": "OH1", + }, + "z-angles": [112.34, 112.94, 120.49, 120.46, 120.4, 120.56, 120.09, 120.25], + "z-dihedrals": [122.27, 180.0, 90.0, -176.46, -175.49, 175.32, -0.19, -178.98], + "z-lengths": [1.56, 1.51, 1.41, 1.41, 1.4, 1.4, 1.4, 1.41], + }, + "VAL": { + "atoms": ["CB", "CG1", "CG2"], + "chi_indices": [1], + "parents": [["N", "C", "CA"], ["N", "CA", "CB"], ["CG1", "CA", "CB"]], + "types": { + "C": "C", + "CA": "CT1", + "CB": "CT1", + "CG1": "CT3", + "CG2": "CT3", + "N": "NH1", + "O": "O", + }, + "z-angles": [111.23, 113.97, 112.17], + "z-dihedrals": [122.95, 180.0, 123.99], + "z-lengths": [1.57, 1.54, 1.54], + }, +} diff --git a/prtm/models/chroma/graph.py b/prtm/models/chroma/graph.py new file mode 100644 index 00000000..365d338d --- /dev/null +++ b/prtm/models/chroma/graph.py @@ -0,0 +1,1124 @@ +# Copyright Generate Biomedicines, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Layers for building graph neural networks. + +This module contains layers for building neural networks that can process +graph-structured data. The internal representations of these layers +are node and edge embeddings. +""" + +from typing import Callable, List, Optional, Tuple + +import torch +import torch.nn as nn +from torch.utils.checkpoint import checkpoint +from tqdm.autonotebook import tqdm + +from prtm.models.chroma.attention import Attention + + +class GraphNN(nn.Module): + """Graph neural network with optional edge updates. + + Args: + num_layers (int): Number of layers. + dim_nodes (int): Hidden dimension of node tensor. + dim_edges (int): Hidden dimension of edge tensor. + dropout (float): Dropout rate. + node_mlp_layers (int): Node update function, number of hidden layers. + Default is 1. + node_mlp_dim (int): Node update function, hidden dimension. + Default is to match MLP output dimension. + update_edge (Boolean): Include an edge-update step. Default: True + edge_mlp_layers (int): Edge update function, number of hidden layers. + Default is 1. + edge_mlp_dim (int): Edge update function, hidden dimension. + Default is to match MLP output dimension. + mlp_activation (str): MLP nonlinearity. + `'relu'`: Rectified linear unit. + `'softplus'`: Softplus. + norm (str): Which normalization function to apply between layers. + `'transformer'`: Default layernorm + `'layer'`: Masked Layer norm with shape (input.shape[1:]) + `'instance'`: Masked Instance norm + scale (float): Scaling factor of edge input when updating node (default=1.0) + attentional (bool): If True, use attention for message aggregation function + instead of a sum. Default is False. + num_attention_heads (int): Number of attention heads (if attentional) to use. + Default is 4. + + Inputs: + node_h (torch.Tensor): Node features with shape + `(num_batch, num_nodes, dim_nodes)`. + edge_h (torch.Tensor): Edge features with shape + `(num_batch, num_nodes, num_neighbors, dim_edges)`. + edge_idx (torch.LongTensor): Edge indices for neighbors with shape + `(num_batch, num_nodes, num_neighbors)`. + mask_i (tensor, optional): Node mask with shape `(num_batch, num_nodes)` + mask_ij (tensor, optional): Edge mask with shape + `(num_batch, num_nodes, num_neighbors)` + + Outputs: + node_h_out (torch.Tensor): Updated node features with shape + `(num_batch, num_nodes, dim_nodes)`. + edge_h_out (torch.Tensor): Updated edge features with shape + `(num_batch, num_nodes, num_neighbors, dim_edges)`. + """ + + def __init__( + self, + num_layers: int, + dim_nodes: int, + dim_edges: int, + node_mlp_layers: int = 1, + node_mlp_dim: Optional[int] = None, + edge_update: bool = True, + edge_mlp_layers: int = 1, + edge_mlp_dim: Optional[int] = None, + mlp_activation: str = "relu", + dropout: float = 0.0, + norm: str = "transformer", + scale: float = 1.0, + skip_connect_input: bool = False, + attentional: bool = False, + num_attention_heads: int = 4, + checkpoint_gradients: bool = False, + ): + super(GraphNN, self).__init__() + + self.skip_connect_input = skip_connect_input + self.checkpoint_gradients = checkpoint_gradients + self.layers = nn.ModuleList( + [ + GraphLayer( + dim_nodes=dim_nodes, + dim_edges=dim_edges, + node_mlp_layers=node_mlp_layers, + node_mlp_dim=node_mlp_dim, + edge_update=edge_update, + edge_mlp_layers=edge_mlp_layers, + edge_mlp_dim=edge_mlp_dim, + mlp_activation=mlp_activation, + dropout=dropout, + norm=norm, + scale=scale, + attentional=attentional, + num_attention_heads=num_attention_heads, + ) + for _ in range(num_layers) + ] + ) + + def forward( + self, + node_h: torch.Tensor, + edge_h: torch.Tensor, + edge_idx: torch.LongTensor, + mask_i: Optional[torch.Tensor] = None, + mask_ij: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Run every layer sequentially + node_h_init = node_h + edge_h_init = edge_h + for i, layer in enumerate(self.layers): + if self.skip_connect_input: + node_h = node_h + node_h_init + edge_h = edge_h + edge_h_init + + # Update edge and node + node_h, edge_h = self.checkpoint( + layer, node_h, edge_h, edge_idx, mask_i, mask_ij + ) + + if self.skip_connect_input: + node_h = node_h - node_h_init + edge_h = edge_h - edge_h_init + + # If mask was provided, apply it + if mask_i is not None: + node_h = node_h * (mask_i.unsqueeze(-1) != 0).type(torch.float32) + if mask_ij is not None: + edge_h = edge_h * (mask_ij.unsqueeze(-1) != 0).type(torch.float32) + return node_h, edge_h + + def checkpoint(self, layer, *args): + if self.checkpoint_gradients: + return checkpoint(layer, *args) + else: + return layer(*args) + + def sequential( + self, + tensors: dict, + pre_step_function: Callable = None, + post_step_function: Callable = None, + ) -> dict: + """Decode the GNN sequentially along the node index `t`, with callbacks. + + Args: + tensors (dict): Initial set of state tensors. At minimum this should + include the arguments to `forward`, namely `node_h`, `edge_h`, + `edge_idx`, `mask_i`, and `mask_ij`. + pre_step_function (function, optional): Callback function that is + optionally applied to `tensors` before each sequential GNN step as + `tensors_new = pre_step_function(t, pre_step_function)` where `t` is + the node index being updated. It should update elements of the + `tensors` dictionary, and it can access and update the intermediate + GNN state cache via the keyed lists of tensors in `node_h_cache` and + `edge_h_cache`. + post_step_function (function, optional): Same as `pre_step_function`, but + optionally applied after each sequential GNN step. + + Returns: + tensors (dict): Processed set of tensors. + """ + + # Initialize the state cache + tensors["node_h_cache"], tensors["edge_h_cache"] = self.init_steps( + tensors["node_h"], tensors["edge_h"] + ) + + # Sequential iteration + num_steps = tensors["node_h"].size(1) + for t in tqdm(range(num_steps), desc="Sequential decoding"): + if pre_step_function is not None: + tensors = pre_step_function(t, tensors) + + tensors["node_h_cache"], tensors["edge_h_cache"] = self.step( + t, + tensors["node_h_cache"], + tensors["edge_h_cache"], + tensors["edge_idx"], + tensors["mask_i"], + tensors["mask_ij"], + ) + + if post_step_function is not None: + tensors = post_step_function(t, tensors) + + return tensors + + def init_steps( + self, node_h: torch.Tensor, edge_h: torch.Tensor + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """Initialize cached node and edge features. + + Args: + node_h (torch.Tensor): Node features with shape + `(num_batch, num_nodes, dim_nodes)`. + edge_h (torch.Tensor): Edge features with shape + `(num_batch, num_nodes, num_neighbors, dim_edges)`. + + Returns: + node_h_cache (torch.Tensor): List of cached node features with `num_layers + 1` + tensors of shape `(num_batch, num_nodes, dim_nodes)`. + edge_h_cache (torch.Tensor): List of cached edge features with `num_layers + 1` + tensors of shape `(num_batch, num_nodes, num_neighbors, dim_edges)`. + """ + num_layers = len(self.layers) + node_h_cache = [node_h.clone() for _ in range(num_layers + 1)] + edge_h_cache = [edge_h.clone() for _ in range(num_layers + 1)] + return node_h_cache, edge_h_cache + + def step( + self, + t: int, + node_h_cache: List[torch.Tensor], + edge_h_cache: List[torch.Tensor], + edge_idx: torch.LongTensor, + mask_i: Optional[torch.Tensor] = None, + mask_ij: Optional[torch.Tensor] = None, + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """Process GNN update for a specific node index t from cached intermediates. + + Inputs: + t (int): Node index to decode. + node_h_cache (List[torch.Tensor]): List of cached node features with + `num_layers + 1` tensors of shape `(num_batch, num_nodes, dim_nodes)`. + edge_h_cache (List[torch.Tensor]): List of cached edge features with + `num_layers + 1` tensors of shape + `(num_batch, num_nodes, num_neighbors, dim_edges)`. + edge_idx (torch.LongTensor): Edge indices for neighbors with shape + `(num_batch, num_nodes, num_neighbors)`. + mask_i (torch.Tensor, optional): Node mask with shape + `(num_batch, num_nodes)`. + mask_ij (torch.Tensor, optional): Edge mask with shape + `(num_batch, num_nodes, num_neighbors)`. + + Outputs: + node_h_cache (List[torch.Tensor]): Updated list of cached node features + with `num_layers + 1` tensors of shape + `(num_batch, num_nodes, dim_nodes)`. This method updates the tensors + in place for memory. + edge_h_cache (List[torch.Tensor]): Updated list of cached edge features + with `num_layers + 1` tensors of shape + `(num_batch, num_nodes, num_neighbors, dim_edges)`. + """ + if self.skip_connect_input: + raise NotImplementedError + + for i, layer in enumerate(self.layers): + # Because the edge updates depend on the updated nodes, + # we need both the input node features node_h and also + # the previous output node states node_h + node_h = node_h_cache[i] + node_h_out = node_h_cache[i + 1] + edge_h = edge_h_cache[i] + # Update edge and node + node_h_t, edge_h_t = checkpoint( + layer.step, t, node_h, node_h_out, edge_h, edge_idx, mask_i, mask_ij + ) + + # Scatter them in place + node_h_cache[i + 1].scatter_( + 1, (t * torch.ones_like(node_h_t)).long(), node_h_t + ) + edge_h_cache[i + 1].scatter_( + 1, (t * torch.ones_like(edge_h_t)).long(), edge_h_t + ) + + return node_h_cache, edge_h_cache + + +class GraphLayer(nn.Module): + """Graph layer that updates each node i given adjacent nodes and edges. + + Args: + dim_nodes (int): Hidden dimension of node tensor. + dim_edges (int): Hidden dimension of edge tensor. + node_mlp_layers (int): Node update function, number of hidden layers. + Default: 1. + node_mlp_dim (int): Node update function, hidden dimension. + Default: Matches MLP output dimension. + update_edge (Boolean): Include an edge-update step. Default: True + edge_mlp_layers (int): Edge update function, number of hidden layers. + Default: 1. + edge_mlp_dim (int): Edge update function, hidden dimension. + Default: Matches MLP output dimension. + mlp_activation (str): MLP nonlinearity. + `'relu'`: Rectified linear unit. + `'softplus'`: Softplus. + dropout (float): Dropout rate. + norm (str): Which normalization function to apply between layers. + `'transformer'`: Default layernorm + `'layer'`: Masked Layer norm with shape (input.shape[1:]) + `'instance'`: Masked Instance norm + scale (float): Scaling factor of edge input when updating node (default=1.0) + + Inputs: + node_h (torch.Tensor): Node features with shape + `(num_batch, num_nodes, dim_nodes)`. + edge_h (torch.Tensor): Edge features with shape + `(num_batch, num_nodes, num_neighbors, dim_edges)`. + edge_idx (torch.LongTensor): Edge indices for neighbors with shape + `(num_batch, num_nodes, num_neighbors)`. + mask_i (tensor, optional): Node mask with shape `(num_batch, num_nodes)` + mask_ij (tensor, optional): Edge mask with shape + `(num_batch, num_nodes, num_neighbors)` + + Outputs: + node_h_out (torch.Tensor): Updated node features with shape + `(num_batch, num_nodes, dim_nodes)`. + edge_h_out (torch.Tensor): Updated edge features with shape + `(num_batch, num_nodes, num_neighbors, dim_nodes)`. + """ + + def __init__( + self, + dim_nodes: int, + dim_edges: int, + node_mlp_layers: int = 1, + node_mlp_dim: Optional[int] = None, + edge_update: bool = True, + edge_mlp_layers: int = 1, + edge_mlp_dim: Optional[int] = None, + mlp_activation: str = "relu", + dropout: float = 0.0, + norm: str = "transformer", + scale: float = 1.0, + attentional: bool = False, + num_attention_heads: int = 4, + ): + super(GraphLayer, self).__init__() + + # Store scale + self.scale = scale + self.dim_nodes = dim_nodes + self.dim_edges = dim_edges + self.attentional = attentional + + self.node_norm_layer = MaskedNorm( + dim=1, num_features=dim_nodes, affine=True, norm=norm + ) + + self.message_mlp = MLP( + dim_in=2 * dim_nodes + dim_edges, + dim_out=dim_nodes, + num_layers_hidden=edge_mlp_layers, + dim_hidden=edge_mlp_dim, + activation=mlp_activation, + dropout=dropout, + ) + self.update_mlp = MLP( + dim_in=2 * dim_nodes, + dim_out=dim_nodes, + num_layers_hidden=node_mlp_layers, + dim_hidden=node_mlp_dim, + activation=mlp_activation, + dropout=dropout, + ) + self.edge_update = edge_update + self.edge_norm_layer = MaskedNorm( + dim=2, num_features=dim_edges, affine=True, norm=norm + ) + if self.edge_update: + self.edge_mlp = MLP( + dim_in=2 * dim_nodes + dim_edges, + dim_out=dim_edges, + num_layers_hidden=edge_mlp_layers, + dim_hidden=edge_mlp_dim, + activation=mlp_activation, + dropout=dropout, + ) + + if self.attentional: + self.attention = Attention(n_head=num_attention_heads, d_model=dim_nodes) + + def attend( + self, node_h: torch.Tensor, messages: torch.Tensor, mask_ij: torch.Tensor + ) -> torch.Tensor: + B, L, K, D = messages.size() + queries = node_h.reshape(-1, 1, D) + keys = messages.reshape(-1, K, D) + values = messages.reshape(-1, K, D) + mask = mask_ij.reshape(-1, 1, 1, K).bool() if mask_ij is not None else None + return self.attention(queries, keys, values, mask=mask).reshape(B, L, D) + + def _normalize(self, node_h, edge_h, mask_i=None, mask_ij=None): + # Normalize node and edge embeddings + node_h_norm = self.node_norm_layer(node_h, mask_i) + edge_h_norm = self.edge_norm_layer(edge_h, mask_ij) + return node_h_norm, edge_h_norm + + def _normalize_t( + self, edge_node_stack_t, mask_ij_t, include_nodes=True, include_edges=True + ): + # Apply normalization (since we have only normalized time t information) + node_i_t = edge_node_stack_t[:, :, :, : self.dim_nodes] + node_j_t = edge_node_stack_t[:, :, :, self.dim_nodes : 2 * self.dim_nodes] + edge_h_t = edge_node_stack_t[:, :, :, 2 * self.dim_nodes :] + if include_nodes: + node_i_t = self.node_norm_layer(node_i_t, mask_ij_t) + node_j_t = self.node_norm_layer(node_j_t, mask_ij_t) + if include_edges: + edge_h_t = self.edge_norm_layer(edge_h_t, mask_ij_t) + edge_node_stack_t = torch.cat([node_i_t, node_j_t, edge_h_t], -1) + return edge_node_stack_t + + def _update_nodes( + self, node_h, node_h_norm, edge_h_norm, edge_idx, mask_i=None, mask_ij=None + ): + """Update nodes given adjacent nodes and edges""" + # Compute messages at each ij + edge_node_stack = pack_edges(node_h_norm, edge_h_norm, edge_idx) + messages = self.message_mlp(edge_node_stack) + if mask_ij is not None: + messages = messages * mask_ij.unsqueeze(-1) + + # Aggregate messages + if self.attentional: + message = self.attend(node_h_norm, messages, mask_ij) + else: + message = messages.sum(2) / self.scale + + node_stack = torch.cat([node_h_norm, message], -1) + + # Update nodes given aggregated messages + node_h_out = node_h + self.update_mlp(node_stack) + if mask_i is not None: + node_h_out = node_h_out * mask_i.unsqueeze(-1) + return node_h_out + + def _update_nodes_t( + self, + t, + node_h, + node_h_norm_t, + edge_h_norm_t, + edge_idx_t, + mask_i_t=None, + mask_ij_t=None, + ): + """Update nodes at index t given adjacent nodes and edges""" + # Compute messages at each ij + edge_node_stack_t = mask_ij_t.unsqueeze(-1) * pack_edges_step( + t, node_h, edge_h_norm_t, edge_idx_t + ) + + # Apply normalization of gathered tensors + edge_node_stack_t = self._normalize_t( + edge_node_stack_t, mask_ij_t, include_edges=False + ) + + messages_t = self.message_mlp(edge_node_stack_t) + if mask_ij_t is not None: + messages_t = messages_t * mask_ij_t.unsqueeze(-1) + + # Aggregate messages + if self.attentional: + message_t = self.attend(node_h_norm_t, messages_t, mask_ij_t) + else: + message_t = messages_t.sum(2) / self.scale + + node_stack_t = torch.cat([node_h_norm_t, message_t], -1) + # Update nodes given aggregated messages + node_h_t = node_h[:, t, :].unsqueeze(1) + node_h_out_t = node_h_t + self.update_mlp(node_stack_t) + if mask_i_t is not None: + node_h_out_t = node_h_out_t * mask_i_t.unsqueeze(-1) + return node_h_out_t + + def _update_edges(self, edge_h, node_h_out, edge_h_norm, edge_idx, mask_ij): + """Update edges given adjacent nodes and edges""" + edge_node_stack = pack_edges(node_h_out, edge_h_norm, edge_idx) + + edge_h_out = edge_h + self.edge_mlp(edge_node_stack) + if mask_ij is not None: + edge_h_out = edge_h_out * mask_ij.unsqueeze(-1) + return edge_h_out + + def _update_edges_t( + self, t, edge_h_t, node_h_out, edge_h_t_norm, edge_idx_t, mask_ij_t + ): + """Update edges given adjacent nodes and edges""" + edge_node_stack_t = pack_edges_step(t, node_h_out, edge_h_t_norm, edge_idx_t) + + edge_h_out_t = edge_h_t + self.edge_mlp(edge_node_stack_t) + if mask_ij_t is not None: + edge_h_out_t = edge_h_out_t * mask_ij_t.unsqueeze(-1) + return edge_h_out_t + + def forward( + self, + node_h: torch.Tensor, + edge_h: torch.Tensor, + edge_idx: torch.LongTensor, + mask_i: Optional[torch.Tensor] = None, + mask_ij: Optional[torch.Tensor] = None, + ): + node_h_norm, edge_h_norm = self._normalize(node_h, edge_h, mask_i, mask_ij) + if mask_i is not None: + mask_i = (mask_i != 0).type(torch.float32) + if mask_ij is not None: + mask_ij = (mask_ij != 0).type(torch.float32) + node_h_out = self._update_nodes( + node_h, node_h_norm, edge_h_norm, edge_idx, mask_i, mask_ij + ) + edge_h_out = None + if self.edge_update: + edge_h_out = self._update_edges( + edge_h, node_h_out, edge_h_norm, edge_idx, mask_ij + ) + return node_h_out, edge_h_out + + def step( + self, + t: int, + node_h: torch.Tensor, + node_h_out: torch.Tensor, + edge_h: torch.Tensor, + edge_idx: torch.LongTensor, + mask_i: Optional[torch.Tensor] = None, + mask_ij: Optional[torch.Tensor] = None, + ): + """Compute update for a single node index `t`. + + This function can be useful for sequential computation of graph + updates, for example with autoregressive architectures. + + Args: + t (int): Index of node dimension to update + node_h (torch.Tensor): Node features with shape + `(num_batch, num_nodes, dim_nodes)`. + node_h_out (torch.Tensor): Cached outputs of preceding steps with shape + `(num_batch, num_nodes, dim_nodes)`. + edge_h (torch.Tensor): Edge features with shape + `(num_batch, num_nodes, num_neighbors, dim_edges)`. + edge_idx (torch.LongTensor): Edge indices for neighbors with shape + `(num_batch, num_nodes, num_neighbors)`. + mask_i (tensor, optional): Node mask with shape `(num_batch, num_nodes)` + mask_ij (tensor, optional): Edge mask with shape + `(num_batch, num_nodes, num_neighbors)` + + Resturns: + node_h_t (torch.Tensor): Updated node features with shape + `(num_batch, 1, dim_nodes)`. + edge_h_t (torch.Tensor): Updated edge features with shape + `(num_batch, 1, num_neighbors, dim_nodes)`. + """ + node_h_t = node_h[:, t, :].unsqueeze(1) + edge_h_t = edge_h[:, t, :, :].unsqueeze(1) + edge_idx_t = edge_idx[:, t, :].unsqueeze(1) + mask_i_t = mask_i[:, t].unsqueeze(1) + mask_ij_t = mask_ij[:, t, :].unsqueeze(1) + + """ For a single step we need to apply the normalization both at node t and + also for all of the neighborhood tensors that feed in at t. + """ + node_h_t_norm, edge_h_t_norm = self._normalize( + node_h_t, edge_h_t, mask_i_t, mask_ij_t + ) + node_h_t = self._update_nodes_t( + t, node_h, node_h_t_norm, edge_h_t_norm, edge_idx_t, mask_i_t, mask_ij_t + ) + + if self.edge_update: + node_h_out = node_h_out.scatter( + 1, (t * torch.ones_like(node_h_t)).long(), node_h_t + ) + edge_h_t = self._update_edges_t( + t, edge_h_t, node_h_out, edge_h_t_norm, edge_idx_t, mask_ij_t + ) + return node_h_t, edge_h_t + + +class MLP(nn.Module): + """Multilayer perceptron with variable input, hidden, and output dims. + + Args: + dim_in (int): Feature dimension of input tensor. + dim_hidden (int or None): Feature dimension of intermediate layers. + Defaults to matching output dimension. + dim_out (int or None): Feature dimension of output tensor. + Defaults to matching input dimension. + num_layers_hidden (int): Number of hidden MLP layers. + activation (str): MLP nonlinearity. + `'relu'`: Rectified linear unit. + `'softplus'`: Softplus. + dropout (float): Dropout rate. Default is 0. + + Inputs: + h (torch.Tensor): Input tensor with shape `(..., dim_in)` + + Outputs: + h (torch.Tensor): Input tensor with shape `(..., dim_in)` + """ + + def __init__( + self, + dim_in: int, + dim_hidden: Optional[int] = None, + dim_out: Optional[int] = None, + num_layers_hidden: int = 1, + activation: str = "relu", + dropout: float = 0.0, + ): + super(MLP, self).__init__() + + # Default is dimension preserving + dim_out = dim_out if dim_out is not None else dim_in + dim_hidden = dim_hidden if dim_hidden is not None else dim_out + + nonlinearites = {"relu": nn.ReLU, "softplus": nn.Softplus} + activation_func = nonlinearites[activation] + + if num_layers_hidden == 0: + layers = [nn.Linear(dim_in, dim_out)] + else: + layers = [] + for i in range(num_layers_hidden): + d_1 = dim_in if i == 0 else dim_hidden + layers = layers + [ + nn.Linear(d_1, dim_hidden), + activation_func(), + nn.Dropout(dropout), + ] + layers = layers + [nn.Linear(dim_hidden, dim_out)] + self.layers = nn.Sequential(*layers) + + def forward(self, h: torch.Tensor) -> torch.Tensor: + return self.layers(h) + + +def collect_neighbors(node_h: torch.Tensor, edge_idx: torch.Tensor) -> torch.Tensor: + """Collect neighbor node features as edge features. + + For each node i, collect the embeddings of neighbors {j in N(i)} as edge + features neighbor_ij. + + Args: + node_h (torch.Tensor): Node features with shape + `(num_batch, num_nodes, num_features)`. + edge_idx (torch.LongTensor): Edge indices for neighbors with shape + `(num_batch, num_nodes, num_neighbors)`. + + Returns: + neighbor_h (torch.Tensor): Edge features containing neighbor node information + with shape `(num_batch, num_nodes, num_neighbors, num_features)`. + """ + num_batch, num_nodes, num_neighbors = edge_idx.shape + num_features = node_h.shape[2] + + # Flatten for the gather operation then reform the full tensor + idx_flat = edge_idx.reshape([num_batch, num_nodes * num_neighbors, 1]) + idx_flat = idx_flat.expand(-1, -1, num_features) + neighbor_h = torch.gather(node_h, 1, idx_flat) + neighbor_h = neighbor_h.reshape((num_batch, num_nodes, num_neighbors, num_features)) + return neighbor_h + + +def collect_edges( + edge_h_dense: torch.Tensor, edge_idx: torch.LongTensor +) -> torch.Tensor: + """Collect sparse edge features from a dense pairwise tensor. + + Args: + edge_h_dense (torch.Tensor): Dense edges features with shape + `(num_batch, num_nodes, num_nodes, num_features)`. + edge_idx (torch.LongTensor): Edge indices for neighbors with shape + `(num_batch, num_nodes, num_neighbors)`. + + Returns: + edge_h (torch.Tensor): Edge features with shape + (num_batch, num_nodes, num_neighbors, num_features)`. + """ + gather_idx = edge_idx.unsqueeze(-1).expand(-1, -1, -1, edge_h_dense.size(-1)) + edge_h = torch.gather(edge_h_dense, 2, gather_idx) + return edge_h + + +def collect_edges_transpose( + edge_h: torch.Tensor, edge_idx: torch.LongTensor, mask_ij: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + """Collect edge embeddings of reversed (transposed) edges in-place. + + Args: + edge_h (torch.Tensor): Edge features with shape + `(num_batch, num_nodes, num_neighbors, num_features_edges)`. + edge_idx (torch.LongTensor): Edge indices for neighbors with shape + `(num_batch, num_nodes, num_neighbors)`. + mask_ij (torch.Tensor): Edge mask with shape + `(num_batch, num_nodes, num_neighbors)` + + Returns: + edge_h_transpose (torch.Tensor): Edge features of transpose with shape + `(num_batch, num_nodes, num_neighbors, num_features_edges)`. + mask_ji (torch.Tensor): Mask indicating presence of reversed edge with shape + `(num_batch, num_nodes, num_neighbors)`. + """ + num_batch, num_residues, num_k, num_features = list(edge_h.size()) + + # Get indices of reverse edges + ij_to_ji, mask_ji = transpose_edge_idx(edge_idx, mask_ij) + + # Gather features at reverse edges + edge_h_flat = edge_h.reshape(num_batch, num_residues * num_k, -1) + ij_to_ji = ij_to_ji.unsqueeze(-1).expand(-1, -1, num_features) + edge_h_transpose = torch.gather(edge_h_flat, 1, ij_to_ji) + edge_h_transpose = edge_h_transpose.reshape( + num_batch, num_residues, num_k, num_features + ) + edge_h_transpose = mask_ji.unsqueeze(-1) * edge_h_transpose + return edge_h_transpose, mask_ji + + +def scatter_edges(edge_h: torch.Tensor, edge_idx: torch.LongTensor) -> torch.Tensor: + """Scatter sparse edge features into a dense pairwise tensor. + Args: + edge_h (torch.Tensor): Edge features with shape + `(num_batch, num_nodes, num_neighbors, num_features_edges)`. + edge_idx (torch.LongTensor): Edge indices for neighbors with shape + `(num_batch, num_nodes, num_neighbors)`. + + Returns: + edge_h_dense (torch.Tensor): Dense edge features with shape + `(batch_size, num_nodes, num_nodes, dimensions)`. + """ + assert edge_h.dim() == 4 + assert edge_idx.dim() == 3 + bs, nres, _, dim = edge_h.size() + edge_indices = edge_idx.unsqueeze(-1).repeat(1, 1, 1, dim) + result = torch.zeros( + size=(bs, nres, nres, dim), + dtype=edge_h.dtype, + device=edge_h.device, + ) + return result.scatter(dim=2, index=edge_indices, src=edge_h) + + +def pack_edges( + node_h: torch.Tensor, edge_h: torch.Tensor, edge_idx: torch.LongTensor +) -> torch.Tensor: + """Pack nodes and edge features into edge features. + + Expands each edge_ij by packing node i, node j, and edge ij into + {node,node,edge}_ij. + + Args: + node_h (torch.Tensor): Node features with shape + `(num_batch, num_nodes, num_features_nodes)`. + edge_h (torch.Tensor): Edge features with shape + `(num_batch, num_nodes, num_neighbors, num_features_edges)`. + edge_idx (torch.LongTensor): Edge indices for neighbors with shape + `(num_batch, num_nodes, num_neighbors)`. + + Returns: + edge_packed (torch.Tensor): Concatenated node and edge features with shape + (num_batch, num_nodes, num_neighbors, num_features_nodes + + 2*num_features_edges)`. + """ + num_neighbors = edge_h.shape[2] + node_i = node_h.unsqueeze(2).expand(-1, -1, num_neighbors, -1) + node_j = collect_neighbors(node_h, edge_idx) + edge_packed = torch.cat([node_i, node_j, edge_h], -1) + return edge_packed + + +def pack_edges_step( + t: int, node_h: torch.Tensor, edge_h_t: torch.Tensor, edge_idx_t: torch.LongTensor +) -> torch.Tensor: + """Pack node and edge features into edge features for a single node index t. + + Expands each edge_ij by packing node i, node j, and edge ij into + {node,node,edge}_ij. + + Args: + t (int): Node index to decode. + node_h (torch.Tensor): Node features at all positions with shape + `(num_batch, num_nodes, num_features_nodes)`. + edge_h_t (torch.Tensor): Edge features at index `t` with shape + `(num_batch, 1, num_neighbors, num_features_edges)`. + edge_idx_t (torch.LongTensor): Edge indices at index `t` for neighbors with shape + `(num_batch, 1, num_neighbors)`. + + Returns: + edge_packed (torch.Tensor): Concatenated node and edge features + for index `t` with shape + (num_batch, 1, num_neighbors, num_features_nodes + + 2*num_features_edges)`. + """ + node_h.shape[1] + num_neighbors = edge_h_t.shape[2] + node_h_t = node_h[:, t, :].unsqueeze(1) + node_i = node_h_t.unsqueeze(2).expand(-1, -1, num_neighbors, -1) + node_j = collect_neighbors(node_h, edge_idx_t) + edge_packed = torch.cat([node_i, node_j, edge_h_t], -1) + return edge_packed + + +def transpose_edge_idx( + edge_idx: torch.LongTensor, mask_ij: torch.Tensor +) -> Tuple[torch.LongTensor, torch.Tensor]: + """Collect edge indices of reverse edges in-place at each edge. + + The tensor `edge_idx` stores a directed graph topology as a tensor of + neighbor indices, where an element `edge_idx[b,i,k]` corresponds to the + node index of neighbor `k` of node `i` in batch member `b`. + + This function takes a directed graph topology and returns an index tensor + that maps, in-place, to the reversed edges (if they exist). The indices + correspond to the contracted dimension of `edge_index` when it is viewed as + `(num_batch, num_nodes * num_neighbors)`. These indices can be used in + conjunction with `torch.gather` to collect edge embeddings of `j->i` at + `i->j`. See `collect_edges_transpose` for an example. + + For reverse `j->i` edges that do not exist in the directed graph, the + function also returns a binary mask `mask_ji` indicating which edges + have both `i->j` and `j->i` present in the graph. + + Args: + edge_idx (torch.LongTensor): Edge indices for neighbors with shape + `(num_batch, num_nodes, num_neighbors)`. + mask_ij (torch.Tensor): Edge mask with shape + `(num_batch, num_nodes, num_neighbors)` + + Returns: + ij_to_ji (torch.LongTensor): Flat indices for indexing ji in-place at ij with + shape `(num_batch, num_nodes * num_neighbors)`. + mask_ji (torch.Tensor): Mask indicating presence of reversed edge with shape + `(num_batch, num_nodes, num_neighbors)`. + """ + num_batch, num_residues, num_k = list(edge_idx.size()) + + # 1. Collect neighbors of neighbors + edge_idx_flat = edge_idx.reshape([num_batch, num_residues * num_k, 1]).expand( + -1, -1, num_k + ) + edge_idx_neighbors = torch.gather(edge_idx, 1, edge_idx_flat) + # (b,i,j,k) gives the kth neighbor of the jth neighbor of i + edge_idx_neighbors = edge_idx_neighbors.reshape( + [num_batch, num_residues, num_k, num_k] + ) + + # 2. Determine which k at j maps back to i (if it exists) + residue_i = torch.arange(num_residues, device=edge_idx.device).reshape( + (1, -1, 1, 1) + ) + edge_idx_match = (edge_idx_neighbors == residue_i).type(torch.float32) + return_mask, return_idx = torch.max(edge_idx_match, -1) + + # 3. Build flat indices + ij_to_ji = edge_idx * num_k + return_idx + ij_to_ji = ij_to_ji.reshape(num_batch, -1) + + # 4. Transpose mask + mask_ji = torch.gather(mask_ij.reshape(num_batch, -1), -1, ij_to_ji) + mask_ji = mask_ji.reshape(num_batch, num_residues, num_k) + mask_ji = mask_ij * return_mask * mask_ji + return ij_to_ji, mask_ji + + +def permute_tensor( + tensor: torch.Tensor, dim: int, permute_idx: torch.LongTensor +) -> torch.Tensor: + """Permute a tensor along a dimension given a permutation vector. + + Args: + tensor (torch.Tensor): Input tensor with shape + `([batch_dims], permutation_length, [content_dims])`. + dim (int): Dimension to permute along. + permute_idx (torch.LongTensor): Permutation index tensor with shape + `([batch_dims], permutation_length)`. + + Returns: + tensor_permute (torch.Tensor): Permuted node features with shape + `([batch_dims], permutation_length, [content_dims])`. + """ + # Resolve absolute dimension + dim = range(len(list(tensor.shape)))[dim] + + # Flatten content dimensions + shape = list(tensor.shape) + batch_dims, permute_length = shape[:dim], shape[dim] + tensor_flat = tensor.reshape(batch_dims + [permute_length] + [-1]) + + # Exap content dimensions + permute_idx_expand = permute_idx.unsqueeze(-1).expand(tensor_flat.shape) + + tensor_permute_flat = torch.gather(tensor_flat, dim, permute_idx_expand) + tensor_permute = tensor_permute_flat.reshape(tensor.shape) + return tensor_permute + + +def permute_graph_embeddings( + node_h: torch.Tensor, + edge_h: torch.Tensor, + edge_idx: torch.LongTensor, + mask_i: torch.Tensor, + mask_ij: torch.Tensor, + permute_idx: torch.LongTensor, +) -> Tuple[torch.Tensor, torch.Tensor, torch.LongTensor, torch.Tensor, torch.Tensor]: + """Permute graph embeddings given a permutation vector. + + Args: + node_h (torch.Tensor): Node features with shape + `(num_batch, num_nodes, dim_nodes)`. + edge_h (torch.Tensor): Edge features with shape + `(num_batch, num_nodes, num_neighbors, dim_edges)`. + edge_idx (torch.LongTensor): Edge indices for neighbors with shape + `(num_batch, num_nodes, num_neighbors)`. + mask_i (tensor, optional): Node mask with shape `(num_batch, num_nodes)` + mask_ij (tensor, optional): Edge mask with shape + `(num_batch, num_nodes, num_neighbors)`. + permute_idx (torch.LongTensor): Permutation vector with shape + `(num_batch, num_nodes)`. + + Returns: + node_h_permute (torch.Tensor): Permuted node features with shape + `(num_batch, num_nodes, dim_nodes)`. + edge_h_permute (torch.Tensor): Permuted edge features with shape + `(num_batch, num_nodes, num_neighbors, dim_edges)`. + edge_idx_permute (torch.LongTensor): Permuted edge indices for neighbors with shape + `(num_batch, num_nodes, num_neighbors)`. + mask_i_permute (tensor, optional): Permuted node mask with shape `(num_batch, num_nodes)` + mask_ij_permute (tensor, optional): Permuted edge mask with shape + `(num_batch, num_nodes, num_neighbors)`. + """ + + # Permuting one-dimensional objects is straightforward gathering + node_h_permute = permute_tensor(node_h, 1, permute_idx) + edge_h_permute = permute_tensor(edge_h, 1, permute_idx) + mask_i_permute = permute_tensor(mask_i, 1, permute_idx) + mask_ij_permute = permute_tensor(mask_ij, 1, permute_idx) + + """ + For edge_idx, there are two-dimensions set each edge idx that + previously pointed to j to now point to the new location + of j which is p^(-1)[j] + edge^(p)[i,k] = p^(-1)[edge[p(i),k]] + """ + # First, permute on the i dimension + edge_idx_permute_1 = permute_tensor(edge_idx, 1, permute_idx) + # Second, permute on the j dimension by using the inverse + permute_idx_inverse = torch.argsort(permute_idx, dim=-1) + edge_idx_1_flat = edge_idx_permute_1.reshape([edge_idx.shape[0], -1]) + edge_idx_permute_flat = torch.gather(permute_idx_inverse, 1, edge_idx_1_flat) + edge_idx_permute = edge_idx_permute_flat.reshape(edge_idx.shape) + + return ( + node_h_permute, + edge_h_permute, + edge_idx_permute, + mask_i_permute, + mask_ij_permute, + ) + + +def edge_mask_causal(edge_idx: torch.LongTensor, mask_ij: torch.Tensor) -> torch.Tensor: + """Make an edge mask causal with mask_ij = 0 for j >= i. + + Args: + edge_idx (torch.LongTensor): Edge indices for neighbors with shape + `(num_batch, num_nodes, num_neighbors)`. + mask_ij (torch.Tensor): Edge mask with shape + `(num_batch, num_nodes, num_neighbors)`. + + Returns: + mask_ij_causal (torch.Tensor): Causal edge mask with shape + `(num_batch, num_nodes, num_neighbors)`. + """ + idx = torch.arange(edge_idx.size(1), device=edge_idx.device) + idx_expand = idx.reshape([1, -1, 1]) + mask_ij_causal = (edge_idx < idx_expand).float() * mask_ij + return mask_ij_causal + + +class MaskedNorm(nn.Module): + """Masked normalization layer. + + Args: + dim (int): Dimensionality of the normalization. Can be 1 for 1D + normalization along dimension 1 or 2 for 2D normalization along + dimensions 1 and 2. + num_features (int): Channel dimension; only needed if `affine` is True. + affine (bool): If True, inclde a learnable affine transformation + post-normalization. Default is False. + norm (str): Type of normalization, can be `instance`, `layer`, or + `transformer`. + eps (float): Small number for numerical stability. + + Inputs: + data (torch.Tensor): Input tensor with shape + `(num_batch, num_nodes, num_channels)` (1D) or + `(num_batch, num_nodes, num_nodes, num_channels)` (2D). + mask (torch.Tensor): Mask tensor with shape + `(num_batch, num_nodes)` (1D) or + `(num_batch, num_nodes, num_nodes)` (2D). + + Outputs: + norm_data (torch.Tensor): Mask-normalized tensor with shape + `(num_batch, num_nodes, num_channels)` (1D) or + `(num_batch, num_nodes, num_nodes, num_channels)` (2D). + """ + + def __init__( + self, + dim: int, + num_features: int = -1, + affine: bool = False, + norm: str = "instance", + eps: float = 1e-5, + ): + super(MaskedNorm, self).__init__() + + self.norm_type = norm + self.dim = dim + self.norm = norm + str(dim) + self.affine = affine + self.eps = eps + + # Dimension to sum + if self.norm == "instance1": + self.sum_dims = [1] + elif self.norm == "layer1": + self.sum_dims = [1, 2] + elif self.norm == "transformer1": + self.sum_dims = [-1] + elif self.norm == "instance2": + self.sum_dims = [1, 2] + elif self.norm == "layer2": + self.sum_dims = [1, 2, 3] + elif self.norm == "transformer2": + self.sum_dims = [-1] + else: + raise NotImplementedError + + # Number of features, only required if affine + self.num_features = num_features + + # Affine transformation is a linear layer on the C channel + if self.affine: + self.weights = nn.Parameter(torch.rand(self.num_features)) + self.bias = nn.Parameter(torch.zeros(self.num_features)) + + def forward( + self, data: torch.Tensor, mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + # Add optional trailing singleton dimension and expand if necessary + if mask is not None: + if len(mask.shape) == len(data.shape) - 1: + mask = mask.unsqueeze(-1) + if data.shape != mask.shape: + mask = mask.expand(data.shape) + + # Input shape is Batch, Channel, Dim1, (dim2 if 2d) + dims = self.sum_dims + if (mask is None) or (self.norm_type == "transformer"): + mask_mean = data.mean(dim=dims, keepdim=True) + mask_std = torch.sqrt( + (((data - mask_mean)).pow(2)).mean(dim=dims, keepdim=True) + self.eps + ) + + # Norm + norm_data = (data - mask_mean) / mask_std + + else: + # Zeroes vector to sum all mask data + norm_data = torch.zeros_like(data).to(data.device).type(data.dtype) + for mask_id in mask.unique(): + # Skip zero, since real mask + if mask_id == 0: + continue + + # Transform mask to temp mask that match mask id + tmask = (mask == mask_id).type(torch.float32) + + # Sum mask for mean + mask_sum = tmask.sum(dim=dims, keepdim=True) + + # Data is tmask, so that mean is only for unmasked pos + mask_mean = (data * tmask).sum(dim=dims, keepdim=True) / mask_sum + mask_std = torch.sqrt( + (((data - mask_mean) * tmask).pow(2)).sum(dim=dims, keepdim=True) + / mask_sum + + self.eps + ) + + # Calculate temp norm, apply mask + tnorm = ((data - mask_mean) / mask_std) * tmask + # Sometime mask is empty, so generate nan that are conversted to 0 + tnorm[tnorm != tnorm] = 0 + + # Add to init zero norm data + norm_data += tnorm + + # Apply affine + if self.affine: + norm_data = norm_data * self.weights + self.bias + + # If mask, apply mask + if mask is not None: + norm_data = norm_data * (mask != 0).type(data.dtype) + return norm_data diff --git a/prtm/models/chroma/graph_backbone.py b/prtm/models/chroma/graph_backbone.py new file mode 100644 index 00000000..acaf3d6d --- /dev/null +++ b/prtm/models/chroma/graph_backbone.py @@ -0,0 +1,416 @@ +# Copyright Generate Biomedicines, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Models for generating protein backbone structure via diffusion. +""" + +from types import SimpleNamespace +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + +from prtm.models.chroma import basic, graph +from prtm.models.chroma.graph_design import BackboneEncoderGNN +from prtm.models.chroma.model_utils import load_model as utility_load_model +from prtm.models.chroma.structure import backbone, diffusion, transforms +from prtm.models.chroma.xcs import validate_XC + + +class GraphBackbone(nn.Module): + """Graph-based backbone generation for protein complexes. + + GraphBackbone parameterizes a generative model of the backbone coordinates + of protein complexes. + + Args: + See documention of `layers.structure.protein_graph.ProteinFeatureGraph`, + `graph.GraphNN`, `layers.structure.backbone.GraphBackboneUpdate` and + `layers.structure.diffusion.DiffusionChainCov` for more details on + hyperparameters. + + Inputs: + X (Tensor): Backbone coordinates with shape + `(num_batch, num_residues, num_atoms, 3)`. + C (LongTensor): Chain map with shape `(num_batch, num_residues)`. + + Outputs: + neglogp (Tensor): Sum of `neglogp_S` and `neglogp_chi`. + """ + + def __init__( + self, + dim_nodes: int = 128, + dim_edges: int = 128, + num_neighbors: int = 30, + node_features: Tuple = (("internal_coords", {"log_lengths": True}),), + edge_features: Tuple = ( + "distances_2mer", + "orientations_2mer", + "distances_chain", + ), + num_layers: int = 3, + dropout: float = 0.1, + node_mlp_layers: int = 1, + node_mlp_dim: Optional[int] = None, + edge_update: bool = True, + edge_mlp_layers: int = 1, + edge_mlp_dim: Optional[int] = None, + skip_connect_input: bool = False, + mlp_activation: str = "softplus", + decoder_num_hidden: int = 512, + graph_criterion: str = "knn", + graph_random_min_local: int = 20, + backbone_update_method: str = "neighbor", + backbone_update_iterations: int = 1, + backbone_update_num_weights: int = 1, + backbone_update_unconstrained: bool = True, + use_time_features: bool = True, + time_feature_type: str = "t", + time_log_feature_scaling: float = 0.05, + noise_schedule: str = "log_snr", + noise_covariance_model: str = "brownian", + noise_beta_min: float = 0.2, + noise_beta_max: float = 70.0, + noise_log_snr_range: Tuple[float] = (-7.0, 13.5), + noise_complex_scaling: bool = False, + loss_scale: float = 10.0, + loss_scale_ssnr_cutoff: float = 0.99, + loss_function: str = "squared_fape", + checkpoint_gradients: bool = False, + prediction_type: str = "X0", + num_graph_cycles: int = 1, + **kwargs, + ): + """Initialize GraphBackbone network.""" + super(GraphBackbone, self).__init__() + + # Save configuration in kwargs + self.kwargs = locals() + self.kwargs.pop("self") + for key in list(self.kwargs.keys()): + if key.startswith("__") and key.endswith("__"): + self.kwargs.pop(key) + args = SimpleNamespace(**self.kwargs) + + # Important global options + self.dim_nodes = args.dim_nodes + self.dim_edges = args.dim_edges + + # Encoder GNN process backbone + self.num_graph_cycles = args.num_graph_cycles + self.encoders = nn.ModuleList( + [ + BackboneEncoderGNN( + dim_nodes=args.dim_nodes, + dim_edges=args.dim_edges, + num_neighbors=args.num_neighbors, + node_features=args.node_features, + edge_features=args.edge_features, + num_layers=args.num_layers, + node_mlp_layers=args.node_mlp_layers, + node_mlp_dim=args.node_mlp_dim, + edge_update=args.edge_update, + edge_mlp_layers=args.edge_mlp_layers, + edge_mlp_dim=args.edge_mlp_dim, + mlp_activation=args.mlp_activation, + dropout=args.dropout, + skip_connect_input=args.skip_connect_input, + graph_criterion=args.graph_criterion, + graph_random_min_local=args.graph_random_min_local, + checkpoint_gradients=checkpoint_gradients, + ) + for i in range(self.num_graph_cycles) + ] + ) + + self.backbone_updates = nn.ModuleList( + [ + backbone.GraphBackboneUpdate( + dim_nodes=args.dim_nodes, + dim_edges=args.dim_edges, + method=args.backbone_update_method, + iterations=args.backbone_update_iterations, + num_transform_weights=args.backbone_update_num_weights, + unconstrained=args.backbone_update_unconstrained, + ) + for i in range(self.num_graph_cycles) + ] + ) + + self.use_time_features = args.use_time_features + self.time_feature_type = args.time_feature_type + self.time_log_feature_scaling = time_log_feature_scaling + if self.use_time_features: + self.time_features = basic.FourierFeaturization( + d_input=1, d_model=dim_nodes, trainable=False, scale=16.0 + ) + + self.noise_perturb = diffusion.DiffusionChainCov( + noise_schedule=args.noise_schedule, + beta_min=args.noise_beta_min, + beta_max=args.noise_beta_max, + log_snr_range=args.noise_log_snr_range, + covariance_model=args.noise_covariance_model, + complex_scaling=args.noise_complex_scaling, + ) + self.noise_schedule = self.noise_perturb.noise_schedule + method = "symeig" + self.loss_scale = args.loss_scale + self.loss_scale_ssnr_cutoff = loss_scale_ssnr_cutoff + self.loss_function = args.loss_function + self.prediction_type = args.prediction_type + self._loss_eps = 1e-5 + + self.loss_diffusion = diffusion.ReconstructionLosses( + diffusion=self.noise_perturb, rmsd_method=method, loss_scale=args.loss_scale + ) + + if self.prediction_type.startswith("scale"): + self.mlp_W = graph.MLP( + dim_in=args.dim_nodes, num_layers_hidden=args.node_mlp_layers, dim_out=1 + ) + + # Wrap sampling functions + _X0_func = lambda X, C, t: self.denoise(X, C, t) + self.sample_sde = lambda C, **kwargs: self.noise_perturb.sample_sde( + _X0_func, C, **kwargs + ) + self.sample_baoab = lambda C, **kwargs: self.noise_perturb.sample_baoab( + _X0_func, C, **kwargs + ) + self.sample_ode = lambda C, **kwargs: self.noise_perturb.sample_ode( + _X0_func, C, **kwargs + ) + self.estimate_metrics = ( + lambda X, C, **kwargs: self.loss_diffusion.estimate_metrics( + _X0_func, X, C, **kwargs + ) + ) + self.estimate_elbo = lambda X, C, **kwargs: self.noise_perturb.estimate_elbo( + _X0_func, X, C, **kwargs + ) + self.estimate_pseudoelbo_X = ( + lambda X, C, **kwargs: self.noise_perturb.estimate_pseudoelbo_X( + _X0_func, X, C, **kwargs + ) + ) + + def _time_features(self, t): + h = {"t": lambda: t, "log_snr": lambda: self.noise_schedule.log_SNR(t)}[ + self.time_feature_type + ]() + + if "log" in self.time_feature_type: + h = self.time_log_feature_scaling * h + + time_h = self.time_features(h[:, None, None]) + return time_h + + @validate_XC() + def denoise( + self, + X: torch.Tensor, + C: torch.Tensor, + t: Optional[Union[float, torch.Tensor]] = None, + return_geometry: bool = False, + ): + if not isinstance(t, torch.Tensor): + t = torch.Tensor([t]).float().to(X.device) + if t.shape == torch.Size([]): + t = t.unsqueeze(-1) + + time_h = self._time_features(t) if self.use_time_features else None + node_h = time_h + edge_h, edge_idx, mask_ij = [None] * 3 + + # Normalize minimum average C-alpha distances + X_update = X + + for i in range(self.num_graph_cycles): + # Encode as graph + node_h, edge_h, edge_idx, mask_i, mask_ij = self.encoders[i]( + X_update, + C, + node_h_aux=node_h, + edge_h_aux=edge_h, + edge_idx=edge_idx, + mask_ij=mask_ij, + ) + # Update backbone + X_update, R_ji, t_ji, logit_ji = self.backbone_updates[i]( + X_update, C, node_h, edge_h, edge_idx, mask_i, mask_ij + ) + + # Shrink towards the input + if time_h is None: + time_h = torch.zeros( + [node_h.shape[0], 1, node_h.shape[2]], device=node_h.device + ) + if self.prediction_type == "scale": + scale_shift = self.mlp_W(time_h) + ssnr = self.noise_perturb.noise_schedule.SSNR(t) + logit_bias = torch.logit(torch.sqrt(1 - ssnr)) + scale = torch.sigmoid(scale_shift + logit_bias[:, None, None])[..., None] + X_update = scale * X_update + (1 - scale) * X + elif self.prediction_type == "scale_cutoff": + # Scale below a given hard-coded noise floor cutoff + scale_shift = self.mlp_W(time_h) + ssnr = self.noise_perturb.noise_schedule.SSNR(t) + logit_bias = torch.logit(torch.sqrt(1 - ssnr)) + scale = torch.sigmoid(scale_shift + logit_bias[:, None, None])[..., None] + + # Skip connect for values of alpha close to 1 + skip = (1 - scale) * (ssnr > self.loss_scale_ssnr_cutoff).float().reshape( + scale.shape + ) + X_update = skip * X + (1 - skip) * X_update + + if not return_geometry: + return X_update + else: + return X_update, R_ji, t_ji, logit_ji, edge_idx, mask_ij + + @validate_XC(all_atom=False) + def _debug_plot_denoising_geometry(self, X, C, t=None): + """Debug plots for analyzing denoising geometry""" + if t is None: + X_noise, t = self.noise_perturb(X, C) + else: + X_noise = self.noise_perturb(X, C, t=t) + + # Compute denoised geometry + ( + X_denoise, + R_ji_pred, + t_ji_pred, + logit_ji_pred, + edge_idx, + mask_ij, + ) = self.denoise(X_noise, C, t, return_geometry=True) + + # Featurize other inputs and outpus + R_ji_native, t_ji_native = self.backbone_updates[0]._inner_transforms( + X, C, edge_idx + ) + R_ji_noise, t_ji_noise = self.backbone_updates[0]._inner_transforms( + X_noise, C, edge_idx + ) + R_ji_denoise, t_ji_denoise = self.backbone_updates[0]._inner_transforms( + X_denoise, C, edge_idx + ) + + R_ji = torch.cat([R_ji_native, R_ji_noise, R_ji_pred, R_ji_denoise], 0) + t_ji = torch.cat([t_ji_native, t_ji_noise, t_ji_pred, t_ji_denoise], 0) + logit_ji = torch.cat([mask_ij, mask_ij, logit_ji_pred[:, :, :, 0], mask_ij], 0) + edge_idx = edge_idx.expand([4, -1, -1]) + from matplotlib import pyplot as plt + + transforms._debug_plot_transforms(R_ji, t_ji, logit_ji, edge_idx, mask_ij) + plt.show() + return X_denoise, X_noise + + @validate_XC(all_atom=False) + def forward( + self, + X: torch.Tensor, + C: torch.Tensor, + t: Optional[Union[torch.Tensor, float]] = None, + **kwargs, + ): + # If all atom structure is passed, discard side chains + X = X[:, :, :4, :] if X.size(2) == 14 else X + + # Sample perturbed structure + if t is None: + X_t, t = self.noise_perturb(X, C) + else: + X_t = self.noise_perturb(X, C, t=t) + + X0_pred, R_ji_pred, t_ji_pred, logit_ji_pred, edge_idx, mask_ij = self.denoise( + X_t, C, t, return_geometry=True + ) + + losses = self.loss_diffusion(X0_pred, X, C, t) + + # Per complex weights + weights = (C > 0).float().sum(-1) + + ssnr = self.noise_perturb.noise_schedule.SSNR(t) + prob_ssnr = self.noise_perturb.noise_schedule.prob_SSNR(ssnr) + importance_weights = 1 / prob_ssnr + + _importance_weight = lambda h: h * importance_weights.reshape( + [-1] + [1] * (len(h.shape) - 1) + ) + _weighted_avg = lambda h: (weights * _importance_weight(h)).sum() / ( + weights.sum() + self._loss_eps + ) + # Interresidue geometry predictions agreement + if self.backbone_updates[0].method != "local": + R_ij_mse, t_ij_mse = self.backbone_updates[0]._transform_loss( + R_ji_pred, t_ji_pred, X, C, edge_idx, mask_ij + ) + losses["batch_translate_mse"] = _weighted_avg( + t_ij_mse / (self.loss_scale**2) + ) + losses["batch_rotate_mse"] = _weighted_avg(R_ij_mse) + losses["batch_transform_mse"] = ( + losses["batch_translate_mse"] + losses["batch_rotate_mse"] + ) + + losses_extend = {} + for k, v in losses.items(): + if "elbo" in k: + losses_extend[k.replace("elbo", "neg_elbo")] = -v + losses.update(losses_extend) + return losses + + +def load_model( + weight_file: str, + device: str = "cpu", + strict: bool = False, + strict_unexpected: bool = False, + verbose: bool = True, +) -> GraphBackbone: + """Load model `GraphBackbone` + + Args: + weight_file (str): The destination path of the model weights to load. + Compatible with files saved by `save_model`. + device (str, optional): Pytorch device specification, e.g. `'cuda'` for + GPU. Default is `'cpu'`. + strict (bool): Whether to require that the keys match between the + input file weights and the model created from the parameters stored + in the model kwargs. + strict_unexpected (bool): Whether to require that there are no + unexpected keys when loading model weights, as distinct from the + strict option which doesn't allow for missing keys either. By + default, we use this option rather than strict for ease of + development when adding model features. + verbose (bool, optional): Show outputs from download and loading. + Default True. + + Returns: + model (GraphBackbone): Instance of `GraphBackbone` with loaded weights. + """ + return utility_load_model( + weight_file, + GraphBackbone, + device=device, + strict=strict, + strict_unexpected=strict_unexpected, + verbose=verbose, + ) diff --git a/prtm/models/chroma/graph_classifier.py b/prtm/models/chroma/graph_classifier.py new file mode 100644 index 00000000..64fe7d95 --- /dev/null +++ b/prtm/models/chroma/graph_classifier.py @@ -0,0 +1,355 @@ +# Copyright Generate Biomedicines, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Models for generating protein sequence and side chain conformations +given backbones. These can be used for sequence design and packing. +""" + + +from types import SimpleNamespace + +import torch +import torch.nn as nn + +from prtm.models.chroma import basic +from prtm.models.chroma.attention import AttentionChainPool +from prtm.models.chroma.basic import NodeProduct, NoOp +from prtm.models.chroma.graph import MLP, MaskedNorm +from prtm.models.chroma.graph_design import BackboneEncoderGNN +from prtm.models.chroma.model_utils import load_model as utility_load_model +from prtm.models.chroma.structure import diffusion +from prtm.models.chroma.xcs import validate_XC + + +class GraphClassifier(nn.Module): + """Graph-based protein classification + + Args: + See documention of `structure.protein_graph.ProteinFeatureGraph`, + and `graph.GraphNN` for more details. + + Inputs: + X (Tensor): Backbone coordinates with shape + `(num_batch, num_residues, num_atoms, 3)`. + C (LongTensor): Chain map with shape `(num_batch, num_residues)`. + O (Tensor) (optional): One-hot sequence tensor of shape `(num_batch, num_residues)` + + Outputs: + node_h (Tensor): residue-based representations that can be used to project various classification predictions + """ + + def __init__( + self, + dim_nodes=128, + dim_edges=128, + num_neighbors=30, + node_features=(("internal_coords", {"log_lengths": True}),), + edge_features=["random_fourier_2mer", "orientations_2mer", "distances_chain"], + num_layers=3, + dropout=0.1, + node_mlp_layers=1, + node_mlp_dim=None, + edge_update=True, + edge_mlp_layers=1, + edge_mlp_dim=None, + skip_connect_input=False, + mlp_activation="softplus", + graph_criterion="knn", + graph_random_min_local=20, + use_time_features=True, + noise_schedule="log_snr", + noise_beta_min=0.2, + noise_beta_max=70.0, + checkpoint_gradients=False, + class_config={}, + out_mlp_layers=2, + noise_covariance_model="globular", + noise_log_snr_range=(-7.0, 13.5), + time_feature_type="t", + time_log_feature_scaling=0.05, + fourier_scale=16.0, + zero_grad_fix=False, + **kwargs, + ): + """Initialize GraphBackbone network.""" + super().__init__() + + # Save configuration in kwargs + self.kwargs = locals() + self.kwargs.pop("self") + for key in list(self.kwargs.keys()): + if key.startswith("__") and key.endswith("__"): + self.kwargs.pop(key) + args = SimpleNamespace(**self.kwargs) + + self.class_config = class_config + # Important global options + self.dim_nodes = args.dim_nodes + self.dim_edges = args.dim_edges + self.mlp_activation = args.mlp_activation + self.zero_grad_fix = zero_grad_fix + + if "random_fourier_2mer" in args.edge_features: + index = args.edge_features.index("random_fourier_2mer") + args.edge_features.pop(index) + args.edge_features.append( + ( + "random_fourier_2mer", + { + "dim_embedding": args.dim_edges, + "trainable": False, + "scale": args.fourier_scale, + }, + ) + ) + + # Encoder GNN process backbone + self.encoder = BackboneEncoderGNN( + dim_nodes=args.dim_nodes, + dim_edges=args.dim_edges, + num_neighbors=args.num_neighbors, + node_features=args.node_features, + edge_features=args.edge_features, + num_layers=args.num_layers, + node_mlp_layers=args.node_mlp_layers, + node_mlp_dim=args.node_mlp_dim, + edge_update=args.edge_update, + edge_mlp_layers=args.edge_mlp_layers, + edge_mlp_dim=args.edge_mlp_dim, + mlp_activation=args.mlp_activation, + dropout=args.dropout, + skip_connect_input=args.skip_connect_input, + graph_criterion=args.graph_criterion, + graph_random_min_local=args.graph_random_min_local, + checkpoint_gradients=checkpoint_gradients, + ) + + self.time_feature_type = args.time_feature_type + self.time_log_feature_scaling = time_log_feature_scaling + + self.use_time_features = use_time_features + if self.use_time_features: + self.time_features = basic.FourierFeaturization( + d_input=1, d_model=dim_nodes, trainable=False, scale=16.0 + ) + + self.sequence_embedding = nn.Embedding(20, dim_nodes) + + self.noise_perturb = diffusion.DiffusionChainCov( + noise_schedule=args.noise_schedule, + beta_min=args.noise_beta_min, + beta_max=args.noise_beta_max, + log_snr_range=args.noise_log_snr_range, + covariance_model=args.noise_covariance_model, + ) + + self._init_heads(class_config, dim_nodes, out_mlp_layers, dropout) + self.condition_sequence_frequency = 0.3 + + def _init_heads(self, class_config, dim_nodes, out_mlp_layers, dropout): + self.heads = {"chain": {}, "first_order": {}, "second_order": {}, "complex": {}} + + for label, config in class_config.items(): + group = config["level"] + if label == "is_interface" or label == "contact": + dim_out = 1 + else: + dim_out = len(config["tokens"]) + if group == "chain": + pool = AttentionChainPool(8, dim_nodes) + elif group == "complex": + raise NotImplementedError + elif group == "second_order": + pool = NoOp() + else: + pool = NoOp() + + if group != "second_order": + if self.zero_grad_fix: + node_norm_layer = MaskedNorm( + dim=1, num_features=dim_nodes, affine=True, norm="layer" + ) + mlp = MLP( + dim_nodes, + dim_hidden=None, + dim_out=dim_out, + num_layers_hidden=out_mlp_layers, + activation=self.mlp_activation, + dropout=dropout, + ) + head = nn.Sequential(node_norm_layer, mlp) + else: + mlp = MLP( + dim_nodes, + dim_hidden=None, + dim_out=dim_out, + num_layers_hidden=out_mlp_layers, + activation="relu", + dropout=dropout, + ) + head = mlp + else: + head = nn.Sequential(nn.Linear(dim_nodes, 16), NodeProduct(16, 1)) + + self.heads[group][label] = head, pool + self.add_module(f"{label}_head", head) + if pool is not None: + self.add_module(f"{label}_pool", pool) + + def _time_features(self, t): + h = { + "t": lambda: t, + "log_snr": lambda: self.noise_perturb.noise_schedule.log_SNR(t), + }[self.time_feature_type]() + + if "log" in self.time_feature_type: + h = self.time_log_feature_scaling * h + + time_h = self.time_features(h[:, None, None]) + return time_h + + @validate_XC() + def encode(self, X, C, O=None, t=None): + if not isinstance(t, torch.Tensor): + t = torch.Tensor([t]).float().to(X.device) + elif t.dim() == 0: + t = t.unsqueeze(0) + + node_h = self._time_features(t) + + if O is not None: + if (not self.training) or ( + torch.rand( + 1, + ).item() + < self.condition_sequence_frequency + ): + node_h = node_h + O @ self.sequence_embedding.weight + + edge_h, edge_idx, mask_ij = [None] * 3 + node_h, edge_h, edge_idx, mask_i, mask_ij = self.encoder( + X, + C, + node_h_aux=node_h, + edge_h_aux=edge_h, + edge_idx=edge_idx, + mask_ij=mask_ij, + ) + + return node_h, edge_h, edge_idx, mask_i, mask_ij + + @validate_XC() + def gradient( + self, X, C, t, label, mask=None, value=None, O=None, scale=1.0, max_norm=None + ): + """ + Args: + X (torch.tensor): (batch, num_res, 4, 3) or (batch, num_res, 14, 3) + C (torch.tensor): (batch, num_res) + t (float): 0 < t <= 1 + label (string): class label to condition on, chosen from `self.class_config.keys()` + mask (torch.tensor): (optional) bool tensor of shape (batch, num_res) for first order scores, (batch, num_chains) for + chain-based scores, and (batch, num_res, num_res) for second order scores. The order of + your score can be determined by inspecting self.class_config[label]['level'] + value (string): (optional) the token-based representation of the value you would like to condition `label` on, + you can select options from `self.class_config[label]['tokens']` for all scores except `is_interface` + or `contact` for which you should leave a `value` of None. + O (torch.tensor): one-hot sequence tensor of size (batch, num_res, num_alphabet) + scale (float): scale factor to multiply gradient by + max_norm (float): if not None, the maximum norm of the gradient (set grad = max_norm * (grad / grad.norm()) if grad.norm() > max_norm) + """ + self.eval() + _bak = self.encoder.checkpoint_gradients + self.encoder.checkpoint_gradients = False + + level = self.class_config[label]["level"] + head, pool = self.heads[level][label] + with torch.enable_grad(): + X.requires_grad = True + node_h, edge_h, edge_idx, mask_i, mask_ij = self.encode(X, C, O, t) + + if level == "chain": + node_h, c_mask = pool(node_h, C) + c_mask = c_mask + elif level == "first_order": + c_mask = C > 0 + elif level == "second_order": + c_mask = (C > 0).unsqueeze(-2) & (C > 0).unsqueeze(-1) + + node_h = head(node_h) + + if mask is not None: + c_mask = mask & c_mask + + if self.class_config[label]["loss"] == "ce": + neglogp = node_h.log_softmax(dim=-1).mul(-1) + else: + neglogp = node_h.sigmoid().log().mul(-1) + + index = ( + self.class_config[label]["tokenizer"][value] if value is not None else 0 + ) + neglogp = neglogp[..., index][c_mask].sum() + neglogp.backward() + grad = scale * X.grad + + if max_norm is not None: + if grad.norm() > max_norm: + grad = max_norm * (grad / grad.norm()) + + self.encoder.checkpoint_gradients = _bak + return grad + + @validate_XC(all_atom=False) + def forward(self, X, C, O=None, **kwargs): + # Sample perturbed structure + X_perturb, t = self.noise_perturb(X, C) + node_h, edge_h, edge_idx, mask_i, mask_ij = self.encode(X_perturb, C, O, t) + return node_h, edge_h + + +def load_model( + weight_file: str, + device: str = "cpu", + strict: bool = False, + strict_unexpected: bool = True, + verbose: bool = True, +) -> GraphClassifier: + """Load model `GraphClassifier` + + Args: + weight_file (str): The destination path of the model weights to load. + Compatible with files saved by `save_model`. + device (str, optional): Pytorch device specification, e.g. `'cuda'` for + GPU. Default is `'cpu'`. + strict (bool): Whether to require that the keys match between the + input file weights and the model created from the parameters stored + in the model kwargs. + strict_unexpected (bool): Whether to require that there are no + unexpected keys when loading model weights, as distinct from the + strict option which doesn't allow for missing keys either. By + default, we use this option rather than strict for ease of + development when adding model features. + + Returns: + model (GraphClassifier): Instance of `GraphClassifier` with loaded weights. + """ + return utility_load_model( + weight_file, + GraphClassifier, + device=device, + strict=strict, + strict_unexpected=strict_unexpected, + verbose=verbose, + ) diff --git a/prtm/models/chroma/graph_design.py b/prtm/models/chroma/graph_design.py new file mode 100644 index 00000000..67244935 --- /dev/null +++ b/prtm/models/chroma/graph_design.py @@ -0,0 +1,2312 @@ +# Copyright Generate Biomedicines, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Models for generating protein sequence and side chain conformations +given backbones. These can be used for sequence design and packing. +""" + + +from types import SimpleNamespace +from typing import Callable, Literal, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.checkpoint import checkpoint + +from prtm.constants.residue_constants import alphabetical_restypes +from prtm.models.chroma import complexity, graph +from prtm.models.chroma.model_utils import load_model as utility_load_model +from prtm.models.chroma.structure import diffusion, potts, protein_graph, sidechain +from prtm.models.chroma.structure.protein_graph_allatom import ( + EdgeSidechainsDirect, + NodeChiRBF, +) +from prtm.models.chroma.xcs import validate_XC + + +class GraphDesign(nn.Module): + """Graph-based sequence design and sidechain packing. + + Given a fixed backbone, a GraphDesign model yields probabilities of residue type + and angles by position. It encodes backbones with a `BackboneEncoderGNN` + and then autoregressively factorizes the joint distribution of + sequence and sidechain conformations given these graph embeddings. + Optional first order marginal and Potts sequence decoders are also available. + + Some `GraphDesign` models are trained in a diffusion-aware mannner + to model sequence likelihoods given a noised structure and a particular time point + along a forwards diffusion process. + + Args: + See documention of `structure.protein_graph.ProteinFeatureGraph`, + and `graph.GraphNN` for more details. + + dim_nodes (int): Hidden dimension of node tensors of underlying GNNs. + dim_edges (int): Hidden dimension of edge tensors of underlying GNNs. + num_neighbors (int): Number of neighbors per node for underlying GNNs. + node_features (tuple): List of node feature specifications for + structure encoder. Features can be given as strings or as + dictionaries. + edge_features (tuple): List of edge feature specifications for + structure encoder. Features can be given as strings or as + dictionaries. + sequence_embedding (str): How to represent sequence when decoding. + Currently the only option is `linear`. + sidechain_embedding (str): How to represent chi angles when decoding. + Options include `chi_linear` for a simple linear layer, `chi_rbf` + for a featurization based on smooth binning of chi angles, + `X_direct` which directly encodes the all-atom coordinates using + random Fourier features, and `mixed_chi_X` which uses both the + featurizations of `chi_rbf` and of `X_direct`. + sidechains (bool): Whether to use a joint sequence/sidechain + autoregressive model to decode the backbones. + num_layers (int): Number of layers of underlying GNNs. Can be overridden + for the structure encoder by `num_layers_encoder`. + num_layers_encoder (int, optional): Number of layers for structure + encoder GNN. + dropout (float): Dropout fraction used for all encoders and decoders + except for the marginal sequence likelihood decoder in + `decoder_S_marginals`. + node_mlp_layers (int): Number of hidden layers for node update function + of underlying GNNs. + node_mlp_dim (int, optional): Dimension of hidden layers for node update + function of underlying GNNs, defaults to match output dimension. + edge_update (bool): Whether to include an edge update step in the GNNs. + edge_mlp_layers (int): Number of hidden layers for edge update function + of underlying GNNs. + edge_mlp_dim (int, optional): Dimension of hidden layers for edge update + function of underlying GNNs, defaults to match output dimension. + skip_connect_input (bool): Whether to include skip connections between + layers of underlying GNNs. + mlp_activation (str): MLP nonlinearity function, `relu` or `softplus` + accepted. + num_alphabet (int): Number of possible residues for sequence decoder. + num_chi_bins (int): Number of chi bins for smooth binning of chi angles + used when `sidechain_embedding` is `chi_rbf` or `mixed_chi_X`. + decoder_num_hidden (int): Dimension of hidden decoder layers. + label_smoothing (float): Level of smoothing to apply to sequence and + sidechain labels. + separate_packing (bool): If True, then autoregressively factorize + sequence and sidechains in two stages where the full sequence is predicted + before all of the chi angles. Otherwise an interleaved factorization + will be used that autoregressively predicts both the residue identity + and chi angles in an alternating manner. Default is True. + graph_criterion (str): Graph criterion for structure encoder, defines + how neighbors are chosen. See + `chroma.models.graph_design.BackboneEncoderGNN` for + allowed values. + graph_random_min_local (int): Minimum number of neighbors in GNN that + come from local neighborhood, before random neighbors are chosen. + graph_attentional (bool): Currently unused, previously used for + experimental GNN attention mechanism. + graph_num_attention_heads (int): Currently unused, previously used for + experimental GNN attention mechanism. + predict_S_marginals (bool): Whether to train marginal sequence decoder. + predict_S_potts (bool): Whether to train Potts sequence decoder. + potts_parameterization (str): How to parametrize Potts sequence decoder, + see `chroma.layer.structure.potts` for allowed values. + potts_num_factors (int, optional): Number of factors to use for Potts + sequence decoder. + potts_symmetric_J (bool): Whether to force J tensor of Potts model to be + symmetric. + noise_schedule (str, optional): Noise schedule for mapping between + diffusion time and noise level, see + chroma.layers.structure.diffusion.DiffusionChainCov for allowed + values. If not set, model should only be provided with denoised + backbones. + noise_covariance_model (str): Covariance mode for mapping between + diffusion time and noise level, see + chroma.layers.structure.diffusion.DiffusionChainCov for allowed + values. + noise_complex_scaling (bool): Whether to scale noise for complexes. + noise_beta_range (Tuple[float, float]): Minimum and maximum noise levels + for noise schedule. + noise_log_snr_range (Tuple[float, float]): Range of log signal-to-noise + ratio for noising. + + Inputs: + X (torch.Tensor): Backbone coordinates with shape + `(num_batch, num_residues, num_atoms, 3)`. + C (torch.LongTensor): Chain map with shape `(num_batch, num_residues)`. + S (torch.LongTensor): Sequence tensor with shape + `(num_batch, num_residues)`. + t (torch.Tensor, optional): Diffusion timesteps corresponding to noisy + input backbones, of shape `(num_batch)`. Use zeros when passing + structures without noise. + sample_noise (bool, optional): Whether to apply noise to input + backbones. + permute_idx (torch.LongTensor, optional): Permutation tensor for fixing + the autoregressive decoding order `(num_batch, num_residues)`. If + `None` (default), a random decoding order will be generated. + priority (torch.Tensor, optional): Priority values for constraining + residue orderings with shape `(num_batch, num_residues)`. + If residues are assigned to integer-valued groups, the sampled + permutation will be ordered such that all residues within a lower- + valued priority group will occur before residues with higher-valued + priority assignments. + + Outputs (dict): + logp_S (torch.Tensor): Sequence log likelihoods per residue with shape + `(num_batch, num_residues)`. + logp_chi (torch.Tensor): Chi angle Log likelihoods per residue with + shape `(num_batch, num_residues, 4)`. + logp_S_marginals (torch.Tensor, optional): Sequence log likelihoods + per residue from marginal decoder with shape + `(num_batch, num_residues)`. + logp_S_potts (torch.Tensor, optional): Sequence log likelihoods per + residue from Potts decoder with shape + `(num_batch, num_residues)`. + chi (torch.Tensor): Chi angles with shape + `(num_batch, num_residues, 4)`. + mask_chi (torch.Tensor): Chi angle mask with shape + `(num_batch, num_residues, 4)`. + node_h_chi (torch.Tensor): Node features used for predicting chi + angles with shape `(num_batch, num_residues, dim_nodes)`. + mask_i (torch.Tensor): Node mask with shape + `(num_batch, num_residues)`. + mask_ij (torch.Tensor): Edge mask with shape + `(num_batch, num_nodes, num_neighbors)`. + mask_ij_causal (torch.Tensor): Causal edge mask for autoregressive + decoding with shape `(num_batch, num_nodes, num_neighbors)`. + edge_idx (torch.LongTensor): Edge indices for neighbors with shape + `(num_batch, num_residues, num_neighbors)`. + permute_idx (torch.LongTensor, optional): Permutation tensor that was + used for the autoregressive decoding order with shape + `(num_batch, num_residues)`. + X_noise (torch.Tensor): Noised structure coordinates with shape + `(num_batch, num_residues, num_atoms, 3)`. + """ + + def __init__( + self, + dim_nodes: int = 128, + dim_edges: int = 128, + num_neighbors: int = 30, + node_features: tuple = (("internal_coords", {"log_lengths": True}),), + edge_features: tuple = ( + "distances_2mer", + "orientations_2mer", + "distances_chain", + ), + sequence_embedding: str = "linear", + sidechain_embedding: str = "chi_rbf", + sidechains: bool = True, + num_layers: int = 3, + num_layers_encoder: Optional[int] = None, + dropout: float = 0.1, + node_mlp_layers: int = 1, + node_mlp_dim: Optional[int] = None, + edge_update: bool = True, + edge_mlp_layers: int = 1, + edge_mlp_dim: Optional[int] = None, + skip_connect_input: bool = False, + mlp_activation: str = "softplus", + num_alphabet: int = 20, + num_chi_bins: int = 20, + decoder_num_hidden: int = 512, + label_smoothing: float = 0.1, + separate_packing: bool = True, + graph_criterion: str = "knn", + graph_random_min_local: int = 20, + graph_attentional: bool = False, + graph_num_attention_heads: int = 4, + predict_S_marginals: bool = False, + predict_S_potts: bool = False, + potts_parameterization: str = "factor", + potts_num_factors: Optional[int] = None, + potts_symmetric_J: bool = True, + noise_schedule: Optional[str] = None, + noise_covariance_model: str = "brownian", + noise_complex_scaling: bool = False, + noise_beta_range: Tuple[float, float] = (0.2, 70.0), + noise_log_snr_range: Tuple[float, float] = (-7.0, 13.5), + checkpoint_gradients: bool = False, + **kwargs + ) -> None: + """Initialize GraphDesign network.""" + super(GraphDesign, self).__init__() + + # Save configuration in kwargs + self.kwargs = locals() + self.kwargs.pop("self") + for key in list(self.kwargs.keys()): + if key.startswith("__") and key.endswith("__"): + self.kwargs.pop(key) + args = SimpleNamespace(**self.kwargs) + + # Important global options + self.dim_nodes = dim_nodes + self.dim_edges = dim_edges + self.num_alphabet = num_alphabet + self.num_chi_bins = num_chi_bins + self.separate_packing = separate_packing + self.sidechains = sidechains + self.predict_S_potts = predict_S_potts + self.traversal = ProteinTraversalSpatial() + + # Encoder GNN process backbone + self.encoder = BackboneEncoderGNN( + dim_nodes=args.dim_nodes, + dim_edges=args.dim_edges, + num_neighbors=args.num_neighbors, + node_features=args.node_features, + edge_features=args.edge_features, + num_layers=( + args.num_layers + if args.num_layers_encoder is None + else args.num_layers_encoder + ), + node_mlp_layers=args.node_mlp_layers, + node_mlp_dim=args.node_mlp_dim, + edge_update=args.edge_update, + edge_mlp_layers=args.edge_mlp_layers, + edge_mlp_dim=args.edge_mlp_dim, + mlp_activation=args.mlp_activation, + dropout=args.dropout, + skip_connect_input=args.skip_connect_input, + graph_criterion=args.graph_criterion, + graph_random_min_local=args.graph_random_min_local, + checkpoint_gradients=checkpoint_gradients, + ) + + # Time features for diffusion + if args.noise_schedule is not None: + self.noise_perturb = diffusion.DiffusionChainCov( + noise_schedule=args.noise_schedule, + beta_min=args.noise_beta_range[0], + beta_max=args.noise_beta_range[1], + log_snr_range=args.noise_log_snr_range, + covariance_model=args.noise_covariance_model, + complex_scaling=args.noise_complex_scaling, + ) + self.time_features = diffusion.NoiseTimeEmbedding( + dim_embedding=args.dim_nodes, + noise_schedule=self.noise_perturb.noise_schedule, + ) + + # Decoder GNN process backbone + if self.sidechains: + self.decoder = SidechainDecoderGNN( + dim_nodes=args.dim_nodes, + dim_edges=args.dim_edges, + num_neighbors=args.num_neighbors, + predict_S=True, + predict_chi=(not args.separate_packing), + sequence_embedding=args.sequence_embedding, + sidechain_embedding=args.sidechain_embedding, + num_layers=args.num_layers, + node_mlp_layers=args.node_mlp_layers, + node_mlp_dim=args.node_mlp_dim, + edge_update=args.edge_update, + edge_mlp_layers=args.edge_mlp_layers, + edge_mlp_dim=args.edge_mlp_dim, + mlp_activation=args.mlp_activation, + dropout=args.dropout, + skip_connect_input=args.skip_connect_input, + num_alphabet=args.num_alphabet, + num_chi_bins=args.num_chi_bins, + decoder_num_hidden=args.decoder_num_hidden, + label_smoothing=args.label_smoothing, + checkpoint_gradients=checkpoint_gradients, + ) + + if args.predict_S_marginals: + self.decoder_S_marginals = NodePredictorS( + num_alphabet=args.num_alphabet, + dim_nodes=args.dim_nodes, + dim_hidden=args.decoder_num_hidden, + label_smoothing=args.label_smoothing, + ) + + if args.predict_S_potts: + self.decoder_S_potts = potts.GraphPotts( + dim_nodes=args.dim_nodes, + dim_edges=args.dim_edges, + num_states=args.num_alphabet, + parameterization=args.potts_parameterization, + num_factors=args.potts_num_factors, + symmetric_J=args.potts_symmetric_J, + dropout=args.dropout, + label_smoothing=args.label_smoothing, + ) + + if args.separate_packing: + # Optionally do a two-stage autoregressive prediction + self.embed_S = nn.Embedding(args.num_alphabet, args.dim_nodes) + self.encoder_S_gnn = graph.GraphNN( + dim_nodes=args.dim_nodes, + dim_edges=args.dim_edges, + num_layers=args.num_layers, + node_mlp_layers=args.node_mlp_layers, + node_mlp_dim=args.node_mlp_dim, + edge_update=args.edge_update, + edge_mlp_layers=args.edge_mlp_layers, + edge_mlp_dim=args.edge_mlp_dim, + mlp_activation=args.mlp_activation, + dropout=args.dropout, + norm="transformer", + scale=args.num_neighbors, + skip_connect_input=args.skip_connect_input, + checkpoint_gradients=checkpoint_gradients, + ) + self.decoder_chi = SidechainDecoderGNN( + dim_nodes=args.dim_nodes, + dim_edges=args.dim_edges, + num_neighbors=args.num_neighbors, + predict_S=False, + predict_chi=True, + sequence_embedding=args.sequence_embedding, + sidechain_embedding=args.sidechain_embedding, + num_layers=args.num_layers, + node_mlp_layers=args.node_mlp_layers, + node_mlp_dim=args.node_mlp_dim, + edge_update=args.edge_update, + edge_mlp_layers=args.edge_mlp_layers, + edge_mlp_dim=args.edge_mlp_dim, + mlp_activation=args.mlp_activation, + dropout=args.dropout, + skip_connect_input=args.skip_connect_input, + num_alphabet=args.num_alphabet, + num_chi_bins=args.num_chi_bins, + decoder_num_hidden=args.decoder_num_hidden, + label_smoothing=args.label_smoothing, + checkpoint_gradients=checkpoint_gradients, + ) + + if sidechains: + self.chi_to_X = sidechain.SideChainBuilder() + self.X_to_chi = sidechain.ChiAngles() + self.loss_rmsd = sidechain.LossSideChainRMSD() + self.loss_clash = sidechain.LossSidechainClashes() + + self.loss_eps = 1e-5 + + @validate_XC() + def forward( + self, + X: torch.Tensor, + C: torch.LongTensor, + S: torch.LongTensor, + t: Optional[torch.Tensor] = None, + sample_noise: bool = False, + permute_idx: Optional[torch.LongTensor] = None, + priority: Optional[torch.LongTensor] = None, + ) -> dict: + # Sample noisy backbones + X_noise = X + if sample_noise and hasattr(self, "noise_perturb"): + X_bb = X[:, :, :4, :] + self.noise_perturb.noise_schedule + t = self.noise_perturb.sample_t(C, t) + X_noise_bb = self.noise_perturb(X_bb, C, t=t) + if self.sidechains: + # Rebuild sidechains on noised backbone from native chi angles + chi, mask_chi = self.X_to_chi(X, C, S) + X_noise, mask_X = self.chi_to_X(X_noise_bb, C, S, chi) + else: + pass + # TODO IDK what to return here + + node_h, edge_h, edge_idx, mask_i, mask_ij = self.encode(X_noise, C, t=t) + + logp_S_marginals = None + if self.kwargs["predict_S_marginals"]: + logp_S_marginals, _ = self.decoder_S_marginals(S, node_h, mask_i) + + logp_S_potts = None + if self.kwargs["predict_S_potts"]: + logp_S_potts = self.decoder_S_potts.loss( + S, node_h, edge_h, edge_idx, mask_i, mask_ij + ) + + # Sample random permutations and build autoregressive mask + if permute_idx is None: + permute_idx = self.traversal(X, C, priority=priority) + + if self.sidechains: + # In one-stage packing, predict S and chi angles in an interleaved manner + ( + logp_S, + logp_chi, + chi, + mask_chi, + node_h_chi, + _, + _, + _, + mask_ij_causal, + ) = self.decoder( + X_noise, C, S, node_h, edge_h, edge_idx, mask_i, mask_ij, permute_idx + ) + else: + logp_S = (None,) + logp_chi = None + chi = None + mask_chi = None + node_h_chi = None + mask_ij_causal = None + + if self.separate_packing: + # In two-stage packing, re-process embeddings with sequence + node_h = node_h + mask_i.unsqueeze(-1) * self.embed_S(S) + node_h, edge_h = self.encoder_S_gnn( + node_h, edge_h, edge_idx, mask_i, mask_ij + ) + _, logp_chi, chi, mask_chi, node_h_chi, _, _, _, _ = self.decoder_chi( + X_noise, C, S, node_h, edge_h, edge_idx, mask_i, mask_ij, permute_idx + ) + if t is None: + t = torch.zeros(C.size(0), device=C.device) + outputs = { + "logp_S": logp_S, + "logp_chi": logp_chi, + "logp_S_marginals": logp_S_marginals, + "logp_S_potts": logp_S_potts, + "chi": chi, + "mask_chi": mask_chi, + "node_h_chi": node_h_chi, + "mask_i": mask_i, + "mask_ij": mask_ij, + "mask_ij_causal": mask_ij_causal, + "edge_idx": edge_idx, + "permute_idx": permute_idx, + "X_noise": X_noise, + "t": t, + } + return outputs + + def set_gradient_checkpointing(self, flag: bool): + """Sets gradient checkpointing to `flag` on all relevant modules""" + self.encoder.checkpoint_gradients = flag + self.encoder.gnn.checkpoint_gradients = flag + if self.sidechains: + self.decoder.checkpoint_gradients = flag + self.decoder.gnn.checkpoint_gradients = flag + if self.separate_packing: + self.encoder_S_gnn.checkpoint_gradients = flag + self.decoder_chi.checkpoint_gradients = flag + self.decoder_chi.gnn.checkpoint_gradients = flag + + @validate_XC() + def encode( + self, X: torch.Tensor, C: torch.Tensor, t: Optional[torch.Tensor] = None + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Encode the backbone and (optionally) the noise level. + + Args: + X (torch.Tensor): Backbone coordinates with shape + `(num_batch, num_residues, num_atoms, 3)`. + C (torch.LongTensor): Chain map with shape + `(num_batch, num_residues)`. + t (torch.Tensor, optional): Diffusion timesteps corresponding to + noisy input backbones, of shape `(num_batch)`. Default is no + noise. + + Returns: + node_h (torch.Tensor): Node features with shape + `(num_batch, num_residues, dim_nodes)`. + edge_h (torch.Tensor): Edge features with shape + `(num_batch, num_residues, num_neighbors, dim_edges)`. + edge_idx (torch.LongTensor): Edge indices for neighbors with shape + `(num_batch, num_residues, num_neighbors)`. + mask_i (torch.Tensor): Node mask with shape + `(num_batch, num_residues)`. + mask_ij (torch.Tensor): Edge mask with shape + `(num_batch, num_nodes, num_neighbors)`. + """ + + node_h_aux = None + if hasattr(self, "time_features"): + t = 0.0 if t is None else t + node_h_aux = self.time_features(t) + + node_h, edge_h, edge_idx, mask_i, mask_ij = self.encoder( + X, C, node_h_aux=node_h_aux + ) + return node_h, edge_h, edge_idx, mask_i, mask_ij + + @validate_XC() + def predict_marginals( + self, X: torch.Tensor, C: torch.Tensor, t: Optional[torch.Tensor] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Predict sequence marginal likelihoods. + + Args: + X (torch.Tensor): Backbone coordinates with shape + `(num_batch, num_residues, num_atoms, 3)`. + C (torch.LongTensor): Chain map with shape + `(num_batch, num_residues)`. + t (torch.Tensor, optional): Diffusion timesteps corresponding to + noisy input backbones, of shape `(num_batch)`. Default is no + noise. + + Returns: + log_probs_S (torch.Tensor): Node-wise sequence log probabilities + with shape `(num_batch, num_residues, 20)`. + mask_i (torch.Tensor): Node mask with shape + `(num_batch, num_residues)`. + """ + + if not self.kwargs["predict_S_marginals"]: + raise Exception( + "This version of GraphDesign was not trained with marginal prediction" + ) + node_h, edge_h, edge_idx, mask_i, mask_ij = self.encode(X, C, t) + log_probs_S = self.decoder_S_marginals.log_probs_S(node_h, mask_i) + return log_probs_S, mask_i + + @validate_XC() + def predict_potts( + self, X: torch.Tensor, C: torch.Tensor, t: Optional[torch.Tensor] = None + ) -> Tuple[torch.Tensor, torch.Tensor, torch.LongTensor]: + """Predict sequence Potts model. + + Args: + X (torch.Tensor): Backbone coordinates with shape + `(num_batch, num_residues, num_atoms, 3)`. + C (torch.LongTensor): Chain map with shape + `(num_batch, num_residues)`. + t (torch.Tensor, optional): Diffusion timesteps corresponding to + noisy input backbones, of shape `(num_batch)`. Default is no + noise. + + Returns: + h (torch.Tensor): The h tensor of a Potts model with dimensions + `(seq_length, n_tokens)`. + J (torch.Tensor): The J tensor of a Potts model with dimensions + `(seq_length, seq_length, n_tokens, n_tokens)`. + edge_idx (torch.LongTensor): Edge indices for neighbors with shape + `(num_batch, num_residues, num_neighbors)` from GNN encoding. + """ + if not self.kwargs["predict_S_potts"]: + raise Exception( + "This version of GraphDesign was not trained with Potts prediction" + ) + node_h, edge_h, edge_idx, mask_i, mask_ij = self.encode(X, C, t) + h, J = self.decoder_S_potts(node_h, edge_h, edge_idx, mask_i, mask_ij) + return h, J, edge_idx + + @validate_XC() + def loss( + self, + X: torch.Tensor, + C: torch.LongTensor, + S: torch.LongTensor, + t: Optional[torch.Tensor] = None, + permute_idx: Optional[torch.LongTensor] = None, + sample_noise: bool = False, + batched: bool = True, + **kwargs + ) -> dict: + """Compute losses used for training. + + Args: + X (torch.Tensor): Backbone coordinates with shape + `(num_batch, num_residues, num_atoms, 3)`. + C (torch.LongTensor): Chain map with shape + `(num_batch, num_residues)`. + S (torch.LongTensor): Sequence tensor with shape + `(num_batch, num_residues)`. + t (torch.Tensor, optional): Diffusion timesteps corresponding to + noisy input backbones, of shape `(num_batch)`. Default is no + noise. + permute_idx (torch.LongTensor, optional): Permutation tensor for + fixing the autoregressive decoding order + `(num_batch, num_residues)`. If `None` (default), a random + decoding order will be generated. + sample_noise (bool): Whether to apply noise to input backbones. + batched (bool): Whether to batch average losses. + + Returns (dict): + neglogp (torch.Tensor): Sum of `neglogp_S` and `neglogp_chi` with + shape `(num_batch, num_residues)`. + neglogp_S (torch.Tensor): Average negative log probability per + residue identity with shape `(num_batch, num_residues)`. + neglogp_S_marginals (torch.Tensor): Average negative log probability + per residue identity from marginal decoder with shape + `(num_batch, num_residues)`. + neglogp_S_potts (torch.Tensor): Average negative log probability per + residue identity from Potts decoder with shape + `(num_batch, num_residues)`. + neglogp_chi (torch.Tensor): Average negative log probability per chi + angle with shape `(num_batch, num_residues)`. + mask_chi (torch.Tensor): Chi angle mask with shape + `(batch_size, num_residues, 4)`. + rmsd (torch.Tensor): Average RMSD per side-chain after sampling. + clash (torch.Tensor): Average number of clashes per side-chain after + sampling. + permute_idx (LongTensor, optional): Permutation tensor that was + used for the autoregressive decoding order with shape + `(num_batch, num_residues)`. + """ + + o = self.forward( + X, C, S, t=t, permute_idx=permute_idx, sample_noise=sample_noise + ) + + # Aggregate into per-residue scores for the batch + if batched: + _avg = lambda m, l: (m * l).sum() / (m.sum() + self.loss_eps) + else: + _avg = lambda m, l: (m * l).sum(dim=tuple(range(1, l.dim()))) / ( + m.sum(dim=tuple(range(1, l.dim()))) + self.loss_eps + ) + mask_S = o["mask_i"] + neglogp_S = -_avg(mask_S, o["logp_S"]) + neglogp_chi = -_avg(o["mask_chi"], o["logp_chi"]) + neglogp = neglogp_S + neglogp_chi + if o["logp_S_marginals"] is not None: + neglogp_S_marginals = -_avg(mask_S, o["logp_S_marginals"]) + neglogp = neglogp + neglogp_S_marginals + else: + neglogp_S_marginals = None + if o["logp_S_potts"] is not None: + neglogp_S_potts = -_avg(mask_S, o["logp_S_potts"]) + neglogp = neglogp + neglogp_S_potts + else: + neglogp_S_potts = None + + # Evaluate sampled side chains + decoder = self.decoder_chi if self.separate_packing else self.decoder + chi_sample = decoder.decoder_chi.sample( + S, o["mask_chi"], o["node_h_chi"], o["mask_i"], temperature=0.01 + ) + X_sample, mask_X = self.chi_to_X(o["X_noise"][:, :, :4, :], C, S, chi_sample) + + # RMSD loss + rmsd_i = self.loss_rmsd(o["X_noise"], X_sample, C, S) + rmsd = _avg(mask_S, rmsd_i) + + # Clash loss measures clashes generated to the past + clashes = self.loss_clash( + X_sample, C, S, edge_idx=o["edge_idx"], mask_ij=o["mask_ij_causal"] + ) + clash = _avg(mask_S, clashes) + + losses = { + "neglogp": neglogp, + "neglogp_S": neglogp_S, + "neglogp_S_marginals": neglogp_S_marginals, + "neglogp_S_potts": neglogp_S_potts, + "neglogp_chi": neglogp_chi, + "mask_chi": o["mask_chi"], + "rmsd": rmsd, + "clash": clash, + "permute_idx": o["permute_idx"], + "t": o["t"], + } + return losses + + @torch.no_grad() + @validate_XC() + def sample( + self, + X: torch.Tensor, + C: torch.LongTensor, + S: Optional[torch.LongTensor] = None, + t: Optional[Union[float, torch.Tensor]] = None, + t_packing: Optional[Union[float, torch.Tensor]] = None, + mask_sample: Optional[torch.Tensor] = None, + permute_idx: Optional[torch.LongTensor] = None, + temperature_S: float = 0.1, + temperature_chi: float = 1e-3, + clamped: bool = False, + resample_chi: bool = True, + return_scores: bool = False, + top_p_S: Optional[float] = None, + ban_S: Optional[tuple] = None, + sampling_method: Literal["potts", "autoregressive"] = "autoregressive", + regularization: Optional[str] = "LCP", + potts_sweeps: int = 500, + potts_proposal: Literal["dlmc", "chromatic"] = "dlmc", + verbose: bool = False, + symmetry_order: Optional[int] = None, + ) -> tuple: + """Sample sequence and side chain conformations given an input structure. + + Args: + X (torch.Tensor): All atom coordinates with shape + `(num_batch, num_residues, 14, 3)`. + C (torch.LongTensor): Chain map with shape + `(num_batch, num_residues)`. + S (torch.LongTensor): Sequence tensor with shape + `(num_batch, num_residues)`. + t (float or torch.Tensor, optional): Diffusion time for models trained with + diffusion augmentation of input structures. Setting `t=0` or + `t=None` will condition the model to treat the structure as + exact coordinates, while values of `t > 0` will condition + the model to treat structures as though they were drawn from + noise-augmented ensembles with that noise level. Default is `None`, + while for robust design we recommend `t=0.5`. May be a float or + a tensor of shape `(num_batch)`. + t_packing (float or torch.Tensor, optional): Potentially separate diffusion + time for packing. + mask_sample (torch.Tensor, optional): Binary tensor mask indicating + positions to be sampled with shape `(num_batch, num_residues)` or + position-specific valid amino acid choices with shape + `(num_batch, num_residues, num_alphabet)`. If `None` (default), all + positions will be sampled. + permute_idx (LongTensor, optional): Permutation tensor for fixing + the autoregressive decoding order `(num_batch, num_residues)`. + If `None` (default), a random decoding order will be generated. + temperature_S (float): Temperature parameter for sampling sequence + tokens. A value of `temperature_S=1.0` corresponds to the + model's unadjusted positions, though because of training such as + label smoothing values less than 1.0 are recommended. Default is + `0.1`. + temperature_chi (float): Temperature parameter for sampling chi + angles. Even if a high temperature sequence is sampled, this is + recommended to always be low. Default is `1E-3`. + clamped (bool): If `True`, no sampling is done and the likelihood + values will be calculated for the input sequence and structure. + Used for validating the sequential versus parallel decoding + modes. Default is `False`. + resample_chi (bool): If `True`, all chi angles will be resampled, + even for sequence positions that were not sampled (i.e. the model + will perform global repacking). Default is `True`. + return_scores (bool): If `True`, return dictionary containing + likelihood scores similar to those produced by `forward`. + top_p_S (float, optional): Option to perform top-p sampling for + autoregressive sequence decoding. If not `None` it will be the + top-p value [1]. + [1] Holtzman et al. The Curious Case of Neural Text Degeneration. (2020) + ban_S (tuple, optional): An optional set of token indices from + `alphabetical_restypes` to ban during sampling. + sampling_method (str): Sampling method for decoding sequence from structure. + If `autoregressive`, sequences will be designed by ancestral sampling with + the autoregessive decoder head. If `potts`, sequences will be designed + via MCMC with the potts decoder head. + regularization (str, optional): Optional sequence regularization to use + during decoding. Can be `LCP` for Local Composition Perplexity regularization + which penalizes local sequence windows from having unnaturally low + compositional entropies. (Implemented for both `potts` and `autoregressive`) + potts_sweeps (int): Number of sweeps to perform for MCMC sampling of `potts` + decoder. A sweep corresponds to a sufficient number of Monte Carlo steps + such that every position could have changed. + potts_proposal (str): MCMC proposal for Potts sampling. Currently implemented + proposals are `dlmc` for Discrete Langevin Monte Carlo [1] or `chromatic` + for Gibbs sampling with graph coloring. + [1] Sun et al. Discrete Langevin Sampler via Wasserstein Gradient Flow (2023). + symmetry_order (int, optional): Optional integer argument to enable + symmetric sequence decoding under `symmetry_order`-order symmetry. + The first `(num_nodes // symmetry_order)` states will be free to + move, and all consecutively tiled sets of states will be locked + to these during decoding. Internally this is accomplished by + summing the parameters Potts model under a symmetry constraint + into this reduced sized system and then back imputing at the end. + Currently only implemented for Potts models. + + Returns: + X_sample (torch.Tensor): Sampled all atom coordinates with shape + `(num_batch, num_residues, 14, 3)`. + S_sample (torch.LongTensor): Sampled sequence tensor with shape + `(num_batch, num_residues)`. + permute_idx (torch.LongTensor): Permutation tensor that was used + for the autoregressive decoding order with shape + `(num_batch, num_residues)`. + scores (dict, optional): Dictionary containing likelihood scores + similar to those produced by `forward`. + """ + if X.shape[2] == 4: + X = F.pad(X, [0, 0, 0, 10]) + alphabet = alphabetical_restypes + node_h, edge_h, edge_idx, mask_i, mask_ij = self.encode(X, C, t=t) + + # Process sampling mask + logits_init = torch.zeros( + list(C.shape) + [len(alphabet)], device=C.device + ).float() + if ban_S is not None: + ban_S = [alphabet.index(c) for c in ban_S] + mask_sample, mask_sample_1D, S_init = potts.init_sampling_masks( + logits_init, mask_sample, S=S, ban_S=ban_S + ) + if not clamped: + S = S_init + + # Sample random permutations and build autoregressive mask + if permute_idx is None: + permute_idx = self.traversal(X, C, priority=mask_sample_1D) + + if symmetry_order is not None and not (sampling_method == "potts"): + raise NotImplementedError( + "Symmetric decoding is currently only supported for Potts models" + ) + + if sampling_method == "potts": + if not self.kwargs["predict_S_potts"]: + raise Exception( + "This GraphDesign model was not trained with Potts prediction" + ) + + # Complexity regularization + penalty_func = None + mask_ij_coloring = None + edge_idx_coloring = None + if regularization == "LCP": + C_complexity = ( + C + if symmetry_order is None + else C[:, : C.shape[1] // symmetry_order] + ) + penalty_func = lambda _S: complexity.complexity_lcp(_S, C_complexity) + # edge_idx_coloring, mask_ij_coloring = complexity.graph_lcp(C, edge_idx, mask_ij) + + S_sample, _ = self.decoder_S_potts.sample( + node_h, + edge_h, + edge_idx, + mask_i, + mask_ij, + S=S, + mask_sample=mask_sample, + temperature=temperature_S, + num_sweeps=potts_sweeps, + penalty_func=penalty_func, + proposal=potts_proposal, + rejection_step=(potts_proposal == "chromatic"), + verbose=verbose, + edge_idx_coloring=edge_idx_coloring, + mask_ij_coloring=mask_ij_coloring, + symmetry_order=symmetry_order, + ) + chi_sample, logp_S, logp_chi = None, None, None + else: + # Sample sequence (and chi angles if one-stage) + + # Complexity regularization + bias_S_func = None + if regularization == "LCP": + bias_S_func = complexity.complexity_scores_lcp_t + + S_sample, chi_sample, logp_S, logp_chi, _ = self.decoder.decode( + X, + C, + S, + node_h, + edge_h, + edge_idx, + mask_i, + mask_ij, + permute_idx, + temperature_S=temperature_S, + temperature_chi=temperature_chi, + sample=not clamped, + mask_sample=mask_sample, + resample_chi=resample_chi, + top_p_S=top_p_S, + ban_S=ban_S, + bias_S_func=bias_S_func, + ) + + if self.separate_packing: + if t != t_packing: + node_h, edge_h, edge_idx, mask_i, mask_ij = self.encode( + X, C, t=t_packing + ) + + # In two-stage packing, re-process embeddings with sequence + node_h = node_h + mask_i.unsqueeze(-1) * self.embed_S(S_sample) + node_h, edge_h = self.encoder_S_gnn( + node_h, edge_h, edge_idx, mask_i, mask_ij + ) + _, chi_sample, _, logp_chi, _ = self.decoder_chi.decode( + X, + C, + S_sample, + node_h, + edge_h, + edge_idx, + mask_i, + mask_ij, + permute_idx, + temperature_chi=temperature_chi, + sample=not clamped, + mask_sample=mask_sample_1D, + resample_chi=resample_chi, + ) + + # Rebuild side chains + X_sample, mask_X = self.chi_to_X(X[:, :, :4, :], C, S_sample, chi_sample) + + if return_scores: + if sampling_method == "potts": + raise NotImplementedError + + # Summarize + mask_chi = sidechain.chi_mask(C, S_sample) + neglogp_S = -(mask_i * logp_S).sum([1]) / ( + (mask_i).sum([1]) + self.loss_eps + ) + neglogp_chi = -(mask_chi * logp_chi).sum([1, 2]) / ( + mask_chi.sum([1, 2]) + self.loss_eps + ) + + scores = { + "neglogp_S": neglogp_S, + "neglogp_chi": neglogp_chi, + "logp_S": logp_S, + "logp_chi": logp_chi, + "mask_i": mask_i, + "mask_chi": mask_chi, + } + return X_sample, S_sample, permute_idx, scores + else: + return X_sample, S_sample, permute_idx + + @validate_XC() + def pack( + self, + X: torch.Tensor, + C: torch.LongTensor, + S: torch.LongTensor, + permute_idx: Optional[torch.LongTensor] = None, + temperature_chi: float = 1e-3, + clamped: bool = False, + resample_chi: bool = True, + return_scores: bool = False, + ) -> tuple: + """Sample side chain conformations given an input structure. + + Args: + X (torch.Tensor): All atom coordinates with shape + `(num_batch, num_residues, 14, 3)`. + C (torch.LongTensor): Chain map with shape + `(num_batch, num_residues)`. + S (torch.LongTensor): Sequence tensor with shape + `(num_batch, num_residues)`. + permute_idx (LongTensor, optional): Permutation tensor for fixing + the autoregressive decoding order `(num_batch, num_residues)`. + If `None` (default), a random decoding order will be generated. + temperature_chi (float): Temperature parameter for sampling chi + angles. Even if a high temperature sequence is sampled, this is + recommended to always be low. Default is `1E-3`. + clamped (bool): If `True`, no sampling is done and the likelihood + values will be calculated for the input sequence and structure. + Used for validating the sequential versus parallel decoding + modes. Default is `False` + resample_chi (bool): If `True`, all chi angles will be resampled, + even for sequence positions that were not sampled (i.e. global + repacking). Default is `True`. + return_scores (bool): If `True`, return dictionary containing + likelihood scores similar to those produced by `forward`. + + Returns: + X_sample (torch.Tensor): Sampled all atom coordinates with shape + `(num_batch, num_residues, 14, 3)`. + neglogp_chi (torch.Tensor, optional): Average negative log + probability per chi angle. + permute_idx (torch.LongTensor): Permutation tensor that was used + for the autoregressive decoding order with shape + `(num_batch, num_residues)`. + scores (dict, optional): Dictionary containing likelihood scores + similar to those produced by `forward`. + """ + assert self.separate_packing + + with torch.no_grad(): + if X.shape[2] == 4: + X = F.pad(X, [0, 0, 0, 10]) + + node_h, edge_h, edge_idx, mask_i, mask_ij = self.encode(X, C) + + # Sample random permutations and build autoregressive mask + if permute_idx is None: + permute_idx = self.traversal(X, C) + + # In two-stage packing, re-process embeddings with sequence + node_h = node_h + mask_i.unsqueeze(-1) * self.embed_S(S) + node_h, edge_h = self.encoder_S_gnn( + node_h, edge_h, edge_idx, mask_i, mask_ij + ) + _, chi_sample, _, logp_chi, _ = self.decoder_chi.decode( + X, + C, + S, + node_h, + edge_h, + edge_idx, + mask_i, + mask_ij, + permute_idx, + temperature_chi=temperature_chi, + sample=not clamped, + resample_chi=resample_chi, + ) + + X_sample, mask_X = self.chi_to_X(X[:, :, :4, :], C, S, chi_sample) + + # Summarize + mask_chi = sidechain.chi_mask(C, S) + neglogp_chi = -(mask_chi * logp_chi).sum([1, 2]) / ( + mask_chi.sum([1, 2]) + self.loss_eps + ) + if return_scores: + scores = { + "neglogp_chi": neglogp_chi, + "logp_chi": logp_chi, + "mask_i": mask_i, + "mask_chi": mask_chi, + } + return X_sample, permute_idx, scores + else: + return X_sample, permute_idx + + return X_sample, neglogp_chi, permute_idx + + +class BackboneEncoderGNN(nn.Module): + """Graph Neural Network for processing protein structure into graph embeddings. + + Args: + See documention of `structure.protein_graph.ProteinFeatureGraph`, + and `graph.GraphNN` for more details. + + dim_nodes (int): Hidden dimension of node tensors. + dim_edges (int): Hidden dimension of edge tensors. + num_neighbors (int): Number of neighbors per nodes. + node_features (tuple): List of node feature specifications. Features + can be given as strings or as dictionaries. + edge_features (tuple): List of edge feature specifications. Features + can be given as strings or as dictionaries. + num_layers (int): Number of layers. + node_mlp_layers (int): Number of hidden layers for node update + function. + node_mlp_dim (int, optional): Dimension of hidden layers for node update + function, defaults to match output dimension. + edge_update (bool): Whether to include an edge update step. + edge_mlp_layers (int): Number of hidden layers for edge update + function. + edge_mlp_dim (int, optional): Dimension of hidden layers for edge update + function, defaults to match output dimension. + skip_connect_input (bool): Whether to include skip connections between + layers. + mlp_activation (str): MLP nonlinearity function, `relu` or `softplus` + accepted. + dropout (float): Dropout fraction. + graph_distance_atom_type (int): Atom type for computing residue-residue + distances for graph construction. Negative values will specify + centroid across atom types. Default is `-1` (centroid). + graph_cutoff (float, optional): Cutoff distance for graph construction: + mask any edges further than this cutoff. Default is `None`. + graph_mask_interfaces (bool): Restrict connections only to within + chains, excluding-between chain interactions. Default is `False`. + graph_criterion (str): Method used for building graph from distances. + Currently supported methods are `{knn, random_log, random_linear}`. + Default is `knn`. + graph_random_min_local (int): Minimum number of neighbors in GNN that + come from local neighborhood, before random neighbors are chosen. + checkpoint_gradients (bool): Switch to implement gradient checkpointing + during training. + + Inputs: + X (torch.Tensor): Backbone coordinates with shape + `(num_batch, num_residues, num_atoms, 3)`. + C (torch.LongTensor): Chain map with shape `(num_batch, num_residues)`. + node_h_aux (torch.LongTensor, optional): Auxiliary node features with + shape `(num_batch, num_residues, dim_nodes)`. + edge_h_aux (torch.LongTensor, optional): Auxiliary edge features with + shape `(num_batch, num_residues, num_neighbors, dim_edges)`. + edge_idx (torch.LongTensor, optional): Input edge indices for neighbors + with shape `(num_batch, num_residues, num_neighbors)`. + mask_ij (torch.Tensor, optional): Input edge mask with shape + `(num_batch, num_nodes, num_neighbors)`. + + Outputs: + node_h (torch.Tensor): Node features with shape + `(num_batch, num_residues, dim_nodes)`. + edge_h (torch.Tensor): Edge features with shape + `(num_batch, num_residues, num_neighbors, dim_edges)`. + edge_idx (torch.LongTensor): Edge indices for neighbors with shape + `(num_batch, num_residues, num_neighbors)`. + mask_i (torch.Tensor): Node mask with shape `(num_batch, num_residues)`. + mask_ij (torch.Tensor): Edge mask with shape + `(num_batch, num_nodes, num_neighbors)`. + """ + + def __init__( + self, + dim_nodes: int = 128, + dim_edges: int = 128, + num_neighbors: int = 30, + node_features: tuple = (("internal_coords", {"log_lengths": True}),), + edge_features: tuple = ( + "distances_2mer", + "orientations_2mer", + "distances_chain", + ), + num_layers: int = 3, + node_mlp_layers: int = 1, + node_mlp_dim: Optional[int] = None, + edge_update: bool = True, + edge_mlp_layers: int = 1, + edge_mlp_dim: Optional[int] = None, + skip_connect_input: bool = False, + mlp_activation: str = "softplus", + dropout: float = 0.1, + graph_distance_atom_type: int = -1, + graph_cutoff: Optional[float] = None, + graph_mask_interfaces: bool = False, + graph_criterion: str = "knn", + graph_random_min_local: int = 20, + checkpoint_gradients: bool = False, + **kwargs + ) -> None: + """Initialize BackboneEncoderGNN.""" + super(BackboneEncoderGNN, self).__init__() + + # Save configuration in kwargs + self.kwargs = locals() + self.kwargs.pop("self") + for key in list(self.kwargs.keys()): + if key.startswith("__") and key.endswith("__"): + self.kwargs.pop(key) + args = SimpleNamespace(**self.kwargs) + + # Important global options + self.dim_nodes = dim_nodes + self.dim_edges = dim_edges + self.checkpoint_gradients = checkpoint_gradients + + graph_kwargs = { + "distance_atom_type": args.graph_distance_atom_type, + "cutoff": args.graph_cutoff, + "mask_interfaces": args.graph_mask_interfaces, + "criterion": args.graph_criterion, + "random_min_local": args.graph_random_min_local, + } + + self.feature_graph = protein_graph.ProteinFeatureGraph( + dim_nodes=args.dim_nodes, + dim_edges=args.dim_edges, + num_neighbors=args.num_neighbors, + graph_kwargs=graph_kwargs, + node_features=args.node_features, + edge_features=args.edge_features, + ) + + self.gnn = graph.GraphNN( + dim_nodes=args.dim_nodes, + dim_edges=args.dim_edges, + num_layers=args.num_layers, + node_mlp_layers=args.node_mlp_layers, + node_mlp_dim=args.node_mlp_dim, + edge_update=args.edge_update, + edge_mlp_layers=args.edge_mlp_layers, + edge_mlp_dim=args.edge_mlp_dim, + mlp_activation=args.mlp_activation, + dropout=args.dropout, + norm="transformer", + scale=args.num_neighbors, + skip_connect_input=args.skip_connect_input, + checkpoint_gradients=checkpoint_gradients, + ) + + @validate_XC(all_atom=False) + def forward( + self, + X: torch.Tensor, + C: torch.LongTensor, + node_h_aux: Optional[torch.Tensor] = None, + edge_h_aux: Optional[torch.Tensor] = None, + edge_idx: Optional[torch.Tensor] = None, + mask_ij: Optional[torch.Tensor] = None, + ) -> Tuple[ + torch.Tensor, torch.Tensor, torch.LongTensor, torch.Tensor, torch.Tensor + ]: + """Encode XC backbone structure into node and edge features.""" + num_batch, num_residues = C.shape + + # Hack to enable checkpointing + if self.checkpoint_gradients and (not X.requires_grad): + X.requires_grad = True + + node_h, edge_h, edge_idx, mask_i, mask_ij = self._checkpoint( + self.feature_graph, X, C, edge_idx, mask_ij + ) + + if node_h_aux is not None: + node_h = node_h + mask_i.unsqueeze(-1) * node_h_aux + if edge_h_aux is not None: + edge_h = edge_h + mask_ij.unsqueeze(-1) * edge_h_aux + + node_h, edge_h = self.gnn(node_h, edge_h, edge_idx, mask_i, mask_ij) + return node_h, edge_h, edge_idx, mask_i, mask_ij + + def _checkpoint(self, module: nn.Module, *args) -> nn.Module: + if self.checkpoint_gradients: + return checkpoint(module, *args) + else: + return module(*args) + + +class SidechainDecoderGNN(nn.Module): + """Autoregressively generate sidechains given backbone graph embeddings. + + Args: + See documention of `structure.protein_graph.ProteinFeatureGraph`, + and `graph.GraphNN` for more details. + + dim_nodes (int): Hidden dimension of node tensors. + dim_edges (int): Hidden dimension of edge tensors. + num_neighbors (int): Number of neighbors per nodes. + predict_S (bool): Whether to predict sequence. + predict_chi (bool): Whether to predict chi angles. + sequence_embedding (str): How to represent sequence when decoding. + Currently the only option is `linear`. + sidechain_embedding (str): How to represent chi angles when decoding. + Options include `chi_linear` for a simple linear layer, `chi_rbf` + for a featurization based on smooth binning of chi angles, + `X_direct` which directly encodes the all-atom coordinates using + random Fourier features, and `mixed_chi_X` which uses both the + featurizations of `chi_rbf` and of `X_direct`. + num_layers (int): Number of layers. + node_mlp_layers (int): Number of hidden layers for node update + function. + node_mlp_dim (int, optional): Dimension of hidden layers for node update + function, defaults to match output dimension. + edge_update (bool): Whether to include an edge update step. + edge_mlp_layers (int): Number of hidden layers for edge update + function. + edge_mlp_dim (int, optional): Dimension of hidden layers for edge update + function, defaults to match output dimension. + skip_connect_input (bool): Whether to include skip connections between + layers. + mlp_activation (str): MLP nonlinearity function, `relu` or `softplus` + accepted. + dropout (float): Dropout fraction. + num_alphabet (int): Number of possible residues. + num_chi_bins (int): Number of chi bins for smooth binning of chi angles + used when `sidechain_embedding` is `chi_rbf` or `mixed_chi_X`. + decoder_num_hidden (int): Dimension of hidden layers. + label_smoothing (float): Level of smoothing to apply to sequence and + sidechain labels. + + Inputs: + X (torch.Tensor): Backbone coordinates with shape + `(num_batch, num_residues, num_atoms, 3)`. + C (torch.LongTensor): Chain map with shape `(num_batch, num_residues)`. + S (torch.LongTensor): Sequence tensor with shape + `(num_batch, num_residues)`. + node_h (torch.Tensor): Node features with shape + `(num_batch, num_residues, dim_nodes)`. + edge_h (torch.Tensor): Edge features with shape + `(num_batch, num_residues, num_neighbors, dim_edges)`. + edge_idx (torch.LongTensor): Edge indices for neighbors with shape + `(num_batch, num_residues, num_neighbors)`. + mask_i (torch.Tensor): Node mask with shape + `(num_batch, num_residues)`. + mask_ij (torch.Tensor): Edge mask with shape + `(num_batch, num_nodes, num_neighbors)`. + permute_idx (torch.LongTensor): Permutation tensor for fixing the + autoregressive decoding order `(num_batch, num_residues)`. If + `None` (default), a random decoding order will be generated. + + Outputs: + logp_S (torch.Tensor): Sequence log likelihoods per residue with shape + `(num_batch, num_residues)`. + logp_chi (torch.Tensor): Chi angle Log likelihoods per residue with + shape `(num_batch, num_residues, 4)`. + chi (torch.Tensor): Sidechain chi angles in radians with shape + `(num_batch, num_residues, 4)`. + mask_chi (torch.Tensor): Mask for chi angles with shape + `(num_batch, num_residues, 4)`. + node_h (torch.Tensor): Node features with shape + `(num_batch, num_residues, dim_nodes)`. + edge_h (torch.Tensor): Edge features with shape + `(num_batch, num_residues, num_neighbors, dim_edges)`. + edge_idx (torch.LongTensor): Edge indices for neighbors with shape + `(num_batch, num_residues, num_neighbors)`. + mask_i (torch.Tensor): Node mask with shape `(num_batch, num_residues)`. + mask_ij (torch.Tensor): Edge mask with shape + `(num_batch, num_nodes, num_neighbors)`. + """ + + def __init__( + self, + dim_nodes: int = 128, + dim_edges: int = 128, + num_neighbors: int = 30, + predict_S: bool = True, + predict_chi: bool = True, + sequence_embedding: str = "linear", + sidechain_embedding: str = "mixed_chi_X", + num_layers: int = 3, + node_mlp_layers: int = 1, + node_mlp_dim: Optional[int] = None, + edge_update: bool = True, + edge_mlp_layers: int = 1, + edge_mlp_dim: Optional[int] = None, + skip_connect_input: bool = False, + mlp_activation: str = "softplus", + dropout: float = 0.1, + num_alphabet: int = 20, + num_chi_bins: int = 20, + decoder_num_hidden: int = 512, + label_smoothing: float = 0.1, + checkpoint_gradients: bool = False, + **kwargs + ): + super(SidechainDecoderGNN, self).__init__() + + # Save configuration in kwargs + self.kwargs = locals() + self.kwargs.pop("self") + for key in list(self.kwargs.keys()): + if key.startswith("__") and key.endswith("__"): + self.kwargs.pop(key) + args = SimpleNamespace(**self.kwargs) + + # Important global options + self.dim_nodes = dim_nodes + self.dim_edges = dim_edges + self.num_alphabet = num_alphabet + self.num_chi_bins = num_chi_bins + + # Predict S, chi or both? + assert predict_S or predict_chi + self.predict_S = predict_S + self.predict_chi = predict_chi + + self.sequence_embedding = sequence_embedding + self.sidechain_embedding = sidechain_embedding + if self.sequence_embedding == "linear": + self.W_S = nn.Embedding(num_alphabet, dim_edges) + + # If we are predicting chi angles, then embed them + if self.predict_chi: + if self.sidechain_embedding == "chi_linear": + self.W_chi = nn.Linear(8, dim_edges) + elif self.sidechain_embedding == "chi_rbf": + self.embed_chi = NodeChiRBF( + dim_out=args.dim_edges, num_chi=4, num_chi_bins=args.num_chi_bins + ) + elif self.sidechain_embedding == "X_direct": + self.embed_X = EdgeSidechainsDirect(dim_out=dim_edges) + elif self.sidechain_embedding == "mixed_chi_X": + self.embed_chi = NodeChiRBF( + dim_out=args.dim_edges, num_chi=4, num_chi_bins=args.num_chi_bins + ) + self.embed_X = EdgeSidechainsDirect(dim_out=dim_edges, basis_type="rff") + + # Decoder GNN process backbone + self.gnn = graph.GraphNN( + dim_nodes=args.dim_nodes, + dim_edges=args.dim_edges, + num_layers=args.num_layers, + node_mlp_layers=args.node_mlp_layers, + node_mlp_dim=args.node_mlp_dim, + edge_update=args.edge_update, + edge_mlp_layers=args.edge_mlp_layers, + edge_mlp_dim=args.edge_mlp_dim, + mlp_activation=args.mlp_activation, + dropout=args.dropout, + norm="transformer", + scale=args.num_neighbors, + skip_connect_input=args.skip_connect_input, + checkpoint_gradients=checkpoint_gradients, + ) + + if self.predict_S: + self.decoder_S = NodePredictorS( + num_alphabet=args.num_alphabet, + dim_nodes=args.dim_nodes, + dim_hidden=args.decoder_num_hidden, + label_smoothing=args.label_smoothing, + ) + + if self.predict_chi: + self.decoder_chi = NodePredictorChi( + num_alphabet=args.num_alphabet, + num_chi_bins=args.num_chi_bins, + dim_nodes=args.dim_nodes, + dim_hidden=args.decoder_num_hidden, + label_smoothing=args.label_smoothing, + ) + + self.loss_eps = 1e-5 + self.chi_to_X = sidechain.SideChainBuilder() + self.X_to_chi = sidechain.ChiAngles() + + @validate_XC() + def forward( + self, + X: torch.Tensor, + C: torch.LongTensor, + S: torch.LongTensor, + node_h: torch.Tensor, + edge_h: torch.Tensor, + edge_idx: torch.LongTensor, + mask_i: torch.Tensor, + mask_ij: torch.Tensor, + permute_idx: torch.LongTensor, + ) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.LongTensor, + torch.Tensor, + torch.Tensor, + ]: + """Predict sequence and chi angles autoregressively given graph features.""" + + # Permute graph representation + ( + node_h_p, + edge_h_p, + edge_idx_p, + mask_i_p, + mask_ij_p, + ) = graph.permute_graph_embeddings( + node_h, edge_h, edge_idx, mask_i, mask_ij, permute_idx + ) + + # Permute sequence and side chain chi angles + X_p = graph.permute_tensor(X, 1, permute_idx) + C_p = graph.permute_tensor(C, 1, permute_idx) + S_p = graph.permute_tensor(S, 1, permute_idx) + chi, mask_chi = self.X_to_chi(X, C, S) + chi_p = graph.permute_tensor(chi, -2, permute_idx) + + # Decode system autoregressively in the permuted coordinates + node_h_p, edge_h_p, edge_idx_p, mask_i_p, mask_ij_p = self._decode_inner( + X_p, C_p, S_p, chi_p, node_h_p, edge_h_p, edge_idx_p, mask_i_p, mask_ij_p + ) + + # Unpermute graph representation + permute_idx_inverse = torch.argsort(permute_idx, dim=-1) + node_h, edge_h, edge_idx, mask_i, mask_ij = graph.permute_graph_embeddings( + node_h_p, edge_h_p, edge_idx_p, mask_i_p, mask_ij_p, permute_idx_inverse + ) + + # Predict per-position joint probabilities of each side-chain's sequence and structure + logp_S, log_probs_S, logp_chi, log_probs_chi = None, None, None, None + if self.predict_S: + ( + logp_S, + log_probs_S, + ) = self.decoder_S(S, node_h, mask_i) + if self.predict_chi: + ( + logp_chi, + log_probs_chi, + ) = self.decoder_chi(S, chi, mask_chi, node_h, mask_i) + return ( + logp_S, + logp_chi, + chi, + mask_chi, + node_h, + edge_h, + edge_idx, + mask_i, + mask_ij, + ) + + def _decode_inner( + self, X_p, C_p, S_p, chi_p, node_h_p, edge_h_p, edge_idx_p, mask_i_p, mask_ij_p + ): + # Build autoregressive mask + mask_ij_p = graph.edge_mask_causal(edge_idx_p, mask_ij_p) + + # Add sequence context + h_S_p = self.W_S(S_p) + h_S_p_ij = graph.collect_neighbors(h_S_p, edge_idx_p) + edge_h_p = edge_h_p + mask_ij_p.unsqueeze(-1) * h_S_p_ij + + # Add side chain context + if self.predict_chi: + if self.sidechain_embedding in ["chi_rbf", "mixed_chi_X"]: + h_chi_p = self.embed_chi(chi_p) + h_chi_p_ij = graph.collect_neighbors(h_chi_p, edge_idx_p) + edge_h_p = edge_h_p + mask_ij_p.unsqueeze(-1) * h_chi_p_ij + + if self.sidechain_embedding == "mixed_chi_X": + edge_feature = self.embed_X(X_p, C_p, S_p, edge_idx_p) + edge_h_p = edge_h_p + mask_ij_p.unsqueeze(-1) * edge_feature + + # Run decoder GNN in parallel (permuted) + node_h_p, edge_h_p = self.gnn( + node_h_p, edge_h_p, edge_idx_p, mask_i_p, mask_ij_p + ) + return node_h_p, edge_h_p, edge_idx_p, mask_i_p, mask_ij_p + + def _decode_scatter(self, tensor, src, t): + """Decoding utility function: Scatter.""" + idx = (t * torch.ones_like(src)).long() + tensor.scatter_(1, idx, src) + + def _decode_pre_func(self, t, tensors_t): + """Decoding pre-step function adds features based on current S and chi.""" + _scatter_t = lambda tensor, src: self._decode_scatter(tensor, src, t) + + # Gather relevant tensors at step t + edge_h_p_t = tensors_t["edge_h_cache"][0][:, t, :, :].unsqueeze(1) + edge_idx_p_t = tensors_t["edge_idx"][:, t, :].unsqueeze(1) + mask_ij_p_t = tensors_t["mask_ij"][:, t, :].unsqueeze(1) + + # Update the edge embeddings at t with the relevant context + mask_ij_p_t = mask_ij_p_t.unsqueeze(-1) + + # Add sequence context + h_S_p_ij_t = graph.collect_neighbors(tensors_t["h_S_p"], edge_idx_p_t) + edge_h_p_t = edge_h_p_t + mask_ij_p_t * h_S_p_ij_t + + # Add chi context + if self.predict_chi: + if self.sidechain_embedding in ["chi_rbf", "mixed_chi_X"]: + h_chi_p_ij_t = graph.collect_neighbors( + tensors_t["h_chi_p"], edge_idx_p_t + ) + edge_h_p_t = edge_h_p_t + mask_ij_p_t * h_chi_p_ij_t + if self.sidechain_embedding == "mixed_chi_X": + h_chi_p_ij_t = self.embed_X.step( + t, + tensors_t["X_p"], + tensors_t["C_p"], + tensors_t["S_p"], + edge_idx_p_t, + ) + edge_h_p_t = edge_h_p_t + mask_ij_p_t * h_chi_p_ij_t + + _scatter_t(tensors_t["edge_h_cache"][0], edge_h_p_t) + return tensors_t + + def _decode_post_func( + self, + t, + tensors_t, + S_p_input, + chi_p_input, + temperature_S, + temperature_chi, + sample, + resample_chi, + mask_sample, + mask_sample_p=None, + top_p_S=None, + ban_S=None, + bias_S_func=None, + ): + """Decoding post-step function updates S and chi.""" + _scatter_t = lambda tensor, src: self._decode_scatter(tensor, src, t) + + # Gather relevant tensors at step t + C_p_t = tensors_t["C_p"][:, t].unsqueeze(1) + edge_h_p_t = tensors_t["edge_h_cache"][0][:, t, :, :].unsqueeze(1) + edge_idx_p_t = tensors_t["edge_idx"][:, t, :].unsqueeze(1) + mask_i_p_t = tensors_t["mask_i"][:, t].unsqueeze(1) + mask_ij_p_t = tensors_t["mask_ij"][:, t, :].unsqueeze(1) + node_h_p_t = tensors_t["node_h_cache"][-1][:, t, :].unsqueeze(1) + idx_p_t = tensors_t["idx_p"][:, t].unsqueeze(1) + + # Sample updated sequence + S_p_t = S_p_input[:, t].unsqueeze(1).clone() + if self.predict_S and sample: + bias_S = None + if bias_S_func is not None: + bias_S = bias_S_func( + t, + tensors_t["S_p"], + tensors_t["C_p"], + tensors_t["idx_p"], + edge_idx_p_t, + mask_ij_p_t, + ) + mask_S_t = None + if mask_sample_p is not None: + mask_S_t = mask_sample_p[:, t] + S_p_t = self.decoder_S.sample( + node_h_p_t, + mask_i_p_t, + temperature=temperature_S, + top_p=top_p_S, + bias=bias_S, + mask_S=mask_S_t, + ) + + _scatter_t(tensors_t["S_p"], S_p_t) + + # Sample updated side chain conformations + mask_chi_p_t = sidechain.chi_mask(C_p_t, S_p_t) + chi_p_t = chi_p_input[:, t].unsqueeze(1).clone() + if self.predict_chi and sample: + # Sample chi angles + chi_p_t_sample = self.decoder_chi.sample( + S_p_t, mask_chi_p_t, node_h_p_t, mask_i_p_t, temperature=temperature_chi + ) + + if mask_sample_p is not None and not resample_chi: + m = mask_sample_p[:, t].unsqueeze(-1).expand([-1, 4]) + chi_p_t = torch.where(m > 0, chi_p_t_sample, chi_p_t) + else: + chi_p_t = chi_p_t_sample + + # Rebuild side chain + X_p_t_bb = tensors_t["X_p"][:, t, :4, :].unsqueeze(1) + X_p_t, _ = self.chi_to_X(X_p_t_bb, C_p_t, S_p_t, chi_p_t) + _scatter_t(tensors_t["X_p"], X_p_t) + _scatter_t(tensors_t["chi_p"], chi_p_t) + + # Score the updated sequence and chi angles + if self.predict_S: + logp_S_p_t, _ = self.decoder_S(S_p_t, node_h_p_t, mask_i_p_t) + _scatter_t(tensors_t["logp_S_p"], logp_S_p_t) + if self.predict_chi: + logp_chi_p_t, _ = self.decoder_chi( + S_p_t, chi_p_t, mask_chi_p_t, node_h_p_t, mask_i_p_t + ) + _scatter_t(tensors_t["logp_chi_p"], logp_chi_p_t) + + # Update sequence and chi features (permuted) + h_S_p_t = self.W_S(S_p_t) + _scatter_t(tensors_t["h_S_p"], h_S_p_t) + + # Cache chi embeddings + if self.predict_chi and self.sidechain_embedding in ["chi_rbf", "mixed_chi_X"]: + h_chi_p_t = self.embed_chi(chi_p_t) + _scatter_t(tensors_t["h_chi_p"], h_chi_p_t) + return tensors_t + + @validate_XC() + def decode( + self, + X: torch.Tensor, + C: torch.LongTensor, + S: torch.LongTensor, + node_h: torch.Tensor, + edge_h: torch.Tensor, + edge_idx: torch.LongTensor, + mask_i: torch.Tensor, + mask_ij: torch.Tensor, + permute_idx: torch.LongTensor, + temperature_S: float = 0.1, + temperature_chi: float = 1e-3, + sample: bool = True, + mask_sample: Optional[torch.Tensor] = None, + resample_chi: bool = True, + top_p_S: Optional[float] = None, + ban_S: Optional[tuple] = None, + bias_S_func: Optional[Callable] = None, + ) -> Tuple[torch.LongTensor, torch.Tensor, torch.Tensor, torch.Tensor, dict]: + """Autoregressively decode sequence and chi angles from graph features. + + Args: + X (torch.Tensor): Backbone coordinates with shape + `(num_batch, num_residues, num_atoms, 3)`. + C (torch.LongTensor): Chain map with shape + `(num_batch, num_residues)`. + S (torch.LongTensor): Sequence tensor with shape + `(num_batch, num_residues)`. + node_h (torch.Tensor): Node features with shape + `(num_batch, num_residues, dim_nodes)`. + edge_h (torch.Tensor): Edge features with shape + `(num_batch, num_residues, num_neighbors, dim_edges)`. + edge_idx (torch.LongTensor): Edge indices for neighbors with shape + `(num_batch, num_residues, num_neighbors)`. + mask_i (torch.Tensor): Node mask with shape + `(num_batch, num_residues)`. + mask_ij (torch.Tensor): Edge mask with shape + `(num_batch, num_nodes, num_neighbors)`. + temperature_chi (float): Temperature parameter for sampling chi + angles. Even if a high temperature sequence is sampled, this is + recommended to always be low. Default is `1E-3`. + sample (bool): Whether to sample sequence and chi angles. + mask_sample (torch.Tensor, optional): Binary tensor mask indicating + positions to be sampled with shape `(num_batch, num_residues)`. + If `None` (default), all positions will be sampled. + resample_chi (bool): If `True`, all chi angles will be resampled, + even for sequence positions that were not sampled (i.e. global + repacking). Default is `True`. + top_p_S (float, optional): Top-p cutoff for Nucleus Sampling, see + Holtzman et al ICLR 2020. + ban_S (tuple, optional): An optional set of token indices from + `alphabetical_restypes` to ban during sampling. + + Returns: + S (torch.LongTensor): Sequence tensor with shape + `(num_batch, num_residues)`. + chi (torch.Tensor): Chi angles with shape + `(num_batch, num_residues, 4)`. + logp_S (torch.Tensor): Sequence log likelihoods per residue with + shape `(num_batch, num_residues)`. + logp_chi (torch.Tensor): Chi angle Log likelihoods per residue with + shape `(num_batch, num_residues, 4)`. + tensors (dict): Processed tensors from GNN decoding. + """ + + # Permute graph representation + ( + node_h_p, + edge_h_p, + edge_idx_p, + mask_i_p, + mask_ij_p, + ) = graph.permute_graph_embeddings( + node_h, edge_h, edge_idx, mask_i, mask_ij, permute_idx + ) + chi, mask_chi = self.X_to_chi(X, C, S) + + # Build autoregressive mask + mask_ij_p = graph.edge_mask_causal(edge_idx_p, mask_ij_p) + + # Initialize tensors + B, N, K = list(edge_idx.shape) + device = node_h.device + idx = torch.arange(end=N, device=device)[None, :].expand(C.shape) + tensors_init = { + "X_p": graph.permute_tensor(X, 1, permute_idx), + "C_p": graph.permute_tensor(C, 1, permute_idx), + "idx_p": graph.permute_tensor(idx, 1, permute_idx), + "S_p": torch.zeros_like(S), + "chi_p": torch.zeros([B, N, 4], device=device), + "h_S_p": torch.zeros([B, N, self.dim_edges], device=device), + "h_chi_p": torch.zeros([B, N, self.dim_edges], device=device), + "node_h": node_h_p, + "edge_h": edge_h_p, + "edge_idx": edge_idx_p, + "mask_i": mask_i_p, + "mask_ij": mask_ij_p, + "logp_S_p": torch.zeros([B, N], device=device), + "logp_chi_p": torch.zeros([B, N, 4], device=device), + } + + # As a sanity check against future state leakage, + # we initialize S and chi and zero and write in the true value + # during sequential decoding + S_p_input = graph.permute_tensor(S, 1, permute_idx) + chi_p_input = graph.permute_tensor(chi, 1, permute_idx) + mask_sample_p = None + if mask_sample is not None: + mask_sample_p = graph.permute_tensor(mask_sample, 1, permute_idx) + + # Pre-step function features current sequence and chi angles + pre_step_func = self._decode_pre_func + + # Post-step function samples sequence and/or chi angles at step t + post_step_func = lambda t, tensors_t: self._decode_post_func( + t, + tensors_t, + S_p_input, + chi_p_input, + temperature_S, + temperature_chi, + sample, + resample_chi, + mask_sample, + mask_sample_p, + top_p_S=top_p_S, + ban_S=ban_S, + bias_S_func=bias_S_func, + ) + + # Sequentially step through a forwards pass of the GNN at each + # position along the node dimension (1), running _pre_func + # and each iteration and _post_func after each iteration + tensors = self.gnn.sequential( + tensors_init, + pre_step_function=pre_step_func, + post_step_function=post_step_func, + ) + + # Unpermute sampled sequence and chi angles + permute_idx_inverse = torch.argsort(permute_idx, dim=-1) + S = graph.permute_tensor(tensors["S_p"], 1, permute_idx_inverse) + chi = graph.permute_tensor(tensors["chi_p"], 1, permute_idx_inverse) + logp_S = graph.permute_tensor(tensors["logp_S_p"], 1, permute_idx_inverse) + logp_chi = graph.permute_tensor(tensors["logp_chi_p"], 1, permute_idx_inverse) + + return S, chi, logp_S, logp_chi, tensors + + +def _filter_logits_top_p(logits, p=0.9): + """Filter logits by top-p (Nucleus sampling). + + See Holtzman et al, ICLR 2020. + + Args: + logits (Tensor): Logits with shape `(..., num_classes)`. + p (float): Cutoff probability. + + Returns: + logits_filters (Tensor): Filtered logits + with shape `(..., num_classes)`. + """ + logits_sort, indices_sort = torch.sort(logits, dim=-1, descending=True) + probs_sort = F.softmax(logits_sort, dim=-1) + probs_cumulative = torch.cumsum(probs_sort, dim=-1) + + # Remove tokens outside nucleus (aside from top token) + logits_sort_filtered = logits_sort.clone() + logits_sort_filtered[probs_cumulative > p] = -float("Inf") + logits_sort_filtered[..., 0] = logits_sort[..., 0] + + # Unsort + logits_filtered = logits_sort_filtered.gather(-1, indices_sort.argsort(-1)) + return logits_filtered + + +class NodePredictorS(nn.Module): + """Predict sequence tokens at each node given embeddings `P(S_i | h_i)`. + + Args: + num_alphabet (int): Number of amino acids. + dim_nodes (int): Node dimension of graph input. + dim_hidden (int): Hidden layer dimension. + loss_eps (float): Small number to avoid division by zero errors when + taking averages. + label_smoothing (float): Level of smoothing to apply. + + Inputs: + S (torch.LongTensor): Sequence tensor with shape + `(num_batch, num_residues)`. + node_h (torch.Tensor): Node features with shape + `(num_batch, num_residues, dim_nodes)`. + mask_i (torch.Tensor): Node mask with shape `(num_batch, num_residues)`. + + Outputs: + logp_S (torch.Tensor): Log likelihoods per residue with shape + `(num_batch, num_residues)`. During training, this applies label + smoothing. + log_probs_S (torch.Tensor): Log probabilities for each token for + at each residue with shape + `(num_batch, num_residues, num_alphabet)`. + """ + + def __init__( + self, + num_alphabet: int, + dim_nodes: int, + dim_hidden: int, + loss_eps: float = 1e-5, + label_smoothing: float = 0.1, + ) -> None: + super(NodePredictorS, self).__init__() + self.num_alphabet = num_alphabet + self.dim_nodes = dim_nodes + self.dim_hidden = dim_hidden + self.loss_eps = loss_eps + + self.label_smoothing = label_smoothing + self.training_loss = torch.nn.CrossEntropyLoss( + reduction="none", label_smoothing=self.label_smoothing + ) + + # Layers for predicting sequence and chi angles + self.S_mlp = graph.MLP( + dim_in=dim_nodes, + dim_hidden=dim_hidden, + dim_out=self.num_alphabet, + num_layers_hidden=2, + ) + + def log_probs_S(self, node_h: torch.Tensor, mask_i: torch.Tensor) -> torch.Tensor: + """Compute `log P(S | X, C)`.""" + mask_i_expand = mask_i.unsqueeze(-1) + S_logits = self.S_mlp(node_h) + log_probs_S = mask_i_expand * F.log_softmax(S_logits, -1) + return log_probs_S + + def forward( + self, S: torch.LongTensor, node_h: torch.Tensor, mask_i: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Evaluate chi angle joint likelihood given graph embeddings.""" + log_probs_S = self.log_probs_S(node_h, mask_i) + + if self.training: + logp_S = -self.training_loss(log_probs_S.permute([0, 2, 1]), S) + else: + logp_S = torch.gather(log_probs_S, 2, S.unsqueeze(-1)).squeeze(-1) + + return logp_S, log_probs_S + + def sample( + self, + node_h: torch.Tensor, + mask_i: torch.Tensor, + temperature: float = 1.0, + top_p: Optional[float] = None, + mask_S: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + ) -> torch.LongTensor: + """Sample sequence and graph embeddings. + + Args: + node_h (torch.Tensor): Node features with shape + `(num_batch, num_residues, dim_nodes)`. + mask_i (torch.Tensor): Node mask with shape + `(num_batch, num_residues)`. + temperature (float): Temperature parameter for sampling sequence + tokens. The default value of 1.0 corresponds to the model's + unadjusted positions, though because of training such as label + smoothing values less than 1.0 are recommended. + top_p (float, optional): Top-p cutoff for Nucleus Sampling, see + Holtzman et al ICLR 2020. + ban_S (tuple, optional): An optional set of token indices from + `alphabetical_restypes` to ban during sampling. + + Returns: + S_sample (torch.LongTensor): Sampled sequence of shape `(num_batch, + num_residues)`. + + """ + num_batch, num_residues, _ = node_h.shape + log_probs_S = self.log_probs_S(node_h, mask_i) + if bias is not None: + log_probs_S = log_probs_S + bias + if mask_S is not None: + log_probs_S = torch.where( + mask_S > 0, log_probs_S, -float("Inf") * torch.ones_like(log_probs_S) + ) + if top_p is not None: + log_probs_S = _filter_logits_top_p(log_probs_S, p=top_p) + p = torch.distributions.categorical.Categorical( + logits=log_probs_S / temperature + ) + S_sample = p.sample() + return S_sample + + +class NodePredictorChi(nn.Module): + """Predict chi angles autoregressively at each node given embeddings. + + Decomposes as `P(chi_i_{1-4} | h_i) = P(chi_i_4 | chi_i_<4 h_i) ... P(chi_i_1 | h_i)`. + + Args: + num_alphabet (int): Number of amino acids. + num_chi_bins (int): Number of discretization bins per chi angle. + dim_nodes (int): Node dimension of graph input. + dim_hidden (int): Hidden layer dimension. + loss_eps (float): Small number to avoid division by zero errors when + taking averages. + label_smoothing (float): Level of smoothing to apply. + + Inputs: + S (torch.LongTensor): Sequence tensor with shape + `(num_batch, num_residues)`. + chi (torch.Tensor): Chi angles with shape + `(num_batch, num_residues, 4)`. + mask_chi (torch.Tensor): Chi angle mask with shape + `(num_batch, num_residues, 4)`. + node_h (torch.Tensor): Node features with shape + `(num_batch, num_residues, dim_nodes)`. + mask_i (torch.Tensor): Node mask with shape `(num_batch, num_residues)`. + + Outputs: + logp_chi (torch.Tensor): Log likelihoods per residue with shape + `(num_batch, num_residues, 4)`. During training, this applies label + smoothing. + log_probs_chi (torch.Tensor): Log probabilities for each chi angle + token at each residue with shape + `(num_batch, num_residues, 4, num_chi_bins)`. + """ + + def __init__( + self, + num_alphabet: int, + num_chi_bins: int, + dim_nodes: int, + dim_hidden: int, + loss_eps: float = 1e-5, + label_smoothing: float = 0.1, + ) -> None: + super(NodePredictorChi, self).__init__() + self.num_alphabet = num_alphabet + self.num_chi_bins = num_chi_bins + self.dim_nodes = dim_nodes + self.dim_hidden = dim_hidden + self.loss_eps = loss_eps + self.label_smoothing = label_smoothing + self.training_loss = torch.nn.CrossEntropyLoss( + reduction="none", label_smoothing=self.label_smoothing + ) + self._init_chi_bins(num_chi_bins) + + # Layers for embedding sequence and chi angles + self.W_S = nn.Embedding(num_alphabet, dim_nodes) + self.chi_embedding = nn.ModuleList( + [ + NodeChiRBF(dim_out=dim_nodes, num_chi=i, num_chi_bins=num_chi_bins) + for i in [1, 2, 3] + ] + ) + + # Layers for chi angles + self.chi_mlp = nn.ModuleList( + [ + graph.MLP( + dim_in=dim_nodes, + dim_hidden=dim_hidden, + dim_out=num_chi_bins, + num_layers_hidden=2, + ) + for t in range(4) + ] + ) + + def _init_chi_bins(self, num_chi_bins): + # Setup bins + bins = torch.tensor( + np.linspace(-np.pi, np.pi, num_chi_bins + 1), dtype=torch.float32 + ).reshape([1, 1, 1, -1]) + self.register_buffer("bins_left", bins[:, :, :, 0:-1]) + self.register_buffer("bins_right", bins[:, :, :, 1:]) + return + + def _log_probs_t(self, t, S, chi, node_h, mask_i): + """Compute `log P(chi_t | chi_ 0: + chi_t = chi[:, :, :t] + if len(chi_t.shape) == 2: + chi_t = chi_t.unsqueeze(-1) + node_h = node_h + self.chi_embedding[t - 1](chi_t) + + chi_logits = mask_i_expand * self.chi_mlp[t](node_h) + log_probs_chi_t = mask_i_expand * F.log_softmax(chi_logits, -1) + return log_probs_chi_t + + def _sample_continuous(self, logits, left, right): + """Reparamaterization gradients via CDF inversion""" + base_shape = list(logits.shape)[:-1] + CMF = torch.cumsum(F.softmax(logits, dim=-1), dim=-1) + u = torch.rand(base_shape, device=logits.device) + _, max_idx = torch.max((u.unsqueeze(-1) < CMF).float(), dim=-1) + max_idx = max_idx.unsqueeze(-1) + + left = left.expand(base_shape + [-1]) + right = right.expand(base_shape + [-1]) + + # Gather panel bounds + CMF_pad = F.pad(CMF, ((1, 0))) + Y_left = torch.gather(left, -1, max_idx) + Y_right = torch.gather(right, -1, max_idx) + CMF_left = torch.gather(CMF_pad, -1, max_idx) + CMF_right = torch.gather(CMF_pad, -1, max_idx + 1) + + # Local CDF inversion + z = Y_left + (Y_right - Y_left) * (u.unsqueeze(-1) - CMF_left) / ( + CMF_right - CMF_left + 1e-5 + ) + z = z.squeeze(-1) + return z + + def forward( + self, + S: torch.LongTensor, + chi: torch.Tensor, + mask_chi: torch.Tensor, + node_h: torch.Tensor, + mask_i: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Evaluate chi angle joint likelihood given graph embeddings.""" + # Build the likelihood sequentially + log_probs_chi_list = [] + for t in range(4): + log_probs_chi_t = self._log_probs_t(t, S, chi, node_h, mask_i) + log_probs_chi_list.append(log_probs_chi_t) + log_probs_chi = torch.stack(log_probs_chi_list, -2) + + # Loss function + chi = chi.unsqueeze(-1) + chi_onehot = ((chi >= self.bins_left) * (chi < self.bins_right)).float() + if self.training: + scale = self.label_smoothing / (self.num_chi_bins - 1) + chi_onehot = ( + chi_onehot * (1 - self.label_smoothing) + (1 - chi_onehot) * scale + ) + logp_chi = mask_chi * (chi_onehot * log_probs_chi).sum(-1) + return logp_chi, log_probs_chi + + def sample( + self, + S: torch.LongTensor, + mask_chi: torch.Tensor, + node_h: torch.Tensor, + mask_i: torch.Tensor, + temperature: float = 1.0, + ) -> torch.Tensor: + """Sample chi angles given sequence and graph embeddings. + + Args: + S (torch.LongTensor): Sequence tensor with shape + `(num_batch, num_residues)`. + mask_chi (torch.Tensor): Chi angle mask with shape + `(num_batch, num_residues, 4)`. + node_h (torch.Tensor): Node features with shape + `(num_batch, num_residues, dim_nodes)`. + mask_i (torch.Tensor): Node mask with shape + `(num_batch, num_residues)`. + temperature (float): Temperature parameter for sampling sequence + tokens. The default value of 1.0 corresponds to the model's + unadjusted positions, though because of training such as label + smoothing values less than 1.0 are recommended. + + Returns: + chi_sample (torch.Tensor): Chi angles with shape + `(num_batch, num_residues, 4)`. + + """ + + # Sample chi angles sequentially + num_batch, num_residues, _ = node_h.shape + chi = torch.zeros( + [num_batch, num_residues, 4], dtype=torch.float32, device=node_h.device + ) + left = self.bins_left.reshape([1, 1, self.num_chi_bins]) + right = self.bins_right.reshape([1, 1, self.num_chi_bins]) + for t in range(4): + log_probs_chi_t = self._log_probs_t(t, S, chi, node_h, mask_i) + chi_t = self._sample_continuous(log_probs_chi_t / temperature, left, right) + chi = chi + F.pad(chi_t.unsqueeze(-1), (t, 3 - t)) + return mask_chi * chi + + +class ProteinTraversalSpatial(nn.Module): + """Samples spatial correlated residue permutations in a protein. + + Args: + smooth_alpha (float): Smoothing parameter for graph smoothing where + 0 corresponds to no smoothing and 1 corresponds to maximal + smoothing. Default is 1. + smooth_steps (int): Number of graph smoothing steps, which must be + nonnegative. More steps will increase the amount of smoothing. + Default is 5. + smooth_randomize (bool): Enables uniform randomization of + `smooth_alpha` on the interval `(0, smooth_alpha)`. Default is + True. + graph_num_neighbors (int): Number of neighbors for graph + construction. Default is 30. + deterministic (bool): Whether to force determinism. Default is + False. + + Inputs: + X (torch.Tensor): All atom coordinates with shape + `(num_batch, num_residues, 14, 3)`. + C (torch.LongTensor): Chain map with shape + `(num_batch, num_residues)`. + priority (torch.Tensor, optional): Priority values for constraining + residue orderings with shape `(num_batch, num_residues)`. + If residues are assigned to integer-valued groups, the sampled + permutation will be ordered such that all residues within a + lower-valued priority group will occur before residues with + higher-valued priority assignments. + + Outputs: + permute_idx (LongTensor): Permutation tensor containing reordered + residue indices with shape `(num_batch, num_residues)`. + """ + + def __init__( + self, + smooth_alpha: float = 1.0, + smooth_steps: int = 5, + smooth_randomize: bool = True, + graph_num_neighbors: int = 30, + deterministic: bool = False, + ) -> None: + super(ProteinTraversalSpatial, self).__init__() + + self.smooth_alpha = smooth_alpha + self.smooth_steps = smooth_steps + self.smooth_randomize = smooth_randomize + self.deterministic = deterministic + self._determistic_seed = 10 + + self.norm_eps = 1e-5 + self.protein_graph = protein_graph.ProteinGraph( + num_neighbors=graph_num_neighbors + ) + + @validate_XC() + def forward( + self, + X: torch.Tensor, + C: torch.LongTensor, + priority: Optional[torch.Tensor] = None, + ): + # Sample random node keys + if not self.deterministic: + z = torch.rand_like(C.float()) + else: + with torch.random.fork_rng(): + torch.random.manual_seed(self._determistic_seed) + z = torch.rand((1, C.shape[1]), device=C.device).expand(C.shape) + + # Graph-based smoothing + alpha = self.smooth_alpha + if self.smooth_randomize and not self.deterministic: + alpha = torch.rand((), device=X.device) + + if alpha > 0: + edge_idx, mask_ij = self.protein_graph(X, C) + for i in range(self.smooth_steps): + z_neighbors = graph.collect_neighbors( + z.unsqueeze(-1), edge_idx + ).squeeze(-1) + z_average = (mask_ij * z_neighbors).sum(2) / ( + mask_ij.sum(2) + self.norm_eps + ) + z = alpha * z_average + (1.0 - alpha) * z + + if priority is not None: + z = z + priority + + # Create permutation + permute_idx = torch.argsort(z, dim=-1) + return permute_idx + + +def load_model( + weight_file: str, + device: str = "cpu", + strict: bool = False, + strict_unexpected: bool = True, + verbose: bool = True, +) -> GraphDesign: + """Load model `GraphDesign` + + Args: + weight_file (str): The destination path of the model weights to load. + Compatible with files saved by `save_model`. + device (str, optional): Pytorch device specification, e.g. `'cuda'` for + GPU. Default is `'cpu'`. + strict (bool): Whether to require that the keys match between the + input file weights and the model created from the parameters stored + in the model kwargs. + strict_unexpected (bool): Whether to require that there are no + unexpected keys when loading model weights, as distinct from the + strict option which doesn't allow for missing keys either. By + default, we use this option rather than strict for ease of + development when adding model features. + + Returns: + model (GraphDesign): Instance of `GraphDesign` with loaded weights. + """ + return utility_load_model( + weight_file, + GraphDesign, + device=device, + strict=strict, + strict_unexpected=strict_unexpected, + verbose=verbose, + ) diff --git a/prtm/models/chroma/graph_energy.py b/prtm/models/chroma/graph_energy.py new file mode 100644 index 00000000..d750b96c --- /dev/null +++ b/prtm/models/chroma/graph_energy.py @@ -0,0 +1,111 @@ +# Copyright Generate Biomedicines, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Models for building energy functions for protein sequence and structure. + +This module contains pytorch models for building energy functions that score +protein sequence and structure and that can be used for partial and full +protein de novo design. +""" + + +import torch.nn as nn + +from prtm.models.chroma import graph + + +class GraphHarmonicFeatures(nn.Module): + """Layer for quadratic node and edge features. + + Args: + dim_nodes (int): Hidden dimension of node tensor. + dim_edges (int): Hidden dimension of edge tensor. + node_mlp_layers (int): Number of hidden layers for node update. + node_mlp_dim (int): Node update function, hidden dimension. + edge_mlp_layers (int): Edge update function, number of hidden layers. + edge_mlp_dim (int): Edge update function, hidden dimension. + mlp_activation (str): MLP nonlinearity. + `'relu'`: Rectified linear unit. + `'softplus'`: Softplus. + + Inputs: + node_h (Tensor): Node embeddings with shape + `(num_batch, num_nodes, dim_nodes)`. + node_feature (Tensor): Node features with shape + `(num_batch, num_nodes, dim_nodes)`. + edge_h (Tensor): Edge embeddings with shape + `(num_batch, num_nodes, num_neighbors, dim_edges)`. + edge_feature (Tensor): Edge features with shape + `(num_batch, num_nodes, num_neighbors, dim_edges)`. + edge_idx (LongTensor): Edge indices for neighbors with shape + `(num_batch, num_nodes, num_neighbors)`. + mask_i (tensor, optional): Node mask with shape `(num_batch, num_nodes)` + mask_ij (tensor, optional): Edge mask with shape + `(num_batch, num_nodes, num_neighbors)`. + + Outputs: + node_h (Tensor): Updated node embeddings with shape + `(num_batch, num_nodes, dim_nodes)`. + edge_h (Tensor): Updated edge embeddings with shape + `(num_batch, num_nodes, num_neighbors, dim_edges)`. + """ + + def __init__( + self, + dim_nodes, + dim_edges, + node_mlp_layers, + node_mlp_dim, + edge_mlp_layers, + edge_mlp_dim, + mlp_activation="softplus", + dropout=0.0, + ): + super(GraphHarmonicFeatures, self).__init__() + self.dim_nodes = dim_nodes + self.dim_edges = dim_edges + self.node_mlp = graph.MLP( + dim_in=dim_nodes, + dim_out=2 * dim_nodes, + num_layers_hidden=node_mlp_layers, + dim_hidden=node_mlp_dim, + activation=mlp_activation, + dropout=dropout, + ) + self.edge_mlp = graph.MLP( + dim_in=dim_edges, + dim_out=2 * dim_edges, + num_layers_hidden=edge_mlp_layers, + dim_hidden=edge_mlp_dim, + activation=mlp_activation, + dropout=dropout, + ) + self.node_out = nn.Linear(dim_nodes, dim_nodes) + self.edge_out = nn.Linear(dim_edges, dim_edges) + + def forward(self, node_h, node_feature, edge_h, edge_feature, mask_i, mask_ij): + node_h_pred = self.node_mlp(node_h) + node_mu = node_h_pred[:, :, : self.dim_nodes] + node_coeff = node_h_pred[:, :, self.dim_nodes :] + node_errors = node_coeff * (node_feature - node_mu) ** 2 + node_h = node_h + self.node_out(node_errors) + node_h = mask_i.unsqueeze(-1) * node_h + + edge_h_pred = self.edge_mlp(edge_h) + edge_mu = edge_h_pred[:, :, :, : self.dim_edges] + edge_coeff = edge_h_pred[:, :, :, self.dim_edges :] + edge_errors = edge_coeff * (edge_feature - edge_mu) ** 2 + edge_h = edge_h + self.edge_out(edge_errors) + edge_h = mask_ij.unsqueeze(-1) * edge_h + return node_h, edge_h diff --git a/prtm/models/chroma/linalg.py b/prtm/models/chroma/linalg.py new file mode 100644 index 00000000..638f932d --- /dev/null +++ b/prtm/models/chroma/linalg.py @@ -0,0 +1,97 @@ +# Copyright Generate Biomedicines, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Layers for linear algebra. + +This module contains additional pytorch layers for linear algebra operations, +such as a more parallelization-friendly implementation of eigvenalue estimation. +""" + +import torch + + +def eig_power_iteration(A, num_iterations=50, eps=1e-5): + """Estimate largest magnitude eigenvalue and associated eigenvector. + + This uses a simple power iteration algorithm to estimate leading + eigenvalues, which can often be considerably faster than torch's built-in + eigenvalue routines. All steps are differentiable and small constants are + added to any division to preserve the stability of the gradients. For more + information on power iteration, see + https://en.wikipedia.org/wiki/Power_iteration. + + Args: + A (tensor): Batch of square matrices with shape + `(..., num_dims, num_dims)`. + num_iterations (int, optional): Number of iterations for power + iteration. Default: 50. + eps (float, optional): Small number to prevent division by zero. + Default: 1E-5. + + Returns: + lam (tensor): Batch of estimated highest-magnitude eigenvalues with + shape `(...)`. + v (tensor): Associated eigvector with shape `(..., num_dims)`. + """ + _safe = lambda x: x + eps + + dims = list(A.size())[:-1] + v = torch.randn(dims, device=A.device).unsqueeze(-1) + for i in range(num_iterations): + v_prev = v + Av = torch.matmul(A, v) + v = Av / _safe(Av.norm(p=2, dim=-2, keepdim=True)) + + # Compute eigenvalue + v_prev = v_prev.transpose(-1, -2) + lam = torch.matmul(v_prev, Av) / _safe(torch.abs(torch.matmul(v_prev, v))) + + # Reshape + v = v.squeeze(-1) + lam = lam.view(list(lam.size())[:-2]) + return lam, v + + +def eig_leading(A, num_iterations=50): + """Estimate largest positive eigenvalue and associated eigenvector. + + This estimates the *most positive* eigenvalue of each matrix in a batch of + matrices by using two consecutive power iterations with spectral shifting. + + Args: + A (tensor): Batch of square matrices with shape + `(..., num_dims, num_dims)`. + num_iterations (int, optional): Number of iterations for power + iteration. Default: 50. + + Returns: + lam (tensor): Estimated most positive eigenvalue with shape `(...)`. + v (tensor): Associated eigenvectors with shape `(..., num_dims)`. + """ + batch_dims = list(A.size())[:-2] + + # First pass gets largest magnitude + lam_1, vec_1 = eig_power_iteration(A, num_iterations) + + # Second pass guaranteed to grab most positive eigenvalue + lam_1_abs = torch.abs(lam_1) + lam_I = lam_1_abs.reshape(batch_dims + [1, 1]) * torch.eye(4, device=A.device).view( + [1 for _ in batch_dims] + [4, 4] + ) + A_shift = A + lam_I + lam_2, vec = eig_power_iteration(A_shift, num_iterations) + + # Shift back to original specta + lam = lam_2 - lam_1_abs + return lam, vec diff --git a/prtm/models/chroma/model_utils.py b/prtm/models/chroma/model_utils.py new file mode 100644 index 00000000..39550c6f --- /dev/null +++ b/prtm/models/chroma/model_utils.py @@ -0,0 +1,113 @@ +# Copyright Generate Biomedicines, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Utilities to save and load models with metadata. +""" + +from pathlib import Path +from urllib.parse import parse_qs, urlparse +from uuid import uuid4 + +import torch + +import prtm.models.chroma.api as api +from prtm.models.chroma.named_models import NAMED_MODELS + + +def save_model(model, weight_file, metadata=None): + """Save model, including optional metadata. + + Args: + model (nn.Module): The model to save. Details about the model needed + for initialization, such as layer sizes, should be in model.kwargs. + weight_file (str): The destination path for saving model weights. + metadata (dict): A dictionary of additional metadata to add to the model + weights. For example, when saving models during training it can be + useful to store `args` representing the CLI args, the date and time + of training, etc. + """ + save_dict = {"init_kwargs": model.kwargs, "model_state_dict": model.state_dict()} + if metadata is not None: + save_dict.update(metadata) + local_path = str( + Path("/tmp", str(uuid4())[:8]) if weight_file.startswith("s3:") else weight_file + ) + torch.save(save_dict, local_path) + if weight_file.startswith("s3:"): + raise NotImplementedError("Uploading to an s3 link not supported.") + + +def load_model( + weights, + model_class, + device="cpu", + strict=False, + strict_unexpected=True, + verbose=True, +): + """Load model saved with save_model. + + Args: + weights (str): The destination path of the model weights to load. + Compatible with files saved by `save_model`. + model_class: Name of model class. + device (str, optional): Pytorch device specification, e.g. `'cuda'` for + GPU. Default is `'cpu'`. + strict (bool): Whether to require that the keys match between the + input file weights and the model created from the parameters stored + in the model kwargs. + strict_unexpected (bool): Whether to require that there are no + unexpected keys when loading model weights, as distinct from the + strict option which doesn't allow for missing keys either. By + default, we use this option rather than strict for ease of + development when adding model features. + verbose (bool, optional): Show outputs from download and loading. Default True. + + Returns: + model (nn.Module): Torch model with loaded weights. + """ + + # Process weights path + if str(weights).startswith("named:"): + weights = weights.split("named:")[1] + if weights not in NAMED_MODELS[model_class.__name__]: + raise Exception(f"Unknown {model_class.__name__} model name: {weights},") + weights = NAMED_MODELS[model_class.__name__][weights]["s3_uri"] + + # resolve s3 paths + if str(weights).startswith("s3:"): + raise NotImplementedError("Loading Models from an S3 link not supported.") + + # download public models from generate + if str(weights).startswith("https:"): + # Decompose into arguments + parsed_url = urlparse(weights) + base_url = f"{parsed_url.scheme}://{parsed_url.netloc}{parsed_url.path}" + model_name = parse_qs(parsed_url.query).get("weights", [None])[0] + weights = api.download_from_generate( + base_url, model_name, force=False, exist_ok=True + ) + + # load model weights + params = torch.load(weights, map_location="cpu") + model = model_class(**params["init_kwargs"]).to(device) + missing_keys, unexpected_keys = model.load_state_dict( + params["model_state_dict"], strict=strict + ) + if strict_unexpected and len(unexpected_keys) > 0: + raise Exception( + f"Error loading model from checkpoint file: {weights} contains {len(unexpected_keys)} unexpected keys: {unexpected_keys}" + ) + return model diff --git a/prtm/models/chroma/modeling.py b/prtm/models/chroma/modeling.py new file mode 100644 index 00000000..343844b1 --- /dev/null +++ b/prtm/models/chroma/modeling.py @@ -0,0 +1,242 @@ +from typing import Optional, Tuple + +import numpy as np +import torch + +from prtm.constants import residue_constants +from prtm.models.chroma.system import System +from prtm.protein import ProteinBase, Protein14, Protein4 + + +def dihedral( + a: np.ndarray, + b: np.ndarray, + c: np.ndarray, + d: np.ndarray, + radians: bool = False, +): + """Computes the dihedral angle formed by four 3D points represented by AtomLocationView objects. + + Args: + a1, a2, a3, a4 (AtomLocationView): four 3D points. + radian (bool, optional): if True (default False), will return the angle in radians. + Otherwise, in degrees. + + Returns: + Dihedral angle `a1`-`a2`-`a3`-`a4`. + """ + AB = a - b + CB = c - b + DC = d - c + + if min([np.linalg.norm(p) for p in [AB, CB, DC]]) == 0.0: + raise Exception("some points coincide in dihedral calculation") + + ABxCB = np.cross(AB, CB) + ABxCB = ABxCB / np.linalg.norm(ABxCB) + DCxCB = np.cross(DC, CB) + DCxCB = DCxCB / np.linalg.norm(DCxCB) + + # the following is necessary for values very close to 1 but just above + dotp = np.dot(ABxCB, DCxCB) + if dotp > 1.0: + dotp = 1.0 + elif dotp < -1.0: + dotp = -1.0 + + angle = np.arccos(dotp) + if np.dot(ABxCB, DC) > 0: + angle *= -1 + if not radians: + angle *= 180.0 / np.pi + + return angle + + +def canonicalize_structure(prot: ProteinBase) -> ProteinBase: + """ + Returns the canonical chroma structure for the given protein. + """ + # Make sure the structure is in 37 atom format + prot = prot.to_protein37() + prot = prot.to_numpy() + + # Get the index for arginine + arg_res_idx = residue_constants.restype_3.index("ARG") + arg_indices = np.where(prot.aatype == arg_res_idx)[0] + + # Check if sidechain atoms are present in the structure + arg_sidechain_atom_indices = np.array( + [ + residue_constants.atom_types.index(atom) + for atom in ["CD", "NE", "CZ", "NH1", "NH2"] + ] + ) + has_sidechain_atoms = np.where( + prot.atom_mask[:, arg_sidechain_atom_indices].sum(axis=1) == 5 + )[0] + + arg_indices_with_sidechains = np.intersect1d(arg_indices, has_sidechain_atoms) + for index in arg_indices_with_sidechains: + dihe1 = dihedral( + prot.atom_positions[index, arg_sidechain_atom_indices[0]], + prot.atom_positions[index, arg_sidechain_atom_indices[1]], + prot.atom_positions[index, arg_sidechain_atom_indices[2]], + prot.atom_positions[index, arg_sidechain_atom_indices[3]], + ) + dihe2 = dihedral( + prot.atom_positions[index, arg_sidechain_atom_indices[0]], + prot.atom_positions[index, arg_sidechain_atom_indices[1]], + prot.atom_positions[index, arg_sidechain_atom_indices[2]], + prot.atom_positions[index, arg_sidechain_atom_indices[4]], + ) + if abs(dihe1) > abs(dihe2): + # Swap NH1 and NH2 positions + nh1_pos = prot.atom_positions[index, arg_sidechain_atom_indices[3]] + nh2_pos = prot.atom_positions[index, arg_sidechain_atom_indices[4]] + prot.atom_positions[index, arg_sidechain_atom_indices[3]] = nh2_pos + prot.atom_positions[index, arg_sidechain_atom_indices[4]] = nh1_pos + + # TODO: Convert back to the original prot atom type? + return prot + + +def protein_to_xcs( + prot: ProteinBase, + all_atom: bool = False, + device: Optional[torch.device] = None, +) -> Tuple[torch.tensor, torch.tensor, torch.tensor]: + """Convert System object to XCS format. + + `C` tensor has shape [num_residues], where it codes positions as 0 + when masked, positive integers for chain indices, and negative integers + to represent missing residues of the corresponding positive integers. + + `S` tensor has shape [num_residues], it will map residue amino acid to alphabet integers. + If it is not found in `alphabet`, it will default to `unknown_token`. Set `mask_unknown` to true if + also want to mask `unk residue` in `chain_map` + + This function takes into account missing residues and updates chain_map + accordingly. + + Args: + system (type): generate System object to convert. + all_atom (bool): Include side chain atoms. Default is `False`. + batch_dimension (bool): Include a batch dimension. Default is `True`. + mask_unknown (bool): Mask residues not found in the alphabet. Default is + `True`. + unknown_token (int): Default token index if a residue is not found in + the alphabet. Default is `0`. + reorder_chain (bool): If set to true will start indexing chain at 1, + else will use the alphabet index (Default: True) + altenate_alphabet (str): Alternative alphabet if not `None`. + alternate_atoms (list): Alternate atom name subset for `X` if not `None`. + get_indices (bool): Also return the location indices corresponding to the + returned `X` tensor. + + Returns: + X (torch.Tensor): Coordinates with shape `(1, num_residues, num_atoms, 3)`. + `num_atoms` will be 14 if `all_atom=True` or 4 otherwise. + C (torch.LongTensor): Chain map with shape `(1, num_residues)`. It codes + positions as 0 when masked, positive integers for chain indices, + and negative integers to represent missing residues of the + corresponding positive integers. + S (torch.LongTensor): Sequence with shape `(1, num_residues)`. + location_indices (np.ndaray, optional): location indices corresponding to + the coordinates in `X`. + + """ + # Either all_atom (backbone plus sidechains is 14 atoms) + # or the backbone (4 atoms) + if all_atom: + prot = prot.to_protein14() + else: + prot = prot.to_protein4() + + prot = prot.to_numpy() + + # We want to reorder the chain_ids to start at 1 and be sequential + chain_reordering = { + chain_id: i for i, chain_id in enumerate(np.unique(prot.chain_index), 1) + } + C = np.vectorize(chain_reordering.get)(prot.chain_index) + X = prot.atom_positions + S = prot.aatype + + # Unknown residues are mapped to 0 for Chroma + # but in prtm parsing they will be marked as X, + # so make the conversion + masked_res_idx = residue_constants.restypes_with_x.index("X") + S[S == masked_res_idx] = 0 + + # Map from the default protein restypes to the alphabetical restypes + chroma_alphabet = residue_constants.alphabetical_restypes + restype_to_alphabetical = { + res_idx: chroma_alphabet.index(res) + for res_idx, res in enumerate(residue_constants.restypes) + } + # Map the restypes to the alphabetical restypes + S = np.vectorize(restype_to_alphabetical.get)(S) + + # Tensor everything and add batch dimension + if device is None: + device = torch.device("cpu") + + X = torch.tensor(X, device=device).float()[None] + C = torch.tensor(C, device=device).type(torch.long)[None] + S = torch.tensor(S, device=device).type(torch.long)[None] + + return X, C, S + + +def protein_from_xcs( + X: torch.Tensor, + C: torch.Tensor, + S: torch.Tensor, +) -> ProteinBase: + """Convert XCS format to System object. + + Args: + X (torch.Tensor): Coordinates with shape `(1, num_residues, num_atoms, 3)`. + `num_atoms` will be 14 if `all_atom=True` or 4 otherwise. + C (torch.LongTensor): Chain map with shape `(1, num_residues)`. It codes + positions as 0 when masked, positive integers for chain indices, + and negative integers to represent missing residues of the + corresponding positive integers. + S (torch.LongTensor): Sequence with shape `(1, num_residues)`. + + Returns: + system (System): System object. + """ + + # Convert to numpy + X = X.detach().cpu().numpy()[0] + C = C.detach().cpu().numpy()[0] + S = S.detach().cpu().numpy()[0] + return X, C, S + + +def get_mask( + protein: ProteinBase, + selection: str, + device: Optional[torch.device] = None, +) -> torch.Tensor: + """ + Generate a mask tensor based on the provided residue selection. + + Args: + protein_system (System): A Chroma System object. + selection (str): A selection string to specify which residues should be included in the mask. + + Returns: + torch.Tensor: A mask tensor of shape `(1, protein length)`, where positions corresponding to selected residues have a value of 1. + """ + canon_protein = canonicalize_structure(protein) + protein_system = System.from_XCS(*protein_to_xcs(canon_protein, all_atom=True)) + + residue_gtis = protein_system.select_residues(selection, gti=True) + D = torch.zeros(1, protein_system.num_residues(), device=device) + for gti in residue_gtis: + D[0, gti] = 1 + + return D diff --git a/prtm/models/chroma/named_models.py b/prtm/models/chroma/named_models.py new file mode 100644 index 00000000..e99f7dd5 --- /dev/null +++ b/prtm/models/chroma/named_models.py @@ -0,0 +1,54 @@ +# Copyright Generate Biomedicines, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" Paths for named models in the zoo """ + +GRAPH_BACKBONE_MODELS = { + "public": { + "s3_uri": "https://chroma-weights.generatebiomedicines.com/downloads?weights=chroma_backbone_v1.0.pt", + "data": "Generate Structure ETL: July 25 2022", + "task": "BLNL backbone model training with EMA, trained July 2023", + }, +} + +GRAPH_CLASSIFIER_MODELS = { + "public": { + "s3_uri": "https://chroma-weights.generatebiomedicines.com/downloads?weights=chroma_proclass_v1.0.pt", + "data": "Generate Structure ETL: June 2022", + "task": "Backbone classification model training with cross-entropy loss", + }, +} + +GRAPH_DESIGN_MODELS = { + "public": { + "s3_uri": "https://chroma-weights.generatebiomedicines.com/downloads?weights=chroma_design_v1.0.pt", + "data": "Generate Structure ETL: July 25 2022", + "task": "Autoregressive joint prediction of sequence and chi angles, two-stage", + }, +} + +PROCAP_MODELS = { + "public": { + "s3_uri": "https://chroma-weights.generatebiomedicines.com/downloads?weights=chroma_procap_v1.0.pt", + "data": "Generate Structure ETL: June 2022", + "task": "Backbone caption model training with cross-entropy loss, using M5 ProClass GNN embeddings", + }, +} + +NAMED_MODELS = { + "GraphBackbone": GRAPH_BACKBONE_MODELS, + "GraphDesign": GRAPH_DESIGN_MODELS, + "GraphClassifier": GRAPH_CLASSIFIER_MODELS, + "ProteinCaption": PROCAP_MODELS, +} diff --git a/prtm/models/chroma/norm.py b/prtm/models/chroma/norm.py new file mode 100644 index 00000000..ff706e2e --- /dev/null +++ b/prtm/models/chroma/norm.py @@ -0,0 +1,297 @@ +# Copyright Generate Biomedicines, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn + + +class MaskedBatchNorm1d(nn.Module): + """A masked version of nn.BatchNorm1d. Only tested for 3D inputs. + + Args: + num_features (int): :math:`C` from an expected input of size + :math:`(N, C, L)` + eps (float): a value added to the denominator for numerical stability. + Default: 1e-5 + momentum (float): the value used for the running_mean and running_var + computation. Can be set to ``None`` for cumulative moving average + (i.e. simple average). Default: 0.1 + affine (bool): a boolean value that when set to ``True``, this module has + learnable affine parameters. Default: ``True`` + track_running_stats (bool) : a boolean value that when set to ``True``, this + module tracks the running mean and variance, and when set to ``False``, + this module does not track such statistics and always uses batch + statistics in both training and eval modes. Default: ``True`` + + Inputs: + x (torch.tensor): of size (batch_size, num_features, sequence_length) + input_mask (torch.tensor or None) : (optional) of dtype (byte) or (bool) of shape (batch_size, 1, sequence_length) zeroes (or False) indicate positions that cannot contribute to computation + Outputs: + y (torch.tensor): of size (batch_size, num_features, sequence_length) + """ + + def __init__( + self, + num_features, + eps=1e-5, + momentum=0.1, + affine=True, + track_running_stats=True, + ): + super(MaskedBatchNorm1d, self).__init__() + self.num_features = num_features + self.eps = eps + self.momentum = momentum + self.affine = affine + if affine: + self.weight = nn.Parameter(torch.Tensor(num_features, 1)) + self.bias = nn.Parameter(torch.Tensor(num_features, 1)) + else: + self.register_parameter("weight", None) + self.register_parameter("bias", None) + self.track_running_stats = track_running_stats + if self.track_running_stats: + self.register_buffer("running_mean", torch.zeros(num_features, 1)) + self.register_buffer("running_var", torch.ones(num_features, 1)) + self.register_buffer( + "num_batches_tracked", torch.tensor(0, dtype=torch.long) + ) + else: + self.register_parameter("running_mean", None) + self.register_parameter("running_var", None) + self.register_parameter("num_batches_tracked", None) + self.reset_parameters() + + def reset_running_stats(self): + if self.track_running_stats: + self.running_mean.zero_() + self.running_var.fill_(1) + self.num_batches_tracked.zero_() + + def reset_parameters(self): + self.reset_running_stats() + if self.affine: + nn.init.ones_(self.weight) + nn.init.zeros_(self.bias) + + def forward(self, input, input_mask=None): + # Calculate the masked mean and variance + B, C, L = input.shape + if input_mask is not None and input_mask.shape != (B, 1, L): + raise ValueError("Mask should have shape (B, 1, L).") + if C != self.num_features: + raise ValueError( + "Expected %d channels but input has %d channels" + % (self.num_features, C) + ) + if input_mask is not None: + masked = input * input_mask + n = input_mask.sum() + else: + masked = input + n = B * L + # Sum + masked_sum = masked.sum(dim=0, keepdim=True).sum(dim=2, keepdim=True) + # Divide by sum of mask + current_mean = masked_sum / n + current_var = ((masked - current_mean) ** 2).sum(dim=0, keepdim=True).sum( + dim=2, keepdim=True + ) / n + # Update running stats + if self.track_running_stats and self.training: + if self.num_batches_tracked == 0: + self.running_mean = current_mean + self.running_var = current_var + else: + self.running_mean = ( + 1 - self.momentum + ) * self.running_mean + self.momentum * current_mean + self.running_var = ( + 1 - self.momentum + ) * self.running_var + self.momentum * current_var + self.num_batches_tracked += 1 + # Norm the input + if self.track_running_stats and not self.training: + normed = (masked - self.running_mean) / ( + torch.sqrt(self.running_var + self.eps) + ) + else: + normed = (masked - current_mean) / (torch.sqrt(current_var + self.eps)) + # Apply affine parameters + if self.affine: + normed = normed * self.weight + self.bias + return normed + + +class MaskedBatchNorm2d(nn.Module): + """A masked version of nn.BatchNorm1d. Only tested for 3D inputs. + + Args: + num_features (int): :math:`C` from an expected input of size + :math:`(N, C, L)` + eps (float): a value added to the denominator for numerical stability. + Default: 1e-5 + momentum (float): the value used for the running_mean and running_var + computation. Can be set to ``None`` for cumulative moving average + (i.e. simple average). Default: 0.1 + affine (bool): a boolean value that when set to ``True``, this module has + learnable affine parameters. Default: ``True`` + track_running_stats (bool) : a boolean value that when set to ``True``, this + module tracks the running mean and variance, and when set to ``False``, + this module does not track such statistics and always uses batch + statistics in both training and eval modes. Default: ``True`` + + Inputs: + x (torch.tensor): of size (batch_size, num_features, sequence_length) + input_mask (torch.tensor or None) : (optional) of dtype (byte) or (bool) of shape (batch_size, 1, sequence_length) zeroes (or False) indicate positions that cannot contribute to computation + Outputs: + y (torch.tensor): of size (batch_size, num_features, sequence_length) + """ + + def __init__( + self, + num_features, + eps=1e-5, + momentum=0.1, + affine=True, + track_running_stats=True, + ): + super().__init__() + self.num_features = num_features + self.eps = eps + self.momentum = momentum + self.affine = affine + if affine: + self.weight = nn.Parameter( + torch.ones( + num_features, + ) + ) + self.bias = nn.Parameter( + torch.zeros( + num_features, + ) + ) + else: + self.register_parameter("weight", None) + self.register_parameter("bias", None) + self.track_running_stats = track_running_stats + if self.track_running_stats: + self.register_buffer("running_mean", torch.zeros(1, 1, 1, num_features)) + self.register_buffer("running_var", torch.ones(1, 1, 1, num_features)) + self.register_buffer( + "num_batches_tracked", torch.tensor(0, dtype=torch.long) + ) + else: + self.register_parameter("running_mean", None) + self.register_parameter("running_var", None) + self.register_parameter("num_batches_tracked", None) + self.reset_parameters() + + def reset_running_stats(self): + if self.track_running_stats: + self.running_mean.zero_() + self.running_var.fill_(1) + self.num_batches_tracked.zero_() + + def reset_parameters(self): + self.reset_running_stats() + if self.affine: + nn.init.ones_(self.weight) + nn.init.zeros_(self.bias) + + def forward(self, input, mask=None): + # Calculate the masked mean and variance + B, L, L, C = input.size() + if mask is not None: + if mask.dim() != 4: + raise ValueError( + f"Input mask must have four dimensions, but has {mask.dim()}" + ) + b, l, l, d = mask.size() + if (b != B) or (l != L): + raise ValueError( + f"Input mask must have shape {(B, L, L, 1)} or {(B, L, L, C)} to match input." + ) + if d == 1: + mask = mask.expand(input.size()) + + if C != self.num_features: + raise ValueError( + "Expected %d channels but input has %d channels" + % (self.num_features, C) + ) + + if mask is None: + mask = input.new_ones(input.size()) + + masked = input * mask + n = mask.sum(dim=(0, 1, 2), keepdim=True) + masked_sum = (masked).sum(dim=(0, 1, 2), keepdim=True) + + current_mean = masked_sum / n + current_var = (mask * (masked - current_mean).pow(2)).sum( + dim=(0, 1, 2), keepdim=True + ) / n + # Update running stats + with torch.no_grad(): + if self.track_running_stats and self.training: + if self.num_batches_tracked == 0: + self.running_mean = current_mean.detach() + self.running_var = current_var.detach() + else: + self.running_mean = ( + 1 - self.momentum + ) * self.running_mean + self.momentum * current_mean.detach() + self.running_var = ( + 1 - self.momentum + ) * self.running_var + self.momentum * current_var.detach() + self.num_batches_tracked += 1 + # Norm the input + if self.track_running_stats and not self.training: + normed = (masked - self.running_mean) / ( + torch.sqrt(self.running_var + self.eps) + ) + else: + normed = (masked - current_mean) / (torch.sqrt(current_var + self.eps)) + # Apply affine parameters + if self.affine: + normed = normed * self.weight + self.bias + + normed = normed * mask + return normed + + +class NormedReductionLayer(nn.Module): + """A ReductionLayer with LayerNorms after the hidden layers.""" + + def __init__(self, input_dim, hidden_dim, output_dim, dropout=0.0): + super().__init__() + self.d1 = nn.Dropout(p=dropout) + self.d2 = nn.Dropout(p=dropout) + self.hidden = nn.Linear(input_dim, hidden_dim) + self.relu = nn.ReLU() + self.output = nn.Linear(hidden_dim, output_dim) + self.norm1 = nn.LayerNorm(input_dim) + self.norm2 = nn.LayerNorm(hidden_dim) + + def reduce(self, x, mask): + masked_x = x * mask + mean_x = masked_x.sum(dim=1) / torch.sum(mask, dim=1) + return mean_x + + def forward(self, x, mask): + reduced_x = self.norm1(self.reduce(x, mask)) + h = self.norm2(self.hidden(reduced_x)) + return self.output(self.relu(h)) diff --git a/prtm/models/chroma/polyseq.py b/prtm/models/chroma/polyseq.py new file mode 100644 index 00000000..be8fe706 --- /dev/null +++ b/prtm/models/chroma/polyseq.py @@ -0,0 +1,277 @@ +# Copyright Generate Biomedicines, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Standard residue names for polymers of different types (e.g., L- or D-amino acid proteins, +mixed-chirality proteins, DNA/RNA, etc.) +""" + +from enum import Enum + + +class polymerType(Enum): + LPROT = 0 + DPROT = 1 + LDPROT = 2 + DNA = 3 + RNA = 4 + + +def polymer_type_name(ptype: polymerType): + if ptype == polymerType.LPROT: + return "polypeptide(L)" + elif ptype == polymerType.DPROT: + return "polypeptide(D)" + elif ptype == polymerType.LDPROT: + return "polypeptide(L,D)" + elif ptype == polymerType.DNA: + return "polydeoxyribonucleotide" + elif ptype == polymerType.RNA: + return "polyribonucleotide" + else: + raise Exception(f"unknown polymer type {ptype}") + + +_res3 = [[] for _ in range(len(polymerType))] + +_res1 = [[] for _ in range(len(polymerType))] + +_res_to_idx = [dict() for _ in range(len(polymerType))] + +_unk_idx = [set() for _ in range(len(polymerType))] + +_gap_idx = [set() for _ in range(len(polymerType))] + +_stp_idx = [set() for _ in range(len(polymerType))] + + +def _add_residue(ptype: polymerType, res3, res1): + if isinstance(ptype, list): + for pt, r3, r1 in zip(ptype, res3, res1): + _add_residue(pt, r3, r1) + else: + _res_to_idx[ptype.value][res3] = len(_res3[ptype.value]) + # single-letter code is ambiguous, so take the first residue when going from single-letter code to index + if res1 not in _res_to_idx[ptype.value]: + _res_to_idx[ptype.value][res1] = _res_to_idx[ptype.value][res3] + _res3[ptype.value].append(res3) + _res1[ptype.value].append(res1) + if res3 == "---": + _gap_idx[ptype.value].add(_res_to_idx[ptype.value][res3]) + elif res3 == "UNK": + _unk_idx[ptype.value].add(_res_to_idx[ptype.value][res3]) + elif res3 == "STP": + _stp_idx[ptype.value].add(_res_to_idx[ptype.value][res3]) + + +def num_tokens(ptype=polymerType.LPROT): + return len(_res3[ptype.value]) + + +def num_known_molecular_tokens(ptype=polymerType.LPROT): + return sum( + [ + not is_punctuation_index(idx) and not is_unknown(idx) + for idx in range(len(_res3[ptype.value])) + ] + ) + + +def res_to_index(res: str, ptype=polymerType.LPROT): + return _res_to_idx[ptype.value].get(res, next(iter(_unk_idx[ptype.value]))) + + +def index_to_single(idx: int, ptype=polymerType.LPROT): + return _res1[ptype.value][idx] + + +def index_to_triple(idx: int, ptype=polymerType.LPROT): + return _res3[ptype.value][idx] + + +def to_single(res: str, ptype=polymerType.LPROT): + return index_to_single(res_to_index(res, ptype)) + + +def to_triple(res: str, ptype=polymerType.LPROT): + return index_to_triple(res_to_index(res, ptype)) + + +def is_gap_index(idx: int, ptype=polymerType.LPROT): + return idx in _gap_idx[ptype.value] + + +def is_stop_index(idx: int, ptype=polymerType.LPROT): + return idx in _stp_idx[ptype.value] + + +def is_unknown(res: str, ptype=polymerType.LPROT): + return is_unknown_index(res_to_index(res, ptype), ptype) + + +def is_unknown_index(idx: int, ptype=polymerType.LPROT): + return idx in _unk_idx[ptype.value] + + +def is_polymer_residue(res: str, ptype: polymerType): + if ptype is None: + # determine if this is a polymer residue for any known polymer + for ptype in polymerType: + if res in _res_to_idx[ptype.value]: + return True + return False + return res in _res_to_idx[ptype.value] + + +def is_punctuation_index(idx: int, ptype=polymerType.LPROT): + return is_gap_index(idx, ptype) or is_stop_index(idx, ptype) + + +def is_canonical(res: str, ptype=polymerType.LPROT): + if ptype == polymerType.LPROT or ptype == polymerType.DPROT: + idx = res_to_index(res, ptype) + return (idx < 20) and (idx >= 0) + elif ptype == polymerType.LDPROT: + return is_canonical(res, polymerType.LPROT) or is_canonical( + mirror_amino_acid(res), polymerType.DPROT + ) + raise Exception(f"do not known how to deal with polymer type {ptype}") + + +def canonical_amino_acids(ptype=polymerType.LPROT): + canonicals = [] + for aa in _res3[ptype.value]: + if is_canonical(aa, ptype): + canonicals.append(aa) + return canonicals + + +_add_residue([polymerType.LPROT, polymerType.DPROT], ["ALA", "DAL"], ["A", "a"]) + +_add_residue([polymerType.LPROT, polymerType.DPROT], ["CYS", "DCY"], ["C", "c"]) + +_add_residue([polymerType.LPROT, polymerType.DPROT], ["ASP", "DAS"], ["D", "d"]) + +_add_residue([polymerType.LPROT, polymerType.DPROT], ["GLU", "DGL"], ["E", "e"]) + +_add_residue([polymerType.LPROT, polymerType.DPROT], ["PHE", "DPN"], ["F", "f"]) + +_add_residue([polymerType.LPROT, polymerType.DPROT], ["GLY", "GLY"], ["G", "G"]) + +_add_residue([polymerType.LPROT, polymerType.DPROT], ["HIS", "DHI"], ["H", "h"]) + +_add_residue([polymerType.LPROT, polymerType.DPROT], ["ILE", "DIL"], ["I", "i"]) + +_add_residue([polymerType.LPROT, polymerType.DPROT], ["LYS", "DLY"], ["K", "k"]) + +_add_residue([polymerType.LPROT, polymerType.DPROT], ["LEU", "DLE"], ["L", "l"]) + +_add_residue([polymerType.LPROT, polymerType.DPROT], ["MET", "MED"], ["M", "m"]) + +_add_residue([polymerType.LPROT, polymerType.DPROT], ["ASN", "DSG"], ["N", "n"]) + +_add_residue([polymerType.LPROT, polymerType.DPROT], ["PRO", "DPR"], ["P", "p"]) + +_add_residue([polymerType.LPROT, polymerType.DPROT], ["GLN", "DGN"], ["Q", "q"]) + +_add_residue([polymerType.LPROT, polymerType.DPROT], ["ARG", "DAR"], ["R", "r"]) + +_add_residue([polymerType.LPROT, polymerType.DPROT], ["SER", "DSN"], ["S", "s"]) + +_add_residue([polymerType.LPROT, polymerType.DPROT], ["THR", "DTH"], ["T", "t"]) + +_add_residue([polymerType.LPROT, polymerType.DPROT], ["VAL", "DVA"], ["V", "v"]) + +_add_residue([polymerType.LPROT, polymerType.DPROT], ["TRP", "DTR"], ["W", "w"]) + +_add_residue([polymerType.LPROT, polymerType.DPROT], ["TYR", "DTY"], ["Y", "y"]) + +_add_residue([polymerType.LPROT, polymerType.DPROT], ["HSD", "DSD"], ["H", "h"]) + +_add_residue([polymerType.LPROT, polymerType.DPROT], ["HSE", "DSE"], ["H", "h"]) + +_add_residue([polymerType.LPROT, polymerType.DPROT], ["HSC", "DSC"], ["H", "h"]) + +_add_residue([polymerType.LPROT, polymerType.DPROT], ["HSP", "DSP"], ["H", "h"]) + +_add_residue([polymerType.LPROT, polymerType.DPROT], ["MSE", "DMS"], ["M", "m"]) + +_add_residue([polymerType.LPROT, polymerType.DPROT], ["CSO", "DCS"], ["C", "c"]) + +_add_residue([polymerType.LPROT, polymerType.DPROT], ["SEC", "DEC"], ["C", "c"]) + +_add_residue([polymerType.LPROT, polymerType.DPROT], ["CSX", "DCX"], ["C", "c"]) + +_add_residue([polymerType.LPROT, polymerType.DPROT], ["HIP", "DHP"], ["H", "h"]) + +_add_residue([polymerType.LPROT, polymerType.DPROT], ["SEP", "DEP"], ["S", "s"]) + +_add_residue([polymerType.LPROT, polymerType.DPROT], ["TPO", "DTP"], ["T", "t"]) + +_add_residue([polymerType.LPROT, polymerType.DPROT], ["PTR", "DPT"], ["Y", "y"]) + +_add_residue([polymerType.LPROT, polymerType.DPROT], ["UNK", "UNK"], ["X", "X"]) + +_add_residue([polymerType.LPROT, polymerType.DPROT], ["STP", "STP"], ["*", "*"]) + +_add_residue([polymerType.LPROT, polymerType.DPROT], ["---", "---"], ["-", "-"]) + +_add_residue([polymerType.LPROT, polymerType.DPROT], ["---", "---"], [".", "."]) + +for grp in [1, 2, 3]: + for tp in [polymerType.LPROT, polymerType.DPROT]: + for idx in range(num_tokens(tp)): + if grp == 1: + if not is_punctuation_index(idx, tp) and ( + not is_unknown_index(idx, tp) + ): + if _res3[tp.value][idx] not in _res3[polymerType.LDPROT.value]: + _add_residue( + polymerType.LDPROT, + _res3[tp.value][idx], + _res1[tp.value][idx], + ) + elif grp == 2: + if is_unknown_index(idx, tp): + if _res3[tp.value][idx] not in _res3[polymerType.LDPROT.value]: + _add_residue( + polymerType.LDPROT, + _res3[tp.value][idx], + _res1[tp.value][idx], + ) + elif grp == 3: + if is_punctuation_index(idx, tp): + if _res3[tp.value][idx] not in _res3[polymerType.LDPROT.value]: + _add_residue( + polymerType.LDPROT, + _res3[tp.value][idx], + _res1[tp.value][idx], + ) + + +def mirror_amino_acid(res: str): + idx = mirror_amino_acid_index(res_to_index(res, polymerType.LDPROT)) + if len(res) == 1: + return index_to_single(idx) + return index_to_triple(idx) + + +def mirror_amino_acid_index(idx: int): + N = num_known_molecular_tokens(polymerType.LDPROT) + + # if this is an unknown residue or a punctuation mark, return as is + if idx >= N: + return idx + + # otherwise, flip chirality + return (idx + N // 2) % N diff --git a/prtm/models/chroma/procap.py b/prtm/models/chroma/procap.py new file mode 100644 index 00000000..52aacaa2 --- /dev/null +++ b/prtm/models/chroma/procap.py @@ -0,0 +1,751 @@ +# Copyright Generate Biomedicines, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from types import SimpleNamespace +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from torch.nn.functional import pad + +from prtm.models.chroma import graph_classifier +from prtm.models.chroma.basic import FourierFeaturization +from prtm.models.chroma.graph_classifier import GraphClassifier +from prtm.models.chroma.graph_design import BackboneEncoderGNN +from prtm.models.chroma.model_utils import load_model as utility_load_model +from prtm.models.chroma.model_utils import save_model as utility_save_model +from prtm.models.chroma.structure import diffusion +from prtm.models.chroma.xcs import validate_XC + + +class ProteinCaption(nn.Module): + """ProCap model for caption likelihood given a noised structure. + + Provides an architecture to model the likelihood of a caption representing a + protein backbone at an arbitrary diffusion time. For caption processing, it + uses a pretrained language model from Hugging Face which can be + user-specified and fine-tuned. For structures, ProteinCaption uses a + `BackboneEncoderGNN` that encodes a structure and its noise level in the + embedding space of the language model. There are several options for + interfacing between the representations of the backbone residues and those + of the caption. + + A `ProteinCaption` model can be used to conditionally generate backbones + given a natural language caption, through the creation of a + `ProCapConditioner` using the model. In this case, the noising parameters + used for the `ProteinCaption` model should be identical to those that were + used to train the underlying backbone diffusion model. + + Args: + lm_id (str): Base language model to pull from Hugging Face. + gnn_dim_edges (int): Number of edges for structure encoder. + context_size (int): When encoding structures by chains, specifies the + maximum number of chains to be used for the encodings. Not used when + `direct_gnn` is specified. + context_per_chain (int): When encoding structures by chain, the number + of context tokens to use per chain. Not used when `direct_gnn` is + specified. + gnn_num_neighbors (int): Number of neighbors per node for structure + encoder. + gnn_num_layers (int): Number of layers for structure encoder. + only_encode_caption_chain (bool): Whether to pass structure of only + chain whose caption is being predicted, as opposed to entire + structure. + gnn_embed_ratio (int): Number of context tokens to extract from GNN per + chain, stacks with gnn_embed_ratio. + graph_criterion (str): Graph criterion for structure encoder, defines + how neighbors are chosen. See + `chroma.models.graph_design.BackboneEncoderGNN` for + allowed values. + node_mlp_layers (int): Number of hidden layers for node update function + of structure encoder. + node_mlp_dim (int, optional): Dimension of hidden layers for node update + function of structure encoder, defaults to match output dimension. + noise_schedule (str): Noise schedule for mapping between diffusion time + and noise level, see + chroma.layers.structure.diffusion.DiffusionChainCov for allowed + values. + covariance_model (str): Covariance mode for mapping between diffusion + time and noise level, see + chroma.layers.structure.diffusion.DiffusionChainCov for allowed + values. + noise_complex_scaling (bool): Whether to scale noise for complexes. + noiseless (bool): Whether to train with denoised structures only, useful + for debugging but resulting model cannot be used for classifier + guidance. + normalize_context_embeddings (bool): Whether to normalize context + embeddings to an overall length of 1. + standardize_context_embeddings (bool): Whether to standardize context + embeddings to have mean 0 and variance 1. + time_feature_type (str): Method of encoding diffusion timestep. + time_log_feature_scaling (float): Scaling of diffusion timestep in + preprocessing when encoding with `time_feature_type = "log_snr"`. + use_transformer (bool): Whether to use transformer to embed context from + residue-level GNN outputs. + classifier_checkpoint (str, optional): Path to pre-trained graph + classifier checkpoint, whose encoder head will be used for structure + encoding. + direct_gnn (bool): Whether to pass in GNN encodings for chains/complexes + directly to the language model, without any pooling or transformer + layers. + classifier_kwargs (dict, optional): Dictionary of parameters to create + classifier network for encoding. Will override classifier_checkpoint + if given. + + + Inputs: + X (torch.Tensor): Backbone tensor of shape `(num_batch, num_residues, + 4, 3)`. + C (torch.Tensor): Chain map of shape `(num_batch, num_residues)`. + Positions with 0 are masked, positive integers are used for chain + indices, and negative integers are used for missing residues of the + chains with indices equal to the corresponding positive integers. + caption (List[str]): List of captions with length `num_batch`. + chain_id (torch.Tensor): Chain indices for given captions of shape + `(num_batch)`. For a caption corresponding to an entire complex, use + -1. + O (torch.Tensor, optional): One-hot sequence tensor of shape + `(num_batch, num_residues, num_alphabet)`. If not given, the loss is + computed without sequence information. + add_noise (bool): Whether to randomly add noise to the input backbones. + If structures are already noised, use `t` instead. + t (torch.Tensor, optional): Diffusion timesteps corresponding to noisy + input backbones, of shape `(num_batch)`. Use zeros when passing + structures without noise. + by_sample (bool): Whether to return loss per sample, as opposed to + overall batch loss. + + Outputs: + loss (Union[transformers.modeling_outputs.CausalLMOutputWithCrossAttentions, + torch.Tensor]): Loss containing average -log(p) of caption tokens + given output structures. If `by_sample` is specified, loss is output + as a tensor of length `(num_batch)`. + """ + + def __init__( + self, + lm_id: str = "EleutherAI/gpt-neo-125m", + gnn_dim_edges: int = 128, + context_size: int = 16, + context_per_chain: int = 1, + gnn_num_neighbors: int = 30, + gnn_num_layers: int = 3, + only_encode_caption_chain: bool = False, + gnn_embed_ratio: int = 1, + graph_criterion: str = "knn", + node_mlp_layers: int = 1, + node_mlp_dim: Optional[int] = None, + noise_schedule: str = "log_snr", + covariance_model: str = "globular", + noise_complex_scaling: bool = False, + noiseless: bool = False, + normalize_context_embeddings: bool = False, + standardize_context_embeddings: bool = False, + time_feature_type: str = "t", + time_log_feature_scaling: float = 0.05, + use_transformer: bool = False, + classifier_checkpoint: Optional[str] = None, + direct_gnn: bool = False, + classifier_kwargs: Optional[dict] = None, + ) -> None: + super().__init__() + + # Save configuration in kwargs + self.kwargs = locals() + self.kwargs.pop("self") + for key in list(self.kwargs.keys()): + if key.startswith("__") and key.endswith("__"): + self.kwargs.pop(key) + args = SimpleNamespace(**self.kwargs) + + try: + import transformers + except ImportError: + print("Install the hugging face package `transformers` to use ProCap") + + self.context_size = context_size + self.context_per_chain = context_per_chain + self.only_encode_caption_chain = only_encode_caption_chain + self.gnn_embed_ratio = gnn_embed_ratio + self.normalize_context_embeddings = normalize_context_embeddings + self.standardize_context_embeddings = standardize_context_embeddings + self.time_feature_type = time_feature_type + self.time_log_feature_scaling = time_log_feature_scaling + self.use_transformer = use_transformer + self.classifier_checkpoint = classifier_checkpoint + self.direct_gnn = direct_gnn + self.classifier_kwargs = classifier_kwargs + + if self.normalize_context_embeddings and self.standardize_context_embeddings: + print( + "Warning: both normalization and standardization of context embeddings" + " are selected, choosing only standardization" + ) + self.normalize_context_embeddings = False + + # Use Pretrained Tokenizer From Hugging Face + self.tokenizer = transformers.AutoTokenizer.from_pretrained( + lm_id, + additional_special_tokens=["<|pdb|>", "<|unconditioned|>"], + eos_token="<|endoftext|>", + pad_token="<|pad|>", + ) + + # Use Pretrained Language Model From Hugging Face + self.language_model = transformers.AutoModelForCausalLM.from_pretrained(lm_id) + + # Embedding + self.language_model.resize_token_embeddings(len(self.tokenizer)) + self.embedder = self.language_model.get_input_embeddings() + self.d_model = self.embedder.embedding_dim + + # Standardization for context embeddings + if self.standardize_context_embeddings: + self.context_normalization = nn.LayerNorm( + self.d_model, elementwise_affine=False + ) + + # Transformer for context embeddings + if self.use_transformer: + self.transformer = nn.Transformer( + nhead=8, + d_model=self.d_model, + num_encoder_layers=6, + num_decoder_layers=6, + dim_feedforward=2048, + batch_first=True, + ) + if gnn_embed_ratio != 1: + print( + "Warning: both use_transformer and gnn_embed_ratio are set, setting" + " gnn_embed_ratio to 1" + ) + self.gnn_embed_ratio = 1 + if context_per_chain != 1: + print( + "Warning: both use_transformer and context_per_chain are set," + " setting context_per_chain to 1" + ) + self.context_per_chain = 1 + if not self.only_encode_caption_chain: + print( + "Warning: use_transformer is set but only_encode_caption_chain is" + " not, this is unsupported! Setting only_encode_caption_chain to" + " True" + ) + self.only_encode_caption_chain = True + + # Pass in GNN encodings without averaging or transformer + if self.direct_gnn: + if gnn_embed_ratio != 1: + print( + "Warning: both direct_gnn and gnn_embed_ratio are set, setting" + " gnn_embed_ratio to 1" + ) + self.gnn_embed_ratio = 1 + if context_per_chain != 1: + print( + "Warning: both direct_gnn and context_per_chain are set, setting" + " context_per_chain to 1" + ) + self.context_per_chain = 1 + if not self.only_encode_caption_chain: + print( + "Warning: direct_gnn is set but only_encode_caption_chain is not," + " this is unsupported! Setting only_encode_caption_chain to True" + ) + self.only_encode_caption_chain = True + if self.use_transformer: + print( + "Warning: direct_gnn and use_transformer are both set, turning off" + " use_transformer" + ) + self.use_transformer = False + if self.context_size is not None: + print( + "Warning: context_size given but not used for direct_gnn, setting" + " context_size to None" + ) + self.context_size = None + + # Use Standard Protein Encoder + if self.classifier_checkpoint is not None or self.classifier_kwargs is not None: + if self.classifier_kwargs is not None: + self.protein_encoder = GraphClassifier(**classifier_kwargs) + else: + self.protein_encoder = graph_classifier.load_model( + classifier_checkpoint + ) + self.classifier_kwargs = self.protein_encoder.kwargs + self.kwargs["classifier_kwargs"] = self.classifier_kwargs + self.protein_encoder_linear = nn.Sequential( + nn.Linear( + self.protein_encoder.dim_nodes, self.d_model * self.gnn_embed_ratio + ), + nn.ReLU(), + ) + else: + self.protein_encoder = BackboneEncoderGNN( + dim_nodes=self.d_model * self.gnn_embed_ratio, + dim_edges=gnn_dim_edges, + num_neighbors=gnn_num_neighbors, + num_layers=gnn_num_layers, + node_mlp_layers=node_mlp_layers, + node_mlp_dim=node_mlp_dim, + graph_criterion=graph_criterion, + ) + + # Use same Noise Layer as in Graph Energy model + if not noiseless: + self.noise_generator = diffusion.DiffusionChainCov( + log_snr_range=(-7.0, 13.5), + noise_schedule=noise_schedule, + covariance_model=covariance_model, + complex_scaling=noise_complex_scaling, + ) + else: + self.noise_generator = None + self.time_features = FourierFeaturization( + d_input=1, + d_model=self.d_model * self.gnn_embed_ratio, + trainable=False, + scale=16.0, + ) + + # Embed Tokens for 21 Residue Possibilities + self.sequence_embedding = nn.Embedding(22, self.d_model * self.gnn_embed_ratio) + + @validate_XC() + def forward( + self, + X: torch.Tensor, + C: torch.Tensor, + caption: List[str], + chain_id: torch.Tensor, + O: Optional[torch.Tensor] = None, + add_noise: bool = True, + t: Optional[Union[torch.Tensor, float]] = None, + by_sample: bool = False, + ) -> Union[ + "transformers.modeling_outputs.CausalLMOutputWithCrossAttentions", torch.Tensor + ]: + if self.noise_generator is None: + t = torch.zeros(X.shape[0]).to(X.device) + + if isinstance(t, float): + t = torch.Tensor([t]).to(X.device) + + elif isinstance(t, torch.Tensor) and t.dim() == 0: + t = t.unsqueeze(0) + + if add_noise and self.noise_generator is not None: + # Add Chain Noise + X, t = self._noise(X, C) + assert all(t <= 1) and all(t >= 0), ( + "Noise Temperatures must be between 0 and 1, but got values" + f" {t[(t > 1) | (t < 0)]}" + ) + else: + assert t is not None, "Must pass diffusion timestep if not adding noise!" + + # Encode Protein Context + + if self.classifier_kwargs is None: + # Aux feature encoding + node_h = self._time_features(t) + if O is not None: + # pad one-hot tensor by two to account for special tokens used + node_h = node_h + pad(O, (0, 2)) @ self.sequence_embedding.weight.to( + X.device + ) + Xe, _, _, Me, _ = self.protein_encoder.to(X.device)(X, C, node_h_aux=node_h) + else: + # TODO: is there a better way to deal with sequence padding tokens when batch size > 1? + if O is not None and O[:, :, -1].any(): + O = None + Xe0, _, _, Me, _ = self.protein_encoder.to(X.device).encode(X, C, O, t) + Xe = self.protein_encoder_linear.to(X.device)(Xe0) + + context_embedding, attention_mask_context = self._encode_context( + Xe, C, Me, chain_id + ) + if self.standardize_context_embeddings: + context_embedding = self.context_normalization.to(Xe.device)( + context_embedding + ) + elif self.normalize_context_embeddings: + context_embedding = torch.nn.functional.normalize(context_embedding, dim=-1) + + # Encode Text Input + if self.direct_gnn: + max_caption_tokens = ( + self.tokenizer.model_max_length - context_embedding.shape[1] + ) + else: + max_caption_tokens = ( + self.tokenizer.model_max_length + - (self.context_size - 1) + * self.gnn_embed_ratio + * self.context_per_chain + - 1 + ) + Y, attention_mask_caption = self._tokenize( + caption, add_stop=True, max_length=max_caption_tokens + ) + Y = Y.to(X.device) + attention_mask_caption = attention_mask_caption.to(X.device) + caption_embedding = self._embed_text(Y) + + # Caption + inputs_embeds = torch.cat([context_embedding, caption_embedding], dim=1) + attention_mask = torch.cat( + [attention_mask_context, attention_mask_caption], dim=1 + ) + labels = torch.cat( + [ + torch.tensor(-100, device=X.device).expand( + attention_mask_context.shape + ), + Y * attention_mask_caption + (-100) * (1 - attention_mask_caption), + ], + dim=1, + ) + + # returns a transformers.modeling_outputs.CausalLMOutputWithCrossAttentions object + # can get logits with output.logits + output = self.language_model.to(X.device).forward( + inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels + ) + if not by_sample: + return output + else: # below code adapted from transformers/modeling_gpt2.py + shift_logits = output.logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + loss_fct = nn.CrossEntropyLoss(reduction="none") + loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) + ).reshape(X.shape[0], -1) + return torch.Tensor( + (loss * (shift_labels != -100).int()).sum(dim=-1) + / (shift_labels != -100).int().sum(dim=-1) + ) + + return output + + @validate_XC(all_atom=False) + def _noise( + self, X: torch.Tensor, C: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Takes in a Structure Tensor X and Chain Tensor C, adds chain noise with quasi-uniformly sampled temperature. + Returns the noised X and the time steps used.""" + assert self.noise_generator is not None, "Model does not have noising!" + return [x.to(X.device) for x in self.noise_generator.to(X.device)(X, C)] + + # Taken from graph classifier model + def _time_features(self, t: torch.Tensor) -> torch.Tensor: + h = { + "t": lambda: t, + "log_snr": lambda: self.noise_generator.noise_schedule.log_SNR(t), + }[self.time_feature_type]() + + if "log" in self.time_feature_type: + h = self.time_log_feature_scaling * h + + time_h = self.time_features.to(t.device)(h[:, None, None]) + return time_h + + def _encode_context( + self, Xe: torch.Tensor, C: torch.Tensor, M: torch.Tensor, polymer_id: int + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Average Pool over Chains after accounting for masking + input: + Xe (torch.Tensor): embedding tensor of shape [batch, residue, d_model] + C (torch.Tensor): chain tensor indexing which chain each residue belongs to [batch, residue] + M (torch.Tensor): mask tensor of shape [batch, residue] + polymer_id (int): index in C of chain, or -1 for entire structure, or 0 to apply no conditioning + """ + + Cm = C * M # Mask Chain Map + Cm[Cm < 0] = 0 # Remove Negatives from Chain Map + + B, R, Dm = Xe.shape + pooled_encoding = [] + for x, c, pid in zip(Xe, Cm, polymer_id): + batch_encoding = [] + + # The predict whole complex token is added under this syntax + if pid == -1: + pdb_embedding = self._embed_text( + self._tokenize(["<|pdb|>"], add_stop=False)[0].to(Xe.device) + ).squeeze(0) + batch_encoding.append(pdb_embedding) + + if pid == 0: + pdb_embedding = ( + self._embed_text( + self._tokenize(["<|unconditioned|>"], add_stop=False)[0] + ) + .squeeze(0) + .to(Xe.device) + ) + batch_encoding.append(pdb_embedding) + + # Power Average Pool By Chain + if pid != 0: + if self.only_encode_caption_chain and (pid != -1): + cid = self._pid_2_cid(pid, c) + residue_mask = c == cid + n_residues = residue_mask.sum(-1) + if self.use_transformer: + encodings = [ + self.transformer.to(Xe.device)( + x[residue_mask].unsqueeze(0), + torch.zeros(1, self.context_size, self.d_model).to( + Xe.device + ), + ).squeeze(0) + ] + elif self.direct_gnn: + encodings = x[residue_mask].unsqueeze(0) + else: + encodings = [ + (x[residue_mask].pow(p).sum(0).unsqueeze(0) / n_residues) + .abs() + .pow(1 / p) + * ( + x[residue_mask].pow(p).sum(0).unsqueeze(0).sign() + if p % 2 == 1 + else 1 + ) + for p in range(1, self.context_per_chain + 1) + ] + encodings = [ + enc.reshape(self.gnn_embed_ratio, -1) for enc in encodings + ] + batch_encoding.extend(encodings) + else: + if self.use_transformer or self.direct_gnn: + residue_mask = ( + c > 0 + ) # just use all embeddings, no chain structure + if self.use_transformer: + # should have pid == -1 to get here, so need encoding of size context_size - 1 because of <|pdb|> token + assert self.only_encode_caption_chain, ( + "only_encode_caption chain = False not supported when" + " use_transformer = True!" + ) + batch_encoding.append( + self.transformer.to(Xe.device)( + x[residue_mask].unsqueeze(0), + torch.zeros( + 1, self.context_size - 1, self.d_model + ).to(Xe.device), + ).squeeze(0) + ) + else: # direct_gnn + batch_encoding.extend(x[residue_mask].unsqueeze(0)) + else: + for cid in torch.unique(c): + if cid == 0: + continue + residue_mask = c == cid + n_residues = residue_mask.sum(-1) + encodings = [ + ( + x[residue_mask].pow(p).sum(0).unsqueeze(0) + / n_residues + ) + .abs() + .pow(1 / p) + * ( + x[residue_mask].pow(p).sum(0).unsqueeze(0).sign() + if p % 2 == 1 + else 1 + ) + for p in range(1, self.context_per_chain + 1) + ] + batch_encoding.extend( + [ + enc.reshape(self.gnn_embed_ratio, -1) + for enc in encodings + ] + ) + + # Reorder the chain embedding to caption to be first + if pid != -1: + first_cid = self._pid_2_cid(pid, c) + try: + if first_cid != 0: + ( + batch_encoding[ + (first_cid - 1) + * self.gnn_embed_ratio + * self.context_per_chain : (first_cid) + * self.gnn_embed_ratio + * self.context_per_chain + ], + batch_encoding[ + 0 : self.gnn_embed_ratio + * self.context_per_chain + ], + ) = ( + batch_encoding[ + 0 : self.gnn_embed_ratio + * self.context_per_chain + ], + batch_encoding[ + (first_cid - 1) + * self.gnn_embed_ratio + * self.context_per_chain : (first_cid) + * self.gnn_embed_ratio + * self.context_per_chain + ], + ) + except IndexError: + print( + "Problem: tried to switch encodings at positions 0 and" + f" {first_cid}, but failed!" + ) + # raise + + pooled_encoding.append(torch.cat(batch_encoding)) + + # Pad with Zero Tensor + X_pooled = torch.nn.utils.rnn.pad_sequence(pooled_encoding, batch_first=True) + + if self.context_size is not None: + if ( + X_pooled.shape[1] + > (self.context_size - 1) + * self.gnn_embed_ratio + * self.context_per_chain + + 1 + ): + print([x.shape for x in pooled_encoding]) + print(polymer_id) + assert ( + X_pooled.shape[1] + <= (self.context_size - 1) + * self.gnn_embed_ratio + * self.context_per_chain + + 1 + ), ( + f"Context is of length {X_pooled.shape[1]}, which is larger than the" + " allowed number of tokens" + f" {(self.context_size - 1) * self.gnn_embed_ratio * self.context_per_chain + 1};" + " this will cause the model to behave poorly!" + ) + if ( + X_pooled.shape[1] + < (self.context_size - 1) + * self.gnn_embed_ratio + * self.context_per_chain + + 1 + and not self.direct_gnn + ): + pad_shape = ( + (self.context_size - 1) + * self.gnn_embed_ratio + * self.context_per_chain + + 1 + - X_pooled.shape[1] + ) + zero_pad = torch.zeros( + [B, pad_shape, int(Dm / self.gnn_embed_ratio)], device=Xe.device + ) + X_pooled = torch.cat([X_pooled, zero_pad], dim=1) + + M_pooled = (X_pooled != 0)[ + :, :, 0 + ] # This is a bit dangerous because very rarely X_pooled could contain zeros in masked regions... + return X_pooled, M_pooled + + def _pid_2_cid(self, pid: int, c: int) -> int: + """This function converts the polymer_entity_id in the pdb to the chain_id in the XCS format of generate.""" + assert pid in c, f"pid value {pid} must be in the chain map!" + chain_values = torch.unique(c) + nonzero_chain_values = chain_values[chain_values != 0] + cid = (nonzero_chain_values == pid).nonzero(as_tuple=True)[0].item() + 1 + return cid + + def _tokenize( + self, text: list, add_stop: bool = True, max_length: Optional[int] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Converts list of strings into a padded tensor, returning the tokenized strings as well as the associated masks.""" + if add_stop: + text = [x + self.tokenizer.eos_token for x in text] + # Note that there are no stop tokens in truncated sequences + tokenized_dict = self.tokenizer( + text, + padding=True, + truncation=True, + max_length=max_length, + return_tensors="pt", + ) + return tokenized_dict["input_ids"], tokenized_dict["attention_mask"] + + def _embed_text(self, tokenized_text: torch.Tensor) -> torch.Tensor: + """Embeds tokenized text.""" + + return self.embedder.to(tokenized_text.device)(tokenized_text) + + +def load_model( + weight_file: str, + device: str = "cpu", + strict: bool = False, + strict_unexpected: bool = True, +) -> ProteinCaption: + """Loads a ProCap model. + + Args: + weight_file (str): Path to the saved model weights. + device (str): Device on which to load the model. + strict (bool): Whether to require that the keys match between the + input file weights and the model created from the parameters stored + in the model kwargs. + strict_unexpected (bool): Whether to require that there are no + unexpected keys when loading model weights, as distinct from the + strict option which doesn't allow for missing keys either. By + default, we use this option rather than strict for ease of + development when adding model features. + + Returns: + model (ProteinCaption): Instance of `ProteinCaption` with loaded + weights. For inference the returned model should be set to eval mode + with `model.eval()`. + """ + return utility_load_model( + weight_file, + ProteinCaption, + device=device, + strict=strict, + strict_unexpected=strict_unexpected, + ) + + +def save_model( + model: ProteinCaption, weight_file: str, metadata: Optional[dict] = None +) -> None: + """Save model, including optional metadata. + + Args: + model (ProteinCaption): An instance of `ProteinCaption`. + weight_file (str): The destination path for saving model weights. + metadata (dict): A dictionary of additional metadata to add to the model + weights. For example, when saving models during training it can be + useful to store `args` representing the CLI args, the date and time + of training, etc. + """ + utility_save_model(model, weight_file, metadata=metadata) diff --git a/prtm/models/chroma/sde.py b/prtm/models/chroma/sde.py new file mode 100644 index 00000000..17c73980 --- /dev/null +++ b/prtm/models/chroma/sde.py @@ -0,0 +1,120 @@ +# Copyright Generate Biomedicines, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Layers for integrating Stochastic Differential Equations (SDEs). + + +""" + + +from typing import Callable, Tuple + +import torch +from tqdm.autonotebook import tqdm + + +def sde_integrate( + sde_func: Callable, + y0: torch.Tensor, + tspan: Tuple, + N: int, + project_func: Callable = None, + T_grid: torch.Tensor = None, +) -> list: + """Integrate an Ito SDE with the Euler-Maruyama method. + + args: + sde_func (function): a function that takes in time and y and returns SDE drift and diffusion terms for the evolution of y + y0 (torch.tensor): the initial value of y, e.g. a noised protein structure tensor + tspan (tuple): a tuple (t_i, t_f) with t_i being the initial time and t_f being the final time for integration + N (int): number of integration steps + + returns: + y_trajectory (list): a list of snapshots of the evolution of y as the SDE is integrated + + """ + + with torch.no_grad(): + # Integrate SDE + y_trajectory = [y0] + + if T_grid is None: + T_grid = torch.linspace(tspan[0], tspan[1], N + 1).to(y0.device) + else: + assert T_grid.shape[0] == N + 1 + + y = y0 + for t0, t1 in tqdm( + zip(T_grid[:-1], T_grid[1:]), total=N, desc="Integrating SDE" + ): + t = t0 + dT = t1 - t0 + + f, gZ = sde_func(t, y) + y = y + dT * f + dT.abs().sqrt() * gZ + y = y if project_func is None else project_func(t, y) + + y_trajectory.append(y) + return y_trajectory + + +def sde_integrate_heun( + sde_func: Callable, + y0: torch.Tensor, + tspan: Tuple, + N: int, + project_func: Callable = None, + T_grid: torch.Tensor = None, +) -> list: + """Integrate an Ito SDE with Heun's method. + + args: + sde_func (function): a function that takes in time and y and returns SDE drift and diffusion terms for the evolution of y + y0 (torch.tensor): the initial value of y, e.g. a noised protein structure tensor + tspan (tuple): a tuple (t_i, t_f) with t_i being the initial time and t_f being the final time for integration + N (int): number of integration steps + + returns: + y_trajectory (list): a list of snapshots of the evolution of y as the SDE is integrated + + """ + + with torch.no_grad(): + # Integrate SDE + y_trajectory = [y0] + dT = (tspan[1] - tspan[0]) / N + + if T_grid is None: + T_grid = torch.linspace(tspan[0], tspan[1], N + 1).to(y0.device) + else: + assert T_grid.shape[0] == N + 1 + + y = y0 + + for t0, t1 in tqdm( + zip(T_grid[:-1], T_grid[1:]), total=N, desc="Integrating SDE" + ): + # for i in tqdm(range(N)): + # t = tspan[0] + i * dT + t = t0 + dT = t1 - t0 + f, gZ = sde_func(t, y) + y_pred = y + dT * f + dT.abs().sqrt() * gZ + f_pred, gZ_pred = sde_func(t, y_pred) + y_correct = y + dT * f_pred + dT.abs().sqrt() * gZ + y = (y_pred + y_correct) / 2.0 + y = y if project_func is None else project_func(t, y) + y_trajectory.append(y) + + return y_trajectory diff --git a/prtm/models/chroma/starparser.py b/prtm/models/chroma/starparser.py new file mode 100644 index 00000000..23478078 --- /dev/null +++ b/prtm/models/chroma/starparser.py @@ -0,0 +1,209 @@ +# Copyright Generate Biomedicines, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import shlex +from dataclasses import dataclass + + +@dataclass +class PeekedLine: + line: str + next_position: int + + +def peek_line(f, peeked: PeekedLine, rewind=True): + ret = True + pos = f.tell() + line = f.readline() + if line == "": # at EOF + ret = False + elif line[-1] == "\n": + line = line[:-1] + peeked.line = line + if rewind: + peeked.next_position = f.tell() + f.seek(pos) + else: + peeked.next_position = pos + return ret + + +def advance(f, peeked: PeekedLine): + f.seek(peeked.next_position) + + +def star_item_parse(line: str): + parts = line.split(".") + if len(parts) < 2: + raise Exception(f"expected at least two parts in the STAR data line {line}") + cat = parts[0] + name_parts = parts[1].split() + name = name_parts[0] + if len(name_parts) >= 2: + val = name_parts[1] + else: + val = "" + return (cat, name, val) + + +def star_read_data(f, names: list, in_loop: bool, cols=False, has_blocks=True): + tab = [] + line = "" + if cols: + tab = [[] for _ in range(len(names))] + peeked = PeekedLine("", 0) + if in_loop: + heads = [] + while peek_line(f, peeked): + if not peeked.line.startswith("_"): + break + parts = peeked.line.split(".") + if len(parts) != 2: + raise Exception(f"expected two parts in the STAR data line {line}") + heads.append(parts[1].strip()) + advance(f, peeked) + + # figure out which columns we want + indices = [-1] * len(names) + for i, name in enumerate(names): + if name in heads: + indices[i] = heads.index(name) + + # read each row and get the corresponding columns + row = [None] * len(heads) + ma = max(indices) + while star_read_data_row(f, row, in_loop, has_blocks): + if (ma >= 0) and (len(row) <= ma): + raise Exception(f"loop row has insufficient elements: {line}") + if not cols: + tab.append([""] * len(names)) + for i, index in enumerate(indices): + if cols: + tab[i].append(row[index] if index >= 0 else "") + else: + tab[-1][i] = row[index] if index >= 0 else "" + else: + if not cols: + tab = [[""] * len(names)] + category, cat, name = "", "", "" + + row = ["", ""] + while star_read_data_row(f, row, in_loop, has_blocks, peeked): + cat, name, _ = star_item_parse(row[0]) + if category == "": + category = cat + elif category != cat: + advance(f, peeked) + break + + if name not in names: + continue + idx = names.index(name) + if cols: + tab[idx].push_back(row[1]) + else: + tab[0][idx] = row[1] + + return tab + + +def star_read_data_row( + f, row: list, in_loop: bool, has_blocks: bool, peeked: PeekedLine = None +): + i = 0 + ret = True + if peeked is None: + peeked = PeekedLine("", 0) + while i < len(row): + if not peek_line(f, peeked, rewind=False): + if peeked.line == "" and i == 0: + return False + raise Exception(f"read {i} tokens when {len(row)} were requested: {row}") + if ( + peeked.line.startswith("loop_") + or peeked.line.startswith("data_") + or (in_loop and peeked.line.startswith("_")) + ): + if i == 0: + advance(f, peeked) + return False + raise Exception( + f"data block ended while reading requested number of tokens: {len(row)}" + ) + + if peeked.line.startswith(";"): + row[i] = peeked.line[1:] + while peek_line(f, peeked, rewind=False): + if peeked.line.startswith(";"): + break + row[i] += peeked.line + i = i + 1 + elif peeked.line.startswith("#"): + pass + else: + elems = ( + [part for part in shlex.split(peeked.line.strip())] + if has_blocks + else peeked.line.strip().split() + ) + if i + len(elems) > len(row): + raise Exception( + f"too many elements when trying to read {len(row)} tokens; last read: {elems}, row was: {row}, i = {i}" + ) + for elem in elems: + row[i] = elem + i = i + 1 + + return ret + + +def star_string_escape(text): + # NOTE: has_space designates whether the string really should be quoted, not + # based on having quote characters within it, but just because of some other + # reason (e.g., it has spaces or is empty or starts with underscore, which can + # have special meaning in CIF). + has_space = (" " in text) or (text == "") or ((len(text) > 0) and (text[0] == "_")) + has_single = "'" in text + has_double = '"' in text + + if not has_single and not has_double: + if not has_space: + return text + else: + return f"'{text}'" + elif not has_single: + return f"'{text}'" + elif not has_double: + return '"' + text + '"' + return "\n;" + str + "\n;" + + +def star_loop_header_write(f, category, names): + f.write("loop_\n") + for name in names: + f.write(f"{category}.{name} \n") + + +def star_value_defined(val): + return (val != ".") and (val != "?") + + +def star_value(val, default): + if star_value_defined(val): + return val + return default + + +def atom_site_token(value): + return "." if value == " " else value diff --git a/prtm/models/chroma/structure/__init__.py b/prtm/models/chroma/structure/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/prtm/models/chroma/structure/backbone.py b/prtm/models/chroma/structure/backbone.py new file mode 100644 index 00000000..50f0d594 --- /dev/null +++ b/prtm/models/chroma/structure/backbone.py @@ -0,0 +1,1126 @@ +# Copyright Generate Biomedicines, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Layers for generating protein structure. + +This module contains pytorch layers for parametrically generating and +manipulating protein backbones. These can be used in tandem with loss functions +to generate and optimize protein structure (e.g. folding from predictions) or +used as an intermediate layer in a learned structure generation model. +""" + +from typing import Optional, Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from prtm.models.chroma.structure import geometry, transforms + + +class ProteinBackbone(nn.Module): + """Protein backbone layer with optimizable geometry (batch form). + + This layer stores the parameters for a protein backbone, which can be based + on either internal coordinate or Cartesian parameterizations. + It outputs coordinates in Cartesian form as 4D tensors with indices + `[batch, position, atom_type, xyz]`. The `atom_type` index runs over the + heavy atoms of a protein backbone in PDB order, i.e. `[N, CA, C, O]`. + The resulting coordinates can be directly optimized with pytorch optimizers. + + Args: + num_residues (int): Number of residues. + num_batch (int): Batch size. + init_state (str): Initialization state. Can be ['alpha', 'beta', ''] + use_internal_coords (Boolean): Use a phi,psi parameterization. + Default is True. + X_init (torch.Tensor, optional): Initialize with pre-specified coordinates. + Requires that use_internal_coords=False. + + Outputs: + X (torch.Tensor): Backbone coordinates with shape + `(batch_size, num_residues, 4, 3)`. + """ + + def __init__( + self, + num_residues: int, + num_batch: int = 1, + init_state: str = "alpha", + use_internal_coords: bool = True, + X_init: Optional[torch.Tensor] = None, + ): + super(ProteinBackbone, self).__init__() + + # Dimensions + self.num_batch = num_batch + self.num_residues = num_residues + + # Rigid body translation and rotation + self.transform = RigidTransform(num_batch=num_batch, keep_centered=True) + + self.use_internal_coords = use_internal_coords + if self.use_internal_coords: + # Internal coordinate parameterization + self.phi = nn.Parameter(torch.zeros(num_batch, num_residues)) + self.psi = nn.Parameter(torch.zeros(num_batch, num_residues)) + + # Initializer + phi_psi = { + "alpha": (np.radians(-60.0), np.radians(-45.0)), + "beta": (np.radians(-140.0), np.radians(135.0)), + } + if init_state in phi_psi: + torch.nn.init.constant_(self.phi, phi_psi[init_state][0]) + torch.nn.init.constant_(self.psi, phi_psi[init_state][1]) + else: + torch.nn.init.uniform_(self.phi, a=-np.pi, b=np.pi) + torch.nn.init.uniform_(self.psi, a=-np.pi, b=np.pi) + + self.backbone_geometry = BackboneBuilder() + else: + # Use a Cartesian parameterization + if X_init is not None: + assert not use_internal_coords + else: + X_init = ProteinBackbone( + num_residues=num_residues, + num_batch=num_batch, + init_state=init_state, + use_internal_coords=True, + )() + self.X = nn.Parameter(X_init) + + def forward(self) -> torch.Tensor: + if self.use_internal_coords: + X = self.backbone_geometry(self.phi, self.psi) + else: + X = self.X + + # Apply rotation and translation + X = self.transform(X) + return X + + +class RigidTransform(nn.Module): + """Rigid-body rotation and translation (batch form). + + This layer stores the parameters for a rigid body rotation and translation. + It can be composed with other generative geometry layers to optimize over + poses. + + Args: + num_batch (int): Number of poses to store parameters for. + keep_centered (Boolean): If True, center the input coordinates by + default. + scale_dX (float): Scale factor which affects the rate of change of + translation. + scale_q (float): Scale factor which affects the rate of change of + rotation. + + Inputs: + X (torch.Tensor): Input coordinates with shape `(batch_size, ..., 3)`. + + Outputs: + X_t (torch.Tensor): Transformed coordinates with shape: + `(batch_size, ..., 3)`. + """ + + def __init__( + self, + num_batch: int = 1, + keep_centered: bool = False, + scale_dX: float = 1.0, + scale_q: float = 1.0, + ): + super(RigidTransform, self).__init__() + self.num_batch = num_batch + + # Cartesian offset initialized to 0 + self.dX = nn.Parameter(torch.zeros(self.num_batch, 3)) + self.scale_dX = scale_dX + + # Unconstrained quaternion initialized to identity + self.scale_q = scale_q + q_init = np.asarray([[1.0, 0, 0, 0]] * self.num_batch) + q_init = torch.tensor(q_init, dtype=torch.float32) / self.scale_q + self.q_unc = nn.Parameter(q_init) + + self.rigid_transform = RigidTransformer(keep_centered=keep_centered) + + def forward(self, X: torch.Tensor) -> torch.Tensor: + dX = self.scale_dX * self.dX + q_unc = self.scale_q * self.q_unc + X_t = self.rigid_transform(X, dX, q_unc) + return X_t + + +class RigidTransformer(nn.Module): + """Rigid-body rotation and translation (batch form). + + This layer applies a rigid body rotation and translation, + and can be composed with other generative geometry layers to modify poses. + + Internally, the coordinates are centered before rotation and translation. + The rotation itself is parameterized as a quaternion to prevent + Gimbal lock (https://en.wikipedia.org/wiki/Gimbal_lock). + + Args: + center_intput (Boolean): Center the input coordinates (default: True) + default. + + Inputs: + X (torch.Tensor): Input coordinates with shape `(batch_size, ..., 3)`. + dX (torch.Tensor): Translation vector with shape `(batch_size, 3)`. + q (torch.Tensor): Rotation vector (quaternion) with shape `(batch_size, 4)`. + It can be any 4-element real vector, but will internally be + normalized to a unit quaternion. + mask (tensor,optional): Mask tensor with shape `(batch_size, ..., 3)`. + + Outputs: + X_t (torch.Tensor): Transformed coordinates with shape `(batch_size, ..., 3)`. + """ + + def __init__(self, center_rotation: bool = True, keep_centered: bool = False): + super(RigidTransformer, self).__init__() + self.center_rotation = center_rotation + self.keep_centered = keep_centered + self.dist_eps = 1e-5 + + def _rotation_matrix(self, q_unc: torch.Tensor) -> torch.Tensor: + """Build rotation matrix from quaternion parameters. + + See en.wikipedia.org/wiki/Quaternions_and_spatial_rotation for further + details on converting between quaternions and rotation matrices. + + Args: + q_unc (torch.Tensor): Unnormalized quaternion representing rotation with + shape `(batch_size, 3)`. + + Returns: + R (torch.Tensor): Rotation matrix with shape `(batch_size, 3)`. + """ + num_batch = q_unc.shape[0] + q = F.normalize(q_unc, dim=-1) + + # fmt: off + a,b,c,d = q.unbind(-1) + a2,b2,c2,d2 = a**2, b**2, c**2, d**2 + R = torch.stack([ + a2 + b2 - c2 - d2, 2*b*c - 2*a*d, 2*b*d + 2*a*c, + 2*b*c + 2*a*d, a2 - b2 + c2 - d2, 2*c*d - 2*a*b, + 2*b*d - 2*a*c, 2*c*d + 2*a*b, a2 - b2 - c2 + d2 + ], dim=-1) + # fmt: on + + R = R.view([num_batch, 3, 3]) + return R + + def forward( + self, + X: torch.Tensor, + dX: torch.Tensor, + q: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + num_batch = X.shape[0] + X_flat = X.reshape([num_batch, -1, 3]) + + # Flatten mask + if mask is not None: + shape_mask = list(mask.size()) + shape_X = list(X.size()) + shape_mask_expand = shape_mask + [ + 1 for i in range(len(shape_X) - 1 - len(shape_mask)) + ] + shape_mask_flat = list(X_flat.size())[:-1] + [1] + + mask_flat = mask.reshape(shape_mask_expand).expand(shape_X[:-1]) + mask_flat = mask_flat.reshape(shape_mask_flat) + + # Compute center + X_mean = torch.sum(mask_flat * X_flat, 1, keepdims=True) / ( + torch.sum(mask_flat, 1, keepdims=True) + self.dist_eps + ) + else: + X_mean = torch.mean(X_flat, 1, keepdims=True) + + # Rotate around center of mass + if self.center_rotation: + X_centered = X_flat - X_mean + else: + X_centered = X_flat + R = self._rotation_matrix(q) + X_rotate = torch.einsum("bxr,bir->bix", R, X_centered) + + # Optionally preserve original centering + if self.center_rotation and not self.keep_centered: + X_rotate = X_rotate + X_mean + + # Translate + X_transform = X_rotate + dX.unsqueeze(1) + + if mask is not None: + X_transform = mask_flat * X_transform + (1 - mask_flat) * X_flat + + X_transform = X_transform.view(X.shape) + return X_transform + + +class BackboneBuilder(nn.Module): + """Protein backbone builder from dihedral angles (batch form). + + See ProteinBackbone() for further explanation of output coordinates. + + When only partial information is given such as phi & psi angles, this module + will fall default to using the ideal geometries given in + Engh & Huber, International Tables for Crystallography (2001). + https://doi.org/10.1107/97809553602060000857 + + Todo: + * Add shifting and padding logic to associate phis and psis with their + 'natural' residue indices rather than the child atoms that they + create during NERF recurrence + * Add control over the bond lengths and angles for Oxygen + + Inputs: + phi (torch.Tensor): Phi dihedral angles with shape `(batch_size, length)`. + psi (torch.Tensor): Psi dihedral angles with shape `(batch_size, length)`. + omega (torch.Tensor, optional): Omega dihedral angles with shape + `(batch_size, length)`. Defaults to ideal geometry. + angles (torch.Tensor, optional): Bond angles with shape + `(batch_size, 3*length)` Defaults to ideal geometry. + lengths (torch.Tensor, optional): Bond lengths with shape + `(batch_size, 3*length)`. Defaults to ideal geometry. + + Outputs: + X (torch.Tensor): Backbone coordinates with shape + `(batch_size, num_residues, 4, 3)`. + """ + + def __init__(self): + super(BackboneBuilder, self).__init__() + + # From "Structure Quality and Target Parameters", Engh & Huber, 2001 + # fmt: off + self.lengths = { + 'N_CA': 1.459, + 'CA_C': 1.525, + 'C_N': 1.336, + 'C_O': 1.229 + } + angles = { + 'N_CA_C': 111.0, + 'CA_C_N': 117.2, + 'C_N_CA': 121.7, + 'omega': 179.3 + } + self.angles = { + k: v * np.pi / 180. for k,v in angles.items() + } + # fmt: on + return + + def forward( + self, + phi: torch.Tensor, + psi: torch.Tensor, + omega: Optional[torch.Tensor] = None, + angles: Optional[torch.Tensor] = None, + lengths: Optional[torch.Tensor] = None, + add_O: bool = True, + ) -> torch.Tensor: + N_batch, N_residues = phi.shape[0], phi.shape[1] + linear_shape = [N_batch, N_residues] + device = phi.device + + """ + This uses a similar (but not identical) approach as NERF: + Parsons et al, Computational Chemistry (2005). + https://doi.org/10.1002/jcc.20237 + See the reference for further explanation about converting from internal + coordinates to Cartesian coordinates. + ____________________________________________________________________ + | N-to-C backbone geometry for NERF | + | i.e. which internal coords create which Cartesian coords | + | [% indicates preceding residue] | + |______________________ _________________________________________ ___| + |i-1 |Residue i |i+1| + | | | | + |Atom: [C%]--omega%--[N]----phi----[CA]----psi---[C]---omega---[N>]| + | | | | | + |Parents | | | | + | Bond: C%_N N_CA CA_C | + | Angle: CA%_C%_N C%_N_CA N_CA_C | + |Dihedral: N%_CA%_C%_N CA%_C%_N_CA C%_N_CA_C | + |--------------------------------------------------------------------| + |Bond: [C_N]% [N_CA] [CA_C] | + |Dihedral: psi% omega% phi | + |____________________________________________________________________| + """ + + if lengths is None: + lengths = torch.tensor( + [[self.lengths[key] for key in ["C_N", "N_CA", "CA_C"]]], + dtype=torch.float32, + ).to(device) + lengths = lengths.repeat(N_batch, N_residues) + + if angles is None: + angles = torch.tensor( + [[self.angles[key] for key in ["CA_C_N", "C_N_CA", "N_CA_C"]]], + dtype=torch.float32, + ).to(device) + angles = angles.repeat(N_batch, N_residues) + + if omega is None: + omega = self.angles["omega"] * torch.ones(linear_shape).to(device) + + # Compute un-rotated Cartesian coordinates in batch + dihedrals = torch.stack([psi, omega, phi], -1) + dihedrals = dihedrals.view([N_batch, 3 * N_residues]) + angles_comp = np.pi - angles + v = torch.stack( + [ + torch.cos(angles_comp), + torch.sin(angles_comp) * torch.cos(dihedrals), + torch.sin(angles_comp) * torch.sin(dihedrals), + ], + -1, + ) + + # Lengths + lengths_list = list(lengths.unsqueeze(-1).unbind(1)) + v_list = list(v.unbind(1)) + + if add_O: + # Build one extra appended residue + lengths_list += lengths_list[-3:] + v_list += v_list[-3:] + + def _build_x_i(v_i, l_i, x, u_minus_1, u_minus_2): + """Recurrence relation for placing atoms (NERF)""" + + # Build matrix encoding local reference frame + n_a_unnorm = torch.cross(u_minus_2, u_minus_1) + n_a = F.normalize(n_a_unnorm, dim=-1) + n_b = torch.cross(n_a, u_minus_1) + + # Matrix multiply version + R = torch.stack([u_minus_1, n_b, n_a], 2) + u_new = torch.matmul(R, v_i.unsqueeze(-1)).squeeze(-1) + + x_new = x + l_i * u_new + return x_new, u_new, u_minus_1 + + # Initialization + x_i = torch.zeros([N_batch, 3]).to(device) + u_i_minus_2 = torch.tensor([[1.0, 0, 0]] * N_batch, dtype=torch.float32).to( + device + ) + u_i_minus_1 = torch.tensor([[0, 1.0, 0]] * N_batch, dtype=torch.float32).to( + device + ) + + # Build chain via NERF recurrence + X = [] + for i, (v_i, l_i) in enumerate(zip(v_list, lengths_list)): + x_i, u_i_minus_1, u_i_minus_2 = _build_x_i( + v_i, l_i, x_i, u_i_minus_1, u_i_minus_2 + ) + X.append(x_i) + X = torch.stack(X, 1) + # [N,AL,3] => [N,L,A,3] + X = X.view([N_batch, -1, 3, 3]) + + if add_O: + # Build the oxygen vector using symmetry + u_1 = F.normalize(X[:, :-1, 2, :] - X[:, :-1, 1, :], dim=-1) # CA->C + u_2 = F.normalize(X[:, :-1, 2, :] - X[:, 1:, 0, :], dim=-1) # C<-N* + u = self.lengths["C_O"] * F.normalize(u_1 + u_2, dim=-1) + X = X[:, :-1, :, :] + X_O = X[:, :, 2, :] + u + X = torch.cat([X, X_O.unsqueeze(2)], 2) + + X = X - X.mean([1, 2, 3], keepdim=True) + return X + + +class FrameBuilder(nn.Module): + """Build protein backbones from rigid residue poses. + + Inputs: + R (torch.Tensor): Rotation of residue orientiations + with shape `(num_batch, num_residues, 3, 3)`. If `None`, + then `q` must be provided instead. + t (torch.Tensor): Translation of residue orientiations + with shape `(num_batch, num_residues, 3)`. This is the + location of the C-alpha coordinates. + C (torch.Tensor): Chain map with shape `(num_batch, num_residues)`. + q (Tensor, optional): Quaternions representing residue orientiations + with shape `(num_batch, num_residues, 4)`. + + Outputs: + X (torch.Tensor): All-atom protein coordinates with shape + `(num_batch, num_residues, 4, 3)` + """ + + def __init__(self, distance_eps: float = 1e-3): + super().__init__() + + # Build idealized backbone fragment + t = torch.tensor( + [ + [1.459, 0.0, 0.0], # N-C via Engh & Huber is 1.459 + [0.0, 0.0, 0.0], # CA is origin + [-0.547, 0.0, -1.424], # C is placed 1.525 A @ 111 degrees from N + ], + dtype=torch.float32, + ).reshape([1, 1, 3, 3]) + R = torch.eye(3).reshape([1, 1, 1, 3, 3]) + self.register_buffer("_t_atom", t) + self.register_buffer("_R_atom", R) + + # Carbonyl geometry from CHARMM all36_prot ALA definition + self._length_C_O = 1.2297 + self._angle_CA_C_O = 122.5200 + self._dihedral_Np_CA_C_O = 180 + self.distance_eps = distance_eps + + def _build_O(self, X_chain: torch.Tensor, C: torch.LongTensor): + """Build backbone carbonyl oxygen.""" + # Build carboxyl groups + X_N, X_CA, X_C = X_chain.unbind(-2) + + # TODO: fix this behavior for termini + mask_next = (C > 0).float()[:, 1:].unsqueeze(-1) + X_N_next = F.pad( + mask_next + * X_N[ + :, + 1:, + ], + (0, 0, 0, 1), + ) + + num_batch, num_residues = C.shape + ones = torch.ones(list(C.shape), dtype=torch.float32, device=C.device) + X_O = geometry.extend_atoms( + X_N_next, + X_CA, + X_C, + self._length_C_O * ones, + self._angle_CA_C_O * ones, + self._dihedral_Np_CA_C_O * ones, + degrees=True, + ) + mask = (C > 0).float().reshape(list(C.shape) + [1, 1]) + X = mask * torch.stack([X_N, X_CA, X_C, X_O], dim=-2) + return X + + def forward( + self, + R: torch.Tensor, + t: torch.Tensor, + C: torch.LongTensor, + q: Optional[torch.Tensor] = None, + ): + assert q is None or R is None + + if R is None: + # (B,N,1,3,3) and (B,N,1,3) + R = geometry.rotations_from_quaternions( + q, normalize=True, eps=self.distance_eps + ) + + R = R.unsqueeze(-3) + t_frame = t.unsqueeze(-2) + X_chain = transforms.compose_translation(R, t_frame, self._t_atom) + X = self._build_O(X_chain, C) + return X + + def inverse( + self, X: torch.Tensor, C: torch.LongTensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Reconstruct transformations from poses. + + Inputs: + X (torch.Tensor): All-atom protein coordinates with shape + `(num_batch, num_residues, 4, 3)` + C (torch.Tensor): Chain map with shape `(num_batch, num_residues)`. + + Outputs: + R (torch.Tensor): Rotation of residue orientiations + with shape `(num_batch, num_residues, 3, 3)`. + t (torch.Tensor): Translation of residue orientiations + with shape `(num_batch, num_residues, 3)`. This is the + location of the C-alpha coordinates. + q (torch.Tensor): Quaternions representing residue orientiations + with shape `(num_batch, num_residues, 4)`. + """ + X_bb = X[:, :, :4, :] + R, t = geometry.frames_from_backbone(X_bb, distance_eps=self.distance_eps) + q = geometry.quaternions_from_rotations(R, eps=self.distance_eps) + mask = (C > 0).float().unsqueeze(-1) + R = mask.unsqueeze(-1) * R + t = mask * t + q = mask * q + return R, t, q + + +class GraphBackboneUpdate(nn.Module): + """Layer for updating backbone coordinates given graph embeddings. + + Args: + dim_nodes (int): Node dimension of graph input. + dim_edges (int): Edge dimension of graph input. + distance_scale (float): Coordinate scaling factor in angstroms. Default + is 10 angstroms per unit, which corresponds to nanometers. + method (str): Method used for predicting coordinates. Options include + * `local`: Node-based relative transformations. + * `neighbor`: Inter-residue geometry. + * `neighbor_global`: Inter-residue geometry with virtual global edge. + * `neighbor_global_affine`: Inter-residue geometry with virtual + global edge, parameterized as a residual update. + * `none`: No transformation-based updates. + iterations (int): Number of method iteractions. + unconstrained (bool): If True, update sub-pose geometries beyond ideal + coordinates. + num_transform_weights (int): Number of uncertainty dimensions per residue + for neighbor-based updates. + black_hole_init (bool): If True, ignore initial geometry and initialize + poses at the origin as in AlphaFold2 (Jumper et al 2020). + + Inputs: + X (torch.Tensor): Backbone coordinates with shape + `(num_batch, num_residues, num_atoms, 3)`. + C (torch.LongTensor): Chain map with shape `(num_batch, num_residues)`. + node_h (torch.Tensor): Node features with shape + `(num_batch, num_residues, dim_nodes)`. + edge_h (torch.Tensor): Edge features with shape + `(num_batch, num_residues, num_neighbors, dim_edges)`. + edge_idx (torch.LongTensor): Edge indices for neighbors with shape + `(num_batch, num_residues, num_neighbors)`. + mask_i (torch.Tensor): Node mask with shape `(num_batch, num_residues)`. + mask_ij (torch.Tensor): Edge mask with shape + `(num_batch, num_nodes, num_neighbors)`. + + Outputs: + X_update (torch.Tensor): Updated backbone coordinates with shape + `(num_batch, num_residues, num_atoms, 3)`. + """ + + def __init__( + self, + dim_nodes: int, + dim_edges: int, + distance_scale: float = 10.0, + distance_eps: float = 1e-3, + method: str = "neighbor", + iterations: int = 1, + unconstrained: bool = True, + num_transform_weights: int = 1, + black_hole_init: bool = False, + ): + super(GraphBackboneUpdate, self).__init__() + self.distance_scale = distance_scale + self.distance_eps = distance_eps + self._eps = 1e-5 + + self.frame_builder = FrameBuilder(distance_eps=distance_eps) + self.method = method + self.iterations = iterations + self.unconstrained = unconstrained + self.num_transform_weights = num_transform_weights + self.black_hole_init = black_hole_init + + if self.method == "local": + self.W_q = nn.Linear(dim_nodes, 4) + self.W_t = nn.Linear(dim_nodes, 3) + elif self.method == "neighbor": + self.W_q = nn.Linear(dim_edges, 4) + self.W_t = nn.Linear(dim_edges, 3) + self.W_w = nn.Linear(dim_edges, self.num_transform_weights) + elif self.method == "neighbor_global": + self.W_q = nn.Linear(dim_edges, 4) + self.W_t = nn.Linear(dim_edges, 3) + self.W_w = nn.Linear(dim_edges, self.num_transform_weights) + self.W_q_global = nn.Linear(dim_nodes, 4) + self.W_t_global = nn.Linear(dim_nodes, 3) + self.W_w_global = nn.Linear(dim_nodes, self.num_transform_weights) + elif self.method == "neighbor_global_affine": + self.W_s_node = nn.Linear(dim_nodes, 2) + self.W_s_edge = nn.Linear(dim_edges, 2) + self.W_q = nn.Linear(dim_edges, 4) + self.W_t = nn.Linear(dim_edges, 3) + self.W_w = nn.Linear(dim_edges, self.num_transform_weights) + self.W_q_global = nn.Linear(dim_nodes, 4) + self.W_t_global = nn.Linear(dim_nodes, 3) + self.W_w_global = nn.Linear(dim_nodes, self.num_transform_weights) + if self.method == "none": + # None does no frame based updates + assert self.unconstrained + + if self.unconstrained: + self.W_t_local = nn.Linear(dim_nodes, 12) + return + + def _init_black_hole(self, X): + R = ( + torch.eye(3, device=X.device, dtype=X.dtype) + .reshape(1, 1, 3, 3) + .repeat(X.size(0), X.size(1), 1, 1) + ) + t = torch.zeros(X.size(0), X.size(1), 3, dtype=X.dtype, device=X.device) + return R, t + + def _update_local_transform(self, X, C, node_h, edge_h, edge_idx, mask_i, mask_ij): + """Update residue frames via transformation from self.""" + R_i, t_i, _ = self.frame_builder.inverse(X, C) + if self.black_hole_init: + R_i, t_i = self._init_black_hole(X) + + # Predict transforms + R = geometry.rotations_from_quaternions( + self.W_q(node_h), normalize=True, eps=self.distance_eps + ) + t = self.distance_scale * self.W_t(node_h) + + # Apply transformations + R_i_pred, t_i_pred = transforms.compose_transforms(R_i, t_i, R, t) + X_update = self.frame_builder(R_i_pred, t_i_pred, C) + return X_update, None, None, None + + def _update_neighbor_transform( + self, X, C, node_h, edge_h, edge_idx, mask_i, mask_ij + ): + """Update residue frames via weighted average transformation from neighbors.""" + + # Predict relative transformations from neighbors to self + R_ji = geometry.rotations_from_quaternions( + self.W_q(edge_h), normalize=True, eps=self.distance_eps + ) + t_ji = self.distance_scale * self.W_t(edge_h) + logit_ij = self.W_w(edge_h) + + # Compute predicted self locations from each neighbor + R_i, t_i, _ = self.frame_builder.inverse(X, C) + if self.black_hole_init: + R_i, t_i = self._init_black_hole(X) + + R_i, t_i = transforms.equilibrate_transforms( + R_i, + t_i, + R_ji, + t_ji, + logit_ij, + mask_ij, + edge_idx, + iterations=self.iterations, + ) + X_update = self.frame_builder(R_i, t_i, C) + + return X_update, R_ji, t_ji, None + + def _update_neighbor_global_transform( + self, X, C, node_h, edge_h, edge_idx, mask_i, mask_ij + ): + """Update residue frames via weighted average transformation from neighbors.""" + + # Predict relative transformations from neighbors to self + R_ji = geometry.rotations_from_quaternions( + self.W_q(edge_h), normalize=True, eps=self.distance_eps + ) + t_ji = self.distance_scale * self.W_t(edge_h) + logit_ji = self.W_w(edge_h) + + # Predict relative transformations to global frame + R_global_i = geometry.rotations_from_quaternions( + self.W_q_global(node_h), normalize=True, eps=self.distance_eps + ) + t_global_i = self.distance_scale * self.W_t_global(node_h) + logit_global_i = self.W_w_global(node_h) + + # Initialize global frame equivariantly + R_i, t_i, _ = self.frame_builder.inverse(X, C) + if self.black_hole_init: + R_i, t_i = self._init_black_hole(X) + + R_global, t_global = transforms.average_transforms( + R_i, t_i, mask_i[..., None], mask_i, dim=1, dither_eps=0.0 + ) + + # Compute predicted self locations from each neighbor + R_i, t_i = transforms.equilibrate_transforms( + R_i, + t_i, + R_ji, + t_ji, + logit_ji, + mask_ij, + edge_idx, + iterations=self.iterations, + R_global=R_global, + t_global=t_global, + R_global_i=R_global_i, + t_global_i=t_global_i, + logit_global_i=logit_global_i, + ) + X_update = self.frame_builder(R_i, t_i, C) + + return X_update, R_ji, t_ji, logit_ji + + def _update_neighbor_global_affine_transform( + self, X, C, node_h, edge_h, edge_idx, mask_i, mask_ij + ): + """Update residue frames via weighted average transformation from neighbors.""" + + # Compute interresidue geometries for current system + R_i_init, t_i_init, _ = self.frame_builder.inverse(X, C) + if self.black_hole_init: + R_i_init, t_i_init = self._init_black_hole(X) + + R_j_init, t_j_init = transforms.collect_neighbor_transforms( + R_i_init, t_i_init, edge_idx + ) + R_global, t_global = transforms.average_transforms( + R_i_init, t_i_init, mask_i[..., None], mask_i, dim=1, dither_eps=0.0 + ) + R_ji_init, t_ji_init = transforms.compose_inner_transforms( + R_j_init, t_j_init, R_i_init.unsqueeze(-3), t_i_init.unsqueeze(-2) + ) + R_gi_init, t_gi_init = transforms.compose_inner_transforms( + R_global.unsqueeze(1), t_global.unsqueeze(1), R_i_init, t_i_init + ) + q_ji_init = geometry.quaternions_from_rotations(R_ji_init) + q_gi_init = geometry.quaternions_from_rotations(R_gi_init) + + # Scale factor + s_node = torch.sigmoid(self.W_s_node(node_h)[..., None]).unbind(-2) + s_edge = torch.sigmoid(self.W_s_edge(edge_h)[..., None]).unbind(-2) + d_scale = self.distance_scale + + # Use edges to predict relative transformations from neighbors to self + q_ji = s_edge[0] * q_ji_init + (1.0 - s_edge[0]) * self.W_q(edge_h) + t_ji = s_edge[1] * t_ji_init + (1.0 - s_edge[1]) * d_scale * self.W_t(edge_h) + logit_ji = self.W_w(edge_h) + + # Use nodes to predict relative transformations to global frame + q_gi = s_node[0] * q_gi_init + (1.0 - s_node[0]) * self.W_q_global(node_h) + t_gi = s_node[1] * t_gi_init + (1.0 - s_node[1]) * d_scale * self.W_t_global( + node_h + ) + logit_gi = self.W_w_global(node_h) + + R_ji = geometry.rotations_from_quaternions( + q_ji, normalize=True, eps=self.distance_eps + ) + R_gi = geometry.rotations_from_quaternions( + q_gi, normalize=True, eps=self.distance_eps + ) + + # Compute predicted self locations from each neighbor + R_i, t_i = transforms.equilibrate_transforms( + R_i_init, + t_i_init, + R_ji, + t_ji, + logit_ji, + mask_ij, + edge_idx, + iterations=self.iterations, + R_global=R_global, + t_global=t_global, + R_global_i=R_gi, + t_global_i=t_gi, + logit_global_i=logit_gi, + ) + X_update = self.frame_builder(R_i, t_i, C) + return X_update, R_ji, t_ji, logit_ji + + def _inner_transforms(self, X, C, edge_idx): + R_i, t_i, _ = self.frame_builder.inverse(X, C) + R_ji, t_ji = transforms.collect_neighbor_inner_transforms(R_i, t_i, edge_idx) + return R_ji, t_ji + + def _transform_loss(self, R_ij_predict, t_ij_predict, X, C, edge_idx, mask_ij): + """Compute loss between transforms""" + R_ij, t_ij = self._inner_transforms(X, C, edge_idx) + R_ij_error = (R_ij_predict - R_ij).square().sum([-1, -2]) + t_ij_error = (t_ij_predict - t_ij).square().sum([-1]) + R_ij_mse = (mask_ij * R_ij_error).sum([1, 2]) / ( + mask_ij.sum([1, 2]) + self._eps + ) + t_ij_mse = (mask_ij * t_ij_error).sum([1, 2]) / ( + mask_ij.sum([1, 2]) + self._eps + ) + return R_ij_mse, t_ij_mse + + def forward( + self, + X: torch.Tensor, + C: torch.LongTensor, + node_h: torch.Tensor, + edge_h: torch.Tensor, + edge_idx: torch.LongTensor, + mask_i: torch.Tensor, + mask_ij: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + methods = { + "local": self._update_local_transform, + "neighbor": self._update_neighbor_transform, + "neighbor_global": self._update_neighbor_global_transform, + "neighbor_global_affine": self._update_neighbor_global_affine_transform, + "none": lambda *args: args[0], + } + method = methods[self.method] + + # Update frames with ideal geometry + X_update, R_ji, t_ji, logit_ji = method( + X, C, node_h, edge_h, edge_idx, mask_i, mask_ij + ) + + if self.unconstrained: + # Predict atomic updates as delta from ideal geometry + # R_i, t_i, _ = self.frame_builder.inverse(X, C) # NOTE: Old models did this which was a typo + R_i, t_i, _ = self.frame_builder.inverse(X_update, C) + t_local = self.W_t_local(node_h).reshape(X.shape) + + # Rotate atomic updates into local frame + R_i = R_i.unsqueeze(-3) + t_i = torch.zeros_like(t_i).unsqueeze(-2) + dX = transforms.compose_translation(R_i, t_i, t_local) + + if self.training: + # Randomly swap between ideal coordinates at train time + mask_drop = ( + torch.randint( + low=0, high=2, size=[C.shape[0], 1, 1, 1], device=X.device + ) + > 0 + ).float() + dX = mask_drop * dX + + X_update = X_update + dX + return X_update, R_ji, t_ji, logit_ji + + +class LossBackboneResidueDistance(nn.Module): + """Compute losses for training denoising diffusion models. + + Inputs: + X_mobile (torch.Tensor): Mobile coordinates with shape + `(num_source, num_atoms, 4, 3)`. + X_target (torch.Tensor): Target coordinates with shape + `(num_target, num_atoms, 4, 3)`. + C (torch.Tensor): Chain map with shape `(num_batch, num_residues)`. + + Outputs: + D_error (Tensor, optional): Per-site average distance errors with shape + `(num_batch)`. + """ + + def __init__(self, dist_eps: float = 1e-3): + super(LossBackboneResidueDistance, self).__init__() + self.dist_eps = dist_eps + + def _D(self, X): + """Compute distance matrix between center of mass""" + X_mean = X.mean(2) + D = ( + (X_mean[:, :, None, :] - X_mean[:, None, :, :]) + .square() + .sum(-1) + .add(self.dist_eps) + .sqrt() + ) + return D + + def forward( + self, X_mobile: torch.Tensor, X_target: torch.Tensor, C: torch.LongTensor + ) -> torch.Tensor: + mask = (C > 0).float() + mask_2D = mask[:, :, None] * mask[:, None, :] + D_error = (self._D(X_mobile) - self._D(X_target)).square() + D_error = (mask_2D * D_error).sum(-1) / (mask_2D.sum(-1) + self.dist_eps) + return D_error + + +def center_X(X: torch.Tensor, C: torch.LongTensor) -> torch.Tensor: + """Center each protein system at the origin. + + Args: + X (torch.Tensor): Backbone coordinates with shape + `(batch_size, num_residues, num_atoms, 3)`. + C (torch.LongTensor): Chain map with shape + `(num_batch, num_residues)`. + Returns: + X_centered (torch.Tensor): Centered backbone coordinates with shape + `(batch_size, num_residues, num_atoms, 3)`. + """ + mask_expand = ( + (C > 0).float().reshape(list(C.shape) + [1, 1]).expand([-1, -1, 4, -1]) + ) + X_mean = (mask_expand * X).sum([1, 2], keepdims=True) / ( + mask_expand.sum([1, 2], keepdims=True) + ) + X_centered = mask_expand * (X - X_mean) + return X_centered + + +def atomic_mean( + X_flat: torch.Tensor, mask: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute the mean across all 4 atom types. + + Args: + X (torch.Tensor): Flattened backbone coordinates with shape + `(batch_size, num_residues * num_atoms, 3)`. + mask (torch.Tensor): Mask with shape `(num_batch, num_residues)`. + Returns: + X_mean (torch.Tensor): System centers with shape `(batch_size, 3)`. + mask_atomic (torch.Tensor): Atomic mask with shape + `(batch_size, num_residues * num_atoms)`. + """ + mask_expand = mask.unsqueeze(-1).expand(-1, -1, 4) + mask_atomic = mask_expand.reshape(mask.shape[0], -1).unsqueeze(-1) + X_mean = torch.sum(mask_atomic * X_flat, 1, keepdims=True) / ( + torch.sum(mask_atomic, 1, keepdims=True) + ) + return X_mean, mask_atomic + + +def scale_around_mean( + X: torch.Tensor, C: torch.LongTensor, scale: float +) -> torch.Tensor: + """Scale coordinates around mean. + + Args: + X (torch.Tensor): Backbone coordinates with shape + `(batch_size, num_residues, num_atoms, 3)`. + C (LongTensor): Chain map with shape + `(num_batch, num_residues)`. + scale (float): Scalar factor by which to rescale + the coordinates. + + Returns: + X_scaled (torch.Tensor): Backbone coordinates with shape + `(batch_size, num_residues, num_atoms, 3)`. + """ + num_atoms = X.size(2) + mask_expand = ( + (C > 0).float().reshape(list(C.shape) + [1, 1]).expand([-1, -1, num_atoms, -1]) + ) + X_mean = (mask_expand * X).sum([1, 2], keepdims=True) / ( + mask_expand.sum([1, 2], keepdims=True) + ) + X_rescale = mask_expand * (scale[:, None, None, None] * (X - X_mean) + X_mean) + return X_rescale + + +def impute_masked_X(X: torch.Tensor, C: torch.LongTensor) -> torch.Tensor: + """Impute missing structure data to enforce chain contiguity. + + The posterior mean under a Brownian bridge is simply either the + nearest unclamped state or a linear interpolant between the two + nearest clamped endpoints along the chain. + + Args: + X (torch.Tensor): Backbone coordinates with shape + `(batch_size, num_residues, num_atoms, 3)`. + C (LongTensor): Chain map with shape + `(num_batch, num_residues)`. + + Returns: + X (torch.Tensor): Imputed backbone coordinates with shape + `(batch_size, num_residues, num_atoms, 3)`. + """ + X_flat = X.reshape(X.shape[0], -1, 3) + mask = (C > 0).type(torch.float32) + X_mean, mask_atomic = atomic_mean(X_flat, mask) + + # Expand chain map into atomic level masking + C_expand = C.unsqueeze(-1).expand(-1, -1, 4) + C_atomic = C_expand.reshape(C.shape[0], -1) + + # Find nearest unmasked positions to the left and right + ix = torch.arange(C_atomic.shape[1], device=X.device).reshape([1, -1]) + mask_atomic_extend = mask_atomic.squeeze(-1) + ix_mask = mask_atomic_extend * ix - (1 - mask_atomic_extend) + ix_left, _ = torch.cummax(ix_mask, dim=1) + ix_flip = torch.flip( + mask_atomic_extend * ix_mask + (1 - mask_atomic_extend) * C_atomic.shape[1], + [1], + ) + ix_right, _ = torch.cummin(ix_flip, dim=1) + ix_right = torch.flip(ix_right, [1]) + + ix_left = ix_left.long() + ix_right = ix_right.long() + + clamped_left = ix_left >= 0 + clamped_right = ix_right < C_atomic.shape[1] + ix_left[ix_left < 0] = 0 + ix_right[ix_right == C_atomic.shape[1]] = 0 + + X_left = torch.gather(X_flat, 1, ix_left.unsqueeze(-1).expand(-1, -1, 3)) + X_right = torch.gather(X_flat, 1, ix_right.unsqueeze(-1).expand(-1, -1, 3)) + + # Enfore that adjacent residues are same chain + C_abs = torch.abs(C_atomic) + C_left = torch.gather(C_abs, 1, ix_left) + C_right = torch.gather(C_abs, 1, ix_right) + clamped_left = clamped_left * (C_left == C_abs) + clamped_right = clamped_right * (C_right == C_abs) + + # Build linear interpolant + D_left = torch.abs(ix - ix_left) + D_right = torch.abs(ix_right - ix) + interp_theta = (D_right / (D_left + D_right + 1e-5)).unsqueeze(-1) + X_interp = interp_theta * X_left + (1 - interp_theta) * X_right + + clamped_left = clamped_left.unsqueeze(-1) + clamped_right = clamped_right.unsqueeze(-1) + X_imputed_flat = mask_atomic * X_flat + (1 - mask_atomic) * ( + clamped_left * clamped_right * X_interp + + clamped_right * (~clamped_left) * X_right + + (~clamped_right) * clamped_left * X_left + ) + + X_imputed = X_imputed_flat.reshape(X.shape) + return X_imputed + + +def expand_chain_map(C: torch.LongTensor) -> torch.Tensor: + """Expand an integer chain map into a binary chain mask. + + Args: + C (LongTensor): Chain map with shape + `(num_batch, num_residues)`. + + Returns: + mask_C (torch.Tensor): Expanded binary chain map with shape + `(num_batch, num_residue, num_chains)`. + """ + + # Compute the per-chain averages of each feature within a chain, in place + num_batch, num_residues = list(C.shape) + num_chains = int(torch.max(C).item()) + + # Build a position == chain expanded mask (B,L,C) + C_expand = C.unsqueeze(-1).expand(-1, -1, num_chains) + idx = torch.arange(num_chains, device=C.device) + 1 + idx_expand = idx.view(1, 1, -1) + mask_C = (idx_expand == C_expand).type(torch.float32) + return mask_C diff --git a/prtm/models/chroma/structure/conditioners.py b/prtm/models/chroma/structure/conditioners.py new file mode 100644 index 00000000..a61eb173 --- /dev/null +++ b/prtm/models/chroma/structure/conditioners.py @@ -0,0 +1,1524 @@ +# Copyright Generate Biomedicines, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Layers for conditioning diffusion generative processes. +""" + +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F + +from prtm.protein import ProteinBase +from prtm.models.chroma import chroma_utils, graph_classifier, procap +from prtm.models.chroma.graph_backbone import GraphBackbone +from prtm.models.chroma.graph_classifier import GraphClassifier +from prtm.models.chroma.graph_design import GraphDesign +from prtm.models.chroma.procap import ProteinCaption +from prtm.models.chroma.modeling import get_mask, protein_to_xcs +from prtm.models.chroma.structure import backbone, mvn, optimal_transport, symmetry +from prtm.models.chroma.structure.backbone import expand_chain_map +from prtm.models.chroma.system import System +from prtm.models.chroma.xcs import validate_XC +from scipy.sparse.csgraph import shortest_path +from torch import nn + + +class Conditioner(torch.nn.Module): + """ + A composable function for parameterizing protein design problems. + + Conditioners provide a general framework for expressing complex protein + design problems in terms of simpler, composable sub-conditions in + a way that enables automatic sampling. To accomplish this, Conditioners + parameterize time-dependent transformations to the global coordinate system + and total energy by mapping from unconstrained coordinates and energy to + potentially updated coordinates and energy. This convention can subsume + classifier guidance, bijective change-of-variables constrained MCMC, and + linear subspace constrained MCMC as special cases. + + A conditioner is implemented as a function which maps from state-energy pairs + at a time point `t` to updated state-energy pairs which may reflect hard constraints + (typically updates to coordinates and energy) and restraintes (updates just to + energy). Conditioners take in and return 5 arguments `X, C, O, U, t`, + where `X,C,O` is the protein complex in the `XCS` convention with `S` expressed as a + one-hot tensor `O`, `U` is the total system energy and `t` is the diffusion time. + Because conditioners have matched input and output types, they can be composed via + sequential chaining. Further examples and descriptions of Conditioners can be found + throughout this module. + + Args: + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + Inputs: + X (torch.Tensor): Coordinates with shape `(batch_size, num_residues, 4, 3)`. + C (torch.LongTensor): Chain map with shape `(batch_size, num_residues)`. + O (torch.Tensor): One-hot sequence with shape + `(batch_size, num_residues, num_alphabet)`. + U (torch.Tensor): energy tensor with shape `(batch_size,)`. + t (Union[torch.Tensor, float]): Time tensor with shape `(batch_size,)` or + a scalar. + + Outputs: + X_out (torch.Tensor): Updated coordinates with shape + `(batch_size, num_residues_out, 4, 3)`. + C_out (torch.LongTensor): Updated chain map with shape + `(batch_size, num_residues_out)`. + O (torch.Tensor): Updated one-hot sequences with shape + `(batch_size, num_residues_out, num_alphabet)`. + U_out (torch.Tensor): Modified energy tensor with shape `(batch_size,)`. + t_out (Union[torch.Tensor, float]): Modified time tensor with shape + `(batch_size,)` or a scalar. + """ + + def __init__(self, *args, **kwargs): + super().__init__() + + def forward( + self, + X: torch.Tensor, + C: torch.LongTensor, + O: torch.Tensor, + U: torch.Tensor, + t: Union[torch.Tensor, float], + ) -> Tuple[ + torch.Tensor, + torch.LongTensor, + torch.Tensor, + torch.Tensor, + Union[torch.Tensor, float], + ]: + pass + + +class Identity(Conditioner): + def __init__(self): + super().__init__() + + @validate_XC() + def forward( + self, + X: torch.Tensor, + C: torch.LongTensor, + O: torch.Tensor, + U: torch.Tensor, + t: Union[torch.Tensor, float], + ) -> Tuple[ + torch.Tensor, + torch.LongTensor, + torch.Tensor, + torch.Tensor, + Union[torch.Tensor, float], + ]: + return X, C, O, U, t + + +class ComposedConditioner(Conditioner): + def __init__(self, conditioners): + super().__init__() + self.conditioners = nn.ModuleList(conditioners) + + @validate_XC() + def forward( + self, + X: torch.Tensor, + C: torch.LongTensor, + O: torch.Tensor, + U: torch.Tensor, + t: Union[torch.Tensor, float], + ) -> Tuple[ + torch.Tensor, + torch.LongTensor, + torch.Tensor, + torch.Tensor, + Union[torch.Tensor, float], + ]: + for conditioner in self.conditioners: + X, C, O, U, t = conditioner(X, C, O, U, t) + return X, C, O, U, t + + def _postprocessing_( + self, protein: ProteinBase, output_dict: Optional[dict] = None + ) -> Union[ProteinBase, Tuple[ProteinBase, dict]]: + for _conditioner in self.conditioners: + if hasattr(_conditioner, "_postprocessing_"): + if output_dict is None: + protein = _conditioner._postprocessing_(protein, output_dict) + else: + protein, output_dict = _conditioner._postprocessing_( + protein, output_dict + ) + + if output_dict is None: + return protein + else: + return protein, output_dict + + +class SubsequenceConditioner(Conditioner): + """ + SequenceConditioner: + A Chroma Conditioning module which, given a GraphDesign model and a subset of + residues for which sequence information is known, can add gradients to sampling + that bias the samples towards increased `log p(sequence | structure)` + + Args: + design_model (GraphDesign): Trained GraphDesign model. + S_condition (torch.Tensor): Of shape (1, num_residues) specifying sequence + information. + mask_condition (torch.Tensor, optional): Of shape (1, num_residues) specifying + which residues to include when computing `log p(sequence | structure)` + weight (float, optional): Overall weight to which the gradient is scaled. + renormalize_grad (bool, optional): Whether to renormalize gradient to have + overall variance `weight`. + """ + + def __init__( + self, + design_model: GraphDesign, + protein: ProteinBase, + selection: str = "all", + weight: float = 1.0, + renormalize_grad: Optional[bool] = False, + ): + super().__init__() + self.design_model = design_model + + # Register sequence buffers + X, C, S = protein_to_xcs(protein) + mask_condition = get_mask(protein, selection) + self.register_buffer("S_condition", S) + self.register_buffer("mask_condition", mask_condition) + + self.weight = weight + self.renormalize_grad = renormalize_grad + + def _transform_gradient(self, grad, C, t): + # grad = clip_atomic_magnitudes_percentile(grad) + scale = self.weight / self.design_model.noise_perturb.noise_schedule.sigma( + t + ).to(C.device) + grad = scale * grad / grad.square().mean().sqrt() + return grad + + @validate_XC() + def forward( + self, + X: torch.Tensor, + C: torch.LongTensor, + O: torch.Tensor, + U: torch.Tensor, + t: Union[torch.Tensor, float], + ) -> Tuple[ + torch.Tensor, + torch.LongTensor, + torch.Tensor, + torch.Tensor, + Union[torch.Tensor, float], + ]: + X_input = X + 0.0 + + if self.renormalize_grad: + X_input.register_hook(lambda _X: self._transform_gradient(_X, C, t)) + if X.shape[2] == 4: + X_input = F.pad(X_input, [0, 0, 0, 10]) + + priority = None + if self.mask_condition is not None: + priority = 1.0 - self.mask_condition + out = self.design_model(X_input, C, self.S_condition, t, priority=priority) + logp_S = out["logp_S"] + + if self.mask_condition is not None: + logp_S = self.mask_condition * logp_S + U = U + self.weight * -logp_S.sum() + return X, C, O, U, t + + +class ShapeConditioner(Conditioner): + """Volumetric potential for optimizing towards arbitrary geometries. + + Args: + X_target (numpy array): Target point cloud with shape `(num_points, 3)`. + noise_schedule (GaussianNoiseSchedule): Diffusion time schedule for loss + scaling. + autoscale (bool): If True, automatically rescale target point cloud coordinates + such that they are approximately volume-scaled to a target protein size. + Volume is roughly estimated by converting the point cloud to a sphere cloud + with radii large enough to overlap with near neighbors and double counting + corrections via inclusion-exclusion. + autoscale_num_residues (int): Target protein size for auto-scaling. + autoscale_target_ratio (float): Scale factor for adjusting the target protein + volume. + scale_invariant (bool): If True, compute the loss in a size invariant manner + by dynamically renormalizing the point clouds to match Radii of gyration. + This approach can be more unstable to integrate and require more careful tuning. + shape_loss_weight (float): Scale factor for the overall restraint. + shape_loss_cutoff (float): Minimal distance deviation that is penalized in the loss, + e.g. to treat as a flat-bottom restraint below the cutoff. + sinkhorn_iterations (int): Number of Sinkhorn iterations for Optimal Transport + calculations. + sinkhorn_scale (float): Entropy regularization scaling parameter for Optimal + Transport calculations. + sinkhorn_iterations_gw (int): Number of Sinkhorn iterations for Gromov-Wasserstein + Optimal Transport calculations. + sinkhorn_scale_gw (float): Entropy regularization scaling parameter for + Gromov-Wasserstein Optimal Transport calculations. + gw_layout (bool): If True, use Gromov-Wasserstein Optimal Transport to compute + a point cloud correspondence assuming ideal protein distance scaling. + gw_layout_coefficient (float): Scale factor with which to combine average + inter-point cloud distances according to OT (Wasserstein) versus + Gromov-Wasserstein couplings. + """ + + def __init__( + self, + X_target, + noise_schedule, + autoscale: bool = True, + autoscale_num_residues: int = 500, + autoscale_target_ratio: float = 0.4, + scale_invariant: bool = False, + shape_loss_weight: float = 20.0, + shape_loss_cutoff: float = 0.0, + shape_cutoff_D: float = 0.01, + scale_max_rg_ratio: float = 1.5, + sinkhorn_iterations: int = 10, + sinkhorn_scale: float = 1.0, + sinkhorn_scale_gw: float = 200.0, + sinkhorn_iterations_gw: int = 30, + gw_layout: bool = True, + gw_layout_coefficient: float = 0.4, + eps: float = 1e-3, + debug: bool = False, + ): + super().__init__() + + self.eps = eps + + self.noise_schedule = noise_schedule + + # Shape control parameters + self.shape_loss_weight = shape_loss_weight + self.shape_loss_cutoff = shape_loss_cutoff + self.scale_invariant = scale_invariant + self.shape_cutoff_D = shape_cutoff_D + self.scale_max_rg_ratio = scale_max_rg_ratio + + # Autoscale volumes (in units of cubic angstroms) + self.autoscale = autoscale + self.autoscale_num_residues = autoscale_num_residues + self.autoscale_target_ratio = autoscale_target_ratio + + self.sinkhorn_iterations = sinkhorn_iterations + self.sinkhorn_scale = sinkhorn_scale + self.sinkhorn_iterations_gw = sinkhorn_iterations_gw + self.sinkhorn_scale_gw = sinkhorn_scale_gw + + self.debug = debug + + if torch.is_tensor(X_target): + X_target = X_target.cpu().data.numpy() + + if self.autoscale: + X_target, self.shape_cutoff_D = chroma_utils.point_cloud_rescale( + X_target, + self.autoscale_num_residues, + scale_ratio=self.autoscale_target_ratio, + ) + + # Map coupling with Gromov Wasserstein optimal transport + self.gw_layout = gw_layout + self.gw_layout_coefficient = gw_layout_coefficient + if self.gw_layout: + self._map_gw_coupling_ideal_glob( + X_target, num_residues=autoscale_num_residues + ) + + X_target = torch.Tensor(X_target) + self.register_buffer("X_target", X_target[None, ...].clone().detach()) + + def _distance_knn(self, X, top_k=12, max_scale=10.0): + """Topology distance.""" + X_np = X.cpu().data.numpy() + D = np.sqrt( + ((X_np[:, :, np.newaxis, :] - X_np[:, np.newaxis, :, :]) ** 2).sum(-1) + ) + + # Distance cutoff + D_cutoff = np.mean(np.sort(D[0, :, :], axis=-1)[:, top_k]) + D[D > D_cutoff] = max_scale * np.max(D) + D = shortest_path(D[0, :, :])[np.newaxis, :, :] + D = torch.Tensor(D).float().to(X.device) + return D + + @torch.no_grad() + def _map_gw_coupling_ideal_glob(self, X_target, num_residues): + """Plan a layout using Gromov-Wasserstein Optimal transport""" + + X_target = torch.Tensor(X_target).float().unsqueeze(0) + if torch.cuda.is_available(): + X_target = X_target.to("cuda") + + chain_ix = torch.arange(4 * num_residues, device=X_target.device) / 4.0 + distance_1D = (chain_ix[None, :, None] - chain_ix[None, None, :]).abs() + # Scaling fit log-log to large scale single chain 6HYP + D_model = 7.21 * distance_1D**0.322 + D_model = D_model / D_model.mean([1, 2], keepdims=True) + + D_target = self._distance_knn(X_target) + D_target = D_target / D_target.mean([1, 2], keepdims=True) + + T_gw, D_gw = optimal_transport.optimize_couplings_gw( + D_model, + D_target, + scale=self.sinkhorn_scale_gw, + iterations_outer=self.sinkhorn_iterations_gw, + iterations_inner=self.sinkhorn_iterations, + ) + + self.register_buffer("T_gw", T_gw.clone().detach()) + return + + def _distance(self, X_i, X_j): + dX = X_i.unsqueeze(2) - X_j.unsqueeze(1) + D = torch.sqrt((dX**2).sum(-1) + self.eps) + return D + + @validate_XC() + def forward( + self, + X: torch.Tensor, + C: torch.LongTensor, + O: torch.Tensor, + U: torch.Tensor, + t: Union[torch.Tensor, float], + ) -> Tuple[ + torch.Tensor, + torch.LongTensor, + torch.Tensor, + torch.Tensor, + Union[torch.Tensor, float], + ]: + # Distance matrix is + # [Num_batch, Num_atoms_target, Num_atoms_model] + X_target = self.X_target + X_model = X.reshape([X.shape[0], -1, 3]) + + # Radius of gyration ceiling + num_residues = X.shape[1] + 2.0 * num_residues**0.333 + self.scale_max_rg_ratio * 2.0 * num_residues**0.4 + self.shape_cutoff_D + + def _center(_X): + _X = _X - _X.mean(1, keepdim=True) + return _X + + def _rg(_X): + _X = _X - _X.mean(1, keepdim=True) + rsq = _X.square().sum(2, keepdim=True) + rg = rsq.mean(1, keepdim=True).sqrt() + return rg + + X_model = _center(X_model) + X_target = _center(X_target) + + if self.scale_invariant: + + def _resize(_X, target_rg): + _X = _X - _X.mean(1, keepdim=True) + rsq = _X.square().sum(2, keepdim=True) + rg = rsq.mean(1, keepdim=True).sqrt() + return _X / rg * target_rg + + X_model = _resize(X_model, _rg(X_target)) + + # Compute interatomic distances + D_inter = self._distance(X_model, X_target) + + # Estimate Wasserstein Distance + cost = D_inter + T_w = optimal_transport.optimize_couplings_sinkhorn( + cost, scale=self.sinkhorn_scale, iterations=self.sinkhorn_iterations + ) + if self.gw_layout: + T_w = T_w + self.T_gw * self.gw_layout_coefficient + T_w = T_w / T_w.sum([-1, -2], keepdims=True) + D_w = (T_w * D_inter).sum([-1, -2]) + + # Scale by sqrt(SNR_t) and constant factor + scale_t = self.shape_loss_weight * self.noise_schedule.SNR(t).sqrt().clamp( + min=1e-3, max=3.0 + ) + neglogp = scale_t * F.softplus(D_w - self.shape_loss_cutoff) + U = U + neglogp + return X, C, O, U, t + + +class ProCapConditioner(Conditioner): + """Natural language conditioning for protein backbones. + + This conditioner uses an underlying `ProteinCaption` model to determine the + likelihood of a noised structure corresponding to a given caption. Captions + can be specified as corresopnding to a particular chain of the structure, or + to the entire complex. The encoded structures and captions are passed to the + model together, and the output loss that adjusts the energy is the masked + cross-entropy over the caption tokens. + + Args: + caption (str): Caption for the conditioner. Currently, a separate + conditioner should be constructed for each desired caption, even + with a single `ProteinCaption` model. + chain_id (int): The 1-indexed chain to which the caption corresponds, or + -1 for captions corresponding to the entire structure. The provided + checkpoints are trained with UniProt captions for chain_id > 0 and + PDB caption for chain_id = -1. Regardless of whether the caption is + specific to one chain, the conditioner acts on the entire structure. + weight (float): Overall factor by which the caption gradient is scaled. + model (generate.models.procap.ProteinCaption, optional): The + input model whose likelihoods are used. If not given, defaults to + the checkpoint used for the paper. + use_sequence (bool): Whether to use input sequence, default False. + device (str, optional): Device on which to store model. If not given, + GPU will be used when available. + Inputs: + X (torch.Tensor): Structure tensor with shape + `(batch_size, num_residues, 4, 3)`. + C (torch.LongTensor): Chain map tensor with shape + `(batch_size, num_residues)` + O (torch.Tensor, optional): One-hot tensor allowing the input of + sequence information, of shape (1, num_residues, num_alphabet). + U (torch.Tensor): Energy tensor with shape `(batch_size,)`. + t (Union[torch.Tensor, float]): Time tensor with shape `(batch_size,)` + or a scalar. + + Outputs: + X_out (torch.Tensor): Unchanged structure tensor with shape + `(batch_size, num_residues, 4, 3)`. + C_out (torch.LongTensor): Unchanged chain map tensor with shape + `(batch_size, num_residues)`. + O_out (torch.Tensor, optional): One-hot tensor allowing the output of + sequence information, of shape (1, num_residues, num_alphabet). + U_out (torch.Tensor): Modified energy tensor with shape `(batch_size,)`. + t_out (Union[torch.Tensor, float]): Modified time tensor with shape + `(batch_size,)` or a scalar. + """ + + def __init__( + self, + caption: str, + chain_id: int, + weight: float = 10, + model: Union[ProteinCaption, str] = "named:public", + use_sequence: bool = False, + device: Optional[str] = None, + ) -> None: + super().__init__() + if isinstance(model, ProteinCaption): + self.model = model + elif isinstance(model, str): + self.model = procap.load_model( + model, device=device, strict_unexpected=False + ) + self.model.eval() + if device is None: + if torch.cuda.is_available(): + self.model.to("cuda") + else: + self.model.to(device) + self.caption = caption + self.register_buffer("chain_id", torch.Tensor([int(chain_id)])) + self.weight = weight + self.use_sequence = use_sequence + + @validate_XC() + def forward( + self, + X: torch.Tensor, + C: torch.LongTensor, + O: torch.Tensor, + U: torch.Tensor, + t: Union[torch.Tensor, float], + ) -> Tuple[ + torch.Tensor, + torch.LongTensor, + torch.Tensor, + torch.Tensor, + Union[torch.Tensor, float], + ]: + loss = self.model( + X, + C, + [self.caption] * X.shape[0], + self.chain_id.to(X.device).expand(X.shape[0], 1), + O=O if self.use_sequence else None, + add_noise=False, + t=t, + ).loss + U = U + self.weight * loss + return X, C, O, U, t + + +class ProClassConditioner(Conditioner): + """ + ProClassConditioner: + A Chroma Conditioning module which can specify chain level annotations for fold, + function, and organism. The current labels that can be conditioned on are: + + * cath: protein domain annotations from . Annotation + examples include `2`, `2.40`, `2.40.155`. + * funfam: domain level functional annotations. + * organism: the organism of origin of a protein. Annotation examples include `Homo + sapiens (Human)`, `Escherichia coli`, `Pseudomonas putida (Arthrobacter + siderocapsulatus)`, `Rattus norvegicus (Rat)` + * pfam: protein family annotations which represent domain level structural + characteristics. + + For a complete list of valid value label pairs import the value dictionary from the + `GraphClassifierLoader` in the zoo. + + Note: + This conditioner is a research preview. Conditioning with it can be inconsistent + and depends on the relative prevalence of a given label in the dataset. + With repeated tries it will produce successful results for more abundant labels. + Please see the supplement to the paper for details. This is currently not + recommended for production use. The most reproducible labels are C level + annotations in cath, (e.g. `1`,`2`,`3`). + + Args: + label (str): The annotation to condition on in the set [cath, funfam, pfam, + organism, secondary_structure]. + value (str, optional): The particular annotation string to use. For a complete + list of values for a given label use the static method + :meth:`possible_conditions`. Defaults to None. + model (GraphClassifier, optional): A ProClass instance to use for conditioning. + if None is provided the recommended model is automatically loaded. Defaults + to None. + weight (float, optional): The weighting of the conditioner relative to the + backbone model. Defaults to 1. + max_norm (float, optional): The maximum magnitude of the gradient, above which + the magnitude is clipped. Defaults to None. + renormalize_grad (bool, optional): Whether to renormalize gradient to have + overall variance `weight`. + use_sequence (bool, optional): Whether to use input sequence, default False. + device (str, optional): the device to put the conditioner on, accepts `cpu` + and `cuda`. If None is provided it will automatically try to put it on the + GPU if possible. Defaults to None. + debug (bool, optional): provides gradient values during optimization for + setting weights and debugging. + """ + + def __init__( + self, + label: str, + value: Union[Optional[str], torch.Tensor] = None, + model: Union[GraphClassifier, str] = "named:public", + weight: float = 5, + max_norm: Optional[float] = 20, + renormalize_grad: Optional[bool] = False, + use_sequence: bool = False, + device: Optional[str] = None, + debug: bool = False, + ) -> None: + super().__init__() + self.label = label + self.value = value + self.max_norm = max_norm + self.renormalize_grad = renormalize_grad + self.weight = weight + self.use_sequence = use_sequence + self.debug = debug + + if isinstance(model, str): + self.proclass_model = graph_classifier.load_model(model, device=device) + elif isinstance(model, GraphClassifier): + self.proclass_model = model + self.proclass_model.eval() + + # Move Model to the indicated device + if device is None: + if torch.cuda.is_available(): + self.proclass_model.to("cuda") + else: + self.proclass_model.to(device) + + self._transform_inputs() + self._validate_inputs() + + def _transform_inputs(self): + # Automatically handle heirarchical inputs in the format X.Y.Z.W + if self.label.lower() in ["cath", "funfam"]: + self.label = self.label.lower() + self.label += "_" + str(self.value.count(".")) + + # Correct Capitalization + if self.label.lower() == "organism": + self.label = "Organism" + + # Support Normative PFam IDs + if self.label.lower() == "pfam": + valid_values = self.proclass_model.class_config["pfam"]["tokens"] + if self.value.count(".") == 0: + valid_ids = [s for s in valid_values if self.value in s] + if len(valid_ids) == 1: + self.value = valid_ids[0] + else: + raise Exception(f"Invalid Value {self.value} for {self.label}.") + + def _validate_inputs(self): + # Check Labels + valid_labels = list(self.proclass_model.heads["chain"].keys()) + valid_labels += list(self.proclass_model.heads["first_order"]) + if self.label not in valid_labels: + valid_label_str = ", ".join(valid_labels) + raise Exception(f"Invalid Label. Label must be one of: {valid_label_str}.") + + # Check Values + if self.label in list(self.proclass_model.heads["chain"].keys()): + valid_values = self.proclass_model.class_config[self.label]["tokens"] + if self.value not in valid_values: + raise Exception(f"Invalid Value {self.value} for {self.label}.") + + def _proclass_neglogp(self, X, C, t, label, value=None, O=None, mask=None): + """ + Args: + X (torch.tensor): (batch, num_res, 4, 3) or (batch, num_res, 14, 3) + C (torch.tensor): (batch, num_res) + t (float): 0 < t <= 1 + label (string): class label to condition on, chosen from + `self.class_config.keys()` + mask (torch.tensor): (optional) bool tensor of shape (batch, num_res) for + first order scores, (batch, num_chains) for hain-based scores, and ( + batch, num_res, num_res) for second order scores. The order of your + score can be determined by inspecting self.class_config[label]['level'] + value (string): (optional) the token-based representation of the value you + would like to condition `label` on, you can select options from + `self.class_config[label]['tokens']` for all scores except `is_interface` + or `contact` for which you should leave a `value` of None. + O (torch.tensor): one-hot sequence tensor of size (batch, num_res, num_alphabet) + """ + self.proclass_model.eval() + _bak = self.proclass_model.encoder.checkpoint_gradients + self.proclass_model.encoder.checkpoint_gradients = False + + level = self.proclass_model.class_config[label]["level"] + head, pool = self.proclass_model.heads[level][label] + + node_h, edge_h, edge_idx, mask_i, mask_ij = self.proclass_model.encode( + X, C, O if self.use_sequence else None, t + ) + + if level == "chain": + node_h, c_mask = pool(node_h, C) + c_mask = c_mask + elif level == "first_order": + c_mask = C > 0 + elif level == "second_order": + c_mask = (C > 0).unsqueeze(-2) & (C > 0).unsqueeze(-1) + node_h = head(node_h) + + if mask is not None: + c_mask = mask & c_mask + + if self.proclass_model.class_config[label]["loss"] == "ce": + neglogp = node_h.log_softmax(dim=-1).mul(-1) + else: + neglogp = node_h.sigmoid().log().mul(-1) + + if level == "chain": + index = ( + self.proclass_model.class_config[label]["tokenizer"][value] + if value is not None + else 0 + ) + neglogp = neglogp[..., index][c_mask].sum() + elif level == "first_order": + if isinstance(value, str): + index = torch.LongTensor( + [ + self.proclass_model.class_config[label]["tokenizer"][v] + for v in value + ] + ).to(neglogp.device) + neglogp = torch.gather( + neglogp, 2, index.unsqueeze(0).unsqueeze(2) + ).sum() + elif isinstance( + value, torch.Tensor + ): # Mask Tensor is Passed for SS Conditioning + logp = -1 * neglogp + masked_log_probs = torch.where( + value > 0, logp, -float("inf") * torch.ones_like(logp) + ) + log_probs_sum = torch.logsumexp(masked_log_probs, dim=-1) + log_probs_sum = torch.where( + value.sum(-1) > 0, log_probs_sum, torch.zeros_like(log_probs_sum) + ) + neglogp = -1 * log_probs_sum.sum() + + self.proclass_model.encoder.checkpoint_gradients = _bak + return neglogp + + def _transform_gradient(self, grad, C, t): + if self.debug: + print("conditioning grad norm:", grad.norm().item()) + if grad.norm() > 1e-8: # Don't rescale zero gradients! + # grad = clip_atomic_magnitudes_percentile(grad,percentile=0.95) + if self.renormalize_grad: + scale = ( + self.weight + / self.proclass_model.noise_perturb.noise_schedule.sigma(t).to( + C.device + ) + ) + grad = scale * (grad / grad.norm()) + else: + grad = self.weight * grad + + if self.max_norm is not None: + if grad.norm() > self.max_norm: + grad = self.max_norm * (grad / grad.norm()) + if self.debug: + print("output_grad_norm", grad.norm().item()) + return grad + + @validate_XC() + def forward( + self, + X: torch.Tensor, + C: torch.LongTensor, + O: torch.Tensor, + U: torch.Tensor, + t: Union[torch.Tensor, float], + ) -> Tuple[ + torch.Tensor, + torch.LongTensor, + torch.Tensor, + torch.Tensor, + Union[torch.Tensor, float], + ]: + X_input = X + 0.0 + X_input.register_hook(lambda _X: self._transform_gradient(_X, C, t)) + + neglogp = self._proclass_neglogp( + X_input, C, t, self.label, value=self.value, O=O + ) + if self.debug: + print("time", t.item(), "neglogp:", neglogp.item()) + return X, C, O, neglogp + U, t + + +class SubstructureConditioner(Conditioner): + """ + SubstructureConditioner: + A Chroma Conditioning module which can specifiy a subset of residues for which to + condition on absolute atomic coordinates, see supplementary section M for more + details. + + Args: + protein (generate.data.protein.Protein): Object containing structural + information to condition on. + backbone_model (generate.models.GraphBackbone): The `GraphBackbone` + object one is conditioning + selection (str): A string specifying the selection to condition on, will be + retrieved by `protein.get_mask(selection)` . The selection can be defined + from a set of residue indices `indices` by + `protein.sys.setSelection(indices, selection)`. + rg (bool, optional): Whether or not to add reconstruction guidance gradients, + see supplementary section M for a discussion. This can reduce incidence of + clashes / bond violations / discontinuities at the cost of inference time + and some stability. + weight (float, optional): Overall weight of the reconstruction guidance term + (untransformed). + tspan (Tuple[float, float], optional): Time interval over which to appl + y reconstruction guidance, can be helpful to turn off at times close to + zero. tspan[0] should be < tspan[1]. + weight_max (float, optional): Final rg gradient is rescaled to have `scale` + variance, where `scale` is clamped to have a maximum value of `max_weight`. + gamma (Optional[float]): Gamma inflates the translational degree of freedom + of the underlying conditional multivariate normal, making it easier for + Chroma to move the center of mass of the infilled samples. + Setting to [0.01, 0.1, 1.0] is a a plausible place to start to increase + sample Rg. + center_init (Optional[bool]): Whether to center the input structural data + """ + + def __init__( + self, + protein: ProteinBase, + backbone_model: GraphBackbone, + selection: str, + rg: bool = False, + weight: float = 1.0, + tspan: Tuple[float, float] = (1e-1, 1), + weight_max: float = 3.0, + gamma: Optional[float] = None, + center_init: Optional[bool] = True, + ): + super().__init__() + self.protein = protein + self.backbone_model = backbone_model + X, C, S = protein.to_XCS() + X = X[:, :, :4, :] + if center_init: + X = backbone.center_X(X, C) + D = get_mask(protein, selection).bool() + self.base_distribution = self.backbone_model.noise_perturb.base_gaussian + self.noise_schedule = self.backbone_model.noise_perturb.noise_schedule + self.conditional_distribution = mvn.ConditionalBackboneMVNGlobular( + covariance_model=self.base_distribution.covariance_model, + complex_scaling=self.base_distribution.complex_scaling, + X=X, + C=C, + D=D, + gamma=gamma, + ) + X = self.conditional_distribution.sample(1) + self.tspan = tspan + self.weight = weight + self.weight_max = weight_max + self.rg = rg + self.register_buffer("X", X) + self.register_buffer("C", C) + self.register_buffer("S", S) + self.register_buffer("D", D) + + def _transform_gradient(self, grad, C, t): + mask = (t > self.tspan[0]) & (t < self.tspan[1]) + scale = ( + (self.weight / self.noise_schedule.sigma(t).to(C.device)) + .clamp(max=self.weight_max) + .masked_fill(~mask, 0.0) + ) + grad = scale * grad / grad.square().mean(dim=[1, 2, 3], keepdim=True).sqrt() + return grad + + def _rg_loss(self, X0, C): + C_clamp = torch.where(self.D, C, -C.abs()) + X0 = backbone.impute_masked_X(X0, C_clamp) + X_target = backbone.impute_masked_X(self.X.repeat(X0.size(0), 1, 1, 1), C_clamp) + loss = ( + self.base_distribution._multiply_R_inverse(X_target - X0, C).square().sum() + ) + return loss + + @validate_XC() + def forward( + self, + X: torch.Tensor, + C: torch.LongTensor, + O: torch.Tensor, + U: torch.Tensor, + t: Union[torch.Tensor, float], + ) -> Tuple[ + torch.Tensor, + torch.LongTensor, + torch.Tensor, + torch.Tensor, + Union[torch.Tensor, float], + ]: + loss = 0.0 + Z = self.base_distribution._multiply_R_inverse(X, C) + X = self.conditional_distribution.sample(Z=Z) + + # Reconstruction guidance + if self.rg: + X_input = X + 0.0 + X_input.register_hook(lambda _X: self._transform_gradient(_X, C, t)) + X0 = self.backbone_model.denoise(X_input, C, t) + loss = self._rg_loss(X0, C) + U = U + loss + return X, C, O, U, t + + +class SymmetryConditioner(Conditioner): + """A class that implements a symmetry conditioner for a protein structure. + + A symmetry conditioner applies a set of symmetry operations to a protein structure + and enforces constraints on the resulting conformations. It can be used to model + symmetric complexes or assemblies of proteins. + + Args: + G (torch.Tensor or str): A tensor of shape (n_sym, 3, 3) representing the symmetry + operations as rotation matrices. + num_chain_neighbors (int): The number of neighbors to consider for each chain in + the complex. + freeze_com (bool): Whether to freeze the center of mass of the complex during + optimization. + grad_com_surgery (bool): Whether to apply gradient surgery to remove the center + of mass component from the gradient. + interface_restraint (bool): Whether to apply a flat-bottom potential to restrain + the distance between neighboring chains in the complex. + restraint_grad (bool): Whether to include the restraint gradient in the total + gradient. + enable_rigid_drift (bool): Whether to enable rigid body drift correction for + the complex. + canonicalize (bool): Whether to canonicalize the chain order and orientation + of the complex. + + Inputs: + X (torch.Tensor): Data tensor with shape `(batch_size, num_residues, 4, 3)`. + C (torch.LongTensor): Conditioning tensor with shape `(batch_size, + num_residues)`. + O (torch.Tensor): One-hot sequence with shape + `(batch_size, num_residues, num_alphabet)`. + U (torch.Tensor): Energy tensor with shape `(batch_size,)`. + t (Union[torch.Tensor, float]): Time tensor with shape `(batch_size,)` or a + scalar. + + Outputs: + X_out (torch.Tensor): Modified data tensor with shape `(batch_size, num_residues + , 4, 3)`. + C_out (torch.LongTensor): Modified conditioning tensor with shape `(batch_size, + num_residues)`. + O_out (torch.Tensor, optional): Modified one-hot tensor with sequence + of shape `(batch_size, num_residues, num_alphabet)`. + U_out (torch.Tensor): Modified Energy tensor with shape `(batch_size,)`. + t_out (Union[torch.Tensor, float]): Modified time tensor with shape `(batch_size + ,)` or a scalar. + """ + + def __init__( + self, + G, + num_chain_neighbors, + freeze_com=False, + grad_com_surgery=False, + interface_restraint=False, + restraint_grad=False, + enable_rigid_drift=True, + canonicalize=True, + seed_idx=None, + ): + super().__init__() + + if type(G) == str: + self.G = symmetry.get_point_group(G) + else: + self.G = G + + self.num_chain_neighbors = num_chain_neighbors + self.freeze_com = freeze_com + self.grad_com_surgery = grad_com_surgery + self.interface_restraint = interface_restraint + self.restraint_grad = restraint_grad + self.enable_rigid_drift = enable_rigid_drift + self.canonicalize = canonicalize + self.seed_idx = seed_idx + + if num_chain_neighbors > self.G.shape[0] - 1: + self.num_chain_neighbors = self.G.shape[0] - 1 + + self.potts_symmetry_order = self.num_chain_neighbors + 1 + + def flat_bottom_potential(self, r, r0, k, d): + condition = torch.abs(r - r0) < d + return torch.where( + condition, torch.zeros_like(r), k * (torch.abs(r - r0) - d) ** 2 + ) + + def translational_scaling(self, C): + """Compute parameters for enforcing Rg scaling""" + + # Build expanded map per chain + C_expand = C.unsqueeze(-1).expand(-1, -1, 4) + C_atomic = C_expand.reshape(C.shape[0], -1) + C_mask_all = backbone.expand_chain_map(torch.abs(C_atomic))[..., None] + + a = 1.5587407701549267 # TODO: this can change if our prior changed + nu = 2.0 / 5.0 + r = 2.0 / 3.0 + + # C_mask_all is () + # Monomer and complex sizes (batch, {chains}) + C_mask = C_mask_all.squeeze(-1) + N_per_chain = C_mask.sum(1) + N_per_complex = C_mask.sum([1, 2]) + + # Compute expected Rg^2 values per complex + Rg2_complex = (r**2) * N_per_complex ** (2.0 * nu) + Rg2_chain = (r**2) * N_per_chain ** (2.0 * nu) + + # Compute OU process parameters + N_per_chain = torch.clip(N_per_chain, 1, 1e6) + + # Compute size-weighted average Rg^2 per chain + Rg2_chain_avg = (N_per_chain * Rg2_chain).sum(1) / (N_per_chain.sum(1) + 1e-5) + Rg2_centers_of_mass = torch.clip(Rg2_complex - Rg2_chain_avg, min=1) + Rg_centers_of_mass = torch.sqrt(Rg2_centers_of_mass) + + N_chains_per_complex = (C_mask.sum(1) > 0).sum(1) + # Correct for the fact that we are sampling chains IID (not + # centered) but want to control centered Rg + std_correction = torch.sqrt( + N_chains_per_complex / (N_chains_per_complex - 1).clamp(min=1) + ) + marginal_COM_std = std_correction * Rg_centers_of_mass + + return marginal_COM_std + + def expand_C(self, C, k): + missing = C < 0 + Cs = [] + for i in range(k): + newC = C.abs() + C.unique().max() * i + newC[missing] = -newC[missing] + Cs += [newC] + C = torch.cat(Cs, dim=1) + return C + + def expand_S(self, S, k): + S = torch.cat([S] * k, dim=1) + return S + + def expand_au(self, X, C, G, scale=True): + n_atoms_per_res = X.shape[-2] + + C_au = C + # compute new chain mask + C = self.expand_C(C, G.shape[0]) + + # compute COM inflation due to tesselate + if scale: + if self.enable_rigid_drift: + translate_ratio = self.translational_scaling(C) / ( + self.translational_scaling(C_au) + * (self.num_chain_neighbors + 1) ** 0.5 + ) + + else: + translate_ratio = 1.0 + + mask_expand = ( + (C_au > 0) + .float() + .reshape(list(C_au.shape) + [1, 1]) + .expand([-1, -1, n_atoms_per_res, -1]) + ) + X_com = (mask_expand * X).sum([1, 2], keepdims=True) / ( + mask_expand.sum([1, 2], keepdims=True) + ) + + X_shifted_mean = X_com * translate_ratio + X = (X - X_com) + X_shifted_mean + + X = torch.einsum("gij,braj->bgrai", G, X).reshape(1, -1, n_atoms_per_res, 3) + + return X, C + + def _postprocessing_( + self, protein: ProteinBase, output_dict: Optional[dict] = None + ) -> Union[ProteinBase, Tuple[ProteinBase, dict]]: + X, C, S = protein_to_xcs(protein, all_atom=True) + X_sym, C_sym, S_sym = self.symmetrize_output(X, C, S) + + return protein + + def center_X(self, X, C): + mask_expand = ( + (C > 0).float().reshape(list(C.shape) + [1, 1]).expand([-1, -1, 4, -1]) + ) + + # compute mean based on backbone coordinates + X_mean = (mask_expand * X[:, :, :4, :]).sum([1, 2], keepdims=True) / ( + mask_expand.sum([1, 2], keepdims=True) + ) + X_centered = X - X_mean + + return X_centered + + def symmetrize_output(self, X, C, S): + if type(X) == torch.Tensor: + assert len(X.shape) == 4 + X = [X] + + n_chains = ( + self.num_chain_neighbors + 1 + if self.num_chain_neighbors + 1 < self.G.shape[0] + else self.G.shape[0] + ) + + C_seed = C.reshape(1, n_chains, -1)[:, 0] + S_seed = S.reshape(1, n_chains, -1)[:, 0] + + traj = [] + for each in X: + n_atoms_per_res = each.shape[-2] + X_seed = each.reshape(1, n_chains, -1, n_atoms_per_res, 3)[:, 0] + X_tess, C_tess = self.expand_au(X_seed, C_seed, self.G, scale=False) + S_tess = self.expand_S(S_seed, k=self.G.shape[0]) + X_tess = self.center_X(X_tess, C_tess) + traj.append(X_tess) + + if len(traj) == 1: + traj = traj[0] + + return traj, C_tess, S_tess + + @validate_XC() + def forward( + self, + X: torch.Tensor, + C: torch.LongTensor, + O: torch.Tensor, + U: torch.Tensor, + t: Union[torch.Tensor, float], + ) -> Tuple[ + torch.Tensor, + torch.LongTensor, + torch.Tensor, + torch.Tensor, + Union[torch.Tensor, float], + ]: + self.G = self.G.to(X.device) + + if self.grad_com_surgery or self.freeze_com: + X_tess, C_tess = self.expand_au(X, C, self.G, scale=False) + else: + X_tess, C_tess = self.expand_au(X, C, self.G, scale=True) + + X_subdomain, C_subdomain, subdomain_idx, seed_idx = symmetry.subsample( + X_tess, C_tess, self.G, self.num_chain_neighbors, seed_idx=self.seed_idx + ) + + if self.canonicalize: + X_canonical = torch.einsum( + "ij,barj->bari", self.G[seed_idx].inverse(), X_subdomain + ) + else: + X_canonical = X_subdomain + + def grad_surgery(dx): + if self.grad_com_surgery: + # inflate COM signal + translate_ratio = self.translational_scaling(C_tess) / ( + self.translational_scaling(C) + ) + + dx_com = dx.mean([0, 1, 2]) + dx_com_scale = dx_com * translate_ratio + dx = (dx - dx_com) + dx_com_scale + + if self.freeze_com: + dx = backbone.center_X(dx, C_subdomain) + + # averaging grad + dx = dx / (self.num_chain_neighbors + 1) + return dx + + X_canonical.register_hook(grad_surgery) + + # Tesselate sequence + symmetry_order = C_subdomain.shape[1] // C.shape[1] + O_subdomain = ( + O[:, None, :, :] + .expand([-1, symmetry_order, -1, -1]) + .reshape(list(C_subdomain.shape) + [O.shape[-1]]) + ) + return X_canonical, C_subdomain, O_subdomain, U, t + + +class ScrewConditioner(Conditioner): + """A class that implements a screw conditioner for a protein structure. + + A screw conditioner applies a screw transformation to a protein structure + and repeats it for a given number of times. It can be used to model + helical or cyclic symmetry of proteins. + + Attributes: + theta (float): The angle of rotation about the z-axis in radians. + tz (float): The translation along the z-axis. + M (int): The number of repetitions of the screw transformation. + + Methods: + prepare_transforms(N_repeat): Compute the rotation matrices and translation + vectors for the screw transformation. + expand_C(C, k): Expand a chain tensor C by duplicating each chain k times with + different labels. + rebuild(X, C, M): Rebuild a protein structure with the screw transformation. + forward(X, C, U, t): Apply the screw transformation to a protein structure and + return modified tensors. + + Inputs: + X (torch.Tensor): Data tensor with shape `(batch_size, num_residues, 4, 3)`. + C (torch.LongTensor): Chain tensor with shape `(batch_size, num_residues)`. + O (torch.Tensor): One-hot sequence with shape + `(batch_size, num_residues, num_alphabet)`. + U (torch.Tensor): Energy tensor with shape `(batch_size,)`. + t (Union[torch.Tensor, float]): Time tensor with shape `(batch_size,)` or a s + calar. + + Outputs: + X_out (torch.Tensor): Modified data tensor with shape `(batch_size, + num_residues * M, 4, 3)`. + C_out (torch.LongTensor): Modified chain tensor with shape `(batch_size, + num_residues * M)`. + O_out (torch.Tensor, optional): Modified one-hot tensor with sequence + of shape `(batch_size, num_residues, num_alphabet)`. + U_out (torch.Tensor): Modified energy tensor with shape `(batch_size,)`. + t_out (Union[torch.Tensor, float]): Modified time tensor with shape + `(batch_size,)` or a scalar. + """ + + def __init__(self, theta, tz, M): + super().__init__() + self.theta = torch.Tensor([theta]).squeeze() + self.tz = tz + self.M = M + self.Rs, self.ts = self.prepare_transforms(M) + + def prepare_transforms(self, N_repeat): + # Rotation matrix for rotation about the z-axis + R_base = torch.tensor( + [ + [torch.cos(self.theta), -torch.sin(self.theta), 0], + [torch.sin(self.theta), torch.cos(self.theta), 0], + [0, 0, 1], + ] + ) + + t_base = torch.tensor([0, 0, self.tz]) + + Rs = [] + ts = [] + + R = R_base + t = t_base + for _ in range(N_repeat): + R = R @ R_base + t = t + t_base + + Rs.append(R[None]) + ts.append(t[None]) + + Rs = torch.cat(Rs, dim=0) + ts = torch.cat(ts, dim=0) + + return Rs, ts + + def expand_C(self, C, k): + Cs = [] + for i in range(k): + newC = C + C.unique().max() * i + Cs += [newC] + C = torch.cat(Cs, dim=1) + return C + + def rebuild(self, X, C, M, au_len): + Rs, ts = self.prepare_transforms(M) + X = torch.einsum("mji,bari->bmarj", Rs.to(X.device), X[:, :au_len]) + X_screw = X + ts.to(X.device)[None][:, :, None, None, :] + C_screw = self.expand_C(C[:, :au_len], Rs.shape[0]) + X_screw = X_screw.reshape(1, -1, 4, 3) + return X_screw, C_screw + + @validate_XC() + def forward(self, X, C, O, U, t): + X.requires_grad = True + X = torch.einsum("mji,bari->bmarj", self.Rs.to(X.device), X) + X_screw = X + self.ts.to(X.device)[None][:, :, None, None, :] + C_screw = self.expand_C(C, self.M) + + def grad_surgery(dx): + dx = dx / (self.M) + return dx + + X.register_hook(grad_surgery) + X_screw = X_screw.reshape(1, -1, 4, 3) + + # Tesselate sequence + symmetry_order = C_screw.shape[1] // C.shape[1] + O_screw = ( + O[:, None, :, :] + .expand([-1, symmetry_order, -1, -1]) + .reshape(list(C_screw.shape) + [O.shape[-1]]) + ) + return X_screw, C_screw, O_screw, U, t + + +class InflateConditioner(Conditioner): + """Inflate conditioner + + This class inherits from the Conditioner class and defines a specific conditioner + that inflates shift the COM of X based on a vector v and a scalar. + + Args: + v (torch.Tensor): Vector to add to X with shape `(num_residues, 4, 3)`. + scale (float): Scale factor for v. + + Inputs: + X (torch.Tensor): Data tensor with shape `(batch_size, num_residues, 4, 3)`. + C (torch.LongTensor): Conditioning tensor with shape `(batch_size, + num_residues)`. + O (torch.Tensor): One-hot sequence with shape + `(batch_size, num_residues, num_alphabet)`. + U (torch.Tensor): Noise tensor with shape `(batch_size,)`. + t (Union[torch.Tensor, float]): Time tensor with shape `(batch_size,)` or a + scalar. + + Outputs: + X_out (torch.Tensor): Modified data tensor with shape `(batch_size, num_residues, + 4, 3)`. + C_out (torch.LongTensor): Modified conditioning tensor with shape `(batch_size, + num_residues)`. + O_out (torch.Tensor, optional): Modified one-hot tensor with sequence + of shape `(batch_size, num_residues, num_alphabet)`. + U_out (torch.Tensor): Modified noise tensor with shape `(batch_size,)`. + t_out (Union[torch.Tensor, float]): Modified time tensor with shape + `(batch_size,)` or a scalar. + """ + + def __init__(self, v: torch.Tensor, scale: float): + super().__init__() + self.v = v / v.norm() + self.scale = scale + + @validate_XC() + def forward( + self, + X: torch.Tensor, + C: torch.LongTensor, + O: torch.Tensor, + U: torch.Tensor, + t: Union[torch.Tensor, float], + ) -> Tuple[ + torch.Tensor, + torch.LongTensor, + torch.Tensor, + torch.Tensor, + Union[torch.Tensor, float], + ]: + X.requires_grad = True + X = X + self.v.to(X.device) * self.scale + return X, C, O, U, t + + +class RgConditioner(Conditioner): + """Conditioners that penalized backbones for having Rg deviated from the expected Rg + Scaling. The penalty function takes the form of a flat bottom potential + + penalty = || ReLU( || Rg(X, C) - Rg_ceiling_scale * expected_Rg(C) || ) ||^2 + + Args: + scale (float): Scale factor for the penalty + Rg_ceiling_scale (float): the flat bottom potentialy width, needs to be larger + than 1. + complex_rg (bool): whether compute expected Rg based on the complex Rg scaling. + If True, expected Rg will be computed by treating the entire complex as if + it is a single cahin. If False, expected Rg will be computed for individual + chains + + Inputs: + X (torch.Tensor): Data tensor with shape `(batch_size, num_residues, 4, 3)`. + C (torch.LongTensor): Conditioning tensor with shape `(batch_size, + num_residues)`. + O (torch.Tensor): One-hot sequence with shape + `(batch_size, num_residues, num_alphabet)`. + U (torch.Tensor): Noise tensor with shape `(batch_size,)`. + t (Union[torch.Tensor, float]): Time tensor with shape `(batch_size,)` or a + scalar. + + Outputs: + X_out (torch.Tensor): Modified data tensor with shape `(batch_size, num_residues, + 4, 3)`. + C_out (torch.LongTensor): Modified conditioning tensor with shape `(batch_size, + num_residues)`. + O_out (torch.Tensor, optional): Modified one-hot tensor with sequence + of shape `(batch_size, num_residues, num_alphabet)`. + U_out (torch.Tensor): Modified noise tensor with shape `(batch_size,)`. + t_out (Union[torch.Tensor, float]): Modified time tensor with shape + `(batch_size,)` or a scalar. + """ + + def __init__( + self, + scale=1.0, + Rg_ceiling_scale=1.5, + complex_rg=False, + ): + super().__init__() + self.eps = 1e-5 + self.scale = scale + self.Rg_ceiling_scale = Rg_ceiling_scale + self.complex_rg = complex_rg + + def means_per_chain(self, _X, _C, eps=1e-5): + """Compute center of mass for each chain in a complex""" + # (B,N) => (B,N,C) => (B,N,C,A,X) + mask_chains = (expand_chain_map(_C) > 0).float() + mask_chains_expand = mask_chains[..., None, None] + X_masked = mask_chains_expand * _X.unsqueeze(2) + # Compute per chain means + X_mean_chains = X_masked.sum([1, 3], keepdims=True) / ( + 4 * mask_chains_expand.sum([1, 3], keepdims=True) + eps + ) + # Compute per complex mean + X_mean_complex = X_masked.sum([1, 2, 3], keepdims=True) / ( + 4 * mask_chains_expand.sum([1, 2, 3], keepdims=True) + eps + ) + return X_masked, X_mean_chains, X_mean_complex, mask_chains + + def expected_Rg(self, N): + """compute expected Rg""" + nu = 2.0 / 5.0 + r = 2.0 + + return ((r**2) * N ** (2.0 * nu)) ** 0.5 + + def compute_Rg( + self, + X, + C, + ): + """compute Rg with X and C""" + X.requires_grad = True + X_masked, X_mean_chains, X_mean_complex, mask_chains = self.means_per_chain( + X, C + ) + + mask_chains_expand = mask_chains[..., None] + r2_i = mask_chains_expand * (X_masked - X_mean_chains).square().sum(-1) + + r2_i_mean = (r2_i + self.eps).mean(-1).sum(1) / (mask_chains.sum(1) + self.eps) + + r_i_rms = torch.sqrt(r2_i_mean + self.eps) + + return r_i_rms + + @validate_XC() + def forward(self, X, C, O, U, t): + if self.complex_rg: + C_tmp = torch.ones_like(C) + else: + C_tmp = C + + # Compute expected Rg + N_chain = expand_chain_map(torch.abs(C_tmp)).sum(1) + r_i_rms_expected = self.expected_Rg(N_chain) + + true_rg = self.compute_Rg(X, C_tmp) + U_Rg = F.relu(true_rg - self.Rg_ceiling_scale * r_i_rms_expected).square() + + U = U + self.scale * U_Rg.sum() + return X, C, O, U, t + + +def clip_atomic_magnitudes_percentile(dX, percentile=0.9): + D = dX.square().sum(-1, keepdims=True).add(1e-5).sqrt() + D_max = D.quantile(percentile) + dX_adjust = dX * D.clamp(max=D_max) / D + return dX_adjust diff --git a/prtm/models/chroma/structure/diffusion.py b/prtm/models/chroma/structure/diffusion.py new file mode 100644 index 00000000..78b3ca53 --- /dev/null +++ b/prtm/models/chroma/structure/diffusion.py @@ -0,0 +1,1959 @@ +# Copyright Generate Biomedicines, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Layers for perturbing protein structure with noise. + +This module contains pytorch layers for perturbing protein structure with noise, +which can be useful both for data augmentation, benchmarking, or denoising based +training. +""" + + +from typing import Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from prtm.constants.residue_constants import alphabetical_restypes +from prtm.models.chroma import basic, sde +from prtm.models.chroma.structure import backbone, hbonds, mvn, rmsd +from prtm.models.chroma.xcs import validate_XC +from torch.autograd import grad +from tqdm.auto import tqdm + + +class GaussianNoiseSchedule: + """ + A general noise schedule for the General Gaussian Forward Path, where noise is added + to the input signal. + + The noise is modeled as Gaussian noise with mean `alpha_t x_0` and variance + `sigma_t^2`, with 'x_0 ~ p(x_0)' The time range of the noise schedule is + parameterized with a user-specified logarithmic signal-to-noise ratio (SNR) range, + where `snr_t = alpha_t^2 / sigma_t^2` is the SNR at time `t`. + + In addition, the object defines a quantity called the scaled signal-to-noise ratio + (`ssnr_t`), which is given by `ssnr_t = alpha_t^2 / (alpha_t^2 + sigma_t^2)` + and is a helpful quantity for analyzing the performance of signal processing + algorithms under different noise conditions. + + This object implements a few standard noise schedule: + + 'log_snr': variance-preserving process with linear log SNR schedule + (https://arxiv.org/abs/2107.00630) + + 'ot_linear': OT schedule + (https://arxiv.org/abs/2210.02747) + + 've_log_snr': variance-exploding process with linear log SNR s hedule + (https://arxiv.org/abs/2011.13456 with log SNR noise schedule) + + User can also implement their own schedules by specifying alpha_func, sigma_func + and compute_t_range. + + """ + + def __init__( + self, + log_snr_range: Tuple[float, float] = (-7.0, 13.5), + kind: str = "log_snr", + ) -> None: + super().__init__() + + if kind not in ["log_snr", "ot_linear", "ve_log_snr"]: + raise NotImplementedError( + f"noise type {kind} is not implemented, only" + " log_snr and ot_linear are supported " + ) + self.kind = kind + self.log_snr_range = log_snr_range + + l_min, l_max = self.log_snr_range + + # map t \in [0, 1] to match the prescribed log_snr range + self.t_max = self.compute_t_range(l_min) + self.t_min = self.compute_t_range(l_max) + self._eps = 1e-5 + + def t_map(self, t: Union[float, torch.Tensor]) -> torch.Tensor: + """map t in [0, 1] to [t_min, t_max] + + Args: + t (Union[float, torch.Tensor]): time + + Returns: + torch.Tensor: mapped time + """ + if not isinstance(t, torch.Tensor): + t = torch.Tensor([t]).float() + + t_max = self.t_max.to(t.device) + t_min = self.t_min.to(t.device) + t_tilde = t_min + (t_max - t_min) * t + + return t_tilde + + def derivative(self, t: torch.Tensor, func: Callable) -> torch.Tensor: + """compute derivative of a function, it supports bached single variable inputs + + Args: + t (torch.Tensor): time variable at which derivatives are taken + func (Callable): function for derivative calculation + + Returns: + torch.Tensor: derivative that is detached from the computational graph + """ + with torch.enable_grad(): + t.requires_grad_(True) + derivative = grad(func(t).sum(), t, create_graph=False)[0].detach() + t.requires_grad_(False) + return derivative + + def tensor_check(self, t: Union[float, torch.Tensor]) -> torch.Tensor: + """convert input to torch.Tensor if it is a float + + Args: + t ( Union[float, torch.Tensor]): input + + Returns: + torch.Tensor: converted torch.Tensor + """ + if not isinstance(t, torch.Tensor): + t = torch.Tensor([t]).float() + return t + + def alpha_func(self, t: Union[float, torch.Tensor]) -> torch.Tensor: + """alpha function that scales the mean, usually goes from 1. to 0. + + Args: + t (Union[float, torch.Tensor]): time in [0, 1] + + Returns: + torch.Tensor: alpha value + """ + + t = self.tensor_check(t) + + if self.kind == "log_snr": + l_min, l_max = self.log_snr_range + t_min, t_max = self.t_min, self.t_max + log_snr = (1 - t) * l_max + t * l_min + + log_alpha = 0.5 * (log_snr - F.softplus(log_snr)) + alpha = log_alpha.exp() + return alpha + + elif self.kind == "ve_log_snr": + return 1 - torch.relu(-t) # make this differentiable + + elif self.kind == "ot_linear": + return 1 - t + + def sigma_func(self, t: Union[float, torch.Tensor]) -> torch.Tensor: + """sigma function that scales the standard deviation, usually goes from 0. to 1. + + Args: + t (Union[float, torch.Tensor]): time in [0, 1] + + Returns: + torch.Tensor: sigma value + """ + t = self.tensor_check(t) + l_min, l_max = self.log_snr_range + + if self.kind == "log_snr": + alpha = self.alpha(t) + return (1 - alpha.pow(2)).sqrt() + + elif self.kind == "ve_log_snr": + # compute sigma value given snr range + + l_min, l_max = self.log_snr_range + t_min, t_max = self.t_min, self.t_max + log_snr = (1 - t) * l_max + t * l_min + return torch.exp(-log_snr / 2) + + elif self.kind == "ot_linear": + return t + + def alpha(self, t: Union[float, torch.Tensor]) -> torch.Tensor: + """compute alpha value for the mapped time in [t_min, t_max] + + Args: + t (Union[float, torch.Tensor]): time in [0, 1] + + Returns: + torch.Tensor: alpha value + """ + return self.alpha_func(self.t_map(t)) + + def sigma(self, t: Union[float, torch.Tensor]) -> torch.Tensor: + """compute sigma value for mapped time in [t_min, t_max] + + Args: + t (Union[float, torch.Tensor]): time in [0, 1] + + Returns: + torch.Tensor: sigma value + """ + return self.sigma_func(self.t_map(t)) + + def alpha_deriv(self, t: Union[float, torch.Tensor]) -> torch.Tensor: + """compute alpha derivative for mapped time in [t_min, t_max] + + Args: + t (Union[float, torch.Tensor]): time in [0, 1] + + Returns: + torch.Tensor: time derivative of alpha_func + """ + t_tilde = self.t_map(t) + alpha_deriv_t = self.derivative(t_tilde, self.alpha_func).detach() + return alpha_deriv_t + + def sigma_deriv(self, t: Union[float, torch.Tensor]) -> torch.Tensor: + """compute sigma derivative for the mapped time in [t_min, t_max] + + Args: + t (Union[float, torch.Tensor]): time in [0, 1] + + Returns: + torch.Tensor: sigma derivative + """ + t_tilde = self.t_map(t) + sigma_deriv_t = self.derivative(t_tilde, self.sigma_func).detach() + return sigma_deriv_t + + def beta(self, t: Union[float, torch.Tensor]) -> torch.Tensor: + """compute the drift coefficient for the OU process of the form + $dx = -\frac{1}{2} \beta(t) x dt + g(t) dw_t$ + + Args: + t (Union[float, torch.Tensor]): t in [0, 1] + + Returns: + torch.Tensor: beta(t) + """ + # t = self.t_map(t) + alpha = self.alpha(t).detach() + t_map = self.t_map(t) + alpha_deriv_t = self.alpha_deriv(t) + beta = -2.0 * alpha_deriv_t / alpha + + return beta + + def g(self, t: Union[float, torch.Tensor]) -> torch.Tensor: + """compute drift coefficient for the OU process: + $dx = -\frac{1}{2} \beta(t) x dt + g(t) dw_t$ + + Args: + t (Union[float, torch.Tensor]): t in [0, 1] + + Returns: + torch.Tensor: g(t) + """ + if self.kind == "log_snr": + t = self.t_map(t) + g = self.beta(t).sqrt() + + else: + alpha_deriv = self.alpha_deriv(t) + alpha_prime_div_alpha = alpha_deriv / self.alpha(t) + sigma_deriv = self.sigma_deriv(t) + sigma_prime_div_sigma = sigma_deriv / self.sigma(t) + + g_sq = ( + 2 + * (sigma_deriv - alpha_prime_div_alpha * self.sigma(t)) + * self.sigma(t) + ) + g = g_sq.sqrt() + + return g + + def SNR(self, t: Union[float, torch.Tensor]) -> torch.Tensor: + """Signal-to-Noise(SNR) ratio mapped in the allowed log_SNR range + + Args: + t (Union[float, torch.Tensor]): time in [0, 1] + + Returns: + torch.Tensor: SNR value + """ + t = self.tensor_check(t) + + if self.kind == "log_snr": + SNR = self.log_SNR(t).exp() + else: + SNR = self.alpha(t).pow(2) / (self.sigma(t).pow(2)) + + return SNR + + def log_SNR(self, t: Union[float, torch.Tensor]) -> torch.Tensor: + """log SNR value + + Args: + t (Union[float, torch.Tensor]): time in [0, 1] + + Returns: + torch.Tensor: log SNR value + """ + t = self.tensor_check(t) + + if self.kind == "log_snr": + l_min, l_max = self.log_snr_range + log_snr = (1 - t) * l_max + t * l_min + + elif self.kind == "ot_linear": + log_snr = self.SNR(t).log() + + return log_snr + + def compute_t_range(self, log_snr: Union[float, torch.Tensor]) -> torch.Tensor: + """Given log(SNR) range : l_max, l_min to compute the time range. + Hand-derivation is required for specific noise schedules. + This function is essentially the inverse of logSNR(t) + + Args: + log_snr (Union[float, torch.Tensor]): logSNR value + + Returns: + torch.Tensor: the inverse logSNR + """ + log_snr = self.tensor_check(log_snr) + l_min, l_max = self.log_snr_range + + if self.kind == "log_snr": + t = (1 / (l_min - l_max)) * (log_snr - l_max) + + elif self.kind == "ot_linear": + t = ((0.5 * log_snr).exp() + 1).reciprocal() + + elif self.kind == "ve_log_snr": + t = (1 / (l_min - l_max)) * (log_snr - l_max) + + return t + + def SNR_derivative(self, t: Union[float, torch.Tensor]) -> torch.Tensor: + """the derivative of SNR(t) + + Args: + t (Union[float, torch.Tensor]): t in [0, 1] + + Returns: + torch.Tensor: SNR derivative + """ + t = self.tensor_check(t) + + if self.kind == "log_snr": + snr_deriv = self.SNR(t) * (self.log_snr_range[0] - self.log_snr_range[1]) + + elif self.kind == "ot_linear": + snr_deriv = self.derivative(t, self.SNR) + return snr_deriv + + def SSNR(self, t: Union[float, torch.Tensor]) -> torch.Tensor: + """Signal to Signal+Noise Ratio (SSNR) = alpha^2 / (alpha^2 + sigma^2) + SSNR monotonically goes from 1 to 0 as t going from 0 to 1. + + Args: + t (Union[float, torch.Tensor]): time in [0, 1] + + Returns: + torch.Tensor: SSNR value + """ + t = self.tensor_check(t) + return self.SNR(t) / (self.SNR(t) + 1) + + def SSNR_inv(self, ssnr: torch.Tensor) -> torch.Tensor: + """the inverse of SSNR + + Args: + ssnr (torch.Tensor): ssnr in [0, 1] + + Returns: + torch.Tensor: time in [0, 1] + """ + l_min, l_max = self.log_snr_range + if self.kind == "log_snr": + return ((ssnr / (1 - ssnr)).log() - l_max) / (l_min - l_max) + elif self.kind == "ot_linear": + # the value of SNNR_inv(t=0.5) need to be determined with L'Hôpital rule + # the inver SNNR_function is solved anyltically: + # see woflram alpha result: https://tinyurl.com/bdh4es5a + singularity_check = (ssnr - 0.5).abs() < self._eps + ssnr_mask = singularity_check.float() + ssnr = ssnr_mask * (0.5 + self._eps) + (1.0 - ssnr_mask) * ssnr + + return (ssnr + (-ssnr * (ssnr - 1)).sqrt() - 1) / (2 * ssnr - 1) + + def SSNR_inv_deriv(self, ssnr: Union[float, torch.Tensor]) -> torch.Tensor: + """SSNR_inv derivative. SSNR_inv is a CDF like quantity, so its derivative is a PDF-like quantity + + Args: + ssnr (Union[float, torch.Tensor]): SSNR in [0, 1] + + Returns: + torch.Tensor: derivative of SSNR + """ + ssnr = self.tensor_check(ssnr) + deriv = self.derivative(ssnr, self.SSNR_inv) + return deriv + + def prob_SSNR(self, ssnr: Union[float, torch.Tensor]) -> torch.Tensor: + """compute prob (SSNR(t)), the minus sign is accounted for the inversion of integration range + + Args: + ssnr (Union[float, torch.Tensor]): SSNR value + + Returns: + torch.Tensor: Prob(SSNR) + """ + return -self.SSNR_inv_deriv(ssnr) + + def linear_logsnr_grid(self, N: int, tspan: Tuple[float, float]) -> torch.Tensor: + """Map uniform time grid to respect logSNR schedule + + Args: + N (int): number of steps + tspan (Tuple[float, float]): time span (t_start, t_end) + + Returns: + torch.Tensor: time grid as torch.Tensor + """ + + logsnr_noise = GaussianNoiseSchedule( + kind="log_snr", log_snr_range=self.log_snr_range + ) + + ts = torch.linspace(tspan[0], tspan[1], N + 1) + SSNR_vp = logsnr_noise.SSNR(ts) + grid = self.SSNR_inv(SSNR_vp) + + # map from t_tilde back to t + grid = (grid - self.t_min) / (self.t_max - self.t_min) + + return grid + + +class NoiseTimeEmbedding(nn.Module): + """ + A class that implements a noise time embedding layer. + + Args: + dim_embedding (int): The dimension of the output embedding vector. + noise_schedule (GaussianNoiseSchedule): A GaussianNoiseSchedule object that + defines the noise schedule function. + rff_scale (float, optional): The scaling factor for the random Fourier features. + Default is 0.8. + feature_type (str, optional): The type of feature to use for the time embedding. + Either "t" or "log_snr". Default is "log_snr". + + Inputs: + t (float): time in (1.0, 0.0). + log_alpha (torch.Tensor, optional): A tensor of log alpha values with + shape `(batch_size,)`. + + Outputs: + time_h (torch.Tensor): A tensor of noise time embeddings with shape + `(batch_size, dim_embedding)`. + """ + + def __init__( + self, + dim_embedding: int, + noise_schedule: GaussianNoiseSchedule, + rff_scale: float = 0.8, + feature_type: str = "log_snr", + ) -> None: + super(NoiseTimeEmbedding, self).__init__() + self.noise_schedule = noise_schedule + self.feature_type = feature_type + self.fourier_features = basic.FourierFeaturization( + d_input=1, d_model=dim_embedding, trainable=False, scale=rff_scale + ) + + def forward( + self, t: torch.Tensor, log_alpha: Optional[torch.Tensor] = None + ) -> torch.Tensor: + if not isinstance(t, torch.Tensor): + t = torch.Tensor([t]).float().to(self.fourier_features.B.device) + if t.dim() == 0: + t = t[None] + + h = {"t": lambda: t, "log_snr": lambda: self.noise_schedule.log_SNR(t)}[ + self.feature_type + ]() + + time_h = self.fourier_features(h[:, None, None]) + return time_h + + +class DiffusionChainCov(nn.Module): + def __init__( + self, + log_snr_range: Tuple[float, float] = (-7.0, 13.5), + noise_schedule: str = "log_snr", + sigma_translation: float = 1.0, + covariance_model: str = "brownian", + complex_scaling: bool = False, + **kwargs, + ) -> None: + """Diffusion backbone noise, with chain-structured covariance. + + This class implements a diffusion backbone noise model. The model uses a + chain-structured covariance matrix capturing the spatial correlations between + residues along the backbone. The model also supports different noise schedules + and integration schemes for the stochastic differential equation (SDE) that + defines the diffusion process. This class also implemented various inference + algorithm by reversing the forward diffusion with user-specified + conditioner program. + + Args: + log_snr_range (tuple, optional): log SNR range. Defaults to (-7.0, 13.5). + noise_schedule (str, optional): noise schedule type. Defaults to "log_snr". + sigma_translation (float, optional): Scaling factor for the translation + component of the covariance matrix. Defaults to 1.0. + covariance_model (str, optional): covariance mode,. Defaults to "brownian". + complex_scaling (bool, optional): Whether to scale the complex component + of the covariance matrix by the translation component. Defaults to False. + **kwargs: Additional arguments for the base Gaussian distribution and + the SDE integration. + """ + super().__init__() + + self.noise_schedule = GaussianNoiseSchedule( + log_snr_range=log_snr_range, + kind=noise_schedule, + ) + + if covariance_model in ["brownian", "globular"]: + self.base_gaussian = mvn.BackboneMVNGlobular( + sigma_translation=sigma_translation, + covariance_model=covariance_model, + complex_scaling=complex_scaling, + ) + elif covariance_model == "residue_gas": + self.base_gaussian = mvn.BackboneMVNResidueGas() + + self.loss_rmsd = rmsd.BackboneRMSD() + self._eps = 1e-5 + self.sde_funcs = { + "langevin": self.langevin, + "reverse_sde": self.reverse_sde, + "ode": self.ode, + } + self.integrate_funcs = { + "euler_maruyama": sde.sde_integrate, + "heun": sde.sde_integrate_heun, + } + + def sample_t( + self, + C: torch.LongTensor, + t: Optional[torch.Tensor] = None, + inverse_CDF: Optional[Callable] = None, + ) -> torch.Tensor: + """Sample a random time index for each batch element + + Inputs: + C (torch.LongTensor): Chain tensor with shape `(batch_size, num_residues)`. + t (torch.Tensor, optional): Time index with shape `(batch_size,)`. + If not given, a random time index will be sampled. Defaults to None. + + Outputs: + t (float): Time index with shape `(batch_size,)`. + """ + if t is not None: + if not isinstance(t, torch.Tensor): + t = torch.Tensor([t]).float() + return t + + num_batch = C.size(0) + if self.training: + # Sample correlated but marginally uniform t + # for variance reduction (Kingma et al 2021) + u = torch.rand([]) + ix = torch.arange(num_batch) / num_batch + t = torch.remainder(u + ix, 1) + else: + t = torch.rand([num_batch]) + if inverse_CDF is not None: + t = inverse_CDF(t) + t = t.to(C.device) + return t + + def sde_forward(self, X, C, t, Z=None): + """Sample an Euler-Maruyama step on forwards SDE. + + That is to say, Euler-Maruyama integration would + correspond to the update. + `X_new = X + dt * f + sqrt(dt) * gZ` + + Args: + + Returns: + f (Tensor): Drift term with shape `()`. + gZ (Tensor): Diffusion term with shape `()`. + """ + + # Sample random perturbation + if Z is None: + Z = torch.randn_like(X) + Z = Z.reshape(X.shape[0], -1, 3) + R_Z = self.base_gaussian._multiply_R(Z, C).reshape(X.shape) + + X = backbone.center_X(X, C) + beta = self.noise_schedule.beta(t) + f = -beta * X / 2.0 + gZ = self.noise_schedule.g(t)[:, None, None] * R_Z + + return f, gZ + + def _schedule_coefficients( + self, + t: torch.Tensor, + inverse_temperature: float = 1.0, + langevin_isothermal: bool = True, + ) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ]: + """ + A method that computes the schedule coefficients for sampling in the reverse time + + Args: + t (float): time in (1.0, 0.0). + inverse_temperature (float, optional): The inverse temperature parameter for + he Langevin dynamics. Default is 1.0. + langevin_isothermal (bool, optional): A flag that indicates whether to use + isothermal or non-isothermal Langevin dynamics. Default is True. + + Returns: + alpha (torch.Tensor): A tensor of alpha values with shape `(batch_size, 1, 1)`. + sigma (torch.Tensor): A tensor of sigma values with shape `(batch_size, 1, 1)`. + beta (torch.Tensor): A tensor of beta values with shape `(batch_size, 1, 1)`. + g (torch.Tensor): A tensor of g values with shape `(batch_size, 1, 1)`. + lambda_t (float): A tensor of lambda_t values with shape `(batch_size, 1, 1)`. + lambda_langevin (torch.Tensor): A tensor of lambda_langevin values with + shape `(batch_size, 1, 1)`. + """ + + # Schedule coeffiecients + alpha = self.noise_schedule.alpha(t)[:, None, None].to(t.device) + sigma = self.noise_schedule.sigma(t)[:, None, None].to(t.device) + beta = self.noise_schedule.beta(t)[:, None, None].to(t.device) + g = self.noise_schedule.g(t)[:, None, None].to(t.device) + + # Temperature coefficients + lambda_t = ( + inverse_temperature + * (sigma.pow(2) + alpha.pow(2)) + / (inverse_temperature * sigma.pow(2) + alpha.pow(2)) + ) + lambda_langevin = inverse_temperature if langevin_isothermal else lambda_t + return alpha, sigma, beta, g, lambda_t, lambda_langevin + + @validate_XC() + def langevin( + self, + X: torch.Tensor, + X0_func: Callable, + C: torch.LongTensor, + t: Union[torch.Tensor, float], + conditioner: Callable = None, + Z: Union[torch.Tensor, None] = None, + inverse_temperature: float = 1.0, + langevin_factor: float = 0.0, + langevin_isothermal: bool = True, + align_X0: bool = True, + ): + """Return the drift and diffusion components of the Langevin dynamics for the + reverse process + + Args: + X (torch.Tensor): A tensor of protein backbone structure with shape + `(batch_size, num_residues, 4, 3)`. + X0_func (Callable): A function a denoising function for protein backbon + e geometry. + C (torch.LongTensor): A chain map tensor with shape `(batch_size, num_residues)`. + t (float): time in (1.0, 0.0). + conditioner (Callable, optional): A conditioner the performs constrained + transformation (see examples in chroma.layers.structure.conditioners). + Z (torch.Tensor, optional): A tensor of random noise with + shape `(batch_size, num_residues, 4, 3)`. Default is None. + inverse_temperature (float, optional): The inverse temperature parameter + for the Langevin dynamics. Default is 1.0. + langevin_factor (float, optional): The scaling factor for the Langevin noise. + Default is 1.0. + langevin_isothermal (bool, optional): A flag that indicates whether to use + isothermal or non-isothermal Langevin dynamics. Default is True. + align_X0 (bool, optional): A flag that indicates whether to align the noised + X and denoised X for score function calculation. + + Returns: + f (torch.Tensor): A tensor of drift terms with shape + `(batch_size, num_residues, 4, 3)`. + gZ (torch.Tensor): A tensor of diffusion terms with shape + `(batch_size, num_residues, 4, 3)`. + """ + + alpha, sigma, beta, g, lambda_t, lambda_langevin = self._schedule_coefficients( + t, + inverse_temperature=inverse_temperature, + langevin_isothermal=langevin_isothermal, + ) + + Z = torch.randn_like(X) if Z is None else Z + + score = self.score(X, X0_func, C, t, conditioner, align_X0=align_X0) + score_transformed = self.base_gaussian.multiply_covariance(score, C) + f = -g.pow(2) * lambda_langevin * langevin_factor / 2.0 * score_transformed + gZ = g * np.sqrt(langevin_factor) * self.base_gaussian._multiply_R(Z, C) + return f, gZ + + @validate_XC() + def reverse_sde( + self, + X: torch.Tensor, + X0_func: Callable, + C: torch.LongTensor, + t: Union[torch.Tensor, float], + conditioner: Callable = None, + Z: Union[torch.Tensor, None] = None, + inverse_temperature: float = 1.0, + langevin_factor: float = 0.0, + langevin_isothermal: bool = True, + align_X0: bool = True, + ): + """Return the drift and diffusion components of the reverse SDE. + + Args: + X (torch.Tensor): A tensor of protein backbone structure with shape + `(batch_size, num_residues, 4, 3)`. + X0_func (Callable): A function a denoising function for the protein backbone + geometry. + C (torch.LongTensor): A tensor of condition features with shape + `(batch_size, num_residues)`. + t (float): time in (1.0, 0.0). + conditioner (Callable, optional): A conditioner the performs constrained + transformation (see examples in chroma.layers.structure.conditioners). + Z (torch.Tensor, optional): A tensor of random noise with shape + `(batch_size, num_residues, 4, 3)`. Default is None. + inverse_temperature (float, optional): The inverse temperature parameter + for the Langevin dynamics. Default is 1.0. + langevin_factor (float, optional): The scaling factor for the Langevin noise. + Default is 0.0. + langevin_isothermal (bool, optional): A flag that indicates whether to use + isothermal or non-isothermal Langevin dynamics. Default is True. + align_X0 (bool, optional): A flag that indicates whether to align the noised + X and denoised X for score function calculation. + + Returns: + f (torch.Tensor): A tensor of drift terms with shape + `(batch_size, num_residues, 4, 3)`. + gZ (torch.Tensor): A tensor of diffusion terms with shape + `(batch_size, num_residues, 4, 3)`. + """ + + # Schedule management + alpha, sigma, beta, g, lambda_t, lambda_langevin = self._schedule_coefficients( + t, + inverse_temperature=inverse_temperature, + langevin_isothermal=langevin_isothermal, + ) + score_scale_t = lambda_t + lambda_langevin * langevin_factor / 2.0 + + # Impute missing data + Z = torch.randn_like(X) if Z is None else Z + + # X = backbone.center_X(X, C) + score = self.score(X, X0_func, C, t, conditioner, align_X0=align_X0) + score_transformed = self.base_gaussian.multiply_covariance(score, C) + + f = ( + beta * (-1 / 2) * backbone.center_X(X, C) + - g.pow(2) * score_scale_t * score_transformed + ) + gZ = g * np.sqrt(1.0 + langevin_factor) * self.base_gaussian._multiply_R(Z, C) + return f, gZ + + @validate_XC() + def ode( + self, + X: torch.Tensor, + X0_func: Callable, + C: torch.LongTensor, + t: Union[torch.Tensor, float], + conditioner: Callable = None, + Z: Union[torch.Tensor, None] = None, + inverse_temperature: float = 1.0, + langevin_factor: float = 0.0, + langevin_isothermal: bool = True, + align_X0: bool = True, + detach_X0: bool = True, + ): + """Return the drift and diffusion components of the probability flow ODE. + + Args: + X (torch.Tensor): A tensor of protein backbone structure with shape + `(batch_size, num_residues, 4, 3)`. + X0_func (Callable): A denoising function that returns a protein backbone + geometry `(batch_size, num_residues, 4, 3)`. + C (torch.LongTensor): A tensor of condition features with shape + `(batch_size, num_residues)`. + t (float): time in (1.0, 0.0). + conditioner (Callable, optional): A conditioner the performs constrained + transformation (see examples in chroma.layers.structure.conditioners). + Z (torch.Tensor, optional): A tensor of random noise with shape + `(batch_size, num_residues, 4, 3)`. Default is None. + inverse_temperature (float, optional): The inverse temperature parameter + for the Langevin dynamics. Default is 1.0. + langevin_factor (float, optional): The scaling factor for the Langevin + noise. Default is 0.0. + langevin_isothermal (bool, optional): A flag that indicates whether to use + isothermal or non-isothermal Langevin dynamics. Default is True. + align_X0 (bool, optional): A flag that indicates whether to align + the noised X and denoised X for score function calculation. + + Returns: + f (torch.Tensor): A tensor of drift terms with shape + `(batch_size, num_residues, 4, 3)`. + gZ (torch.Tensor): A tensor of diffusion terms with shape + `(batch_size, num_residues, 4, 3)`. + """ + + # Schedule management + alpha, sigma, beta, g, lambda_t, lambda_langevin = self._schedule_coefficients( + t, + inverse_temperature=inverse_temperature, + langevin_isothermal=langevin_isothermal, + ) + + # Impute missing data + X = backbone.center_X(X, C) + score = self.score( + X, X0_func, C, t, conditioner, align_X0=align_X0, detach_X0=detach_X0 + ) + score_transformed = self.base_gaussian.multiply_covariance(score, C) + f = (-1 / 2) * beta * X - 0.5 * lambda_langevin * g.pow(2) * score_transformed + gZ = torch.zeros_like(f) + return f, gZ + + @validate_XC() + def energy( + self, + X: torch.Tensor, + X0_func: Callable, + C: torch.Tensor, + t: torch.Tensor, + detach_X0: bool = True, + align_X0: bool = True, + ) -> torch.Tensor: + """Compute the diffusion energy as a function of denoised X + + Args: + X (torch.Tensor): A tensor of protein backbone coordinates with shape + `(batch_size, num_residues, 4, 3)`. + X0_func (Callable): A function a denoising function for protein backbone + geometry. + C (torch.LongTensor): A tensor of condition features with shape + `(batch_size, num_residues)`. + t (float): time in (1.0, 0.0). + detach_X0 (bool, optional): A flag that indicates whether to detach the + denoise X for score function evaluation + align_X0 (bool, optional): A flag that indicates whether to align the + noised X and denoised X for score function calculation. + + Returns: + U_diffusion (torch.Tensor): A tensor of diffusion energy values with + shape `(batch_size,)`. + """ + + X = backbone.impute_masked_X(X, C) + alpha = self.noise_schedule.alpha(t).to(X.device) + sigma = self.noise_schedule.sigma(t).to(X.device) + if detach_X0: + with torch.no_grad(): + X0 = X0_func(X, C, t=t) + else: + X0 = X0_func(X, C, t=t) + if align_X0: + X0, _ = self.loss_rmsd.align(X0, X, C, align_unmasked=True) + if detach_X0: + X0 = X0.detach() + Z = self._X_to_Z(X, X0, C, alpha, sigma) + U_diffusion = (0.5 * (Z**2)).sum([1, 2, 3]) + return U_diffusion + + @validate_XC() + def score( + self, + X: torch.Tensor, + X0_func: Callable, + C: torch.Tensor, + t: Union[torch.Tensor, float], + conditioner: Callable = None, + detach_X0: bool = True, + align_X0: bool = True, + U_traj: List = [], + ) -> torch.Tensor: + """Compute the score function + + Args: + X (torch.Tensor): A tensor of protein back geometry with shape + `(batch_size, num_residues, 4, 3)`. + X0_func (Callable): A function a denoising function for protein backbone + geometry. + C (torch.LongTensor): A tensor of chain map with shape + `(batch_size, num_residues)`. + t (Union[torch.Tensor, float]): time in (1.0, 0.0). + conditioner (Callable, optional): A conditioner the performs constrained + transformation (see examples in chroma.layers.structure.conditioners). + detach_X0 (bool, optional): A flag that indicates whether to detach the + denoised X for score function evaluation + align_X0 (bool, optional): A flag that indicates whether to align the + noised X and denoised X for score function calculation. + U_traj (List, optional): Record diffusion energy as a list. + + Returns: + score (torch.Tensor): A tensor of score values with shape + `(batch_size, num_residues, 4, 3)`. + """ + + X = backbone.impute_masked_X(X, C) + with torch.enable_grad(): + X = X.detach().clone() + X.requires_grad = True + + # Apply optional conditioner transformations to state and energy + Xt, Ct, U_conditioner = X, C, 0.0 + St = torch.zeros(Ct.shape, device=Xt.device).long() + Ot = F.one_hot(St, len(alphabetical_restypes)).float() + if conditioner is not None: + Xt, Ct, _, U_conditioner, _ = conditioner(X, C, Ot, U_conditioner, t) + U_conditioner = torch.as_tensor(U_conditioner) + + # Compute system energy + U_diffusion = self.energy( + Xt, X0_func, Ct, t, detach_X0=detach_X0, align_X0=align_X0 + ) + + U_traj.append(U_diffusion.detach().cpu()) + + # Compute score function as negative energy gradient + U_total = U_diffusion.sum() + U_conditioner.sum() + U_total.backward() + score = -X.grad + score = score.masked_fill((C <= 0)[..., None, None], 0.0) + return score + + def elbo(self, X0_pred, X0, C, t): + """ITD ELBO as a weighted average of denoising error, + inspired by https://arxiv.org/abs/2302.03792""" + if not isinstance(t, torch.Tensor): + t = torch.Tensor([t]).float().to(X0.device) + + # Interpolate missing data with Brownian Bridge posterior + X0 = backbone.impute_masked_X(X0, C) + X0_pred = backbone.impute_masked_X(X0_pred, C) + + # Compute whitened residual + dX = (X0 - X0_pred).reshape([X0.shape[0], -1, 3]) + R_inv_dX = self.base_gaussian._multiply_R_inverse(dX, C) + + # Average per atom, including over "missing" positions that we filled in + + weight = 0.5 * self.noise_schedule.SNR_derivative(t)[:, None, None, None] + snr = self.noise_schedule.SNR(t)[:, None, None, None] + loss_itd = ( + weight * (R_inv_dX.pow(2) - 1 / (1 + snr)) + - 0.5 * np.log(np.pi * 2.0 * np.e) + ).reshape(X0.shape) + + # Compute average per-atom loss (including over missing regions) + mask = (C != 0).float() + mask_atoms = mask.reshape(mask.shape + (1, 1)).expand([-1, -1, 4, 1]) + + # Per-complex + elbo_gap = (mask_atoms * loss_itd).sum([1, 2, 3]) + logdet = self.base_gaussian.log_determinant(C) + elbo_unnormalized = elbo_gap - logdet + + # Normalize per atom + elbo = elbo_unnormalized / (mask_atoms.sum([1, 2, 3]) + self._eps) + + # Compute batch average + weights = mask_atoms.sum([1, 2, 3]) + elbo_batch = (weights * elbo).sum() / (weights.sum() + self._eps) + return elbo, elbo_batch + + def pseudoelbo(self, loss_per_residue, C, t): + """Compute pseudo-ELBOs as weighted averages of other errors.""" + if not isinstance(t, torch.Tensor): + t = torch.Tensor([t]).float().to(C.device) + + # Average per atom, including over x"missing" positions that we filled in + weight = 0.5 * self.noise_schedule.SNR_derivative(t)[:, None] + loss = weight * loss_per_residue + + # Compute average loss + mask = (C > 0).float() + pseudoelbo = (mask * loss).sum(-1) / (mask.sum(-1) + self._eps) + pseudoelbo_batch = (mask * loss).sum() / (mask.sum() + self._eps) + return pseudoelbo, pseudoelbo_batch + + def _baoab_sample_step( + self, + _x, + p, + C, + t, + dt, + score_func, + gamma=2.0, + kT=1.0, + n_equil=1, + ode_boost=True, + langevin_isothermal=False, + ): + gamma = torch.Tensor([gamma]).to(_x.device) + ( + alpha, + sigma, + beta, + g, + lambda_t, + lambda_langevin, + ) = self._schedule_coefficients( + t, + inverse_temperature=1 / kT, + langevin_isothermal=langevin_isothermal, + ) + + def baoab_step(_x, p, t): + Z = torch.randn_like(_x) + c1 = torch.exp(-gamma * dt) + c3 = torch.sqrt((1 / lambda_t) * (1 - c1**2)) + + # BAOAB scheme + p_half = p + score_func(t, C, _x) * dt / 2 # B + _x_half = ( + _x + + g.pow(2) * self.base_gaussian.multiply_covariance(p_half, C) * dt / 2 + ) # A + p_half2 = c1 * p_half + c3 * ( + 1 / g + ) * self.base_gaussian._multiply_R_inverse_transpose( + Z, C + ) # O + _x = ( + _x_half + + g.pow(2) * self.base_gaussian.multiply_covariance(p_half2, C) * dt / 2 + ) # A + p = p_half2 + score_func(t, C, _x) * dt / 2 # B + + return _x, p + + def ode_step(t, _x): + score = score_func(t, C, _x) + score_transformed = self.base_gaussian.multiply_covariance(score, C) + _x = _x + 0.5 * (_x + score_transformed) * g.pow(2) * dt + return _x + + for i in range(n_equil): + _x, p = baoab_step(_x, p, t) + + if ode_boost: + _x = ode_step(t, _x) + + return _x, p + + @torch.no_grad() + def sample_sde( + self, + X0_func: Callable, + C: torch.LongTensor, + X_init: Optional[torch.Tensor] = None, + conditioner: Optional[Callable] = None, + N: int = 100, + tspan: Tuple[float, float] = (1.0, 0.001), + inverse_temperature: float = 1.0, + langevin_factor: float = 0.0, + langevin_isothermal: bool = True, + sde_func: str = "reverse_sde", + integrate_func: str = "euler_maruyama", + initialize_noise: bool = True, + remap_time: bool = False, + remove_drift_translate: bool = False, + remove_noise_translate: bool = False, + align_X0: bool = True, + ) -> Dict[str, torch.Tensor]: + """Sample from the SDE using a numerical integration scheme. + + This function samples from the stochastic differential equation (SDE) defined + by the model using a numerical integration scheme such as Euler-Maruyama or + huen. The SDE can be either in the forward or reverse direction. The function + also supports optional conditioning on external variables and adding Langevin + noise to the SDE dynamics. + + Args: + X0_func (Callable): A denoising function that maps `(X, C, t)` to `X0`. + C (torch.LongTensor): Conditioner tensor with shape `(num_batch, + num_residues)`. + X_init (torch.Tensor, optional): Initial state tensor with shape `(num_batch + , num_residues, 4 ,3)` or None. + If None, a zero tensor will be used as the initial state. + conditioner (Callable, optional): A function that transforms X, C, U, t. + If None, no conditioning will be applied. + N (int): Number of integration steps. + tspan (Tuple[float,float]): Time span for integration. + inverse_temperature (float): Inverse temperature parameter for SDE. + langevin_factor (float): Langevin factor for adding noise to SDE. + langevin_isothermal (bool): Whether to use isothermal or adiabatic Langevin + dynamics. + sde_func (str): Which SDE function to use ('reverse_sde', 'langevin' or 'ode'). + integrate_func (str): Which integration function to use ('euler_maruyama' + or 'heun'). + initialize_noise (bool): Whether to initialize the state with noise. + remap_time (bool): Whether to remap the time grid according to the noise + schedule. + remove_drift_translate (bool): Whether to remove the net translational + component from the drift term. + remove_noise_translate (bool): Whether to remove the net translational + component from the noise term. + align_X0 (bool): Whether to Kabsch align X0 with X before computing SDE terms. + + Returns: + outputs (Dict[str, torch.Tensor]): A dictionary of output tensors with the + following keys: + - 'C': The conditioned tensor with shape `(num_batch,num_residues)`. + - 'X_sample': The final sampled state tensor with shape `(num_batch, + num_residues ,4 ,3)`. + - 'X_trajectory': A list of state tensors along the trajectory with + shape `(num_batch,num_residues ,4 ,3)` each. + - 'Xhat_trajectory': A list of transformed state tensors along the + trajectory with shape `(num_batch,num_residues ,4 ,3)` each. + - 'Xunc_trajectory': A list of unconstrained state tensors along the + trajectory with shape `(num_batch,num_residues ,4 ,3)` each. + """ + + # Setup SDE integration + integrate_func = self.integrate_funcs[integrate_func] + sde_func = self.sde_funcs[sde_func] + T_grid = ( + self.noise_schedule.linear_logsnr_grid(N=N, tspan=tspan).to(C.device) + if remap_time + else torch.linspace(tspan[0], tspan[1], N + 1).to(C.device) + ) + + # Intercept the X0 function for tracking Xt and Xhat + Xhat_trajectory = [] + Xt_trajectory = [] + U_trajectory = [] + + def _X0_func(_X, _C, t): + _X0 = X0_func(_X, _C, t) + Xt_trajectory.append(_X.detach()) + Xhat_trajectory.append(_X0.detach()) + return _X0 + + def sdefun(_t, _X): + f, gZ = sde_func( + _X, + _X0_func, + C, + _t, + conditioner=conditioner, + inverse_temperature=inverse_temperature, + langevin_factor=langevin_factor, + langevin_isothermal=langevin_isothermal, + align_X0=align_X0, + ) + # Remove net translational component + if remove_drift_translate: + f = backbone.center_X(f, C) + if remove_noise_translate: + gZ = backbone.center_X(gZ, C) + return f, gZ + + # Initialization + if initialize_noise and X_init is not None: + X_init = self.forward(X_init, C, t=tspan[0]).detach() + elif X_init is None: + X_init = torch.zeros(list(C.shape) + [4, 3], device=C.device) + X_init = self.forward(X_init, C, t=tspan[0]).detach() + + # Determine output shape via a test forward pass + if conditioner: + with torch.enable_grad(): + X_init_test = X_init.clone() + X_init_test.requires_grad = True + S_test = torch.zeros(C.shape, device=X_init.device).long() + O_test = F.one_hot(S_test, len(alphabetical_restypes)).float() + U_test = 0.0 + t_test = torch.tensor([0.0], device=X_init.device) + _, Ct, _, _, _ = conditioner(X_init_test, C, O_test, U_test, t_test) + else: + Ct = C + + # Integrate + X_trajectory = integrate_func(sdefun, X_init, tspan, N=N, T_grid=T_grid) + + # Return constrained coordinates + outputs = { + "C": Ct, + "X_sample": Xt_trajectory[-1], + "X_trajectory": [Xt_trajectory[-1]] + Xt_trajectory, + "Xhat_trajectory": Xhat_trajectory, + "Xunc_trajectory": X_trajectory, + } + return outputs + + @torch.no_grad() + def estimate_pseudoelbo_X( + self, + X0_func, + X, + C, + num_samples=50, + deterministic_seed=0, + return_elbo_t=False, + noise=True, + ): + with torch.random.fork_rng(): + torch.random.manual_seed(deterministic_seed) + + mask = (C > 0).float() + mask_atoms = mask.reshape(list(mask.shape) + [1, 1]).expand([-1, -1, 4, 1]) + + elbo = [] + T = np.linspace(1e-4, 1.0, num_samples) + for t in tqdm(T.tolist()): + X_noise = self.forward(X, C, t=t) if noise else X + X_denoise = X0_func(X_noise, C, t) + + elbo_t = -self.noise_schedule.SNR_derivative(t).to(X.device) * ( + ((mask_atoms * (X_denoise - X) / 10.0) ** 2).sum([1, 2, 3]) + / mask_atoms.sum([1, 2, 3]) + ) + elbo.append(elbo_t) + elbo = torch.stack(elbo, 0) + if not return_elbo_t: + elbo = elbo.mean(0) + return elbo + + def _score_direct( + self, + Xt, + X0_func, + C, + t, + align_X0=True, + ): + X0 = X0_func(Xt, C, t) + + """Compute the score function directly. (Sometimes numerically unstable)""" + + alpha = self.noise_schedule.alpha(t).to(Xt.device) + sigma = self.noise_schedule.sigma(t).to(Xt.device) + + # Impute sensibly behaved values in masked regions for numerical stability + # X0 = backbone.impute_masked_X(X0, C) + Xt = backbone.impute_masked_X(Xt, C) + + if align_X0: + X0, _ = self.loss_rmsd.align(X0, Xt, C, align_unmasked=True) + + # Compute mean + X_mu = self._mean(X0, C, alpha) + X_mu = backbone.impute_masked_X(X_mu, C) + dX = Xt - X_mu + + Ci_dX = self.base_gaussian.multiply_inverse_covariance(dX, C) + score = -Ci_dX / sigma.pow(2)[:, None, None, None] + + # Mask + score = score.masked_fill((C <= 0)[..., None, None], 0.0) + + return score + + def estimate_logp( + self, + X0_func: Callable, + X_sample: torch.Tensor, + C: torch.LongTensor, + N: int, + return_trace_t: bool = False, + ): + """Estimate the model logP for given protein backboones + (num_batch, num_residues, 4, 3) by the Continuous Normalizing Flow formalism + + Reference: + https://arxiv.org/abs/1810.01367 + https://arxiv.org/abs/1806.07366 + + Args: + X0_func (Callable): A function that returns the initial protein backboone + (num) features given a condition. + X_sample (torch.Tensor): A tensor of protein backboone (num) features with + shape + `(batch_size, num_residues, 4, 3)`. + C (torch.Tensor): A tensor of condition features with shape `(batch_size, + num_residues)`. + N (int, optional): number of ode integration steps + return_trace_t (bool, optional): A flag that indicates whether to return the + log |df / dx| for each time step for the integrated log Jacobian trance. + Default is False. + + Returns: + elbo (torch.Tensor): A tensor of logP value + if return_elbo_t is False, or `(N)` if return_elbo_t + is True. + """ + + def divergence(fn, x, t): + """Calculate Divergance with Stochastic Trace Estimator""" + vec_eps = torch.randn_like(x) + fn_out, eps_J_prod = torch.autograd.functional.vjp( + fn, (t, x), vec_eps, create_graph=False + ) + eps_J_eps = ( + (eps_J_prod[1] * vec_eps).reshape(x.shape[0], -1).sum(-1).unsqueeze(-1) + ) + return fn_out, eps_J_eps + + def flow_gradient( + X, + X0_func, + C, + t, + ): + """Compute the time gradient from the probability flow ODE.""" + + _, _, beta, g, _, _ = self._schedule_coefficients(t) + score = self._score_direct(X, X0_func, C, t) + dXdt = (-1 / 2) * beta * X - 0.5 * g.pow(2) * score + + return dXdt + + def odefun(_t, _X): + _t = _t.detach() + f = flow_gradient( + _X, + X0_func, + C, + _t, + ) + return f + + # foward integration to noise + X_sample = backbone.center_X(X_sample, C) + X_sample = backbone.impute_masked_X(X_sample, C) + C = C.abs() + + out = self.sample_sde( + X0_func=X0_func, + C=C, + X_init=X_sample, + N=N, + sde_func="ode", + tspan=(0, 1.0), + inverse_temperature=1.0, + langevin_factor=0.0, + initialize_noise=False, + align_X0=False, + ) + + X_flow = out["X_trajectory"][1:] + + # get ode function + ddlogp = [] + + for i, t in enumerate(tqdm(torch.linspace(1e-2, 1.0, len(X_flow)))): + with torch.enable_grad(): + dlogP = divergence(odefun, X_flow[i], t[None].to(C.device))[1] + ddlogp.append(dlogP.item()) + + logp_x1 = self.base_gaussian.log_prob(X_flow[-1], C).item() + + if return_trace_t: + return np.array(ddlogp) / ((C > 0).float().sum().item() * 4) + else: + return (logp_x1 + np.array(ddlogp).mean()) / ( + (C > 0).float().sum().item() * 4 + ) + + @torch.no_grad() + @validate_XC(all_atom=False) + def estimate_elbo( + self, + X0_func: Callable, + X: torch.Tensor, + C: torch.LongTensor, + num_samples: int = 50, + deterministic_seed: int = 0, + return_elbo_t: bool = False, + grad_logprob_Y_func: Optional[Callable] = None, + ) -> torch.Tensor: + """Estimate the evidence lower bound (ELBO) for given protein backboones + (num_batch, num_residues, 4, 3) and condition. + + Args: + X0_func (Callable): A function that returns the initial protein backboone + (num) features given a condition. + X (torch.Tensor): A tensor of protein backboone (num) features with shape + `(batch_size, num_residues, 4, 3)`. + C (torch.Tensor): A tensor of condition features with shape `(batch_size, + num_residues)`. + num_samples (int, optional): The number of time steps to sample for + estimating the ELBO. Default is 50. + deterministic_seed (int, optional): The seed for generating random noise. + Default is 0. + return_elbo_t (bool, optional): A flag that indicates whether to return the + ELBO for each time step or the average ELBO. Default is False. + grad_logprob_Y_func (Optional[Callable], optional): A function that returns + the gradient of the log probability of the observed protein backboone (num) + given a time step and a noisy image. Default is None. + + Returns: + elbo (torch.Tensor): A tensor of ELBO values with shape `(batch_size,)` + if return_elbo_t is False, or `(num_samples, batch_size)` if return_elbo_t + is True. + """ + X = backbone.impute_masked_X(X, C) + + with torch.random.fork_rng(): + torch.random.manual_seed(deterministic_seed) + mask = (C > 0).float() + mask_atoms = mask.reshape(list(mask.shape) + [1, 1]).expand([-1, -1, 4, 1]) + + elbo = [] + T = np.linspace(1e-4, 1.0, num_samples) + for t in tqdm(T.tolist()): + X_noise = self.forward(X, C, t=t) + X_denoise = X0_func(X_noise, C, t) + + # Adjust X-hat estimate with aux-grad + if grad_logprob_Y_func is not None: + with torch.random.fork_rng(): + grad = grad_logprob_Y_func(t, X_noise) + sigma_square = ( + self.noise_schedule.sigma(t).square().to(X.device) + ) + dXhat = sigma_square * self.base_gaussian.multiply_covariance( + grad, C + ) + dXhat = backbone.center_X(dXhat, C) + X_denoise = X_denoise + dXhat + + elbo_t, _ = self.elbo(X_denoise, X, C, t) + + elbo.append(elbo_t) + + elbo_t = torch.stack(elbo, 0) + + if return_elbo_t: + return elbo_t + else: + return elbo_t.mean(0) + + def conditional_X0( + self, X0: torch.Tensor, score: torch.Tensor, C: torch.tensor, t: torch.Tensor + ) -> torch.Tensor: + """Use Bayes theorem and Tweedie formula to obtain a conditional X0 given + prior X0 and a conditional score \nabla_x p( y | x) + X0 <- X0 + \frac{sigma_t**2}{alpha_t} \Sigma score + Args: + X0 (torch.Tensor): backbone coordinates of size (batch, num_residues, 4, 3) + score (torch.Tensor): of size (batch, num_residues, 4, 3) + C (torch.Tensor): of size (batch, num_residues) + t (torch.Tensor): of size (batch,) + + Returns: + X0 (torch.Tensor): updated conditional X0 of size (batch, num_residues, 4, 3) + """ + alpha, sigma, _, _, _, _ = self._schedule_coefficients(t) + X_update = sigma.pow(2).div(alpha)[ + ..., None + ] * self.base_gaussian.multiply_covariance(score, C) + return X0 + X_update + + def _mean(self, X, C, alpha): + """Build the diffusion kernel mean given alpha""" + # Compute the MVN mean + X_mu = backbone.scale_around_mean(X, C, alpha) + return X_mu + + def _X_to_Z(self, X_sample, X, C, alpha, sigma): + """Convert from output space to standardized space""" + + # Impute missing data with conditional means + X = backbone.impute_masked_X(X, C) + X_sample = backbone.impute_masked_X(X_sample, C) + + # sigma = self.noise_schedule.sigma(t).to(X.device) + + # Step 4. [Inverse] Add mean + X_mu = self._mean(X, C, alpha) + X_mu = backbone.impute_masked_X(X_mu, C) + X_noise = (X_sample - X_mu).reshape(X.shape[0], -1, 3) + + # Step 3. [Inverse] Scale noise by sigma + X_noise = X_noise / sigma[:, None, None] + + # Step 1 & 2. Multiply Z by inverse square root of covariance + Z = self.base_gaussian._multiply_R_inverse(X_noise, C) + + return Z + + def _Z_to_X(self, Z, X, C, alpha, sigma): + """Convert from standardized space to output space""" + + # Step 1 & 2. Multiply Z by square root of covariance + dX = self.base_gaussian._multiply_R(Z, C) + + # Step 3. Scale noise by alpha + dX = sigma[:, None, None, None] * dX.reshape(X.shape) + + # Step 4. Add mean + X_mu = self._mean(X, C, alpha) + X_sample = X_mu + dX + + return X_sample + + def sample_conditional( + self, X: torch.Tensor, C: torch.LongTensor, t: torch.Tensor, s: torch.Tensor + ) -> torch.Tensor: + """ + Samples from the forward process q(x_{t} | x_{s}) for t > s. + See appendix A.1 in [https://arxiv.org/pdf/2107.00630.pdf]. `forward` does this for s = 0. + Args: + X (torch.Tensor): Input coordinates with shape `(batch_size, num_residues, + 4, 3)` at time `t0`. + C (torch.Tensor): Chain tensor with shape `(batch_size, num_residues)`. + t (torch.Tensor): Time index with shape `(batch_size,)`. + s (torch.Tensor): Time index with shape `(batch_size,)`. + + Returns: + X_sample (torch.Tensor): Sampled coordinates from the forward diffusion + marginals with shape `(batch_size, num_residues, 4, 3)`. + """ + assert (t > s).all() + X = backbone.impute_masked_X(X, C) + # Do we need this? + X = backbone.center_X(X, C) + alpha_ts = self.noise_schedule.alpha(t) / self.noise_schedule.alpha(s) + sigma_ts = ( + self.noise_schedule.sigma(t).pow(2) + - alpha_ts.pow(2) * self.noise_schedule.sigma(s).pow(2) + ).sqrt() + + X_sample = alpha_ts * X + sigma_ts * self.base_gaussian.sample(C) + # Do we need this? + X_sample = backbone.center_X(X_sample - X, C) + X + return X_sample + + @validate_XC(all_atom=False) + def forward( + self, X: torch.Tensor, C: torch.LongTensor, t: Optional[torch.Tensor] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Sample from the forwards diffusion marginals at time t + + Inputs: + X (torch.Tensor): Input coordinates with shape `(batch_size, num_residues, + 4, 3)`. + C (torch.LongTensor): Chain tensor with shape `(batch_size, num_residues)`. + t (torch.Tensor, optional): Time index with shape `(batch_size,)`. If not + given, a random time index will be sampled. Defaults to None. + + Outputs: + X_sample (torch.Tensor): Sampled coordinates from the forward diffusion + marginals with shape `(batch_size, num_residues, 4, 3)`. + t (torch.Tensor, optional): Time index with shape `(batch_size,)`. Only + returned if t is not given as input. + """ + + # Draw a sample from the prior + X_prior = self.base_gaussian.sample(C) + + # Sample time if not given + t_input = t + t = self.sample_t(C, t) + + alpha = self.noise_schedule.alpha(t)[:, None, None, None].to(X.device) + sigma = self.noise_schedule.sigma(t)[:, None, None, None].to(X.device) + + X_sample = alpha * X + sigma * X_prior + X_sample = backbone.center_X(X_sample - X, C) + X + + if t_input is None: + return X_sample, t + else: + return X_sample + + +class ReconstructionLosses(nn.Module): + """Compute diffusion reconstruction losses for protein backbones. + + Args: + diffusion (DiffusionChainCov): Diffusion object parameterizing a + forwards diffusion over protein backbones. + loss_scale (float): Length scale parameter used for setting loss error + scaling in units of Angstroms. Default is 10, which corresponds to + using units of nanometers. + rmsd_method (str): Method used for computing RMSD superpositions. Can + be "symeig" (default) or "power" for power iteration. + + Inputs: + X0_pred (torch.Tensor): Denoised coordinates with shape + `(num_batch, num_residues, 4, 3)`. + X (torch.Tensor): Unperturbed coordinates with shape + `(num_batch, num_residues, 4, 3)`. + C (torch.LongTensor): Chain map with shape `(num_batch, num_residues)`. + t (torch.Tensor): Diffusion time with shape `(batch_size,)`. + Should be on [0,1]. + + Outputs: + losses (dict): Dictionary of reconstructions computed across different + metrics. Metrics prefixed with `batch_` will be batch-averaged scalars + while other metrics should be per batch member with shape + `(num_batch, ...)`. + """ + + def __init__( + self, + diffusion: DiffusionChainCov, + loss_scale: float = 10.0, + rmsd_method: str = "symeig", + ): + super().__init__() + self.noise_perturb = diffusion + self.loss_scale = loss_scale + self._loss_eps = 1e-5 + + # Auxiliary losses + self.loss_rmsd = rmsd.BackboneRMSD(method=rmsd_method) + self.loss_fragment = rmsd.LossFragmentRMSD(method=rmsd_method) + self.loss_fragment_pair = rmsd.LossFragmentPairRMSD(method=rmsd_method) + self.loss_neighborhood = rmsd.LossNeighborhoodRMSD(method=rmsd_method) + self.loss_hbond = hbonds.LossBackboneHBonds() + self.loss_distance = backbone.LossBackboneResidueDistance() + + self.loss_functions = { + "elbo": self._loss_elbo, + "rmsd": self._loss_rmsd, + "pseudoelbo": self._loss_pseudoelbo, + "fragment": self._loss_fragment, + "pair": self._loss_pair, + "neighborhood": self._loss_neighborhood, + "distance": self._loss_distance, + "hbonds": self._loss_hbonds, + } + + def _batch_average(self, loss, C): + weights = (C > 0).float().sum(-1) + return (weights * loss).sum() / (weights.sum() + self._loss_eps) + + def _loss_elbo(self, losses, X0_pred, X, C, t, w=None, X_t_2=None): + losses["elbo"], losses["batch_elbo"] = self.noise_perturb.elbo(X0_pred, X, C, t) + + def _loss_rmsd(self, losses, X0_pred, X, C, t, w=None, X_t_2=None): + _, rmsd_denoise = self.loss_rmsd.align(X, X0_pred, C) + _, rmsd_noise = self.loss_rmsd.align(X, X_t_2, C) + rmsd_ratio_per_item = w * rmsd_denoise / (rmsd_noise + self._loss_eps) + global_mse_normalized = ( + w + * self.loss_scale + * rmsd_denoise.square() + / (rmsd_noise.square() + self._loss_eps) + ) + losses["rmsd_ratio"] = self._batch_average(rmsd_ratio_per_item, C) + losses["global_mse"] = global_mse_normalized + losses["batch_global_mse"] = self._batch_average(global_mse_normalized, C) + + def _loss_pseudoelbo(self, losses, X0_pred, X, C, t, w=None, X_t_2=None): + # Unaligned residual pseudoELBO + unaligned_mse = ((X - X0_pred) / self.loss_scale).square().sum(-1).mean(-1) + losses["elbo_X"], losses["batch_pseudoelbo_X"] = self.noise_perturb.pseudoelbo( + unaligned_mse, C, t + ) + + def _loss_fragment(self, losses, X0_pred, X, C, t, w=None, X_t_2=None): + # Aligned Fragment MSE loss + mask = (C > 0).float() + rmsd_fragment = self.loss_fragment(X0_pred, X, C) + rmsd_fragment_noise = self.loss_fragment(X_t_2, X, C) + fragment_mse_normalized = ( + self.loss_scale + * w + * ( + (mask * rmsd_fragment.square()).sum(1) + / ((mask * rmsd_fragment_noise.square()).sum(1) + self._loss_eps) + ) + ) + losses["fragment_mse"] = fragment_mse_normalized + losses["batch_fragment_mse"] = self._batch_average(fragment_mse_normalized, C) + + def _loss_pair(self, losses, X0_pred, X, C, t, w=None, X_t_2=None): + # Aligned Pair MSE loss + rmsd_pair, mask_ij_pair = self.loss_fragment_pair(X0_pred, X, C) + rmsd_pair_noise, mask_ij_pair = self.loss_fragment_pair(X_t_2, X, C) + pair_mse_normalized = ( + self.loss_scale + * w + * ( + (mask_ij_pair * rmsd_pair.square()).sum([1, 2]) + / ( + (mask_ij_pair * rmsd_pair_noise.square()).sum([1, 2]) + + self._loss_eps + ) + ) + ) + losses["pair_mse"] = pair_mse_normalized + losses["batch_pair_mse"] = self._batch_average(pair_mse_normalized, C) + + def _loss_neighborhood(self, losses, X0_pred, X, C, t, w=None, X_t_2=None): + # Neighborhood MSE + rmsd_neighborhood, mask = self.loss_neighborhood(X0_pred, X, C) + rmsd_neighborhood_noise, mask = self.loss_neighborhood(X_t_2, X, C) + neighborhood_mse_normalized = ( + self.loss_scale + * w + * ( + (mask * rmsd_neighborhood.square()).sum(1) + / ((mask * rmsd_neighborhood_noise.square()).sum(1) + self._loss_eps) + ) + ) + losses["neighborhood_mse"] = neighborhood_mse_normalized + losses["batch_neighborhood_mse"] = self._batch_average( + neighborhood_mse_normalized, C + ) + + def _loss_distance(self, losses, X0_pred, X, C, t, w=None, X_t_2=None): + # Distance MSE + mask = (C > 0).float() + distance_mse = self.loss_distance(X0_pred, X, C) + distance_mse_noise = self.loss_distance(X_t_2, X, C) + distance_mse_normalized = self.loss_scale * ( + w + * (mask * distance_mse).sum(1) + / ((mask * distance_mse_noise).sum(1) + self._loss_eps) + ) + losses["distance_mse"] = distance_mse_normalized + losses["batch_distance_mse"] = self._batch_average(distance_mse_normalized, C) + + def _loss_hbonds(self, losses, X0_pred, X, C, t, w=None, X_t_2=None): + # HBond recovery + outs = self.loss_hbond(X0_pred, X, C) + hb_local, hb_nonlocal, error_co = [w * o for o in outs] + + losses["batch_hb_local"] = self._batch_average(hb_local, C) + losses["hb_local"] = hb_local + losses["batch_hb_nonlocal"] = self._batch_average(hb_nonlocal, C) + losses["hb_nonlocal"] = hb_nonlocal + losses["batch_hb_contact_order"] = self._batch_average(error_co, C) + + @torch.no_grad() + @validate_XC(all_atom=False) + def estimate_metrics( + self, + X0_func: Callable, + X: torch.Tensor, + C: torch.LongTensor, + num_samples: int = 50, + deterministic_seed: int = 0, + use_noise: bool = True, + return_samples: bool = False, + tspan: Tuple[float] = (1e-4, 1.0), + ): + """Estimate time-averaged reconstruction losses of protein backbones. + + Args: + X0_func (Callable): A denoising function that maps `(X, C, t)` to `X0`. + X (torch.Tensor): A tensor of protein backboone (num) features with shape + `(batch_size, num_residues, 4, 3)`. + C (torch.Tensor): A tensor of condition features with shape `(batch_size, + num_residues)`. + num_samples (int, optional): The number of time steps to sample for + estimating the ELBO. Default is 50. + use_noise (bool): If True, add noise to each structure before denoising. + Default is True. When False this can be used for estimating if + if structures are fixed points of the denoiser across time. + deterministic_seed (int, optional): The seed for generating random noise. + Default is 0. + return_samples (bool): If True, include intermediate sampled + values for each metric. Default is false. + tspan (tuple[float]): Tuple of floats indicating the diffusion + times between which to integrate. + + Returns: + metrics (dict): A dictionary of reconstruction metrics averaged over + time. + metrics_samples (dict, optional): A dictionary of in metrics + averaged over time. + """ + # + X = backbone.impute_masked_X(X, C) + with torch.random.fork_rng(): + torch.random.manual_seed(deterministic_seed) + T = np.linspace(1e-4, 1.0, num_samples) + losses = [] + for t in tqdm(T.tolist(), desc="Integrating diffusion metrics"): + X_noise = self.noise_perturb(X, C, t=t) if use_noise else X + X_denoise = X0_func(X_noise, C, t) + losses_t = self.forward(X_denoise, X, C, t) + + # Discard batch estimated objects + losses_t = { + k: v + for k, v in losses_t.items() + if not k.startswith("batch_") and k != "rmsd_ratio" + } + losses.append(losses_t) + + # Transpose list of dicts to a dict of lists + metrics_samples = {k: [d[k] for d in losses] for k in losses[0].keys()} + + # Average final metrics across time + metrics = { + k: torch.stack(v, 0).mean(0) + for k, v in metrics_samples.items() + if isinstance(v[0], torch.Tensor) + } + if return_samples: + return metrics, metrics_samples + else: + return metrics + + @validate_XC() + def forward( + self, + X0_pred: torch.Tensor, + X: torch.Tensor, + C: torch.LongTensor, + t: torch.Tensor, + ): + # Collect all losses and tensors for metric tracking + losses = {"t": t, "X": X, "X0_pred": X0_pred} + X_t_2 = self.noise_perturb(X, C, t=t) + + # Per complex weights + ssnr = self.noise_perturb.noise_schedule.SSNR(t).to(X.device) + prob_ssnr = self.noise_perturb.noise_schedule.prob_SSNR(ssnr) + importance_weights = 1 / prob_ssnr + + for _loss in self.loss_functions.values(): + _loss(losses, X0_pred, X, C, t, w=importance_weights, X_t_2=X_t_2) + return losses + + +def _debug_viz_gradients( + pml_file, X_list, dX_list, C, S, arrow_length=2.0, name="gradient", color="red" +): + """ """ + lines = [ + "from pymol.cgo import *", + "from pymol import cmd", + f'color_1 = list(pymol.cmd.get_color_tuple("{color}"))', + 'color_2 = list(pymol.cmd.get_color_tuple("blue"))', + ] + + with open(pml_file, "w") as f: + for model_ix, X in enumerate(X_list): + print(model_ix) + lines = lines + ["obj_1 = []"] + + dX = dX_list[model_ix] + scale = dX.norm(dim=-1).mean().item() + X_i = X + X_j = X + arrow_length * dX / scale + + for a_ix in range(4): + for i in range(X.size(1)): + x_i = X_i[0, i, a_ix, :].tolist() + x_j = X_j[0, i, a_ix, :].tolist() + lines = lines + [ + f"obj_1 = obj_1 + [CYLINDER] + {x_i} + {x_j} + [0.15]" + " + color_1 + color_1" + ] + lines = lines + [f'cmd.load_cgo(obj_1, "{name}", {model_ix+1})'] + f.write("\n" + "\n".join(lines)) + lines = [] + + +def _debug_viz_XZC(X, Z, C, rgb=True): + from matplotlib import pyplot as plt + + if len(X.shape) > 3: + X = X.reshape(X.shape[0], -1, 3) + if len(Z.shape) > 3: + Z = Z.reshape(Z.shape[0], -1, 3) + if C.shape[1] != X.shape[1]: + C_expand = C.unsqueeze(-1).expand(-1, -1, 4) + C = C_expand.reshape(C.shape[0], -1) + + # C_mask = expand_chain_map(torch.abs(C)) + # X_expand = torch.einsum('nix,nic->nicx', X, C_mask) + # plt.plot(X_expand[0,:,:,0].data.numpy()) + N = X.shape[1] + Ymax = torch.max(X[0, :, 0]).item() + plt.figure(figsize=[12, 4]) + plt.subplot(2, 1, 1) + + plt.bar( + np.arange(0, N), + (C[0, :].data.numpy() < 0) * Ymax, + width=1.0, + edgecolor=None, + color="lightgrey", + ) + if rgb: + plt.plot(X[0, :, 0].data.numpy(), "r", linewidth=0.5) + plt.plot(X[0, :, 1].data.numpy(), "g", linewidth=0.5) + plt.plot(X[0, :, 2].data.numpy(), "b", linewidth=0.5) + plt.xlim([0, N]) + plt.grid() + plt.title("X") + plt.xticks([]) + plt.subplot(2, 1, 2) + plt.plot(Z[0, :, 0].data.numpy(), "r", linewidth=0.5) + plt.plot(Z[0, :, 1].data.numpy(), "g", linewidth=0.5) + plt.plot(Z[0, :, 2].data.numpy(), "b", linewidth=0.5) + plt.plot(C[0, :].data.numpy(), "orange") + plt.xlim([0, N]) + plt.grid() + plt.title("RInverse @ [X]") + plt.xticks([]) + plt.savefig("xzc.pdf") + else: + plt.plot(X[0, :, 0].data.numpy(), "k", linewidth=0.5) + plt.xlim([0, N]) + plt.grid() + plt.title("X") + plt.xticks([]) + plt.subplot(2, 1, 2) + plt.plot(Z[0, :, 0].data.numpy(), "k", linewidth=0.5) + plt.plot(C[0, :].data.numpy(), "orange") + plt.xlim([0, N]) + plt.grid() + plt.title("Inverse[X]") + plt.xticks([]) + plt.savefig("xzc.pdf") + exit() diff --git a/prtm/models/chroma/structure/geometry.py b/prtm/models/chroma/structure/geometry.py new file mode 100644 index 00000000..93b8a478 --- /dev/null +++ b/prtm/models/chroma/structure/geometry.py @@ -0,0 +1,681 @@ +# Copyright Generate Biomedicines, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Layers for measuring and building atomic geometries in proteins. + +This module contains pytorch layers for computing common geometric features of +protein backbones in a differentiable way and for converting between internal +and Cartesian coordinate representations. +""" + +from typing import Optional, Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Distances(nn.Module): + """Euclidean distance layer (pairwise). + + This layer computes batched pairwise Euclidean distances, where the input + tensor is treated as a batch of vectors with the final dimension as the + feature dimension and the dimension for pairwise expansion can be specified. + + Args: + distance_eps (float, optional): Small parameter to adde to squared + distances to make gradients smooth near 0. + + Inputs: + X (tensor): Input coordinates with shape `([...], length, [...], 3)`. + dim (int, optional): Dimension upon which to expand to pairwise + distances. Defaults to -2. + mask (tensor, optional): Masking tensor with shape + `([...], length, [...])`. + + Outputs: + D (tensor): Distances with shape `([...], length, length, [...])` + """ + + def __init__(self, distance_eps=1e-3): + super(Distances, self).__init__() + self.distance_eps = distance_eps + + def forward( + self, X: torch.Tensor, mask: Optional[torch.Tensor] = None, dim: float = -2 + ) -> torch.Tensor: + dim_expand = dim if dim < 0 else dim + 1 + dX = X.unsqueeze(dim_expand - 1) - X.unsqueeze(dim_expand) + D_square = torch.sum(dX**2, -1) + D = torch.sqrt(D_square + self.distance_eps) + if mask is not None: + mask_expand = mask.unsqueeze(dim) * mask.unsqueeze(dim + 1) + D = mask_expand * D + return D + + +class VirtualAtomsCA(nn.Module): + """Virtual atoms layer, branching from backbone C-alpha carbons. + + This layer places virtual atom coordinates relative to backbone coordinates + in a differentiable way. + + Args: + virtual_type (str, optional): Type of virtual atom to place. Currently + supported types are `dicons`, a virtual placement that was + optimized to predict potential rotamer interactions, and `cbeta` + which places a virtual C-beta carbon assuming ideal geometry. + distance_eps (float, optional): Small parameter to add to squared + distances to make gradients smooth near 0. + + Inputs: + X (Tensor): Backbone coordinates with shape + `(num_batch, num_residues, num_atom_types, 3)`. + C (Tensor): Chain map tensor with shape `(num_batch, num_residues)`. + + Outputs: + X_virtual (Tensor): Virtual coordinates with shape + `(num_batch, num_residues, 3)`. + """ + + def __init__(self, virtual_type="dicons", distance_eps=1e-3): + super(VirtualAtomsCA, self).__init__() + self.distance_eps = distance_eps + + """ + Geometry specifications + dicons + Length CA-X: 2.3866 + Angle N-CA-X: 111.0269 + Dihedral C-N-CA-X: -138.886412 + + cbeta + Length CA-X: 1.532 (Engh and Huber, 2001) + Angle N-CA-X: 109.5 (tetrahedral geometry) + Dihedral C-N-CA-X: -125.25 (109.5 / 2 - 180) + """ + self.virtual_type = virtual_type + virtual_geometries = { + "dicons": [2.3866, 111.0269, -138.8864122], + "cbeta": [1.532, 109.5, -125.25], + } + self.virtual_geometries = virtual_geometries + self.distance_eps = distance_eps + + def geometry(self): + bond, angle, dihedral = self.virtual_geometries[self.virtual_type] + return bond, angle, dihedral + + def forward(self, X: torch.Tensor, C: torch.LongTensor) -> torch.Tensor: + bond, angle, dihedral = self.geometry() + + ones = torch.ones([1, 1], device=X.device) + bonds = bond * ones + angles = angle * ones + dihedrals = dihedral * ones + + # Build reference frame + # 1.C -> 2.N -> 3.CA -> 4.X + X_N, X_CA, X_C, X_O = X.unbind(2) + X_virtual = extend_atoms( + X_C, + X_N, + X_CA, + bonds, + angles, + dihedrals, + degrees=True, + distance_eps=self.distance_eps, + ) + + # Mask missing positions + mask = (C > 0).type(torch.float32).unsqueeze(-1) + X_virtual = mask * X_virtual + return X_virtual + + +def normed_vec(V: torch.Tensor, distance_eps: float = 1e-3) -> torch.Tensor: + """Normalized vectors with distance smoothing. + + This normalization is computed as `U = V / sqrt(|V|^2 + eps)` to avoid cusps + and gradient discontinuities. + + Args: + V (Tensor): Batch of vectors with shape `(..., num_dims)`. + distance_eps (float, optional): Distance smoothing parameter for + for computing distances as `sqrt(sum_sq) -> sqrt(sum_sq + eps)`. + Default: 1E-3. + + Returns: + U (Tensor): Batch of normalized vectors with shape `(..., num_dims)`. + """ + # Unit vector from i to j + mag_sq = (V**2).sum(dim=-1, keepdim=True) + mag = torch.sqrt(mag_sq + distance_eps) + U = V / mag + return U + + +def normed_cross( + V1: torch.Tensor, V2: torch.Tensor, distance_eps: float = 1e-3 +) -> torch.Tensor: + """Normalized cross product between vectors. + + This normalization is computed as `U = V / sqrt(|V|^2 + eps)` to avoid cusps + and gradient discontinuities. + + Args: + V1 (Tensor): Batch of vectors with shape `(..., 3)`. + V2 (Tensor): Batch of vectors with shape `(..., 3)`. + distance_eps (float, optional): Distance smoothing parameter for + for computing distances as `sqrt(sum_sq) -> sqrt(sum_sq + eps)`. + Default: 1E-3. + + Returns: + C (Tensor): Batch of cross products `v_1 x v_2` with shape `(..., 3)`. + """ + C = normed_vec(torch.cross(V1, V2, dim=-1), distance_eps=distance_eps) + return C + + +def lengths( + atom_i: torch.Tensor, atom_j: torch.Tensor, distance_eps: float = 1e-3 +) -> torch.Tensor: + """Batched bond lengths given batches of atom i and j. + + Args: + atom_i (Tensor): Atom `i` coordinates with shape `(..., 3)`. + atom_j (Tensor): Atom `j` coordinates with shape `(..., 3)`. + distance_eps (float, optional): Distance smoothing parameter for + for computing distances as `sqrt(sum_sq) -> sqrt(sum_sq + eps)`. + Default: 1E-3. + + Returns: + L (Tensor): Elementwise bond lengths `||x_i - x_j||` with shape `(...)`. + """ + # Bond length of i-j + dX = atom_j - atom_i + L = torch.sqrt((dX**2).sum(dim=-1) + distance_eps) + return L + + +def angles( + atom_i: torch.Tensor, + atom_j: torch.Tensor, + atom_k: torch.Tensor, + distance_eps: float = 1e-3, + degrees: bool = False, +) -> torch.Tensor: + """Batched bond angles given atoms `i-j-k`. + + Args: + atom_i (Tensor): Atom `i` coordinates with shape `(..., 3)`. + atom_j (Tensor): Atom `j` coordinates with shape `(..., 3)`. + atom_k (Tensor): Atom `k` coordinates with shape `(..., 3)`. + distance_eps (float, optional): Distance smoothing parameter for + for computing distances as `sqrt(sum_sq) -> sqrt(sum_sq + eps)`. + Default: 1E-3. + degrees (bool, optional): If True, convert to degrees. Default: False. + + Returns: + A (Tensor): Elementwise bond angles with shape `(...)`. + """ + # Bond angle of i-j-k + U_ji = normed_vec(atom_i - atom_j, distance_eps=distance_eps) + U_jk = normed_vec(atom_k - atom_j, distance_eps=distance_eps) + inner_prod = torch.einsum("bix,bix->bi", U_ji, U_jk) + inner_prod = torch.clamp(inner_prod, -1, 1) + A = torch.acos(inner_prod) + if degrees: + A = A * 180.0 / np.pi + return A + + +def dihedrals( + atom_i: torch.Tensor, + atom_j: torch.Tensor, + atom_k: torch.Tensor, + atom_l: torch.Tensor, + distance_eps: float = 1e-3, + degrees: bool = False, +) -> torch.Tensor: + """Batched bond dihedrals given atoms `i-j-k-l`. + + Args: + atom_i (Tensor): Atom `i` coordinates with shape `(..., 3)`. + atom_j (Tensor): Atom `j` coordinates with shape `(..., 3)`. + atom_k (Tensor): Atom `k` coordinates with shape `(..., 3)`. + atom_l (Tensor): Atom `l` coordinates with shape `(..., 3)`. + distance_eps (float, optional): Distance smoothing parameter for + for computing distances as `sqrt(sum_sq) -> sqrt(sum_sq + eps)`. + Default: 1E-3. + degrees (bool, optional): If True, convert to degrees. Default: False. + + Returns: + D (Tensor): Elementwise bond dihedrals with shape `(...)`. + """ + U_ij = normed_vec(atom_j - atom_i, distance_eps=distance_eps) + U_jk = normed_vec(atom_k - atom_j, distance_eps=distance_eps) + U_kl = normed_vec(atom_l - atom_k, distance_eps=distance_eps) + normal_ijk = normed_cross(U_ij, U_jk, distance_eps=distance_eps) + normal_jkl = normed_cross(U_jk, U_kl, distance_eps=distance_eps) + # _inner_product = lambda a, b: torch.einsum("bix,bix->bi", a, b) + _inner_product = lambda a, b: (a * b).sum(-1) + cos_dihedrals = _inner_product(normal_ijk, normal_jkl) + angle_sign = _inner_product(U_ij, normal_jkl) + cos_dihedrals = torch.clamp(cos_dihedrals, -1, 1) + D = torch.sign(angle_sign) * torch.acos(cos_dihedrals) + if degrees: + D = D * 180.0 / np.pi + return D + + +def extend_atoms( + X_1: torch.Tensor, + X_2: torch.Tensor, + X_3: torch.Tensor, + lengths: torch.Tensor, + angles: torch.Tensor, + dihedrals: torch.Tensor, + distance_eps: float = 1e-3, + degrees: bool = False, +) -> torch.Tensor: + """Place atom `X_4` given `X_1`, `X_2`, `X_3` and internal coordinates. + + ___________________ + | X_1 - X_2 | + | | | + | X_3 - [X_4] | + |___________________| + + This uses a similar approach as NERF: + Parsons et al, Computational Chemistry (2005). + https://doi.org/10.1002/jcc.20237 + See the reference for further explanation about converting from internal + coordinates to Cartesian coordinates. + + Args: + X_1 (Tensor): First atom coordinates with shape `(..., 3)`. + X_2 (Tensor): Second atom coordinates with shape `(..., 3)`. + X_3 (Tensor): Third atom coordinates with shape `(..., 3)`. + lengths (Tensor): Bond lengths `X_3-X_4` with shape `(...)`. + angles (Tensor): Bond angles `X_2-X_3-X_4` with shape `(...)`. + dihedrals (Tensor): Bond dihedrals `X_1-X_2-X_3-X_4` with shape `(...)`. + distance_eps (float, optional): Distance smoothing parameter for + for computing distances as `sqrt(sum_sq) -> sqrt(sum_sq + eps)`. + This preserves differentiability for zero distances. Default: 1E-3. + degrees (bool, optional): If True, inputs are treated as degrees. + Default: False. + + Returns: + X_4 (Tensor): Placed atom with shape `(..., 3)`. + """ + if degrees: + angles *= np.pi / 180.0 + dihedrals *= np.pi / 180.0 + + r_32 = X_2 - X_3 + r_12 = X_2 - X_1 + n_1 = normed_vec(r_32, distance_eps=distance_eps) + n_2 = normed_cross(n_1, r_12, distance_eps=distance_eps) + n_3 = normed_cross(n_1, n_2, distance_eps=distance_eps) + + lengths = lengths.unsqueeze(-1) + cos_angle = torch.cos(angles).unsqueeze(-1) + sin_angle = torch.sin(angles).unsqueeze(-1) + cos_dihedral = torch.cos(dihedrals).unsqueeze(-1) + sin_dihedral = torch.sin(dihedrals).unsqueeze(-1) + + X_4 = X_3 + lengths * ( + cos_angle * n_1 + + (sin_angle * sin_dihedral) * n_2 + + (sin_angle * cos_dihedral) * n_3 + ) + return X_4 + + +class InternalCoords(nn.Module): + """Internal coordinates layer. + + This layer computes internal coordinates (ICs) from a batch of protein + backbones. To make the ICs differentiable everywhere, this layer replaces + distance calculations of the form `sqrt(sum_sq)` with smooth, non-cusped + approximation `sqrt(sum_sq + eps)`. + + Args: + distance_eps (float, optional): Small parameter to add to squared + distances to make gradients smooth near 0. + + Inputs: + X (Tensor): Backbone coordinates with shape + `(num_batch, num_residues, num_atom_types, 3)`. + C (Tensor): Chain map tensor with shape + `(num_batch, num_residues)`. + + Outputs: + dihedrals (Tensor): Backbone dihedral angles with shape + `(num_batch, num_residues, 4)` + angles (Tensor): Backbone bond lengths with shape + `(num_batch, num_residues, 4)` + lengths (Tensor): Backbone bond lengths with shape + `(num_batch, num_residues, 4)` + """ + + def __init__(self, distance_eps=1e-3): + super(InternalCoords, self).__init__() + self.distance_eps = distance_eps + + def forward( + self, + X: torch.Tensor, + C: Optional[torch.Tensor] = None, + return_masks: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + mask = (C > 0).float() + X_chain = X[:, :, :3, :] + num_batch, num_residues, _, _ = X_chain.shape + X_chain = X_chain.reshape(num_batch, 3 * num_residues, 3) + + # This function historically returns the angle complement + _lengths = lambda Xi, Xj: lengths(Xi, Xj, distance_eps=self.distance_eps) + _angles = lambda Xi, Xj, Xk: np.pi - angles( + Xi, Xj, Xk, distance_eps=self.distance_eps + ) + _dihedrals = lambda Xi, Xj, Xk, Xl: dihedrals( + Xi, Xj, Xk, Xl, distance_eps=self.distance_eps + ) + + # Compute internal coordinates associated with -[N]-[CA]-[C]- + NCaC_L = _lengths(X_chain[:, 1:, :], X_chain[:, :-1, :]) + NCaC_A = _angles(X_chain[:, :-2, :], X_chain[:, 1:-1, :], X_chain[:, 2:, :]) + NCaC_D = _dihedrals( + X_chain[:, :-3, :], + X_chain[:, 1:-2, :], + X_chain[:, 2:-1, :], + X_chain[:, 3:, :], + ) + + # Compute internal coordinates associated with [C]=[O] + _, X_CA, X_C, X_O = X.unbind(dim=2) + X_N_next = X[:, 1:, 0, :] + O_L = _lengths(X_C, X_O) + O_A = _angles(X_CA, X_C, X_O) + O_D = _dihedrals(X_N_next, X_CA[:, :-1, :], X_C[:, :-1, :], X_O[:, :-1, :]) + + if C is None: + C = torch.zeros_like(mask) + + # Mask nonphysical bonds and angles + # Note: this could probably also be expressed as a Conv, unclear + # which is faster and this probably not rate-limiting. + C = C * (mask.type(torch.long)) + ii = torch.stack(3 * [C], dim=-1).view([num_batch, -1]) + L0, L1 = ii[:, :-1], ii[:, 1:] + A0, A1, A2 = ii[:, :-2], ii[:, 1:-1], ii[:, 2:] + D0, D1, D2, D3 = ii[:, :-3], ii[:, 1:-2], ii[:, 2:-1], ii[:, 3:] + + # Mask for linear backbone + mask_L = torch.eq(L0, L1) + mask_A = torch.eq(A0, A1) * torch.eq(A0, A2) + mask_D = torch.eq(D0, D1) * torch.eq(D0, D2) * torch.eq(D0, D3) + mask_L = mask_L.type(torch.float32) + mask_A = mask_A.type(torch.float32) + mask_D = mask_D.type(torch.float32) + + # Masks for branched oxygen + mask_O_D = torch.eq(C[:, :-1], C[:, 1:]) + mask_O_D = mask_O_D.type(torch.float32) + mask_O_A = mask + mask_O_L = mask + + def _pad_pack(D, A, L, O_D, O_A, O_L): + # Pad and pack together the components + D = F.pad(D, (1, 2)) + A = F.pad(A, (0, 2)) + L = F.pad(L, (0, 1)) + O_D = F.pad(O_D, (0, 1)) + D, A, L = [x.reshape(num_batch, num_residues, 3) for x in [D, A, L]] + _pack = lambda a, b: torch.cat([a, b.unsqueeze(-1)], dim=-1) + L = _pack(L, O_L) + A = _pack(A, O_A) + D = _pack(D, O_D) + return D, A, L + + D, A, L = _pad_pack(NCaC_D, NCaC_A, NCaC_L, O_D, O_A, O_L) + mask_D, mask_A, mask_L = _pad_pack( + mask_D, mask_A, mask_L, mask_O_D, mask_O_A, mask_O_L + ) + mask_expand = mask.unsqueeze(-1) + mask_D = mask_expand * mask_D + mask_A = mask_expand * mask_A + mask_L = mask_expand * mask_L + + D = mask_D * D + A = mask_A * A + L = mask_L * L + + if not return_masks: + return D, A, L + else: + return D, A, L, mask_D, mask_A, mask_L + + +class VirtualAtomsCA(nn.Module): + """Virtual atoms layer, branching from backbone C-alpha carbons. + + This layer places virtual atom coordinates relative to backbone coordinates + in a differentiable way. + + Args: + virtual_type (str, optional): Type of virtual atom to place. Currently + supported types are `dicons`, a virtual placement that was + optimized to predict potential rotamer interactions, and `cbeta` + which places a virtual C-beta carbon assuming ideal geometry. + distance_eps (float, optional): Small parameter to add to squared + distances to make gradients smooth near 0. + + Inputs: + X (Tensor): Backbone coordinates with shape + `(num_batch, num_residues, num_atom_types, 3)`. + C (Tensor): Chain map tensor with shape `(num_batch, num_residues)`. + + Outputs: + X_virtual (Tensor): Virtual coordinates with shape + `(num_batch, num_residues, 3)`. + """ + + def __init__(self, virtual_type="dicons", distance_eps=1e-3): + super(VirtualAtomsCA, self).__init__() + self.distance_eps = distance_eps + + """ + Geometry specifications + dicons + Length CA-X: 2.3866 + Angle N-CA-X: 111.0269 + Dihedral C-N-CA-X: -138.886412 + + cbeta + Length CA-X: 1.532 (Engh and Huber, 2001) + Angle N-CA-X: 109.5 (tetrahedral geometry) + Dihedral C-N-CA-X: -125.25 (109.5 / 2 - 180) + """ + self.virtual_type = virtual_type + virtual_geometries = { + "dicons": [2.3866, 111.0269, -138.8864122], + "cbeta": [1.532, 109.5, -125.25], + } + self.virtual_geometries = virtual_geometries + self.distance_eps = distance_eps + + def geometry(self): + bond, angle, dihedral = self.virtual_geometries[self.virtual_type] + return bond, angle, dihedral + + def forward(self, X: torch.Tensor, C: torch.LongTensor) -> torch.Tensor: + bond, angle, dihedral = self.geometry() + + ones = torch.ones([1, 1], device=X.device) + bonds = bond * ones + angles = angle * ones + dihedrals = dihedral * ones + + # Build reference frame + # 1.C -> 2.N -> 3.CA -> 4.X + X_N, X_CA, X_C, X_O = X.unbind(2) + X_virtual = extend_atoms( + X_C, + X_N, + X_CA, + bonds, + angles, + dihedrals, + degrees=True, + distance_eps=self.distance_eps, + ) + + # Mask missing positions + mask = (C > 0).type(torch.float32).unsqueeze(-1) + X_virtual = mask * X_virtual + return X_virtual + + +def quaternions_from_rotations(R: torch.Tensor, eps: float = 1e-3) -> torch.Tensor: + """Convert a batch of rotation matrices to quaternions. + + See en.wikipedia.org/wiki/Quaternions_and_spatial_rotation for further + details on converting between quaternions and rotation matrices. + + Args: + R (tensor): Batch of rotation matrices with shape `(..., 3, 3)`. + + Returns: + q (tensor): Batch of quaternion vectors with shape `(..., 4)`. Quaternion + is in the order `[angle, axis_x, axis_y, axis_z]`. + """ + + batch_dims = list(R.shape)[:-2] + R_flat = R.reshape(batch_dims + [9]) + R00, R01, R02, R10, R11, R12, R20, R21, R22 = R_flat.unbind(-1) + + # Quaternion possesses both an axis and angle of rotation + _sqrt = lambda r: torch.sqrt(F.relu(r) + eps) + q_angle = _sqrt(1 + R00 + R11 + R22).unsqueeze(-1) + magnitudes = _sqrt( + 1 + torch.stack([R00 - R11 - R22, -R00 + R11 - R22, -R00 - R11 + R22], -1) + ) + signs = torch.sign(torch.stack([R21 - R12, R02 - R20, R10 - R01], -1)) + q_axis = signs * magnitudes + + # Normalize (for safety and a missing factor of 2) + q_unc = torch.cat((q_angle, q_axis), -1) + q = normed_vec(q_unc, distance_eps=eps) + return q + + +def rotations_from_quaternions( + q: torch.Tensor, normalize: bool = False, eps: float = 1e-3 +) -> torch.Tensor: + """Convert a batch of quaternions to rotation matrices. + + See en.wikipedia.org/wiki/Quaternions_and_spatial_rotation for further + details on converting between quaternions and rotation matrices. + + Returns: + q (tensor): Batch of quaternion vectors with shape `(..., 4)`. Quaternion + is in the order `[angle, axis_x, axis_y, axis_z]`. + normalize (boolean, optional): Option to normalize the quaternion before + conversion. + + Args: + R (tensor): Batch of rotation matrices with shape `(..., 3, 3)`. + """ + batch_dims = list(q.shape)[:-1] + if normalize: + q = normed_vec(q, distance_eps=eps) + + a, b, c, d = q.unbind(-1) + a2, b2, c2, d2 = a**2, b**2, c**2, d**2 + R = torch.stack( + [ + a2 + b2 - c2 - d2, + 2 * b * c - 2 * a * d, + 2 * b * d + 2 * a * c, + 2 * b * c + 2 * a * d, + a2 - b2 + c2 - d2, + 2 * c * d - 2 * a * b, + 2 * b * d - 2 * a * c, + 2 * c * d + 2 * a * b, + a2 - b2 - c2 + d2, + ], + dim=-1, + ) + + R = R.view(batch_dims + [3, 3]) + return R + + +def frames_from_backbone(X: torch.Tensor, distance_eps: float = 1e-3): + """Convert a backbone into local reference frames. + + Args: + X (Tensor): Backbone coordinates with shape `(..., 4, 3)`. + distance_eps (float, optional): Distance smoothing parameter for + for computing distances as `sqrt(sum_sq) -> sqrt(sum_sq + eps)`. + Default: 1E-3. + + Returns: + R (Tensor): Reference frames with shape `(..., 3, 3)`. + X_CA (Tensor): C-alpha coordinates with shape `(..., 3)` + """ + X_N, X_CA, X_C, X_O = X.unbind(-2) + u_CA_N = normed_vec(X_N - X_CA, distance_eps) + u_CA_C = normed_vec(X_C - X_CA, distance_eps) + n_1 = u_CA_N + n_2 = normed_cross(n_1, u_CA_C, distance_eps) + n_3 = normed_cross(n_1, n_2, distance_eps) + R = torch.stack([n_1, n_2, n_3], -1) + return R, X_CA + + +def hat(omega: torch.Tensor) -> torch.Tensor: + """ + Maps [x,y,z] to [[0,-z,y], [z,0,-x], [-y, x, 0]] + Args: + omega (torch.tensor): of size (*, 3) + Returns: + hat{omega} (torch.tensor): of size (*, 3, 3) skew symmetric element in so(3) + """ + target = torch.zeros(*omega.size()[:-1], 9, device=omega.device) + index1 = torch.tensor([7, 2, 3], device=omega.device).expand( + *target.size()[:-1], -1 + ) + index2 = torch.tensor([5, 6, 1], device=omega.device).expand( + *target.size()[:-1], -1 + ) + return ( + target.scatter(-1, index1, omega) + .scatter(-1, index2, -omega) + .reshape(*target.size()[:-1], 3, 3) + ) + + +def V(omega: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: + I = torch.eye(3, device=omega.device).expand(*omega.size()[:-1], 3, 3) + theta = omega.pow(2).sum(dim=-1, keepdim=True).add(eps).sqrt()[..., None] + omega_hat = hat(omega) + M1 = ((1 - theta.cos()) / theta.pow(2)) * (omega_hat) + M2 = ((theta - theta.sin()) / theta.pow(3)) * (omega_hat @ omega_hat) + return I + M1 + M2 diff --git a/prtm/models/chroma/structure/hbonds.py b/prtm/models/chroma/structure/hbonds.py new file mode 100644 index 00000000..5f3a4b5b --- /dev/null +++ b/prtm/models/chroma/structure/hbonds.py @@ -0,0 +1,256 @@ +# Copyright Generate Biomedicines, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Layers for annotating hydrogen bonds in protein structures. +""" + +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from prtm.models.chroma.graph import collect_neighbors +from prtm.models.chroma.structure import protein_graph +from prtm.models.chroma.structure.geometry import normed_vec + + +class BackboneHBonds(nn.Module): + """Compute hydrogen bonds from protein backbones. + + We use the simple electrostatic model for calling hydrogen + bonds of DSSP, which is described at + https://en.wikipedia.org/wiki/DSSP_(algorithm). After + placing virtual hydrogens on all backbone nitrogens, + we consider potential hydrogen bonds with carbonyl groups + on the backbone with residue distance |i-j| > 2. The + picture is: + + -0.20e +0.20e -0.42e +0.42e + [N_i]-----[H_i] ::: [O_j]=====[C_j] + + Args: + cutoff_energy (float, optional): Cutoff energy with + default value -0.5 (DSSP). + cutoff_distance (float, optional): Max distance + between `N_i` and `O_j` with default value 3.6 angstroms. + cutoff_gap (float, optional): Minimum tolerated residue + distance, i.e. `|i-j| >= cutoff_gap`. + Default value of 3. + + Inputs: + X (Tensor): Backbone coordinates with shape + `(num_batch, num_residues, num_atom_types, 3)`. + C (LongTensor): Chain map tensor with shape `(num_batch, num_residues)`. + edge_idx (LongTensor): Edge indices for neighbors with shape + `(num_batch, num_residues, num_neighbors)`. + mask_ij (Tensor): Edge mask with shape + `(num_batch, num_nodes, num_neighbors)`. + + Outputs: + hbonds (Tensor): Binary matrix annotating backbone hydrogen bonds + with shape `(num_batch, num_nodes, num_neighbors)`. + mask_hb_ij (Tensor): Hydrogen bond mask with shape + `(num_batch, num_nodes, num_neighbors)`. + H_i (Tensor): Virtual hydrogen coordinates with shape + `(num_batch, num_nodes, 3)`. + """ + + def __init__( + self, + cutoff_energy: float = -0.5, + cutoff_distance: float = 3.6, + cutoff_gap: float = 3, + distance_eps: float = 1e-3, + ) -> None: + super(BackboneHBonds, self).__init__() + self.cutoff_energy = cutoff_energy + self.cutoff_distance = cutoff_distance + self.cutoff_gap = cutoff_gap + self._coefficient = 0.42 * 0.2 * 332 + self._eps = distance_eps + + # Lishan Yao et al. JACS 2008, NMR data + self._length_NH = 1.015 + return + + def forward( + self, + X: torch.Tensor, + C: torch.LongTensor, + edge_idx: torch.LongTensor, + mask_ij: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + num_batch, num_residues, _, _ = X.shape + # Collect coordinates at i and j + X_flat = X.reshape([num_batch, num_residues, -1]) + X_j_flat = collect_neighbors(X_flat, edge_idx) + X_j = X_j_flat.reshape([num_batch, num_residues, -1, 4, 3]) + + # Get amide [N-H] atoms at i by + # by placing virtual H from C_{i-1}-N-Ca neg bisector + X_prev = F.pad(X, [0, 0, 0, 0, 1, 0], mode="replicate")[:, :-1, :, :] + C_prev_i = X_prev[:, :, 2, :] + N_i = X[:, :, 0, :] + Ca_i = X[:, :, 1, :] + u_CprevN_i = normed_vec(N_i - C_prev_i) + u_CaN_i = normed_vec(N_i - Ca_i) + u_NH_i = normed_vec(u_CprevN_i + u_CaN_i) + H_i = N_i + self._length_NH * u_NH_i + # Add broadcasting dimensions + N_i = N_i[:, :, None, :] + H_i = H_i[:, :, None, :] + + # Get carbonyl [C=O] atoms at j + O_j = X_j[:, :, :, 3, :] + C_j = X_j[:, :, :, 2, :] + + _invD = ( + lambda Xi, Xj: (Xi - Xj).square().sum(-1).add(self._eps).sqrt().reciprocal() + ) + U_ij = self._coefficient * ( + _invD(N_i, O_j) - _invD(N_i, C_j) + _invD(H_i, C_j) - _invD(H_i, O_j) + ) + + # Mask any bonds exceeding donor/acceptor cutoff distance + D_nonhydrogen = (N_i - O_j).square().sum(-1).add(self._eps).sqrt() + mask_ij_cutoff_D = (D_nonhydrogen < self.cutoff_distance).float() + + # Mask hbonds on same chain with |i-j| < gap_cutoff + mask_ij_nonlocal = 1.0 - _locality_mask(C, edge_idx, cutoff=self.cutoff_gap) + + # Ignore N terminal hydrogen bonding because of ambiguous hydrogen placement + C_prev = F.pad(C, [1, 0], "constant")[:, 1:] + mask_i = ((C > 0) * (C == C_prev)).float() + mask_j = collect_neighbors(C[..., None], edge_idx)[..., 0] + mask_ij_internal = mask_i[..., None] * (mask_j > 0).float() + + mask_hb_ij = mask_ij * mask_ij_nonlocal * mask_ij_cutoff_D * mask_ij_internal + + # Call hydrogen bonds + hbonds = mask_hb_ij * (U_ij < self.cutoff_energy).float() + return hbonds, mask_hb_ij, H_i + + +class LossBackboneHBonds(nn.Module): + """Score hydrogen bond recovery from protein backbones. + + Args: + See `BackboneHBonds`. + + Inputs: + X (Tensor): Backbone coordinates to score with shape + `(num_batch, num_residues, 4, 3)`. + X_target (Tensor): Reference coordinates to compare to with shape + `(num_batch, num_residues, 4, 3)`. + C (LongTensor): Chain map tensor with shape `(num_batch, num_residues)`. + + Outputs: + recovery_local (Tensor): Local hydrogen bond recovery with shape + `(num_batch)`. + recovery_nonlocal (Tensor): Nonlocal hydrogen bond recovery with shape + `(num_batch)`. + error_co (Tensor): Absolute error in terms of contact order recovery + """ + + def __init__( + self, + cutoff_local: float = 8, + cutoff_energy: float = -0.5, + cutoff_distance: float = 3.6, + cutoff_gap: float = 3, + distance_eps: float = 1e-3, + num_neighbors: int = 30, + ) -> None: + super(LossBackboneHBonds, self).__init__() + self.cutoff_local = cutoff_local + self.cutoff_energy = cutoff_energy + self.cutoff_distance = cutoff_distance + self.cutoff_gap = cutoff_gap + self._eps = 1e-3 + + self.graph_builder = protein_graph.ProteinGraph(num_neighbors=num_neighbors) + self.hbonds = BackboneHBonds( + cutoff_energy=cutoff_energy, + cutoff_distance=cutoff_distance, + cutoff_gap=cutoff_gap, + ) + + def forward( + self, + X: torch.Tensor, + X_target: torch.Tensor, + C: torch.LongTensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # Build Graph + edge_idx, mask_ij = self.graph_builder(X_target, C) + hb_target, mask_hb, H_i = self.hbonds(X_target, C, edge_idx, mask_ij) + hb_current, _, _ = self.hbonds(X, C, edge_idx, mask_ij) + + # Split into local and long range hbonds + mask_local = _locality_mask(C, edge_idx, cutoff=self.cutoff_local) + hb_target_local = mask_local * hb_target + hb_target_nonlocal = (1 - mask_local) * hb_target + + # Compute per complex + recovery_local = (hb_current * hb_target_local).sum([1, 2]) / ( + hb_target_local.sum([1, 2]) + self._eps + ) + recovery_nonlocal = (hb_current * hb_target_nonlocal).sum([1, 2]) / ( + hb_target_nonlocal.sum([1, 2]) + self._eps + ) + + # Compute contact order + co_target = _contact_order(hb_target, C, edge_idx) + co_current = _contact_order(hb_current, C, edge_idx) + + error_co = (co_target - co_current).abs() + return recovery_local, recovery_nonlocal, error_co + + +def _ij_distance( + C: torch.LongTensor, + edge_idx: torch.LongTensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + C_i = C[..., None] + C_j = collect_neighbors(C_i, edge_idx)[..., 0] + ix = torch.arange(C.shape[1], device=C.device)[None, :, None].expand( + C.shape[0], -1, -1 + ) + jx = collect_neighbors(ix, edge_idx)[..., 0] + dij = (jx - ix).abs() + mask_same_chain = C_i.eq(C_j).float() + return dij, mask_same_chain + + +def _contact_order( + contacts: torch.Tensor, + C: torch.LongTensor, + edge_idx: torch.LongTensor, + eps: float = 1e-3, +) -> torch.Tensor: + """Compute contact order""" + dij, mask_same_chain = _ij_distance(C, edge_idx) + mask_ij = mask_same_chain * contacts + CO = (mask_ij * dij).sum([1, 2]) / (mask_ij + eps).sum([1, 2]) + return CO + + +def _locality_mask( + C: torch.LongTensor, + edge_idx: torch.LongTensor, + cutoff: float, +) -> torch.Tensor: + dij, mask_same_chain = _ij_distance(C, edge_idx) + mask_ij_local = ((dij < cutoff) * mask_same_chain).float() + return mask_ij_local diff --git a/prtm/models/chroma/structure/mvn.py b/prtm/models/chroma/structure/mvn.py new file mode 100644 index 00000000..d15e018a --- /dev/null +++ b/prtm/models/chroma/structure/mvn.py @@ -0,0 +1,886 @@ +# Copyright Generate Biomedicines, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Layers for multivariate normal models of protein structure. + +This module contains pytorch layers for perturbing protein structure with noise, +which can be useful both for data augmentation, benchmarking, or denoising based +training. +""" + +import math +from typing import Optional + +import numpy as np +import torch +import torch.nn.functional as F +from prtm.models.chroma import conv +from prtm.models.chroma.structure import backbone + + +class BackboneMVNGlobular(torch.nn.Module): + """ + Gaussian model for protein backbones. + """ + + def __init__( + self, + covariance_model="brownian", + complex_scaling=False, + sigma_translation=1.0, + **kwargs, + ): + super().__init__() + + # These constant was derived from fitting uniform phi,psi chains + self._scale = 1.5587407701549267 + + # These parameterize the scaling law, *per xyz dimension* + # Rg = Rg0 * N_atoms ^ nu + self._nu = 2.0 / 5.0 + self._rg_0_1D = 2.0 / 3.0 + + # Exact solution for Rg0, agrees with above to 2 decimals. + # We divide the literature prefactor for a per-residue + # scaling law (John J Tanner 2016) by two terms: + # 1. a conversion factor from residues to atoms (4^nu) + # 2. the sqrt(3) to account for isotropic dimensional (xyz) Rg contributions + # self._rg_0_1D = 2.0 / (4 ** self._nu * np.sqrt(3)) + + self.covariance_model = covariance_model + self.complex_scaling = complex_scaling + self.sigma_translation = sigma_translation + + def _atomic_mean(self, X_flat, mask): + """Compute the mean across all 4 atom types by mask expansion""" + mask_expand = mask.unsqueeze(-1).expand(-1, -1, 4) + mask_atomic = mask_expand.reshape(mask.shape[0], -1).unsqueeze(-1) + X_mean = torch.sum(mask_atomic * X_flat, 1, keepdims=True) / ( + torch.sum(mask_atomic, 1, keepdims=True) + ) + return X_mean, mask_atomic + + def _C_atomic(self, C): + # Expand chain map into atomic level masking + C_expand = C.unsqueeze(-1).expand(-1, -1, 4) + C_atomic = C_expand.reshape(C.shape[0], -1) + return C_atomic + + def _globular_parameters(self, C_mask_all, translation_inflation=None): + """Compute parameters for enforcing Rg scaling""" + # Rg scaling constants + nu = self._nu + a = self._scale + r = self._rg_0_1D + + # C_mask_all is () + # Monomer and complex sizes (batch, {chains}) + C_mask = C_mask_all.squeeze(-1) + N_per_chain = C_mask.sum(1) + N_per_complex = C_mask.sum([1, 2]) + + # Compute expected Rg^2 values per complex + Rg2_complex = (r**2) * N_per_complex ** (2.0 * nu) + Rg2_chain = (r**2) * N_per_chain ** (2.0 * nu) + + # Compute OU process parameters + N_per_chain = torch.clip(N_per_chain, 1, 1e6) + + # Decay parameter B is related to global spring coefficient + # as k = (1-B)^2 + B = (3.0 / N_per_chain) + N_per_chain ** (-nu) * torch.sqrt( + N_per_chain ** (2 * (nu - 1)) * (N_per_chain**2 + 9) - (a / r) ** 2 + ) + B = torch.clip(B, 1e-4, 1.0 - 1e-4) + + # OU process equilibrium standard deviation warm-starts process + x_init_std = torch.sqrt(1.0 / (1.0 - B**2)) + + # Compute size-weighted average Rg^2 per chain + Rg2_chain_avg = (N_per_chain * Rg2_chain).sum(1) / (N_per_chain.sum(1) + 1e-5) + Rg2_centers_of_mass = torch.clip(Rg2_complex - Rg2_chain_avg, min=1) + Rg_centers_of_mass = torch.sqrt(Rg2_centers_of_mass) + + # Mean scaling parameter deflates equilibrium variance to unity and + # optionally re-inflates the per center of mass variance to implement + # complex scaling + if translation_inflation is not None: + # This argument overrides default per-chain translational scaling + # to [translation_inflation * Complex Rg] + marginal_COM_std = translation_inflation * Rg2_complex.sqrt() + mean_shift_scale = (x_init_std - marginal_COM_std[..., None]) / (x_init_std) + elif self.complex_scaling: + N_chains_per_complex = (C_mask.sum(1) > 0).sum(1) + # Correct for the fact that we are sampling chains IID (not + # centered) but want to control centered Rg + std_correction = torch.sqrt( + N_chains_per_complex / (N_chains_per_complex - 1).clamp(min=1) + ) + marginal_COM_std = std_correction * Rg_centers_of_mass + mean_shift_scale = (x_init_std - marginal_COM_std[..., None]) / (x_init_std) + else: + mean_shift_scale = (x_init_std - 1.0) / (x_init_std) + + return B[..., None], x_init_std[..., None], mean_shift_scale[..., None] + + def _expand_masks(self, C): + C_atomic = self._C_atomic(C) + C_mask_all = backbone.expand_chain_map(torch.abs(C_atomic))[..., None] + C_mask_present = backbone.expand_chain_map(C_atomic)[..., None] + return C_mask_all, C_mask_present + + def _expand_per_chain(self, Z, C): + """Build augmented [num_batch, 4*num_residues, num_chains, 3] system""" + # Build masks and augmented [B,4N,C,3] system + C_mask_all, C_mask_present = self._expand_masks(C) + Z_expand = C_mask_all * Z[..., None, :] + return C_mask_all, C_mask_present, Z_expand + + def _shift_means(self, X_expand, C_mask_mean, C_mask_apply, scale, shift=None): + """Inflate or deflate per-chain means by a scale factor.""" + X_chain_mean = (C_mask_mean * X_expand).sum(1, keepdims=True) / ( + C_mask_mean.sum(1, keepdims=True) + 1e-5 + ) + shift = shift if shift is not None else 0 + shift = shift + scale * X_chain_mean + X_expand = C_mask_apply * (X_expand + shift) + return X_expand + + def _translate_by_x1(self, X_expand, C_mask, scale_mean, scale_x1): + """Shift mean to mean <- mean + scale_mean * mean + scale_x1 * x1""" + X_1 = self._gather_chain_init(X_expand, C_mask) + X_expand = self._shift_means( + X_expand, C_mask, C_mask, scale=scale_mean, shift=X_1 * scale_x1 + ) + return X_expand + + def _translate_by_x1_transpose(self, X_expand, C_mask, scale_mean, scale_x1): + """Transpose of _translate_by_x1.""" + + # Shift mean (Symmetric under transpose) + X_chain_sum = (C_mask * X_expand).sum(1, keepdims=True) + X_chain_mean = X_chain_sum / (C_mask.sum(1, keepdims=True) + 1e-5) + X_expand = C_mask * (X_expand + scale_mean * X_chain_mean) + + # Update to X_init + # The transpose of updating all by X_init is updating X_init by all + first_index = torch.max(C_mask, 1, keepdim=True)[1] + first_index_expand = first_index.expand(-1, -1, -1, 3) + X_init = torch.gather(X_expand, 1, first_index_expand) + X_init_update = X_init + scale_x1 * X_chain_sum + X_expand = X_expand.scatter(1, first_index_expand, X_init_update) + return X_expand + + def _gather_chain_init(self, X_expand, C_mask): + """Extract first coordinates, per chain""" + first_index = torch.max(C_mask, 1, keepdim=True)[1] + first_index_expand = first_index.expand(-1, -1, -1, 3) + X_init = torch.gather(X_expand, 1, first_index_expand) + return X_init + + def _multiply_R(self, Z, C): + """Multiply by the square root of the covariance matrix""" + if Z.dim() == 4: + Z = Z.reshape(Z.shape[0], -1, 3) + + C_mask_all, C_mask_present, Z_expand = self._expand_per_chain(Z, C) + + if self.covariance_model == "brownian": + # Step 1. Scaled cumsum along each chain (including missing residues) + # [B,4N,3] -> [B,4N,C,3] + R_Z_expand = C_mask_all * torch.cumsum(Z_expand, 1) * self._scale + + # Step 2. Translate by rescaled X_1 + R_Z_expand = self._translate_by_x1( + R_Z_expand, C_mask_all, scale_mean=-1, scale_x1=self.sigma_translation + ) + elif self.covariance_model == "globular": + # Build coefficients per chain as as [B,C,1] + B, x_init_std, mean_shift_scale = self._globular_parameters(C_mask_all) + + # Step 1. R_init + # Scale z_1 to have equilibrium variance + # z_1 will be the position where (1 - mask_{i-1}) = mask_i + C_mask_prev = F.pad(C_mask_all[:, :-1, :, :], (0, 0, 0, 0, 1, 0)) + mask_init = (1.0 - C_mask_prev) * C_mask_all + # Inflate z_1 by the equilibrium variance + Z_expand = (1.0 - mask_init) * Z_expand + mask_init * x_init_std[ + :, None, ... + ] * Z_expand + + # Step 2. R_sum + # Apply linear recurrence `x_i = z_i + b * x_{i-1}` + # Repack all independent signals and coeffs (B,C,3) in 1D + num_B, num_N, num_C, _ = Z_expand.shape + # [B,4N,C,3] => [B,C,3,4N] => [BC3, 4N] + Z_1D = Z_expand.permute([0, 2, 3, 1]).reshape([-1, num_N]) + # [B,C,1] => [BC,1] => [BC,3] => [BC3] + B_1D = B.reshape([-1, 1]).expand([-1, 3]).reshape([-1]) + R_Z_1D = self._scale * conv.filter1D_linear_decay(Z_1D, B_1D) + # [BC3,4N] -> [B,C,3,4N] -> [B,4N,C,3] + R_Z_expand = R_Z_1D.reshape([num_B, num_C, 3, num_N]).permute([0, 3, 1, 2]) + R_Z_expand = C_mask_all * R_Z_expand + + # Step 3. R_center + # Rescale translational variance + scale = -mean_shift_scale[:, None, ...] + R_Z_expand = self._shift_means( + R_Z_expand, C_mask_all, C_mask_all, scale=scale + ) + + # Collapse out chain dimension + R_Z = R_Z_expand.sum(2).reshape(Z.shape[0], -1, 4, 3) + return R_Z + + def _multiply_R_transpose(self, Z, C): + """Multiply by the square root of the covariance matrix (transpose)""" + if Z.dim() == 4: + Z = Z.reshape(Z.shape[0], -1, 3) + + # Inflate chain dimension [B,4N,C,3] + C_mask_all, C_mask_present, Z_expand = self._expand_per_chain(Z, C) + + if self.covariance_model == "brownian": + # Step 2. [Transpose of] Translate by rescaled X_1 + Z_expand = self._translate_by_x1_transpose( + Z_expand, C_mask_all, scale_mean=-1, scale_x1=self.sigma_translation + ) + + # Step 1. [Transpose of] Scaled cumsum along each chain + Rt_Z_expand = torch.flip(torch.cumsum(torch.flip(Z_expand, [1]), 1), [1]) + Rt_Z_expand = C_mask_all * Rt_Z_expand * self._scale + + elif self.covariance_model == "globular": + # Build coefficients per chain as as [B,C,1] + B, x_init_std, mean_shift_scale = self._globular_parameters(C_mask_all) + Rt_Z_expand = Z_expand + + # Step 3. R_center_transpose = R_center (by symmetry) + scale = -mean_shift_scale[:, None, ...] + Rt_Z_expand = self._shift_means( + Rt_Z_expand, C_mask_all, C_mask_all, scale=scale + ) + + # Step 2. R_sum_transpose = R_sum @ R_flip + # Apply linear recurrence `x_i = z_i + b * x_{i-1}` + # Repack all independent signals and coeffs (B,C,3) in 1D + num_B, num_N, num_C, _ = Rt_Z_expand.shape + # [B,4N,C,3] => [B,C,3,4N] => [BC3, 4N] + Z_1D = Rt_Z_expand.permute([0, 2, 3, 1]).reshape([-1, num_N]) + Z_1D_reverse = torch.flip(Z_1D, [1]) + # [B,C,1] => [BC,1] => [BC,3] => [BC3] + B_1D = B.reshape([-1, 1]).expand([-1, 3]).reshape([-1]) + Rt_Z_1D_reverse = self._scale * conv.filter1D_linear_decay( + Z_1D_reverse, B_1D + ) + Rt_Z_1D = torch.flip(Rt_Z_1D_reverse, [1]) + # [BC3,4N] -> [B,C,3,4N] -> [B,4N,C,3] + Rt_Z_expand = Rt_Z_1D.reshape([num_B, num_C, 3, num_N]).permute( + [0, 3, 1, 2] + ) + Rt_Z_expand = C_mask_all * Rt_Z_expand + + # Step 1. R_init_transpose = R_init (by symmetry) + # Scale z_1 to have equilibrium variance + # z_1 will be the position where (1 - mask_{i-1}) = mask_i + C_mask_prev = F.pad(C_mask_all[:, :-1, :, :], (0, 0, 0, 0, 1, 0)) + mask_init = (1.0 - C_mask_prev) * C_mask_all + # Inflate z_1 by the equilibrium variance + Rt_Z_expand = (1.0 - mask_init) * Rt_Z_expand + mask_init * x_init_std[ + :, None, ... + ] * Rt_Z_expand + + # Collapse out chain dimension + Rt_Z = Rt_Z_expand.sum(2).reshape(Z.shape[0], -1, 4, 3) + return Rt_Z + + def _multiply_R_inverse(self, X, C): + """Multiply by the inverse of the square root of the covariance matrix""" + if X.dim() == 4: + X = X.reshape(X.shape[0], -1, 3) + + # Inflate chain dimension [B,4N,C,3] + C_mask_all, C_mask_present, X_expand = self._expand_per_chain(X, C) + + if self.covariance_model == "brownian": + # Step 2. [Inverse of] Translate by rescaled X_1 + X_expand = self._translate_by_x1( + X_expand, C_mask_all, scale_mean=1 / self.sigma_translation, scale_x1=-1 + ) + + # Step 1. [Inverse of] Scaled cumsum per chain [X_i - X_(i-1)] + Ri_X_expand = X_expand - F.pad(X_expand[:, :-1, :, :], (0, 0, 0, 0, 1, 0)) + Ri_X_expand = C_mask_all * Ri_X_expand / self._scale + + elif self.covariance_model == "globular": + # Build coefficients per chain as as [B,C,1] + B, x_init_std, mean_shift_scale = self._globular_parameters(C_mask_all) + + # Step 3. R_center_inverse + # Rescale translational variance + mean_shift_scale_inverse = mean_shift_scale / (1 - mean_shift_scale) + scale = mean_shift_scale_inverse[:, None, ...] + X_expand = self._shift_means(X_expand, C_mask_all, C_mask_all, scale=scale) + + # Step 2. R_sum_inverse + # Apply linear recurrence `x_i = z_i + b * x_{i-1}` + X_prev = F.pad(X_expand[:, :-1, :, :], (0, 0, 0, 0, 1, 0)) + Ri_X_expand = ( + C_mask_all * (X_expand - B[:, None, ...] * X_prev) / self._scale + ) + + # Step 1. R_init_inverse + # Scale z_1 to have equilibrium variance + # z_1 will be the position where (1 - mask_{i-1}) = mask_i + C_mask_prev = F.pad(C_mask_all[:, :-1, :, :], (0, 0, 0, 0, 1, 0)) + mask_init = (1.0 - C_mask_prev) * C_mask_all + Ri_X_expand = ( + 1.0 - mask_init + ) * Ri_X_expand + mask_init * Ri_X_expand / x_init_std[:, None, ...] + + # Collapse out chain dimension + Ri_X = Ri_X_expand.sum(2).reshape(X.shape[0], -1, 4, 3) + return Ri_X + + def _multiply_R_inverse_transpose(self, X, C): + """Multiply by the inverse trasnpose of the square root of the + covariance matrix + """ + if X.dim() == 4: + X = X.reshape(X.shape[0], -1, 3) + + C_mask_all, C_mask_present, X_expand = self._expand_per_chain(X, C) + + if self.covariance_model == "brownian": + # Step 1. [Inverse transpose of] Scaled cumsum per chain [X_i - X_(i+1)] + Rit_X_expand = X_expand - F.pad(X_expand[:, 1:, :, :], (0, 0, 0, 0, 0, 1)) + Rit_X_expand = C_mask_all * Rit_X_expand / self._scale + + # Step 2. [Inverse transpose of] Translate by rescaled X_1 + Rit_X_expand = self._translate_by_x1_transpose( + Rit_X_expand, + C_mask_all, + scale_mean=1 / self.sigma_translation, + scale_x1=-1, + ) + elif self.covariance_model == "globular": + # Build coefficients per chain as as [B,C,1] + B, x_init_std, mean_shift_scale = self._globular_parameters(C_mask_all) + Rit_X_expand = X_expand + + # Step 1. R_init_inverse_transpose = R_init_inverse (by symmetry) + # Scale z_1 to have equilibrium variance + # z_1 will be the position where (1 - mask_{i-1}) = mask_i + C_mask_prev = F.pad(C_mask_all[:, :-1, :, :], (0, 0, 0, 0, 1, 0)) + mask_init = (1.0 - C_mask_prev) * C_mask_all + Rit_X_expand = ( + 1.0 - mask_init + ) * Rit_X_expand + mask_init * Rit_X_expand / x_init_std[:, None, ...] + + # Step 2. R_sum_inverse_transpose + # Apply linear recurrence `x_i = z_i + b * x_{i-1}` + X_future = F.pad(Rit_X_expand[:, 1:, :, :], (0, 0, 0, 0, 0, 1)) + Rit_X_expand = ( + C_mask_all * (Rit_X_expand - B[:, None, ...] * X_future) / self._scale + ) + + # Step 3. R_center_inverse_transpose = R_center_inverse (by symmetry) + # Rescale translational variance + mean_shift_scale_inverse = mean_shift_scale / (1 - mean_shift_scale) + scale = mean_shift_scale_inverse[:, None, ...] + Rit_X_expand = self._shift_means( + Rit_X_expand, C_mask_all, C_mask_all, scale=scale + ) + + Rit_X = Rit_X_expand.sum(2).reshape(X.shape[0], -1, 4, 3) + return Rit_X + + def multiply_covariance(self, dX, C): + """Multiply by the covariance matrix. + + Args: + dX (Tensor): Backbone tensor with dimensions + `(num_batch, num_residues, 4, 3)`. + (Note: this will typically be a gradient or direction vector, + such as the score function. Not absolute coordinates). + C (Tensor): Chain map with dimensions + + returns: + C_dX (Tensor): The matrix-vector product resulting from + left-multiplying by the covariance matrix. + """ + # Covariance C = G @ G.T + dX_flat = dX.reshape([dX.shape[0], -1, 3]) + Rt_dX = self._multiply_R_transpose(dX_flat, C) + C_dX = self._multiply_R(Rt_dX, C) + C_dX = C_dX.reshape(dX.shape) + return C_dX + + def multiply_inverse_covariance(self, dX, C): + """Multiply by the inverse covariance matrix. + + Args: + dX (Tensor): Backbone tensor with dimensions + `(num_batch, num_residues, 4, 3)`. + C (Tensor): Chain map with dimensions + + returns: + Ci_dX (Tensor): The matrix-vector product resulting from + left-multiplying by the inverse covariance matrix. + """ + # Covariance C = G @ G.T + dX_flat = dX.reshape([dX.shape[0], -1, 3]) + Ri_dX = self._multiply_R_inverse(dX_flat, C) + Ci_dX = self._multiply_R_inverse_transpose(Ri_dX, C) + Ci_dX = Ci_dX.reshape(dX.shape) + return Ci_dX + + def log_determinant(self, C): + """Compute log determinant of the covariance matrix""" + C_mask_all, C_mask_present = self._expand_masks(C) + + B, x_init_std, xi = self._globular_parameters(C_mask_all) + a = self._scale + B = B[..., 0] + xi = xi[..., 0] + + # Compute determinants per chain + N_chain = C_mask_all.sum([1, 3]) + logdet_chain = ( + N_chain * np.log(a) + torch.log(1.0 - xi) - 0.5 * torch.log(1.0 - B**2) + ) + + # We pick up one determinant per chain per spatial dimension (xyz) + logdet = 3.0 * logdet_chain.sum(-1) + return logdet + + def log_prob(self, X: torch.Tensor, C: torch.Tensor) -> torch.Tensor: + """ + Compute log probability for Backbone MVN as follows: + + term1 = -n/2 log(2π) + term2 = -1/2 log|Σ| + term3 = -1/2 ∑_{i=1}^{n} (x_i - μ)^T Σ^-1 (x_i - μ) + logP = term1 + term2 + term3 + + Args: + X (torch.Tensor): of size (batch, num_residues, 4, 3) + C (torch.Tensor): of size (batch, num_residues) + + Returns: + logp (torch.Tensor): of size (batch,) + """ + term1 = -(C.shape[1] * 4 * 3) / 2 * np.log(2 * np.pi) + term2 = -1 / 2 * self.log_determinant(C) + term3 = -1 / 2 * (X * self.multiply_inverse_covariance(X, C)).sum([1, 2, 3]) + logp = term1 + term2 + term3 + + return logp + + def sample( + self, + C: torch.Tensor, + ddX: Optional[torch.Tensor] = None, + Z: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Draws samples from the MVN. + + Args: + C: (torch.Tensor): specifying the shape of the samples. + ddX (torch.Tensor, optional): Optionally can specify a shift ddX which will be transformed (R^t * ddX) and used to shift Z. + Z (torch.Tensor, optional): Optionally can specify random normal samples to transform into samples from the backbone MVN. + + Returns: + X (torch.Tensor): of size (C.size(0), C.size(1), 4, 3) samples from the MVN. + """ + num_batch = C.shape[0] + num_residues = C.shape[1] + + if Z is None: + Z = torch.randn([num_batch, num_residues * 4, 3], device=C.device) + if ddX is None: + X_flat = self._multiply_R(Z, C) + else: + RtddX = self._multiply_R_transpose(ddX, C) + X_flat = self._multiply_R(Z + RtddX, C) + + X = X_flat.reshape([num_batch, num_residues, 4, 3]) + return X + + +class ConditionalBackboneMVNGlobular(BackboneMVNGlobular): + """ + A conditional MVN distribution where some subset of the atomic coordinates are known. + Args: + covariance_model (str): Specifying the covariance_model of the base distribution (which respect to which we are conditioning). + complex_scaling (bool): Specifying the complex_scaling of the base distribution (which respect to which we are conditioning). + sigma_translation (float): Specifying the sigma_translation of the base distribution (which respect to which we are conditioning). + X (torch.Tensor): of size (1, num_residues, 4, 3) containing atomic coordinates + C (torch.Tensor): of size (1, num_residues) specifying chain specification + D (torch.Tensor): of size (1, num_residues) containing 1s or 0s, (castable as Byte or Bool) where 1 indicates a residue's structural information + is to be conditioned on. + gamma (float): This inflates variance of the center of mass of the samples generated by the CMVN. + """ + + def __init__( + self, + covariance_model: str = "brownian", + complex_scaling: bool = False, + sigma_translation: float = 1.0, + X: Optional[torch.Tensor] = None, + C: Optional[torch.Tensor] = None, + D: Optional[torch.Tensor] = None, + gamma: Optional[float] = None, + **kwargs, + ): + super().__init__(covariance_model, complex_scaling, sigma_translation, **kwargs) + assert D.shape[0] == 1 + assert C.shape[0] == 1 + assert X.shape[0] == 1 + self.gamma = gamma + self.register_buffer("X", X) + self.register_buffer("C", C) + self.register_buffer("D", D.float()) + + self._check_C(self.C.abs()) + R, RRt = self._materialize_RRt(self.C) + self.register_buffer("R", R) + self.register_buffer("RRt", RRt) + + R_clamp, RRt_clamp = self._condition_RRt(self.RRt, self.D) + self.register_buffer("R_clamp", R_clamp) + self.register_buffer("RRt_clamp", RRt_clamp) + + R_clamp_inverse = torch.linalg.pinv(self.R_clamp) + self.register_buffer("R_clamp_inverse", R_clamp_inverse) + self.register_buffer("mu_sample", self.mu(X)) + + def _center_of_mass(self, X, C): + mask_expand = ( + (C > 0).float().reshape(list(C.shape) + [1, 1]).expand([-1, -1, 4, -1]) + ) + X_mean = (mask_expand * X).sum([1, 2], keepdims=True) / ( + mask_expand.sum([1, 2], keepdims=True) + ) + return X_mean + + def mu(self, X: torch.Tensor): + """ + Returns the mean of the conditional distribution obtained by conditioning on atomic coordinates specified in `X` at + the residues specified in `self.D`. + Args: + X (torch.Tensor): of size (B, num_residues, 4, 3) + + Returns: + X_mu (torch.Tensor): of size (B, num_residues, 4, 3) + """ + B, _, _, _ = X.size() + loc = self._center_of_mass(X, self.D).squeeze().reshape(B, 1, 3) + m = (self.D_atom[..., None] > 0).repeat(B, 1, 3) + X_flat = X.reshape(B, -1, 3) + X_restricted = X_flat[m].reshape(B, -1, 3) + mu = loc + (self.S12 @ torch.linalg.pinv(self.S22) @ (X_restricted - loc)) + X_mu = X_flat.scatter(1, self.zero_indices[None, ..., None].repeat(B, 1, 3), mu) + return X_mu.reshape(B, -1, 4, 3) + + def sample( + self, + num_batch: int = 1, + Z: Optional[torch.Tensor] = None, + mu_X: Optional[torch.Tensor] = None, + ): + """ + Draws samples from the conditional MVN with mean `mu_X`. + + Args: + num_batch (int): Number of samples to draw. + Z (torch.Tensor, optional): of size (batch, num_residues, 4, 3) random standard normal samples can be specified (that are transformed into samples from the CMVN) + mu_X (torch.Tensor, optional): of size (batch, num_residues, 4, 3) optionally take the mean with respect to a different `X` tensor than was used to instantiate the class. + + Returns: + samples (torch.Tensor): of size (num_batch, num_residues, 4, 3) + """ + if Z is not None: + num_batch = Z.shape[0] + C_expand = self.C.repeat(num_batch, 1, 1, 1) + mu = self.mu_sample.repeat(num_batch, 1, 1, 1) + if mu_X is not None: + mu = self.mu(mu_X) + if Z is None: + Z = torch.randn_like(mu) + return mu + self._multiply_R(Z, C_expand) + + def _scatter(self, A, index, source): + J = torch.zeros_like(A) + J[index[:, None], index[None, :]] = source + return J + + def _materialize_RRt(self, C): + """As in C.4 of `https://www.biorxiv.org/content/10.1101/2022.12.01.518682v1.full.pdf`""" + a = self._scale + bs, sl = C.size() + Z = torch.randn(bs, 4 * sl, 3).to(C.device) + C_mask_all, C_mask_present, Z_expand = self._expand_per_chain(Z, C) + + gamma = self.gamma if self.gamma is None or self.gamma > 0.0 else None + b, x_init_std, xi = self._globular_parameters( + C_mask_all, translation_inflation=gamma + ) + + C_atom = self._C_atomic(C.abs()) + R_center = self._build_R_center(C_atom, xi) + R_sum = self._build_R_sum(C_atom, b) + R_init = self._build_R_init(C_atom, b) + + R = a * R_center @ R_sum @ R_init + RRt = R @ R.t() + return R, RRt + + def _check_C(self, C): + _C = C[0][:-1] - C[0][1:] + if (_C > 0).any(): + raise ValueError("Chain map needs to be increasing in this class!") + + def _build_R_center(self, C_atom, xi): + chain_indices = C_atom.unique() + blocks = [] + for chain_index, _xi in zip(chain_indices, xi[0]): + N = C_atom[C_atom == chain_index].numel() + blocks.append( + ( + torch.eye(N, device=_xi.device) + - (_xi / N) * torch.ones(N, N, device=_xi.device) + ) + ) + return torch.block_diag(*blocks) + + def _build_R_sum(self, C_atom, b): + chain_indices = C_atom.unique() + blocks = [] + for chain_index, _b in zip(chain_indices, b[0]): + N = C_atom[C_atom == chain_index].numel() + blocks.append( + ( + _b + ** ( + torch.arange(N, device=_b.device).unsqueeze(0) + - torch.arange(N, device=_b.device).unsqueeze(-1) + ) + .tril() + .abs() + ).tril() + ) + return torch.block_diag(*blocks) + + def _build_R_init(self, C_atom, b): + indices = [(C_atom == k).float().argmax(1).item() for k in C_atom.unique()] + N = C_atom.numel() + P3 = torch.eye(N).to(C_atom.device) + for index, _b in zip(indices, b[0]): + P3.diagonal().data[index] = 1 / math.sqrt(1 - _b**2) + return P3 + + def _condition_RRt(self, RRt, D): + """ + Args: + RRt (torch.tensor): of size (N x N) the original full covariance + D (torch.tensor): of dtype float and size (1xN) containing 1.0 for known indices else 0.0. + """ + self.register_buffer("D_atom", self._C_atomic(D)) + self.register_buffer("zero_indices", torch.nonzero((1 - self.D_atom[0]))[:, 0]) + self.register_buffer("nonzero_indices", torch.nonzero(self.D_atom[0])[:, 0]) + + self.register_buffer("S11", RRt[self.zero_indices][:, self.zero_indices]) + self.register_buffer("S12", RRt[self.zero_indices][:, self.nonzero_indices]) + self.register_buffer("S21", RRt[self.nonzero_indices][:, self.zero_indices]) + self.register_buffer("S22", RRt[self.nonzero_indices][:, self.nonzero_indices]) + + S_clamp = self.S11 - ((self.S12 @ torch.linalg.pinv(self.S22) @ self.S21)) + R_clamp = torch.linalg.cholesky(S_clamp) + self.register_buffer("RRt_clamp_restricted", R_clamp @ R_clamp.t()) + RRt_clamp = self._scatter( + torch.zeros_like(RRt), self.zero_indices, self.RRt_clamp_restricted + ) + R_clamp = self._scatter(torch.zeros_like(RRt), self.zero_indices, R_clamp) + return R_clamp, RRt_clamp + + def _multiply_R(self, Z, C): + Z_flat = Z.reshape([Z.shape[0], -1, 3]) + return (self.R_clamp @ Z_flat).reshape(Z.shape) + + def _multiply_R_transpose(self, Z, C): + Z_flat = Z.reshape([Z.shape[0], -1, 3]) + return (self.R_clamp.t() @ Z_flat).reshape(Z.shape) + + def _multiply_R_inverse(self, X, C): + X_flat = X.reshape([X.shape[0], -1, 3]) + return (self.R_clamp_inverse @ X_flat).reshape(X.shape) + + def _multiply_R_inverse_transpose(self, X, C): + X_flat = X.reshape([X.shape[0], -1, 3]) + return (self.R_clamp_inverse.t() @ X_flat).reshape(X.shape) + + def multiply_covariance(self, dX, C): + dX_flat = dX.reshape([dX.shape[0], -1, 3]) + return (self.RRt_clamp @ dX_flat).reshape(dX.shape) + + def multiply_inverse_covariance(self, dX, C): + dX_flat = dX.reshape([dX.shape[0], -1, 3]) + return (self.RRt_clamp_inverse @ dX_flat).reshape(dX.shape) + + +class BackboneMVNResidueGas(torch.nn.Module): + """ + Gaussian model for protein backbones. + """ + + def __init__(self, stddev_CA=10.0, stddev_atoms=1.0, **kwargs): + super().__init__() + self.stddev_CA = stddev_CA + self.stddev_atoms = stddev_atoms + + # The full R matrix factorizes into a block diagonal of 4x4 matrices + s1 = stddev_CA + s2 = stddev_atoms + # Atoms are N-CA-C=O + R_local = torch.tensor( + [[s2, s1, 0, 0], [0, s1, 0, 0], [0, s1, s2, 0], [0, s1, 0, s2]] + ).float() + self.register_buffer("R_local", R_local) + self.register_buffer("Ri_local", torch.linalg.inv(R_local).detach()) + + def _unflatten(self, Z): + if len(Z.shape) == 3: + num_batch, num_atoms, _ = Z.shape + num_residues = num_atoms // 4 + Z_unflat = Z.reshape([num_batch, num_residues, 4, 3]) + return Z_unflat + else: + return Z + + def _multiply_R(self, Z, C): + """Multiply by the square root of the covariance matrix""" + Z_unflat = self._unflatten(Z) + R_Z_unflat = torch.einsum("biax,ca->bicx", Z_unflat, self.R_local) + R_Z = R_Z_unflat.reshape(Z.shape) + return R_Z + + def _multiply_R_transpose(self, Z, C): + """Multiply by the square root of the covariance matrix (transpose)""" + Z_unflat = self._unflatten(Z) + Rt_Z_unflat = torch.einsum("biax,ac->bicx", Z_unflat, self.R_local) + Rt_Z = Rt_Z_unflat.reshape(Z.shape) + return Rt_Z + + def _multiply_R_inverse(self, X, C): + """Multiply by the inverse of the square root of the covariance matrix""" + X_unflat = self._unflatten(X) + Ri_X = torch.einsum("biax,ca->bicx", X_unflat, self.Ri_local) + return Ri_X.reshape(X.shape) + + def _multiply_R_inverse_transpose(self, X, C): + """Multiply by the inverse trasnpose of the square root of the + covariance matrix + """ + X_unflat = self._unflatten(X) + Rit_X = torch.einsum("biax,ac->bicx", X_unflat, self.Ri_local) + return Rit_X.reshape(X.shape) + + def multiply_covariance(self, dX, C): + """Multiply by the covariance matrix. + + Args: + dX (Tensor): Backbone tensor with dimensions + `(num_batch, num_residues, 4, 3)`. + (Note: this will typically be a gradient or direction vector, + such as the score function. Not absolute coordinates). + C (Tensor): Chain map with dimensions + + returns: + C_dX (Tensor): The matrix-vector product resulting from + left-multiplying by the covariance matrix. + """ + # Covariance C = G @ G.T + dX_flat = dX.reshape([dX.shape[0], -1, 3]) + Rt_dX = self._multiply_R_transpose(dX_flat, C) + C_dX = self._multiply_R(Rt_dX, C) + C_dX = C_dX.reshape(dX.shape) + return C_dX + + def multiply_inverse_covariance(self, dX, C): + """Multiply by the inverse covariance matrix. + + Args: + dX (Tensor): Backbone tensor with dimensions + `(num_batch, num_residues, 4, 3)`. + C (Tensor): Chain map with dimensions + + returns: + Ci_dX (Tensor): The matrix-vector product resulting from + left-multiplying by the inverse covariance matrix. + """ + # Covariance C = G @ G.T + dX_flat = dX.reshape([dX.shape[0], -1, 3]) + Ri_dX = self._multiply_R_inverse(dX_flat, C) + Ci_dX = self._multiply_R_inverse_transpose(Ri_dX, C) + Ci_dX = Ci_dX.reshape(dX.shape) + return Ci_dX + + def log_determinant(self, C): + """Compute log determinant of the covariance matrix""" + log_s1 = np.log(self.stddev_CA) + log_s2 = np.log(self.stddev_atoms) + num_residues = C.ne(0).float().sum(1) + """ We have + det([s2,s1, 0, 0], + [0, s1, 0, 0], + [0, s1,s2, 0], + [0, s1, 0,s2]) + = + det([s1, 0, 0, 0], + [s1, s2, 0, 0], + [s1, 0,s2, 0], + [s1, 0, 0,s2]) + = s1 * s2^3 + And we pick up one determinant per residue per xyz dimension + """ + logdet = 3 * num_residues * (log_s1 + 3.0 * log_s2) + return logdet + + def log_prob( + self, X: torch.Tensor, C: torch.Tensor, *, normalized: bool = False + ) -> torch.Tensor: + raise NotImplementedError + + def sample( + self, + C: torch.Tensor, + ddX: Optional[torch.Tensor] = None, + Z: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Sample from the Gaussian.""" + num_batch = C.shape[0] + num_residues = C.shape[1] + + if Z is None: + Z = torch.randn([num_batch, num_residues * 4, 3], device=C.device) + if ddX is None: + X_flat = self._multiply_R(Z, C) + else: + RtddX = self._multiply_R_transpose(ddX, C) + X_flat = self._multiply_R(Z + RtddX, C) + + X = X_flat.reshape([num_batch, num_residues, 4, 3]) + return X diff --git a/prtm/models/chroma/structure/optimal_transport.py b/prtm/models/chroma/structure/optimal_transport.py new file mode 100644 index 00000000..bca5e1dd --- /dev/null +++ b/prtm/models/chroma/structure/optimal_transport.py @@ -0,0 +1,100 @@ +# Copyright Generate Biomedicines, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Layers for comparing and mapping point clouds via optimal transport. + +This module contains minimalist implementations of basic optimal transport +routines which can be used to, for example, measure similarities between +point clouds of different shapes by computing optimal mappings between them. +For more information see the excellent book by Peyre, +https://arxiv.org/pdf/1803.00567.pdf +""" + +import numpy as np +import torch + + +def optimize_couplings_sinkhorn(C, scale=1.0, iterations=10): + """Solve entropy regularized optimized transport via Sinkhorn iteration. + + This method uses the log-domain for numerical stability. + + Args: + C (Tensor): Batch of cost matrices with with shape `(B, I, J)`. + scale (float, optional): Entropy regularization parameter for + rescaling the cost matrix. + iterations (int, optional): Number of Sinkhorn iterations. + + Returns: + T (Tensor): Couplings map with shape `(B, I, J)`. + """ + log_T = -C * scale + + # Initialize normalizers + B, I, J = log_T.shape + log_u = torch.zeros((B, I), device=log_T.device) + log_v = torch.zeros((B, J), device=log_T.device) + log_a = log_u - np.log(I) + log_b = log_v - np.log(J) + + # Iterate normalizers + for j in range(iterations): + log_u = log_a - torch.logsumexp(log_T + log_v.unsqueeze(1), 2) + log_v = log_b - torch.logsumexp(log_T + log_u.unsqueeze(2), 1) + log_T = log_T + log_v.unsqueeze(1) + log_u.unsqueeze(2) + T = torch.exp(log_T) + return T + + +def optimize_couplings_gw( + D_a, + D_b, + scale=200.0, + iterations_outer=30, + iterations_inner=10, +): + """Gromov-Wasserstein Optimal Transport. + https://arxiv.org/pdf/1905.07645.pdf + + Args: + D_a (Tensor): Distance matrix describing objects in set `a` with shape `(B, I, I)`. + D_b (Tensor): Distance matrix describing objects in set `b` with shape `(B, J, J)`. + scale (float, optional): Entropy regularization parameter for + rescaling the cost matrix. + iterations_outer (int, optional): Number of outer GW iterations. + iterations_inner (int, optional): Number of inner Sinkhorn iterations. + + Returns: + T (Tensor): Couplings map with shape `(B, I, J)`. + + """ + + # Gromov-Wasserstein Distance + N_a = D_a.shape[1] + N_b = D_b.shape[1] + p_a = torch.ones_like(D_a[:, :, 0]) / N_a + p_b = torch.ones_like(D_b[:, :, 0]) / N_b + C_ab = ( + torch.einsum("bij,bj->bi", D_a**2, p_a)[:, :, None] + + torch.einsum("bij,bj->bi", D_b**2, p_b)[:, None, :] + ) + T_gw = torch.einsum("bi,bj->bij", p_a, p_b) + for i in range(iterations_outer): + cost = C_ab - 2.0 * torch.einsum("bik,bkl,blj->bij", D_a, T_gw, D_b) + T_gw = optimize_couplings_sinkhorn(cost, scale, iterations=iterations_inner) + + # Compute cost + cost = C_ab - 2.0 * torch.einsum("bik,bkl,blj->bij", D_a, T_gw, D_b) + D_gw = (T_gw * cost).sum([-1, -2]).abs().sqrt() + return T_gw, D_gw diff --git a/prtm/models/chroma/structure/params/centering_2g3n.params b/prtm/models/chroma/structure/params/centering_2g3n.params new file mode 100644 index 00000000..63372726 --- /dev/null +++ b/prtm/models/chroma/structure/params/centering_2g3n.params @@ -0,0 +1,2 @@ +2g3n;[["internal_coords", {"log_lengths": true}]];["distances_2mer", "orientations_2mer", "orientations_chain", "distances_chain"] {"[\"internal_coords\", {\"log_lengths\": true}]": [-0.003, -0.001, -0.981, -0.994, 0.359, 0.447, 0.521, 0.51, -0.736, 0.178, 0.006, -0.01, 0.932, 0.893, 0.852, 0.86, 0.383, 0.428, 0.285, 0.216], "\"distances_2mer\"": [0.999, 0.39, 0.021, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.147, 0.863, 0.684, 0.073, 0.001, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.004, 0.17, 0.893, 0.642, 0.063, 0.001, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.036, 0.411, 0.884, 0.406, 0.035, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.017, 0.008, 0.014, 0.035, 0.045, 0.058, 0.071, 0.08, 0.097, 0.111, 0.102, 0.088, 0.077, 0.065, 0.058, 0.052, 0.047, 0.043, 0.04, 0.037, 0.003, 0.017, 0.027, 0.02, 0.045, 0.071, 0.07, 0.084, 0.101, 0.111, 0.103, 0.087, 0.072, 0.064, 0.059, 0.052, 0.046, 0.043, 0.04, 0.037, 0.003, 0.018, 0.028, 0.031, 0.04, 0.059, 0.074, 0.081, 0.099, 0.11, 0.106, 0.088, 0.074, 0.065, 0.058, 0.052, 0.046, 0.043, 0.04, 0.037, 0.0, 0.006, 0.031, 0.048, 0.039, 0.047, 0.074, 0.09, 0.096, 0.104, 0.1, 0.091, 0.079, 0.067, 0.058, 0.052, 0.047, 0.043, 0.04, 0.037, 0.147, 0.863, 0.684, 0.073, 0.001, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.999, 0.39, 0.021, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.123, 0.818, 0.738, 0.09, 0.001, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.006, 0.195, 0.925, 0.594, 0.052, 0.001, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.003, 0.017, 0.027, 0.02, 0.045, 0.071, 0.07, 0.084, 0.1, 0.11, 0.103, 0.087, 0.073, 0.065, 0.058, 0.052, 0.047, 0.043, 0.04, 0.037, 0.017, 0.007, 0.003, 0.025, 0.049, 0.069, 0.084, 0.081, 0.1, 0.116, 0.105, 0.08, 0.07, 0.064, 0.058, 0.052, 0.046, 0.042, 0.04, 0.037, 0.002, 0.017, 0.028, 0.02, 0.046, 0.069, 0.071, 0.088, 0.103, 0.113, 0.102, 0.084, 0.071, 0.064, 0.059, 0.052, 0.046, 0.043, 0.04, 0.036, 0.0, 0.005, 0.029, 0.042, 0.043, 0.052, 0.073, 0.094, 0.108, 0.101, 0.097, 0.089, 0.075, 0.065, 0.058, 0.052, 0.047, 0.043, 0.039, 0.037, 0.004, 0.17, 0.893, 0.642, 0.063, 0.001, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.123, 0.818, 0.738, 0.09, 0.001, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.999, 0.39, 0.021, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.255, 0.972, 0.502, 0.035, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.003, 0.018, 0.028, 0.031, 0.04, 0.059, 0.074, 0.081, 0.099, 0.109, 0.105, 0.088, 0.075, 0.065, 0.058, 0.051, 0.046, 0.043, 0.04, 0.037, 0.002, 0.017, 0.028, 0.02, 0.046, 0.069, 0.071, 0.088, 0.103, 0.113, 0.102, 0.084, 0.071, 0.063, 0.059, 0.052, 0.046, 0.042, 0.039, 0.037, 0.017, 0.007, 0.011, 0.036, 0.045, 0.061, 0.072, 0.084, 0.105, 0.115, 0.102, 0.084, 0.071, 0.064, 0.057, 0.051, 0.046, 0.042, 0.039, 0.036, 0.004, 0.017, 0.016, 0.037, 0.055, 0.05, 0.068, 0.093, 0.106, 0.106, 0.099, 0.086, 0.075, 0.065, 0.057, 0.052, 0.047, 0.043, 0.039, 0.036, 0.0, 0.036, 0.411, 0.884, 0.406, 0.035, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.006, 0.195, 0.925, 0.594, 0.052, 0.001, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.255, 0.972, 0.502, 0.035, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.999, 0.39, 0.021, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.006, 0.031, 0.048, 0.039, 0.047, 0.074, 0.09, 0.096, 0.103, 0.1, 0.091, 0.079, 0.068, 0.058, 0.052, 0.047, 0.043, 0.04, 0.037, 0.0, 0.005, 0.029, 0.042, 0.043, 0.052, 0.073, 0.094, 0.108, 0.102, 0.097, 0.089, 0.075, 0.065, 0.058, 0.052, 0.047, 0.043, 0.039, 0.037, 0.004, 0.017, 0.016, 0.037, 0.055, 0.05, 0.068, 0.093, 0.106, 0.106, 0.099, 0.086, 0.075, 0.065, 0.057, 0.052, 0.047, 0.043, 0.039, 0.036, 0.017, 0.007, 0.005, 0.027, 0.059, 0.064, 0.064, 0.086, 0.105, 0.109, 0.097, 0.085, 0.076, 0.066, 0.058, 0.052, 0.047, 0.043, 0.039, 0.036, 0.017, 0.008, 0.014, 0.035, 0.045, 0.058, 0.071, 0.08, 0.097, 0.111, 0.102, 0.088, 0.077, 0.065, 0.058, 0.052, 0.047, 0.043, 0.04, 0.037, 0.003, 0.017, 0.027, 0.02, 0.045, 0.071, 0.07, 0.084, 0.1, 0.11, 0.103, 0.087, 0.073, 0.065, 0.058, 0.052, 0.047, 0.043, 0.04, 0.037, 0.003, 0.018, 0.028, 0.031, 0.04, 0.059, 0.074, 0.081, 0.099, 0.109, 0.105, 0.088, 0.075, 0.065, 0.058, 0.051, 0.046, 0.043, 0.04, 0.037, 0.0, 0.006, 0.031, 0.048, 0.039, 0.047, 0.074, 0.09, 0.096, 0.103, 0.1, 0.091, 0.079, 0.068, 0.058, 0.052, 0.047, 0.043, 0.04, 0.037, 0.999, 0.39, 0.021, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.147, 0.863, 0.684, 0.073, 0.001, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.004, 0.17, 0.893, 0.642, 0.063, 0.001, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.036, 0.414, 0.883, 0.403, 0.035, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.003, 0.017, 0.027, 0.02, 0.045, 0.071, 0.07, 0.084, 0.101, 0.111, 0.103, 0.087, 0.072, 0.064, 0.059, 0.052, 0.046, 0.043, 0.04, 0.037, 0.017, 0.007, 0.003, 0.025, 0.049, 0.069, 0.084, 0.081, 0.1, 0.116, 0.105, 0.08, 0.07, 0.064, 0.058, 0.052, 0.046, 0.042, 0.04, 0.037, 0.002, 0.017, 0.028, 0.02, 0.046, 0.069, 0.071, 0.088, 0.103, 0.113, 0.102, 0.084, 0.071, 0.063, 0.059, 0.052, 0.046, 0.042, 0.039, 0.037, 0.0, 0.005, 0.029, 0.042, 0.043, 0.052, 0.073, 0.094, 0.108, 0.102, 0.097, 0.089, 0.075, 0.065, 0.058, 0.052, 0.047, 0.043, 0.039, 0.037, 0.147, 0.863, 0.684, 0.073, 0.001, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.999, 0.39, 0.021, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.123, 0.818, 0.738, 0.09, 0.001, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.006, 0.195, 0.925, 0.594, 0.052, 0.001, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.003, 0.018, 0.028, 0.031, 0.04, 0.059, 0.074, 0.081, 0.099, 0.11, 0.106, 0.088, 0.074, 0.065, 0.058, 0.052, 0.046, 0.043, 0.04, 0.037, 0.002, 0.017, 0.028, 0.02, 0.046, 0.069, 0.071, 0.088, 0.103, 0.113, 0.102, 0.084, 0.071, 0.064, 0.059, 0.052, 0.046, 0.043, 0.04, 0.036, 0.017, 0.007, 0.011, 0.036, 0.045, 0.061, 0.072, 0.084, 0.105, 0.115, 0.102, 0.084, 0.071, 0.064, 0.057, 0.051, 0.046, 0.042, 0.039, 0.036, 0.004, 0.017, 0.016, 0.037, 0.055, 0.05, 0.068, 0.093, 0.106, 0.106, 0.099, 0.086, 0.075, 0.065, 0.057, 0.052, 0.047, 0.043, 0.039, 0.036, 0.004, 0.17, 0.893, 0.642, 0.063, 0.001, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.123, 0.818, 0.738, 0.09, 0.001, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.999, 0.39, 0.021, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.255, 0.972, 0.502, 0.035, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.006, 0.031, 0.048, 0.039, 0.047, 0.074, 0.09, 0.096, 0.104, 0.1, 0.091, 0.079, 0.067, 0.058, 0.052, 0.047, 0.043, 0.04, 0.037, 0.0, 0.005, 0.029, 0.042, 0.043, 0.052, 0.073, 0.094, 0.108, 0.101, 0.097, 0.089, 0.075, 0.065, 0.058, 0.052, 0.047, 0.043, 0.039, 0.037, 0.004, 0.017, 0.016, 0.037, 0.055, 0.05, 0.068, 0.093, 0.106, 0.106, 0.099, 0.086, 0.075, 0.065, 0.057, 0.052, 0.047, 0.043, 0.039, 0.036, 0.017, 0.007, 0.005, 0.027, 0.059, 0.064, 0.064, 0.086, 0.105, 0.109, 0.097, 0.085, 0.076, 0.066, 0.058, 0.052, 0.047, 0.043, 0.039, 0.036, 0.0, 0.036, 0.414, 0.883, 0.403, 0.035, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.006, 0.195, 0.925, 0.594, 0.052, 0.001, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.255, 0.972, 0.502, 0.035, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.999, 0.39, 0.021, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -3.179, 0.383, 0.903, 1.145, 2.666, 2.725, 2.721, 2.735, 0.383, -3.179, 0.428, 0.879, 2.726, 2.671, 2.724, 2.736, 0.903, 0.428, -3.179, 0.216, 2.721, 2.724, 2.662, 2.724, 1.145, 0.879, 0.216, -3.179, 2.736, 2.736, 2.724, 2.673, 2.666, 2.726, 2.721, 2.736, -3.179, 0.383, 0.903, 1.143, 2.725, 2.671, 2.724, 2.736, 0.383, -3.179, 0.428, 0.879, 2.721, 2.724, 2.662, 2.724, 0.903, 0.428, -3.179, 0.216, 2.735, 2.736, 2.724, 2.673, 1.143, 0.879, 0.216, -3.179], "\"orientations_2mer\"": [0.0, 0.0, 0.0, -0.955, -0.0, 0.0, -0.79, 0.0, -0.543, -0.663, 0.07, -0.613, -0.035, 0.003, -0.074, -0.045, 0.0, -0.073, -0.043, -0.006, -0.081, -0.049, -0.007, -0.082, 0.955, 0.0, -0.0, 0.0, 0.0, 0.0, -0.343, 0.0, -0.863, -0.307, 0.073, -0.778, 0.057, 0.0, -0.079, 0.044, -0.001, -0.076, 0.042, -0.005, -0.088, 0.037, -0.008, -0.086, 0.79, -0.0, 0.543, 0.343, -0.0, 0.863, 0.0, 0.0, 0.0, -0.164, 0.138, -0.424, 0.082, -0.001, 0.005, 0.075, -0.001, 0.01, 0.072, -0.004, 0.0, 0.066, -0.006, -0.006, 0.663, -0.07, 0.613, 0.307, -0.073, 0.778, 0.164, -0.138, 0.424, 0.0, 0.0, 0.0, 0.092, -0.01, 0.039, 0.085, -0.011, 0.039, 0.083, -0.015, 0.036, 0.078, -0.016, 0.03, 0.035, -0.003, 0.074, -0.057, -0.0, 0.079, -0.082, 0.001, -0.005, -0.092, 0.01, -0.039, 0.0, 0.0, 0.0, -0.032, -0.012, -0.018, -0.037, -0.017, -0.026, -0.039, -0.021, -0.028, 0.045, -0.0, 0.073, -0.044, 0.001, 0.076, -0.075, 0.001, -0.01, -0.085, 0.011, -0.039, 0.032, 0.012, 0.018, 0.0, 0.0, 0.0, -0.029, -0.016, -0.024, -0.034, -0.026, -0.027, 0.043, 0.006, 0.081, -0.042, 0.005, 0.088, -0.072, 0.004, -0.0, -0.083, 0.015, -0.036, 0.037, 0.017, 0.026, 0.029, 0.016, 0.024, 0.0, 0.0, 0.0, -0.03, -0.03, -0.022, 0.049, 0.007, 0.082, -0.037, 0.008, 0.086, -0.066, 0.006, 0.006, -0.078, 0.016, -0.03, 0.039, 0.021, 0.028, 0.034, 0.026, 0.027, 0.03, 0.03, 0.022, 0.0, 0.0, 0.0], "\"orientations_chain\"": [0.213, 0.019, 0.0, -0.0, 0.889, 0.0, 0.0, 0.0, 3.244, 0.016, -0.0, -0.0, 0.501, 0.003, -0.002, -0.04, 2.717, 0.044, -0.001, -0.076, 0.457, 0.0, -0.0, -0.003], "\"distances_chain\"": [0.221, 2.826, -0.013]} +2g3n;[["internal_coords", {"log_lengths": true}]];["orientations_2mer", "distances_chain", "orientations_chain", ["random_fourier_2mer", {"dim_embedding": 192, "trainable": false, "scale": 0.1}]] {"[\"internal_coords\", {\"log_lengths\": true}]": [-0.003, -0.001, -0.981, -0.994, 0.359, 0.447, 0.521, 0.51, -0.736, 0.178, 0.006, -0.01, 0.932, 0.893, 0.852, 0.86, 0.383, 0.428, 0.285, 0.216], "\"orientations_2mer\"": [0.0, 0.0, 0.0, -0.955, -0.0, 0.0, -0.79, 0.0, -0.543, -0.663, 0.07, -0.613, -0.029, 0.004, -0.062, -0.036, 0.002, -0.061, -0.035, -0.002, -0.067, -0.038, -0.003, -0.067, 0.955, 0.0, -0.0, 0.0, 0.0, 0.0, -0.343, 0.0, -0.863, -0.307, 0.073, -0.778, 0.047, 0.002, -0.065, 0.038, 0.001, -0.063, 0.037, -0.002, -0.071, 0.034, -0.004, -0.07, 0.79, -0.0, 0.543, 0.343, -0.0, 0.863, 0.0, 0.0, 0.0, -0.164, 0.138, -0.424, 0.069, 0.001, 0.005, 0.065, 0.001, 0.008, 0.063, -0.001, 0.002, 0.059, -0.002, -0.002, 0.663, -0.07, 0.613, 0.307, -0.073, 0.778, 0.164, -0.138, 0.424, 0.0, 0.0, 0.0, 0.078, -0.007, 0.033, 0.074, -0.008, 0.033, 0.073, -0.01, 0.031, 0.069, -0.011, 0.027, 0.029, -0.004, 0.062, -0.047, -0.002, 0.065, -0.069, -0.001, -0.005, -0.078, 0.007, -0.033, 0.0, 0.0, 0.0, -0.024, -0.008, -0.013, -0.028, -0.011, -0.018, -0.029, -0.014, -0.02, 0.036, -0.002, 0.061, -0.038, -0.001, 0.063, -0.065, -0.001, -0.008, -0.074, 0.008, -0.033, 0.024, 0.008, 0.013, 0.0, 0.0, 0.0, -0.021, -0.009, -0.017, -0.025, -0.017, -0.019, 0.035, 0.002, 0.067, -0.037, 0.002, 0.071, -0.063, 0.001, -0.002, -0.073, 0.01, -0.031, 0.028, 0.011, 0.018, 0.021, 0.009, 0.017, 0.0, 0.0, 0.0, -0.021, -0.021, -0.016, 0.038, 0.003, 0.067, -0.034, 0.004, 0.07, -0.059, 0.002, 0.002, -0.069, 0.011, -0.027, 0.029, 0.014, 0.02, 0.025, 0.017, 0.019, 0.021, 0.021, 0.016, 0.0, 0.0, 0.0], "\"distances_chain\"": [0.261, 2.862, -0.013], "\"orientations_chain\"": [0.413, 0.022, -0.0, -0.0, 0.869, 0.0, -0.0, -0.0, 3.273, 0.02, -0.0, -0.0, 0.501, 0.002, -0.003, -0.036, 2.887, 0.038, 0.001, -0.063, 0.455, 0.0, -0.0, -0.002], "[\"random_fourier_2mer\", {\"dim_embedding\": 192, \"trainable\": false, \"scale\": 0.1}]": [-0.006, 0.022, 0.006, 0.007, 0.023, 0.0, 0.021, 0.017, 0.008, 0.018, 0.001, 0.065, 0.019, -0.001, 0.003, 0.02, 0.019, 0.016, 0.012, -0.008, -0.01, -0.002, 0.014, -0.004, 0.022, 0.038, 0.038, -0.001, 0.023, 0.016, -0.001, -0.002, 0.025, 0.031, 0.17, 0.014, 0.02, -0.021, 0.006, 0.022, -0.009, -0.005, 0.017, 0.012, 0.016, 0.01, -0.018, 0.009, 0.048, 0.009, -0.0, 0.019, 0.016, -0.017, 0.016, -0.009, 0.028, 0.005, 0.01, 0.009, 0.011, 0.009, 0.014, 0.0, -0.008, -0.015, 0.016, 0.002, 0.02, 0.022, 0.189, 0.008, -0.007, 0.004, 0.008, 0.012, 0.016, 0.04, 0.024, 0.017, -0.0, 0.015, 0.006, -0.007, 0.027, -0.003, 0.031, -0.002, 0.026, -0.004, -0.004, 0.033, 0.038, 0.027, 0.05, -0.003, -0.033, -0.03, -0.023, -0.009, 0.025, -0.022, -0.021, 0.012, 0.016, -0.02, -0.003, 0.073, -0.04, 0.006, 0.01, 0.0, 0.011, 0.007, 0.03, -0.002, -0.008, 0.004, -0.023, -0.007, 0.021, 0.017, 0.018, -0.019, 0.008, 0.028, -0.006, 0.005, 0.077, -0.014, 0.134, -0.003, 0.037, 0.0, 0.003, -0.015, 0.01, 0.003, 0.005, 0.003, 0.007, 0.008, 0.008, 0.016, 0.005, 0.021, -0.012, 0.017, -0.024, -0.002, -0.003, -0.006, -0.004, 0.008, -0.023, -0.028, 0.005, 0.001, -0.008, 0.018, 0.024, -0.02, 0.009, -0.027, 0.013, -0.0, 0.041, 0.002, -0.013, 0.001, 0.035, -0.01, 0.003, 0.001, 0.006, 0.019, 0.017, -0.029, 0.002, -0.0, -0.01, -0.012, -0.009, -0.017, -0.001, 0.015, 0.005, 0.005, 0.01, 0.018, 0.011, 0.009]} diff --git a/prtm/models/chroma/structure/potts.py b/prtm/models/chroma/structure/potts.py new file mode 100644 index 00000000..9e0cbe16 --- /dev/null +++ b/prtm/models/chroma/structure/potts.py @@ -0,0 +1,1169 @@ +# Copyright Generate Biomedicines, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Layers for building Potts models. + +This module contains layers for parameterizing Potts models from +graph embeddings. +""" + +from typing import Callable, List, Literal, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from prtm.models.chroma import graph +from tqdm.auto import tqdm + + +class GraphPotts(nn.Module): + """Conditional Random Field (conditional Potts model) layer on a graph. + + Arguments: + dim_nodes (int): Hidden dimension of node tensor. + dim_edges (int): Hidden dimension of edge tensor. + num_states (int): Size of the vocabulary. + parameterization (str): Parameterization choice in + `{'linear', 'factor', 'score', 'score_zsum', 'score_scale'}`, or + any of those suffixed with `_beta`, which will add in a globally + learnable temperature scaling parameter. + symmetric_J (bool): If True enforce symmetry of Potts model i.e. + `J_ij(s_i, s_j) = J_ji(s_j, s_i)`. + init_scale (float): Scale factor for the weights and couplings at + initialization. + dropout (float): Probability of per-dimension dropout on `[0,1]`. + label_smoothing (float): Label smoothing probability on for when + per token likelihoods. + num_factors (int): Number of factors to use for the `factor` + parameterization mode. + beta_init (float): Initial temperature scaling factor for parameterizations + with the `_beta` suffix. + + Inputs: + node_h (torch.Tensor): Node features with shape + `(num_batch, num_nodes, dim_nodes)`. + edge_h (torch.Tensor): Edge features with shape + `(num_batch, num_nodes, num_neighbors, dim_edges)`. + edge_idx (torch.LongTensor): Edge indices with shape + `(num_batch, num_nodes, num_neighbors)`. + mask_i (torch.Tensor): Node mask with shape `(num_batch, num_nodes)` + mask_ij (torch.Tensor): Edge mask with shape + `(num_batch, num_nodes, num_neighbors)` + + Outputs: + h (torch.Tensor): Potts model fields :math:`h_i(s_i)` with shape + `(num_batch, num_nodes, num_states)`. + J (Tensor): Potts model couplings :math:`J_{ij}(s_i, s_j)` with shape + `(num_batch, num_nodes, num_neighbors, num_states, num_states)`. + """ + + def __init__( + self, + dim_nodes: int, + dim_edges: int, + num_states: int, + parameterization: str = "score", + symmetric_J: bool = True, + init_scale: float = 0.1, + dropout: float = 0.0, + label_smoothing: float = 0.0, + num_factors: Optional[int] = None, + beta_init: float = 10.0, + ): + super(GraphPotts, self).__init__() + self.dim_nodes = dim_nodes + self.dim_edges = dim_edges + self.num_states = num_states + + self.label_smoothing = label_smoothing + + # Beta parameterization support temperature learning + self.scale_beta = False + if parameterization.endswith("_beta"): + parameterization = parameterization.split("_beta")[0] + self.scale_beta = True + self.log_beta = nn.Parameter(np.log(beta_init) * torch.ones(1)) + + self.init_scale = init_scale + self.parameterization = parameterization + self.symmetric_J = symmetric_J + if self.parameterization == "linear": + self.log_scale = nn.Parameter(np.log(init_scale) * torch.ones(1)) + self.W_h = nn.Linear(self.dim_nodes, self.num_states, bias=True) + self.W_J = nn.Linear(self.dim_edges, self.num_states**2, bias=True) + elif self.parameterization == "factor": + self.log_scale = nn.Parameter(np.log(init_scale) * torch.ones(1)) + self.W_h = nn.Linear(self.dim_nodes, self.num_states, bias=True) + self.W_J_left = nn.Linear(self.dim_edges, self.num_states**2, bias=True) + self.W_J_right = nn.Linear(self.dim_edges, self.num_states**2, bias=True) + elif self.parameterization == "score": + if num_factors is None: + num_factors = dim_edges + self.num_factors = num_factors + self.log_scale = nn.Parameter(np.log(init_scale) * torch.ones(1)) + self.W_h_bg = nn.Linear(self.dim_nodes, 1) + self.W_J_bg = nn.Linear(self.dim_edges, 1) + self.W_h = nn.Linear(self.dim_nodes, self.num_states, bias=True) + self.W_J_left = nn.Linear( + self.dim_edges, self.num_states * num_factors, bias=True + ) + self.W_J_right = nn.Linear( + self.dim_edges, self.num_states * num_factors, bias=True + ) + elif self.parameterization == "score_zsum": + if num_factors is None: + num_factors = dim_edges + self.num_factors = num_factors + self.log_scale = nn.Parameter(np.log(init_scale) * torch.ones(1)) + self.W_h = nn.Linear(self.dim_nodes, self.num_states, bias=True) + self.W_J_left = nn.Linear( + self.dim_edges, self.num_states * num_factors, bias=True + ) + self.W_J_right = nn.Linear( + self.dim_edges, self.num_states * num_factors, bias=True + ) + elif self.parameterization == "score_scale": + if num_factors is None: + num_factors = dim_edges + self.num_factors = num_factors + self.W_h_bg = nn.Linear(self.dim_nodes, 1) + self.W_J_bg = nn.Linear(self.dim_edges, 1) + self.W_h_log_scale = nn.Linear(self.dim_nodes, 1) + self.W_J_log_scale = nn.Linear(self.dim_edges, 1) + self.W_h = nn.Linear(self.dim_nodes, self.num_states) + self.W_J_left = nn.Linear(self.dim_edges, self.num_states * num_factors) + self.W_J_right = nn.Linear(self.dim_edges, self.num_states * num_factors) + else: + print(f"Unknown potts parameterization: {parameterization}") + raise NotImplementedError + self.dropout = nn.Dropout(dropout) + + def _mask_J(self, edge_idx, mask_i, mask_ij): + # Remove self edges + device = edge_idx.device + ii = torch.arange(edge_idx.shape[1]).view((1, -1, 1)).to(device) + not_self = torch.ne(edge_idx, ii).type(torch.float32) + + # Remove missing edges + self_present = mask_i.unsqueeze(-1) + neighbor_present = graph.collect_neighbors(self_present, edge_idx) + neighbor_present = neighbor_present.squeeze(-1) + + mask_J = not_self * self_present * neighbor_present + if mask_ij is not None: + mask_J = mask_ij * mask_J + return mask_J + + def forward( + self, + node_h: torch.Tensor, + edge_h: torch.Tensor, + edge_idx: torch.LongTensor, + mask_i: torch.Tensor, + mask_ij: torch.Tensor, + ): + mask_J = self._mask_J(edge_idx, mask_i, mask_ij) + + if self.parameterization == "linear": + # Compute site params (h) from node embeddings + # Compute coupling params (J) from edge embeddings + scale = torch.exp(self.log_scale) + h = scale * mask_i.unsqueeze(-1) * self.W_h(node_h) + J = scale * mask_J.unsqueeze(-1) * self.W_J(edge_h) + J = J.view(list(edge_h.size())[:3] + ([self.num_states] * 2)) + elif self.parameterization == "factor": + scale = torch.exp(self.log_scale) + h = scale * mask_i.unsqueeze(-1) * self.W_h(node_h) + mask_J = scale * mask_J.unsqueeze(-1) + shape_J = list(edge_h.size())[:3] + ([self.num_states] * 2) + J_left = (mask_J * self.W_J_left(edge_h)).view(shape_J) + J_right = (mask_J * self.W_J_right(edge_h)).view(shape_J) + J = torch.matmul(J_left, J_right) + J = self.dropout(J) + # Zero-sum gauge + h = h - h.mean(-1, keepdim=True) + J = ( + J + - J.mean(-1, keepdim=True) + - J.mean(-2, keepdim=True) + + J.mean(dim=[-1, -2], keepdim=True) + ) + elif self.parameterization == "score": + node_h = self.dropout(node_h) + edge_h = self.dropout(edge_h) + + scale = torch.exp(self.log_scale) + mask_h = scale * mask_i.unsqueeze(-1) + mask_J = scale * mask_J.unsqueeze(-1) + h = mask_h * self.W_h(node_h) + + shape_J_prefix = list(edge_h.size())[:3] + J_left = (mask_J * self.W_J_left(edge_h)).view( + shape_J_prefix + [self.num_states, self.num_factors] + ) + J_right = (mask_J * self.W_J_right(edge_h)).view( + shape_J_prefix + [self.num_factors, self.num_states] + ) + J = torch.matmul(J_left, J_right) + + # Zero-sum gauge + h = h - h.mean(-1, keepdim=True) + J = ( + J + - J.mean(-1, keepdim=True) + - J.mean(-2, keepdim=True) + + J.mean(dim=[-1, -2], keepdim=True) + ) + + # Background components + h = h + mask_h * self.W_h_bg(node_h) + J = J + (mask_J * self.W_J_bg(edge_h)).unsqueeze(-1) + elif self.parameterization == "score_zsum": + node_h = self.dropout(node_h) + edge_h = self.dropout(edge_h) + + scale = torch.exp(self.log_scale) + mask_h_scale = scale * mask_i.unsqueeze(-1) + mask_J_scale = scale * mask_J.unsqueeze(-1) + h = mask_h_scale * self.W_h(node_h) + + shape_J_prefix = list(edge_h.size())[:3] + J_left = (mask_J_scale * self.W_J_left(edge_h)).view( + shape_J_prefix + [self.num_states, self.num_factors] + ) + J_right = (mask_J_scale * self.W_J_right(edge_h)).view( + shape_J_prefix + [self.num_factors, self.num_states] + ) + J = torch.matmul(J_left, J_right) + J = self.dropout(J) + + # Zero-sum gauge + J = ( + J + - J.mean(-1, keepdim=True) + - J.mean(-2, keepdim=True) + + J.mean(dim=[-1, -2], keepdim=True) + ) + + # Subtract off J background average + mask_J = mask_J.view(list(mask_J.size()) + [1, 1]) + J_i_avg = J.sum(dim=[1, 2], keepdim=True) / mask_J.sum([1, 2], keepdim=True) + J = mask_J * (J - J_i_avg) + elif self.parameterization == "score_scale": + node_h = self.dropout(node_h) + edge_h = self.dropout(edge_h) + + mask_h = mask_i.unsqueeze(-1) + mask_J = mask_J.unsqueeze(-1) + h = mask_h * self.W_h(node_h) + + shape_J_prefix = list(edge_h.size())[:3] + J_left = (mask_J * self.W_J_left(edge_h)).view( + shape_J_prefix + [self.num_states, self.num_factors] + ) + J_right = (mask_J * self.W_J_right(edge_h)).view( + shape_J_prefix + [self.num_factors, self.num_states] + ) + J = torch.matmul(J_left, J_right) + + # Zero-sum gauge + h = h - h.mean(-1, keepdim=True) + J = ( + J + - J.mean(-1, keepdim=True) + - J.mean(-2, keepdim=True) + + J.mean(dim=[-1, -2], keepdim=True) + ) + + # Background components + log_scale = np.log(self.init_scale) + h_scale = torch.exp(self.W_h_log_scale(node_h) + log_scale) + J_scale = torch.exp(self.W_J_log_scale(edge_h) + 2 * log_scale).unsqueeze( + -1 + ) + h_bg = mask_h * self.W_h_bg(node_h) + J_bg = (mask_J * self.W_J_bg(edge_h)).unsqueeze(-1) + h = h_scale * (h + h_bg) + J = J_scale * (J + J_bg) + + if self.symmetric_J: + J = self._symmetrize_J(J, edge_idx, mask_ij) + + if self.scale_beta: + beta = torch.exp(self.log_beta) + h = beta * h + J = beta * J + return h, J + + def _symmetrize_J_serial(self, J, edge_idx, mask_ij): + """Enforce symmetry of J matrices, serial version.""" + num_batch, num_residues, num_k, num_states, _ = list(J.size()) + + # Symmetrization based on raw indexing - extremely slow; for debugging + import time + + _start = time.time() + J_symm = torch.zeros_like(J) + for b in range(J.size(0)): + for i in range(J.size(1)): + for k_i in range(J.size(2)): + for k_j in range(J.size(2)): + j = edge_idx[b, i, k_i] + if edge_idx[b, j, k_j] == i: + J_symm[b, i, k_i, :, :] = ( + J[b, i, k_i, :, :] + + J[b, j, k_j, :, :].transpose(-1, -2) + ) / 2.0 + speed = J.size(0) * J.size(1) / (time.time() - _start) + print(f"symmetrized at {speed} residue/s") + return J_symm + + def _symmetrize_J(self, J, edge_idx, mask_ij): + """Enforce symmetry of J matrices via adding J_ij + J_ji^T""" + num_batch, num_residues, num_k, num_states, _ = list(J.size()) + + # Flatten and gather J_ji matrices using transpose indexing + J_flat = J.view(num_batch, num_residues, num_k, -1) + J_flat_transpose, mask_ji = graph.collect_edges_transpose( + J_flat, edge_idx, mask_ij + ) + J_transpose = J_flat_transpose.view( + num_batch, num_residues, num_k, num_states, num_states + ) + # Transpose J_ji matrices to symmetrize as (J_ij + J_ji^T)/2 + J_transpose = J_transpose.transpose(-2, -1) + mask_ji = (0.5 * mask_ji).view(num_batch, num_residues, num_k, 1, 1) + J_symm = mask_ji * (J + J_transpose) + return J_symm + + def energy( + self, + S: torch.LongTensor, + h: torch.Tensor, + J: torch.Tensor, + edge_idx: torch.LongTensor, + ) -> torch.Tensor: + """Compute Potts model energy from sequence. + + Inputs: + S (torch.LongTensor): Sequence with shape `(num_batch, num_nodes)`. + h (torch.Tensor): Potts model fields :math:`h_i(s_i)` with shape + `(num_batch, num_nodes, num_states)`. + J (Tensor): Potts model couplings :math:`J_{ij}(s_i, s_j)` with shape + `(num_batch, num_nodes, num_neighbors, num_states, num_states)`. + edge_idx (torch.LongTensor): Edge indices with shape + `(num_batch, num_nodes, num_neighbors)`. + + Outputs: + U (torch.Tensor): Potts total energies with shape `(num_batch)`. + Lower energies are more favorable. + """ + # Gather J [Batch,i,j,A_i,A_j] => J_ij(:,A_j) [Batch,i,j,A_i] + S_j = graph.collect_neighbors(S.unsqueeze(-1), edge_idx) + S_j = S_j.unsqueeze(-1).expand(-1, -1, -1, self.num_states, -1) + J_ij = torch.gather(J, -1, S_j).squeeze(-1) + + # Sum out J contributions + J_i = J_ij.sum(2) / 2.0 + r_i = h + J_i + + U_i = torch.gather(r_i, 2, S.unsqueeze(-1)) + U = U_i.sum([1, 2]) + return U + + def pseudolikelihood( + self, + S: torch.LongTensor, + h: torch.Tensor, + J: torch.Tensor, + edge_idx: torch.LongTensor, + ) -> torch.Tensor: + """Compute Potts pseudolikelihood from sequence + + Inputs: + S (torch.LongTensor): Sequence with shape `(num_batch, num_nodes)`. + h (torch.Tensor): Potts model fields :math:`h_i(s_i)` with shape + `(num_batch, num_nodes, num_states)`. + J (Tensor): Potts model couplings :math:`J_{ij}(s_i, s_j)` with shape + `(num_batch, num_nodes, num_neighbors, num_states, num_states)`. + edge_idx (torch.LongTensor): Edge indices with shape + `(num_batch, num_nodes, num_neighbors)`. + + Outputs: + log_probs (torch.Tensor): Potts log-pseudolihoods with shape + `(num_batch, num_nodes, num_states)`. + """ + + # Gather J [Batch,i,j,A_i,A_j] => J_ij(:,A_j) [Batch,i,j,A_i] + S_j = graph.collect_neighbors(S.unsqueeze(-1), edge_idx) + S_j = S_j.unsqueeze(-1).expand(-1, -1, -1, self.num_states, -1) + J_ij = torch.gather(J, -1, S_j).squeeze(-1) + + # Sum out J contributions + J_i = J_ij.sum(2) + + logits = h + J_i + log_probs = F.log_softmax(-logits, dim=-1) + return log_probs + + def log_composite_likelihood( + self, + S: torch.LongTensor, + h: torch.Tensor, + J: torch.Tensor, + edge_idx: torch.LongTensor, + mask_i: torch.Tensor, + mask_ij: torch.Tensor, + smoothing_alpha: float = 0.0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute Potts pairwise composite likelihoods from sequence. + + Inputs: + S (torch.LongTensor): Sequence with shape `(num_batch, num_nodes)`. + h (torch.Tensor): Potts model fields :math:`h_i(s_i)` with shape + `(num_batch, num_nodes, num_states)`. + J (Tensor): Potts model couplings :math:`J_{ij}(s_i, s_j)` with shape + `(num_batch, num_nodes, num_neighbors, num_states, num_states)`. + edge_idx (torch.LongTensor): Edge indices with shape + `(num_batch, num_nodes, num_neighbors)`. + mask_i (torch.Tensor): Node mask with shape `(num_batch, num_nodes)` + mask_ij (torch.Tensor): Edge mask with shape + `(num_batch, num_nodes, num_neighbors)`. + smoothing_alpha (float): Label smoothing probability on `(0,1)`. + + Outputs: + logp_ij (torch.Tensor): Potts pairwise composite likelihoods evaluated + for the current sequence with shape + `(num_batch, num_nodes, num_neighbors)`. + mask_p_ij (torch.Tensor): Edge mask with shape + `(num_batch, num_nodes, num_neighbors)`. + """ + num_batch, num_residues, num_k, num_states, _ = list(J.size()) + + # Gather J clamped at j + # [Batch,i,j,A_i,A_j] => J_ij(:,A_j) [Batch,i,j,A_i] + S_j = graph.collect_neighbors(S.unsqueeze(-1), edge_idx) + S_j = S_j.unsqueeze(-1).expand(-1, -1, -1, num_states, -1) + # (B,i,j,A_i) + J_clamp_j = torch.gather(J, -1, S_j).squeeze(-1) + + # Gather J clamped at i + S_i = S.view(num_batch, num_residues, 1, 1, 1) + S_i = S_i.expand(-1, -1, num_k, num_states, num_states) + # (B,i,j,1,A_j) + J_clamp_i = torch.gather(J, -2, S_i) + + # Compute background per-site contributions that sum out J + # (B,i,j,A_i) => (B,i,A_i) + r_i = h + J_clamp_j.sum(2) + r_j = graph.collect_neighbors(r_i, edge_idx) + + # Remove J_ij from the i contributions + # (B,i,A_i) => (B,i,:,A_i,:) + r_i = r_i.view([num_batch, num_residues, 1, num_states, 1]) + r_i_minus_ij = r_i - J_clamp_j.unsqueeze(-1) + + # Remove J_ji from the j contributions + # (B,j,A_j) => (B,:,j,:,A_j) + r_j = r_j.view([num_batch, num_residues, num_k, 1, num_states]) + r_j_minus_ji = r_j - J_clamp_i + + # Composite likelihood (B,i,j,A_i,A_j) + logits_ij = r_i_minus_ij + r_j_minus_ji + J + logits_ij = logits_ij.view([num_batch, num_residues, num_k, -1]) + logp = F.log_softmax(-logits_ij, dim=-1) + logp = logp.view([num_batch, num_residues, num_k, num_states, num_states]) + + # Score the current sequence under + # (B,i,j,A_i,A_j) => (B,i,j,A_i) => (B,i,j) + logp_j = torch.gather(logp, -1, S_j).squeeze(-1) + S_i = S.view(num_batch, num_residues, 1, 1).expand(-1, -1, num_k, -1) + logp_ij = torch.gather(logp_j, -1, S_i).squeeze(-1) + + # Optional label smoothing (scaled assuming per-token smoothing ) + if smoothing_alpha > 0.0: + # Foreground probability + num_bins = num_states**2 + prob_no_smooth = (1.0 - smoothing_alpha) ** 2 + prob_background = (1.0 - prob_no_smooth) / float(num_bins - 1) + # The second term corrects for double counting in background sum + p_foreground = prob_no_smooth - prob_background + logp_ij = p_foreground * logp_ij + prob_background * logp.sum([-2, -1]) + + mask_p_ij = self._mask_J(edge_idx, mask_i, mask_ij) + logp_ij = mask_p_ij * logp_ij + return logp_ij, mask_p_ij + + def loss( + self, + S: torch.LongTensor, + node_h: torch.Tensor, + edge_h: torch.Tensor, + edge_idx: torch.LongTensor, + mask_i: torch.Tensor, + mask_ij: torch.Tensor, + ) -> torch.Tensor: + """Compute per-residue losses given a sequence. + + Inputs: + S (torch.LongTensor): Sequence with shape `(num_batch, num_nodes)`. + node_h (torch.Tensor): Node features with shape + `(num_batch, num_nodes, dim_nodes)`. + edge_h (torch.Tensor): Edge features with shape + `(num_batch, num_nodes, num_neighbors, dim_edges)`. + edge_idx (torch.LongTensor): Edge indices with shape + `(num_batch, num_nodes, num_neighbors)`. + mask_i (torch.Tensor): Node mask with shape `(num_batch, num_nodes)` + mask_ij (torch.Tensor): Edge mask with shape + `(num_batch, num_nodes, num_neighbors)` + + Outputs: + logp_i (torch.Tensor): Potts per-residue normalized composite + log likelihoods with shape`(num_batch, num_nodes)`. + """ + + # Compute parameters + h, J = self.forward(node_h, edge_h, edge_idx, mask_i, mask_ij) + + # Log composite likelihood + logp_ij, mask_p_ij = self.log_composite_likelihood( + S, + h, + J, + edge_idx, + mask_i, + mask_ij, + smoothing_alpha=self.label_smoothing if self.training else 0.0, + ) + + # Map into approximate local likelihoods + logp_i = ( + mask_i + * torch.sum(mask_p_ij * logp_ij, dim=-1) + / (2.0 * torch.sum(mask_p_ij, dim=-1) + 1e-3) + ) + return logp_i + + def sample( + self, + node_h: torch.Tensor, + edge_h: torch.Tensor, + edge_idx: torch.LongTensor, + mask_i: torch.Tensor, + mask_ij: torch.Tensor, + S: Optional[torch.LongTensor] = None, + mask_sample: Optional[torch.Tensor] = None, + num_sweeps: int = 100, + temperature: float = 0.1, + temperature_init: float = 1.0, + penalty_func: Optional[Callable[[torch.LongTensor], torch.Tensor]] = None, + differentiable_penalty: bool = True, + rejection_step: bool = False, + proposal: Literal["dlmc", "chromatic"] = "dlmc", + verbose: bool = False, + edge_idx_coloring: Optional[torch.LongTensor] = None, + mask_ij_coloring: Optional[torch.Tensor] = None, + symmetry_order: Optional[int] = None, + ) -> Tuple[torch.LongTensor, torch.Tensor]: + """Sample from Potts model with Chromatic Gibbs sampling. + + Args: + node_h (torch.Tensor): Node features with shape + `(num_batch, num_nodes, dim_nodes)`. + edge_h (torch.Tensor): Edge features with shape + `(num_batch, num_nodes, num_neighbors, dim_edges)`. + edge_idx (torch.LongTensor): Edge indices with shape + `(num_batch, num_nodes, num_neighbors)`. + mask_i (torch.Tensor): Node mask with shape `(num_batch, num_nodes)`. + mask_ij (torch.Tensor): Edge mask with shape + `(num_batch, num_nodes, num_neighbors)`. + S (torch.LongTensor, optional): Sequence for initialization with + shape `(num_batch, num_nodes)`. + mask_sample (torch.Tensor, optional): Binary sampling mask indicating + positions which are free to change with shape + `(num_batch, num_nodes)` or which tokens are acceptable at each position + with shape `(num_batch, num_nodes, alphabet)`. + num_sweeps (int): Number of sweeps of Chromatic Gibbs to perform, + i.e. the depth of sampling as measured by the number of times + every position has had an opportunity to update. + temperature (float): Final sampling temperature. + temperature_init (float): Initial sampling temperature, which will + be linearly interpolated to `temperature` over the course of + the burn in phase. + penalty_func (Callable, optional): An optional penalty function which + takes a sequence `S` and outputes a `(num_batch)` shaped tensor + of energy adjustments, for example as regularization. + differentiable_penalty (bool): If True, gradients of penalty function + will be used to adjust the proposals. + rejection_step (bool): If True, perform a Metropolis-Hastings + rejection step. + proposal (str): MCMC proposal for Potts sampling. Currently implemented + proposals are `dlmc` for Discrete Langevin Monte Carlo [1] or `chromatic` + for Gibbs sampling with graph coloring. + [1] Sun et al. Discrete Langevin Sampler via Wasserstein Gradient Flow (2023). + verbose (bool): If True print verbose output during sampling. + edge_idx_coloring (torch.LongerTensor, optional): Alternative + graph dependency structure that can be provided for the + Chromatic Gibbs algorithm when it performs initial graph + coloring. Has shape + `(num_batch, num_nodes, num_neighbors_coloring)`. + mask_ij_coloring (torch.Tensor): Edge mask for the alternative dependency + structure with shape `(num_batch, num_nodes, num_neighbors_coloring)`. + symmetry_order (int, optional): Optional integer argument to enable + symmetric sequence decoding under `symmetry_order`-order symmetry. + The first `(num_nodes // symmetry_order)` states will be free to + move, and all consecutively tiled sets of states will be locked + to these during decoding. Internally this is accomplished by + summing the parameters Potts model under a symmetry constraint + into this reduced sized system and then back imputing at the end. + + Returns: + S (torch.LongTensor): Sampled sequences with + shape `(num_batch, num_nodes)`. + U (torch.Tensor): Sampled energies with shape `(num_batch)`. Lower + is more favorable. + """ + B, N, _ = node_h.shape + + # Compute parameters + h, J = self.forward(node_h, edge_h, edge_idx, mask_i, mask_ij) + + if symmetry_order is not None: + h, J, edge_idx, mask_i, mask_ij = fold_symmetry( + symmetry_order, h, J, edge_idx, mask_i, mask_ij + ) + S = S[:, : (N // symmetry_order)] + if mask_sample is not None: + mask_sample = mask_sample[:, : (N // symmetry_order)] + + S_sample, U_sample = sample_potts( + h, + J, + edge_idx, + mask_i, + mask_ij, + S=S, + mask_sample=mask_sample, + num_sweeps=num_sweeps, + temperature=temperature, + temperature_init=temperature_init, + penalty_func=penalty_func, + differentiable_penalty=differentiable_penalty, + rejection_step=rejection_step, + proposal=proposal, + verbose=verbose, + edge_idx_coloring=edge_idx_coloring, + mask_ij_coloring=mask_ij_coloring, + ) + + if symmetry_order is not None: + assert N % symmetry_order == 0 + S_sample = ( + S_sample[:, None, :].expand([-1, symmetry_order, -1]).reshape([B, N]) + ) + return S_sample, U_sample + + +def compute_potts_energy( + S: torch.LongTensor, + h: torch.Tensor, + J: torch.Tensor, + edge_idx: torch.LongTensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute Potts model energies from sequence. + + Args: + S (torch.LongTensor): Sequence with shape `(num_batch, num_nodes)`. + h (torch.Tensor): Potts model fields :math:`h_i(s_i)` with shape + `(num_batch, num_nodes, num_states)`. + J (Tensor): Potts model couplings :math:`J_{ij}(s_i, s_j)` with shape + `(num_batch, num_nodes, num_neighbors, num_states, num_states)`. + edge_idx (torch.LongTensor): Edge indices with shape + `(num_batch, num_nodes, num_neighbors)`. + + Returns: + U (torch.Tensor): Potts total energies with shape `(num_batch)`. + Lower energies are more favorable. + U_i (torch.Tensor): Potts local conditional energies with shape + `(num_batch, num_nodes, num_states)`. + """ + S_j = graph.collect_neighbors(S.unsqueeze(-1), edge_idx) + S_j = S_j.unsqueeze(-1).expand(-1, -1, -1, h.shape[-1], -1) + J_ij = torch.gather(J, -1, S_j).squeeze(-1) + + # Sum out J contributions to yield local conditionals + J_i = J_ij.sum(2) + U_i = h + J_i + + # Correct for double counting in total energy + S[..., None] + U = ( + torch.gather(U_i, -1, S[..., None]) - 0.5 * torch.gather(J_i, -1, S[..., None]) + ).sum((1, 2)) + return U, U_i + + +def fold_symmetry( + symmetry_order: int, + h: torch.Tensor, + J: torch.Tensor, + edge_idx: torch.LongTensor, + mask_i: torch.Tensor, + mask_ij: torch.Tensor, + normalize=True, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Fold Potts model symmetrically. + + Args: + symmetry_order (int): The order of symmetry by which to fold the Potts + model such that the first `(num_nodes // symmetry_order)` states + represent the entire system and all fields and couplings to and + among other copies of this base system are collected together in + single reduced Potts model. + h (torch.Tensor): Potts model fields :math:`h_i(s_i)` with shape + `(num_batch, num_nodes, num_states)`. + J (Tensor): Potts model couplings :math:`J_{ij}(s_i, s_j)` with shape + `(num_batch, num_nodes, num_neighbors, num_states, num_states)`. + edge_idx (torch.LongTensor): Edge indices with shape + `(num_batch, num_nodes, num_neighbors)`. + mask_i (torch.Tensor): Node mask with shape `(num_batch, num_nodes)`. + mask_ij (torch.Tensor): Edge mask with shape + `(num_batch, num_nodes, num_neighbors)`. + normalize (bool): If True (default), aggregate the Potts model as an average + energy across asymmetric units instead of as a sum. + + Returns: + h_fold (torch.Tensor): Potts model fields :math:`h_i(s_i)` with shape + `(num_batch, num_nodes_folded, num_states)`, where + `num_nodes_folded = num_nodes // symmetry_order`. + J_fold (Tensor): Potts model couplings :math:`J_{ij}(s_i, s_j)` with shape + `(num_batch, num_nodes_folded, num_neighbors, num_states, num_states)`. + edge_idx_fold (torch.LongTensor): Edge indices with shape + `(num_batch, num_nodes_folded, num_neighbors)`. + mask_i_fold (torch.Tensor): Node mask with shape `(num_batch, num_nodes_folded)`. + mask_ij_fold (torch.Tensor): Edge mask with shape + `(num_batch, num_nodes_folded, num_neighbors)`. + + """ + B, N, K, Q, _ = J.shape + device = h.device + + N_asymmetric = N // symmetry_order + # Fold edges by densifying the assymetric unit and averaging + edge_idx_au = torch.remainder(edge_idx, N_asymmetric).clamp(max=N_asymmetric - 1) + + def _pairwise_fold(_T): + # Fold-sum along neighbor dimension + shape = list(_T.shape) + shape[2] = N_asymmetric + _T_au_expand = torch.zeros(shape, device=device).float() + extra_dims = len(_T.shape) - len(edge_idx_au.shape) + edge_idx_au_expand = edge_idx_au.reshape( + list(edge_idx_au.shape) + [1] * extra_dims + ).expand([-1, -1, -1] + [Q] * extra_dims) + _T_au_expand.scatter_add_(2, edge_idx_au_expand, _T.float()) + + # Fold-mean along self dimension + shape_out = [shape[0], -1, N_asymmetric, N_asymmetric] + shape[3:] + _T_au = _T_au_expand.reshape(shape_out).sum(1) + return _T_au + + J_fold = _pairwise_fold(J) + mask_ij_fold = (_pairwise_fold(mask_ij) > 0).float() + edge_idx_fold = ( + torch.arange(N_asymmetric, device=device) + .long()[None, None, :] + .expand(mask_ij_fold.shape) + ) + + # Drop unused edges + mask_ij_fold.sum(2).max().item() + _, sort_ix = torch.sort(mask_ij_fold, dim=2, descending=True) + sort_ix_J = sort_ix[..., None, None].expand(list(sort_ix.shape) + [Q, Q]) + edge_idx_fold = torch.gather(edge_idx_fold, 2, sort_ix) + mask_ij_fold = torch.gather(mask_ij_fold, 2, sort_ix) + J_fold = torch.gather(J_fold, 2, sort_ix_J) + + # Fold-mean along self dimension + h_fold = h.reshape([B, -1, N_asymmetric, Q]).sum(1) + mask_i_fold = (mask_i.reshape([B, -1, N_asymmetric]).sum(1) > 0).float() + if normalize: + h_fold = h_fold / symmetry_order + J_fold = J_fold / symmetry_order + return h_fold, J_fold, edge_idx_fold, mask_i_fold, mask_ij_fold + + +@torch.no_grad() +def _color_graph(edge_idx, mask_ij, max_iter=100): + """Stochastic graph coloring.""" + # Randomly assign initial colors + B, N, K = edge_idx.shape + # By Brooks we only need K + 1, but one extra color aids convergence + num_colors = K + 2 + S = torch.randint(0, num_colors, (B, N), device=edge_idx.device) + + # Ignore self-attachement + ix = torch.arange(edge_idx.shape[1], device=edge_idx.device)[None, ..., None] + mask_ij = (mask_ij * torch.ne(edge_idx, ix).float())[..., None] + + # Iteratively replace clashing sites with an available color + i = 0 + total_clashes = 1 + while total_clashes > 0 and i < max_iter: + # Tabulate available colors in neighborhood + O_i = F.one_hot(S, num_colors).float() + N_i = (mask_ij * graph.collect_neighbors(O_i, edge_idx)).sum(2) + clashes = (O_i * N_i).sum(-1) + N_i = torch.where(N_i > 0, -float("inf") * torch.ones_like(N_i), N_i) + + # Resample from this distribution where clashing + S_new = torch.distributions.categorical.Categorical(logits=N_i).sample() + S = torch.where(clashes > 0, S_new, S) + i += 1 + total_clashes = clashes.sum().item() + return S + + +@torch.no_grad() +def sample_potts( + h: torch.Tensor, + J: torch.Tensor, + edge_idx: torch.LongTensor, + mask_i: torch.Tensor, + mask_ij: torch.Tensor, + S: Optional[torch.LongTensor] = None, + mask_sample: Optional[torch.Tensor] = None, + num_sweeps: int = 100, + temperature: float = 1.0, + temperature_init: float = 1.0, + annealing_fraction: float = 0.8, + penalty_func: Optional[Callable[[torch.LongTensor], torch.Tensor]] = None, + differentiable_penalty: bool = True, + rejection_step: bool = False, + proposal: Literal["dlmc", "chromatic"] = "dlmc", + verbose: bool = True, + return_trajectory: bool = False, + thin_sweeps: int = 3, + edge_idx_coloring: Optional[torch.LongTensor] = None, + mask_ij_coloring: Optional[torch.Tensor] = None, +) -> Union[ + Tuple[torch.LongTensor, torch.Tensor], + Tuple[torch.LongTensor, torch.Tensor, List[torch.LongTensor], List[torch.Tensor]], +]: + """Sample from Potts model with Chromatic Gibbs sampling. + + Args: + h (torch.Tensor): Potts model fields :math:`h_i(s_i)` with shape + `(num_batch, num_nodes, num_states)`. + J (Tensor): Potts model couplings :math:`J_{ij}(s_i, s_j)` with shape + `(num_batch, num_nodes, num_neighbors, num_states, num_states)`. + edge_idx (torch.LongTensor): Edge indices with shape + `(num_batch, num_nodes, num_neighbors)`. + mask_i (torch.Tensor): Node mask with shape `(num_batch, num_nodes)`. + mask_ij (torch.Tensor): Edge mask with shape + `(num_batch, num_nodes, num_neighbors)`. + S (torch.LongTensor, optional): Sequence for initialization with + shape `(num_batch, num_nodes)`. + mask_sample (torch.Tensor, optional): Binary sampling mask indicating + positions which are free to change with shape + `(num_batch, num_nodes)` or which tokens are acceptable at each position + with shape `(num_batch, num_nodes, alphabet)`. + num_sweeps (int): Number of sweeps of Chromatic Gibbs to perform, + i.e. the depth of sampling as measured by the number of times + every position has had an opportunity to update. + temperature (float): Final sampling temperature. + temperature_init (float): Initial sampling temperature, which will + be linearly interpolated to `temperature` over the course of + the burn in phase. + annealing_fraction (float): Fraction of the total sampling run during + which temperature annealing occurs. + penalty_func (Callable, optional): An optional penalty function which + takes a sequence `S` and outputes a `(num_batch)` shaped tensor + of energy adjustments, for example as regularization. + differentiable_penalty (bool): If True, gradients of penalty function + will be used to adjust the proposals. + rejection_step (bool): If True, perform a Metropolis-Hastings + rejection step. + proposal (str): MCMC proposal for Potts sampling. Currently implemented + proposals are `dlmc` for Discrete Langevin Monte Carlo [1] or `chromatic` + for Gibbs sampling with graph coloring. + [1] Sun et al. Discrete Langevin Sampler via Wasserstein Gradient Flow (2023). + verbose (bool): If True print verbose output during sampling. + return_trajectory (bool): If True, also output the sampling trajectories + of `S` and `U`. + thin_sweeps (int): When returning trajectories, only save every `thin_sweeps` + state to reduce memory usage. + edge_idx_coloring (torch.LongerTensor, optional): Alternative + graph dependency structure that can be provided for the + Chromatic Gibbs algorithm when it performs initial graph + coloring. Has shape + `(num_batch, num_nodes, num_neighbors_coloring)`. + mask_ij_coloring (torch.Tensor): Edge mask for the alternative dependency + structure with shape `(num_batch, num_nodes, num_neighbors_coloring)`. + + Returns: + S (torch.LongTensor): Sampled sequences with + shape `(num_batch, num_nodes)`. + U (torch.Tensor): Sampled energies with shape `(num_batch)`. Lower is more + favorable. + S_trajectory (List[torch.LongTensor]): List of sampled sequences through + time each with shape `(num_batch, num_nodes)`. + U_trajectory (List[torch.Tensor]): List of sampled energies through time + each with shape `(num_batch)`. + """ + # Initialize masked proposals and mask h + mask_S, mask_mutatable, S = init_sampling_masks(-h, mask_sample, S) + h_numerical_zero = h.max() + 1e3 * max(1.0, temperature) + h = torch.where(mask_S > 0, h, h_numerical_zero * torch.ones_like(h)) + + # Block update schedule + if proposal == "chromatic": + if edge_idx_coloring is None: + edge_idx_coloring = edge_idx + if mask_ij_coloring is None: + mask_ij_coloring = mask_ij + schedule = _color_graph(edge_idx_coloring, mask_ij_coloring) + num_colors = schedule.max() + 1 + num_iterations = num_colors * num_sweeps + else: + num_iterations = num_sweeps + + num_iterations_annealing = int(annealing_fraction * num_iterations) + temperatures = np.linspace( + temperature_init, temperature, num_iterations_annealing + ).tolist() + [temperature] * (num_iterations - num_iterations_annealing) + + if proposal == "chromatic": + _energy_proposal = lambda _S, _T: _potts_proposal_gibbs( + _S, + h, + J, + edge_idx, + T=_T, + penalty_func=penalty_func, + differentiable_penalty=differentiable_penalty, + ) + elif proposal == "dlmc": + _energy_proposal = lambda _S, _T: _potts_proposal_dlmc( + _S, + h, + J, + edge_idx, + T=_T, + penalty_func=penalty_func, + differentiable_penalty=differentiable_penalty, + ) + else: + raise NotImplementedError + + cumulative_sweeps = 0 + if return_trajectory: + S_trajectory = [] + U_trajectory = [] + for i, T_i in enumerate(tqdm(temperatures, desc="Potts Sampling")): + # Cycle through Gibbs updates random sites to the update with fixed prob + if proposal == "chromatic": + mask_update = schedule.eq(i % num_colors) + else: + mask_update = torch.ones_like(S) > 0 + if mask_mutatable is not None: + mask_update = mask_update * (mask_mutatable > 0) + + # Compute current energy and local conditionals + U, logp = _energy_proposal(S, T_i) + + # Propose + S_new = torch.distributions.categorical.Categorical(logits=logp).sample() + S_new = torch.where(mask_update, S_new, S) + + # Metropolis-Hastings adjusment + if rejection_step: + + def _flux(_U, _logp, _S): + logp_transition = torch.gather(_logp, -1, _S[..., None]) + _logp_ij = (mask_update.float() * logp_transition[..., 0]).sum(1) + flux = -_U / T_i + _logp_ij + return flux + + U_new, logp_new = _energy_proposal(S_new, T_i) + + _flux_backward = _flux(U_new, logp_new, S) + _flux_forward = _flux(U, logp, S_new) + acc_ratio = torch.exp((_flux_backward - _flux_forward)).clamp(max=1.0) + if verbose: # and i % 100 == 0: + print( + f"{(U_new - U).mean().item():0.2f}" + f"\t{(_flux_backward - _flux_forward).mean().item():0.2f}" + f"\t{acc_ratio.mean().item():0.2f}" + ) + u = torch.bernoulli(acc_ratio)[..., None] + S = torch.where(u > 0, S_new, S) + cumulative_sweeps += (u * mask_update).sum(1).mean().item() / S.shape[1] + else: + S = S_new + cumulative_sweeps += (mask_update).float().sum(1).mean().item() / S.shape[1] + + if return_trajectory and i % (thin_sweeps) == 0: + S_trajectory.append(S) + U_trajectory.append(U) + + U, _ = compute_potts_energy(S, h, J, edge_idx) + + if verbose: + print(f"Effective number of sweeps: {cumulative_sweeps}") + if return_trajectory: + return S, U, S_trajectory, U_trajectory + else: + return S, U + + +def init_sampling_masks( + logits_init: torch.Tensor, + mask_sample: Optional[torch.Tensor] = None, + S: Optional[torch.LongTensor] = None, + ban_S: Optional[List[int]] = None, +): + """Parse sampling masks and an initial sequence. + + Args: + logits_init (torch.Tensor): Logits for sequence initialization with shape + `(num_batch, num_nodes, alphabet)`. + mask_sample (torch.Tensor, optional): Binary sampling mask indicating which + positions are free to change with shape `(num_batch, num_nodes)` or which + tokens are valid at each position with shape + `(num_batch, num_nodes, alphabet)`. In the latter case, `mask_sample` will + take priority over `S` except for positions in which `mask_sample` is + all zero. + S (torch.LongTensor optional): Initial sequence with shape + `(num_batch, num_nodes)`. + ban_S (list of int, optional): Optional list of alphabet indices to ban from + all positions during sampling. + + Returns: + mask_sample (torch.Tensor): Finalized position specific mask with shape + `(num_batch, num_nodes, alphabet)`. + S (torch.Tensor): Self-consistent initial `S` with shape + `(num_batch, num_nodes)`. + """ + + if S is None and mask_sample is not None: + raise Exception("To use masked sampling, please provide an initial S") + + if mask_sample is None: + mask_S = torch.ones_like(logits_init) + elif mask_sample.dim() == 2: + # Position-restricted sampling + mask_sample_expand = mask_sample[..., None].expand(logits_init.shape) + O_init = F.one_hot(S, logits_init.shape[-1]).float() + mask_S = mask_sample_expand + (1 - mask_sample_expand) * O_init + elif mask_sample.dim() == 3: + O_init = F.one_hot(S, logits_init.shape[-1]).float() + # Mutation-restricted sampling + mask_zero = (mask_sample.sum(-1, keepdim=True) == 0).float() + mask_S = ((mask_zero * O_init + mask_sample) > 0).float() + else: + raise NotImplementedError + if ban_S is not None: + mask_S[:, :, ban_S] = 0.0 + mask_S_1D = (mask_S.sum(-1) > 1).float() + + logits_init_masked = 1000 * mask_S + logits_init + S = torch.distributions.categorical.Categorical(logits=logits_init_masked).sample() + return mask_S, mask_S_1D, S + + +def _potts_proposal_gibbs( + S, h, J, edge_idx, T=1.0, penalty_func=None, differentiable_penalty=True +): + U, U_i = compute_potts_energy(S, h, J, edge_idx) + + if penalty_func is not None: + if differentiable_penalty: + with torch.enable_grad(): + S_onehot = F.one_hot(S, h.shape[0 - 1]).float() + S_onehot.requires_grad = True + U_penalty = penalty_func(S_onehot) + U_i_adjustment = torch.autograd.grad(U_penalty.sum(), [S_onehot])[ + 0 + ].detach() + U_penalty = U_penalty.detach() + U_i = U_i + 0.5 * U_i_adjustment + else: + U_penalty = penalty_func(S_onehot) + U = U + U_penalty + + logp_i = F.log_softmax(-U_i / T, dim=-1) + return U, logp_i + + +def _potts_proposal_dlmc( + S, + h, + J, + edge_idx, + T=1.0, + penalty_func=None, + differentiable_penalty=True, + dt=0.1, + autoscale=True, + balancing_func="sigmoid", +): + # Compute energy gap + U, U_i = compute_potts_energy(S, h, J, edge_idx) + U_i = U_i + if penalty_func is not None: + O = F.one_hot(S, h.shape[0 - 1]).float() + if differentiable_penalty: + with torch.enable_grad(): + O.requires_grad = True + U_penalty = penalty_func(O) + U_i_adjustment = torch.autograd.grad(U_penalty.sum(), [O])[0].detach() + U_penalty = U_penalty.detach() + U_i_adjustment = U_i_adjustment - torch.gather( + U_i_adjustment, -1, S[..., None] + ) + + U_i - torch.gather(U_i, -1, S[..., None]) + U_i = U_i + U_i_adjustment + else: + U_penalty = penalty_func(O) + U = U + U_penalty + + # Compute local equilibrium distribution + logP_j = F.log_softmax(-U_i / T, dim=-1) + + # Compute transition log probabilities + O = F.one_hot(S, h.shape[0 - 1]).float() + logP_i = torch.gather(logP_j, -1, S[..., None]) + if balancing_func == "sqrt": + log_Q_ij = 0.5 * (logP_j - logP_i) + elif balancing_func == "sigmoid": + log_Q_ij = F.logsigmoid(logP_j - logP_i) + else: + raise NotImplementedError + + rate = torch.exp(log_Q_ij - logP_j) + + # Compute transition probability + logP_ij = logP_j + (-(-dt * rate).expm1()).log() + p_flip = ((1.0 - O) * logP_ij.exp()).sum(-1, keepdim=True) + + # DEBUG: + # flux = ((1. - O) * torch.exp(log_Q_ij)).mean([1,2], keepdim=True) + # print(f" ->Flux is {flux.item():0.2f}, FlipProb is {p_flip.mean():0.2f}") + + logP_ii = (1.0 - p_flip).clamp(1e-5).log() + logP_ij = (1.0 - O) * logP_ij + O * logP_ii + return U, logP_ij diff --git a/prtm/models/chroma/structure/protein_graph.py b/prtm/models/chroma/structure/protein_graph.py new file mode 100644 index 00000000..e0b4ecc4 --- /dev/null +++ b/prtm/models/chroma/structure/protein_graph.py @@ -0,0 +1,1596 @@ +# Copyright Generate Biomedicines, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Layers for building graph representations of protein structure. + +This module contains pytorch layers for representing protein structure as a +graph with node and edge features based on geometric information. The graph +features are differentiable with respect to input coordinates and can be used +for building protein scoring functions and optimizing protein geometries +natively in pytorch. +""" + +import json +import os +from typing import Optional, Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from prtm.models.chroma import graph +from prtm.models.chroma.basic import FourierFeaturization, PositionalEncoding +from prtm.models.chroma.protein import Protein +from prtm.models.chroma.structure import backbone, geometry, transforms + + +class ProteinFeatureGraph(nn.Module): + """Graph featurizer for protein chains and complexes. + + This module builds graph representations of protein structures that are + differentiable with respect to input coordinates and invariant with respect + to global rotations and translations. It takes as input a batch of + protein backbones (single chains or complexes), constructs a sparse graph + with residues as nodes, and featurizes the backbones in terms of node and + edge feature tensors. + + The graph representation has 5 components: + 1. Node features `node_h` representing residues in the protein. + 2. Edge features `edge_h` representing relationships between residues. + 3. Index map `edge_idx` representing graph topology. + 4. Node mask `mask_i` that specifies which nodes are present. + 5. Edge mask `mask_ij` that specifies which edges are present. + + Criteria for constructing the graph currently include k-Nearest Neighbors or + distance-weighted edge sampling. + + Node and edge features are specified as tuples to make it simpler to add + additional features and options while retaining backwards compatibility. + Specifically, each node or edge feature type can be added to the list either + in default configuration by a `'feature_name'` keyword, or in modified form + with a `('feature_name', feature_kwargs)` tuple. + + Example usage: + graph = ProteinFeatureGraph( + graph_type='knn', + node_features=('dihedrals',), + edge_features=[ + 'chain_distance', + ('dmat_6mer', {'D_function': 'log'}) + ] + ) + node_h, edge_h, edge_idx, mask_i, mask_ij = graph(X, C) + + This builds a kNN graph with dihedral angles as node + features and 6mer interatomic distance matrices (process) 6mers, where + the options for post-processing the 6mers are passed as a kwargs dict. + + Args: + dim_nodes (int): Hidden dimension of node features. + dim_edges (int): Hidden dimension of edge features. + num_neighbors (int): Maximum degree of the graph. + graph_kwargs (dict): Arguments for graph construction. Default is None. + node_features (list): List of node feature strings and optional args. + Valid feature strings are `{internal_coords}`. + edge_features (list): List of node feature strings and optional args. + Valid feature strings are `{'distances_6mer','distances_chain'}`. + centered (boolean): Flag for enabling feature centering. If `True`, + the features will be will centered by subtracting an empirical mean + that was computed on the reference PDB `centered_pdb`. The statistics + are per-dimension of every node and edge feature. If they have not + previously been computed, the PDB will be downloaded, featurized, + and aggregated into local statistics that are cached in the repo. + centered_pdb (str): PDB code for the reference PDB to compute some + empirical feature statistics from. + + Inputs: + X (torch.Tensor): Backbone coordinates with shape + `(num_batch, num_residues, 4, 3)`. The standard atom indices for + for the the third dimension are PDB order (`[N, CA, C, O]`). + C (LongTensor, optional): Chain map with shape + `(num_batch, num_residues)`. The chain map codes positions as `0` + when masked, poitive integers for chain indices, and negative + integers to represent missing residues of the corresponding + positive integers. + custom_D (Tensor, optional): Pre-computed custom distance map + for graph construction `(numb_batch,num_residues,num_residues)`. + If present, this will override the behavior of `graph_type` and used + as the distances for k-nearest neighbor graph construction. + custom_mask_2D (Tensor, optional): Custom 2D mask to apply to `custom_D` + with shape `(numb_batch,num_residues,num_residues)`. + + Outputs: + node_h (torch.Tensor): Node features with shape + `(num_batch, num_residues, dim_nodes)`. + edge_h (torch.Tensor): Edge features with shape + `(num_batch, num_residues, num_neighbors, dim_edges)`. + edge_idx (torch.LongTensor): Edge indices for neighbors with shape + `(num_batch, num_residues, num_neighbors)`. + mask_i (torch.Tensor): Node mask with shape `(num_batch, num_residues)`. + mask_ij (torch.Tensor): Edge mask with shape + `(num_batch, num_nodes, num_neighbors)`. + """ + + def __init__( + self, + dim_nodes: int, + dim_edges: int, + num_neighbors: int = 30, + graph_kwargs: dict = None, + node_features: tuple = ("internal_coords",), + edge_features: tuple = ("distances_6mer", "distances_chain"), + centered: bool = True, + centered_pdb: str = "2g3n", + ): + super(ProteinFeatureGraph, self).__init__() + + self.dim_nodes = dim_nodes + self.dim_edges = dim_edges + self.num_neighbors = num_neighbors + graph_kwargs = graph_kwargs if graph_kwargs is not None else {} + self.graph_builder = ProteinGraph(num_neighbors, **graph_kwargs) + self.node_features = node_features + self.edge_features = edge_features + + def _init_layer(layer_dict, features): + # Parse option string + custom_args = not isinstance(features, str) + key = features[0] if custom_args else features + kwargs = features[1] if custom_args else {} + return layer_dict[key](**kwargs) + + # Node feature compilation + node_dict = { + "internal_coords": NodeInternalCoords, + "cartesian_coords": NodeCartesianCoords, + "radii": NodeRadii, + } + self.node_layers = nn.ModuleList( + [_init_layer(node_dict, option) for option in self.node_features] + ) + # Edge feature compilation + edge_dict = { + "distances_6mer": EdgeDistance6mer, + "distances_2mer": EdgeDistance2mer, + "orientations_2mer": EdgeOrientation2mer, + "position_2mer": EdgePositionalEncodings, + "distances_chain": EdgeDistanceChain, + "orientations_chain": EdgeOrientationChain, + "cartesian_coords": EdgeCartesianCoords, + "random_fourier_2mer": EdgeRandomFourierFeatures2mer, + } + self.edge_layers = nn.ModuleList( + [_init_layer(edge_dict, option) for option in self.edge_features] + ) + + # Load feature centering params as buffers + self.centered = centered + self.centered_pdb = centered_pdb.lower() + if self.centered: + self._load_centering_params(self.centered_pdb) + + """ + Storing separate linear transformations for each layer, rather than concat + one + large linear, provides a more even weighting of the different input + features when used with standard weight initialization. It has the + specific effect actually re-weighting the weight variance based on + the number of input features for each feature type. Otherwise, the + relative importance of each feature goes with the number of feature + dimensions. + """ + self.node_linears = nn.ModuleList( + [nn.Linear(l.dim_out, self.dim_nodes) for l in self.node_layers] + ) + self.edge_linears = nn.ModuleList( + [nn.Linear(l.dim_out, self.dim_edges) for l in self.edge_layers] + ) + return + + def forward( + self, + X: torch.Tensor, + C: torch.Tensor, + edge_idx: Optional[torch.LongTensor] = None, + mask_ij: torch.Tensor = None, + custom_D: Optional[torch.Tensor] = None, + custom_mask_2D: Optional[torch.Tensor] = None, + ) -> Tuple[ + torch.Tensor, torch.Tensor, torch.LongTensor, torch.Tensor, torch.Tensor + ]: + mask_i = chain_map_to_mask(C) + if mask_ij is None or edge_idx is None: + edge_idx, mask_ij = self.graph_builder( + X, C, custom_D=custom_D, custom_mask_2D=custom_mask_2D + ) + + # Aggregate node layers + node_h = None + for i, layer in enumerate(self.node_layers): + node_h_l = layer(X, edge_idx, C) + if self.centered: + node_h_l = node_h_l - self.__getattr__(f"node_means_{i}") + node_h_l = self.node_linears[i](node_h_l) + node_h = node_h_l if node_h is None else node_h + node_h_l + if node_h is None: + node_h = torch.zeros(list(X.shape[:2]) + [self.dim_nodes], device=X.device) + + # Aggregate edge layers + edge_h = None + for i, layer in enumerate(self.edge_layers): + edge_h_l = layer(X, edge_idx, C) + if self.centered: + edge_h_l = edge_h_l - self.__getattr__(f"edge_means_{i}") + edge_h_l = self.edge_linears[i](edge_h_l) + edge_h = edge_h_l if edge_h is None else edge_h + edge_h_l + if edge_h is None: + edge_h = torch.zeros(list(X.shape[:2]) + [self.dim_nodes], device=X.device) + + # Apply masks + node_h = mask_i.unsqueeze(-1) * node_h + edge_h = mask_ij.unsqueeze(-1) * edge_h + + return node_h, edge_h, edge_idx, mask_i, mask_ij + + def _load_centering_params(self, reference_pdb: str): + basepath = os.path.dirname(os.path.abspath(__file__)) + "/params/" + if not os.path.exists(basepath): + os.makedirs(basepath) + + filename = f"centering_{reference_pdb}.params" + self.centering_file = basepath + filename + key = ( + reference_pdb + + ";" + + json.dumps(self.node_features) + + ";" + + json.dumps(self.edge_features) + ) + + # Attempt to load saved centering params, otherwise compute and cache + json_line = None + with open(self.centering_file, "a+") as f: + prefix = key + "\t" + f.seek(0) + for line in f: + if line.startswith(prefix): + json_line = line.split(prefix)[1] + break + + if json_line is not None: + print("Loaded from cache") + param_dictionary = json.loads(json_line) + else: + print(f"Computing reference stats for {reference_pdb}") + param_dictionary = self._reference_stats(reference_pdb) + json_line = json.dumps(param_dictionary) + f.write(prefix + "\t" + json_line + "\n") + + for i, layer in enumerate(self.node_layers): + key = json.dumps(self.node_features[i]) + tensor = torch.tensor(param_dictionary[key], dtype=torch.float32) + tensor = tensor.view(1, 1, -1) + self.register_buffer(f"node_means_{i}", tensor) + + for i, layer in enumerate(self.edge_layers): + key = json.dumps(self.edge_features[i]) + tensor = torch.tensor(param_dictionary[key], dtype=torch.float32) + tensor = tensor.view(1, 1, -1) + self.register_buffer(f"edge_means_{i}", tensor) + return + + def _reference_stats(self, reference_pdb): + X, C, _ = Protein.from_PDBID(reference_pdb).to_XCS() + stats_dict = self._feature_stats(X, C) + return stats_dict + + def _feature_stats(self, X, C, verbose=False, center=False): + mask_i = chain_map_to_mask(C) + edge_idx, mask_ij = self.graph_builder(X, C) + + def _masked_stats(feature, mask, dims, verbose=False): + mask = mask.unsqueeze(-1) + feature = mask * feature + sum_mask = mask.sum() + mean = feature.sum(dims, keepdim=True) / sum_mask + var = torch.sum(mask * (feature - mean) ** 2, dims) / sum_mask + std = torch.sqrt(var) + mean = mean.view(-1) + std = std.view(-1) + + if verbose: + frac = (100.0 * std**2 / (mean**2 + std**2)).type(torch.int32) + print(f"Fraction of raw variance: {frac}") + return mean, std + + # Collect statistics + stats_dict = {} + + # Aggregate node layers + for i, layer in enumerate(self.node_layers): + node_h = layer(X, edge_idx, C) + if center: + node_h = node_h - self.__getattr__(f"node_means_{i}") + mean, std = _masked_stats(node_h, mask_i, dims=[0, 1]) + + # Store in dictionary + key = json.dumps(self.node_features[i]) + stats_dict[key] = mean.tolist() + + # Aggregate node layers + for i, layer in enumerate(self.edge_layers): + edge_h = layer(X, edge_idx, C) + if center: + edge_h = edge_h - self.__getattr__(f"edge_means_{i}") + mean, std = _masked_stats(edge_h, mask_ij, dims=[0, 1, 2]) + + # Store in dictionary + key = json.dumps(self.edge_features[i]) + stats_dict[key] = mean.tolist() + + # Round to small number of decimal places + stats_dict = {k: [round(f, 3) for f in v] for k, v in stats_dict.items()} + return stats_dict + + +class ProteinGraph(nn.Module): + """Build a graph topology given a protein backbone. + + Args: + num_neighbors (int): Maximum number of neighbors in the graph. + distance_atom_type (int): Atom type for computing residue-residue + distances for graph construction. Negative values will specify + centroid across atom types. Default is `-1` (centroid). + cutoff (float): Cutoff distance for graph construction. If not None, + mask any edges further than this cutoff. Default is `None`. + mask_interfaces (Boolean): Restrict connections only to within chains, + excluding-between chain interactions. Default is `False`. + criterion (string, optional): Method used for building graph from distances. + Currently supported methods are `{knn, random_log, random_linear}`. + Default is `knn`. + random_alpha (float, optional): Length scale parameter for random graph + generation. Default is 3. + random_temperature (float, optional): Temperature parameter for + random graph sampling. Between 0 and 1 this value will interpolate + between a normal k-NN graph and sampling from the graph generation + process. Default is 1.0. + + Inputs: + X (torch.Tensor): Backbone coordinates with shape + `(num_batch, num_residues, 4, 3)`. + C (torch.LongTensor): Chain map with shape + `(num_batch, num_residues)`. + custom_D (torch.Tensor, optional): Optional external distance map, for example + based on other distance metrics, with shape + `(num_batch, num_residues, num_residues)`. + custom_mask_2D (torch.Tensor, optional): Optional mask to apply to distances + before computing dissimilarities with shape + `(num_batch, num_residues, num_residues)`. + + Outputs: + edge_idx (torch.LongTensor): Edge indices for neighbors with shape + `(num_batch, num_residues, num_neighbors)`. + mask_ij (torch.Tensor): Edge mask with shape + `(num_batch, num_nodes, num_neighbors)`. + """ + + def __init__( + self, + num_neighbors: int = 30, + distance_atom_type: int = -1, + cutoff: Optional[float] = None, + mask_interfaces: bool = False, + criterion: str = "knn", + random_alpha: float = 3.0, + random_temperature: float = 1.0, + random_min_local: float = 20, + deterministic: bool = False, + deterministic_seed: int = 10, + ): + super(ProteinGraph, self).__init__() + self.num_neighbors = num_neighbors + self.distance_atom_type = distance_atom_type + self.cutoff = cutoff + self.mask_interfaces = mask_interfaces + self.distances = geometry.Distances() + self.knn = kNN(k_neighbors=num_neighbors) + + self.criterion = criterion + self.random_alpha = random_alpha + self.random_temperature = random_temperature + self.random_min_local = random_min_local + self.deterministic = deterministic + self.deterministic_seed = deterministic_seed + + def _mask_distances(self, X, C, custom_D=None, custom_mask_2D=None): + mask_1D = chain_map_to_mask(C) + mask_2D = mask_1D.unsqueeze(2) * mask_1D.unsqueeze(1) + if self.distance_atom_type > 0: + X_atom = X[:, :, self.distance_atom_type, :] + else: + X_atom = X.mean(dim=2) + if custom_D is None: + D = self.distances(X_atom, dim=1) + else: + D = custom_D + + if custom_mask_2D is None: + if self.mask_interfaces: + mask_2D = torch.eq(C.unsqueeze(1), C.unsqueeze(2)) + mask_2D = mask_2D * mask_2D.type(torch.float32) + if self.cutoff is not None: + mask_cutoff = (D <= self.cutoff).type(torch.float32) + mask_2D = mask_cutoff * mask_2D + else: + mask_2D = custom_mask_2D + return D, mask_1D, mask_2D + + def _perturb_distances(self, D): + # Replace distance by log-propensity + if self.criterion == "random_log": + logp_edge = -3 * torch.log(D) + elif self.criterion == "random_linear": + logp_edge = -D / self.random_alpha + elif self.criterion == "random_uniform": + logp_edge = D * 0 + else: + return D + + if not self.deterministic: + Z = torch.rand_like(D) + else: + with torch.random.fork_rng(): + torch.random.manual_seed(self.deterministic_seed) + Z_shape = [1] + list(D.shape)[1:] + Z = torch.rand(Z_shape, device=D.device) + + # Sample Gumbel noise + G = -torch.log(-torch.log(Z)) + + # Negate because are doing argmin instead of argmax + D_key = -(logp_edge / self.random_temperature + G) + + return D_key + + def forward( + self, + X: torch.Tensor, + C: torch.LongTensor, + custom_D: Optional[torch.Tensor] = None, + custom_mask_2D: Optional[torch.Tensor] = None, + ) -> Tuple[torch.LongTensor, torch.Tensor]: + D, mask_1D, mask_2D = self._mask_distances(X, C, custom_D, custom_mask_2D) + + if self.criterion != "knn": + if self.random_min_local > 0: + # Build first k-NN graph (local) + self.knn.k_neighbors = self.random_min_local + edge_idx_local, _, mask_ij_local = self.knn(D, mask_1D, mask_2D) + + # Build mask exluding these first ones + mask_ij_remaining = 1.0 - mask_ij_local + mask_2D_remaining = torch.ones_like(mask_2D).scatter( + 2, edge_idx_local, mask_ij_remaining + ) + mask_2D = mask_2D * mask_2D_remaining + + # Build second k-NN graph (random) + self.knn.k_neighbors = self.num_neighbors - self.random_min_local + D = self._perturb_distances(D) + edge_idx_random, _, mask_ij_random = self.knn(D, mask_1D, mask_2D) + edge_idx = torch.cat([edge_idx_local, edge_idx_random], 2) + mask_ij = torch.cat([mask_ij_local, mask_ij_random], 2) + + # Handle small proteins + k = min(self.num_neighbors, D.shape[-1]) + edge_idx = edge_idx[:, :, :k] + mask_ij = mask_ij[:, :, :k] + + self.knn.k_neighbors = self.num_neighbors + return edge_idx.contiguous(), mask_ij.contiguous() + else: + D = self._perturb_distances(D) + + edge_idx, edge_D, mask_ij = self.knn(D, mask_1D, mask_2D) + return edge_idx, mask_ij + + +class kNN(nn.Module): + """Build a k-nearest neighbors graph given a dissimilarity matrix. + + Args: + k_neighbors (int): Number of nearest neighbors to include as edges of + each node in the graph. + + Inputs: + D (torch.Tensor): Dissimilarity matrix with shape + `(num_batch, num_nodes, num_nodes)`. + mask (torch.Tensor, optional): Node mask with shape `(num_batch, num_nodes)`. + mask_2D (torch.Tensor, optional): Edge mask with shape + `(num_batch, num_nodes, num_nodes)`. + + Outputs: + edge_idx (torch.LongTensor): Edge indices with shape + `(num_batch, num_nodes, k)`. The slice `edge_idx[b,i,:]` contains + the indices `{j in N(i)}` of the k nearest neighbors of node `i` + in object `b`. + edge_D (torch.Tensor): Distances to each neighbor with shape + `(num_batch, num_nodes, k)`. + mask_ij (torch.Tensor): Edge mask with shape + `(num_batch, num_nodes, num_neighbors)`. + """ + + def __init__(self, k_neighbors: int): + super(kNN, self).__init__() + self.k_neighbors = k_neighbors + + def forward( + self, + D: torch.Tensor, + mask: Optional[torch.Tensor] = None, + mask_2D: Optional[torch.Tensor] = None, + ) -> Tuple[torch.LongTensor, torch.Tensor, torch.Tensor]: + mask_full = None + if mask is not None: + mask_full = mask.unsqueeze(2) * mask.unsqueeze(1) + if mask_2D is not None: + mask_full = mask_2D if mask_full is None else mask_full * mask_2D + if mask_full is not None: + max_float = np.finfo(np.float32).max + D = mask_full * D + (1.0 - mask_full) * max_float + + k = min(self.k_neighbors, D.shape[-1]) + edge_D, edge_idx = torch.topk(D, int(k), dim=-1, largest=False) + + mask_ij = None + if mask_full is not None: + mask_ij = graph.collect_edges(mask_full.unsqueeze(-1), edge_idx) + mask_ij = mask_ij.squeeze(-1) + return edge_idx, edge_D, mask_ij + + +class NodeInternalCoords(nn.Module): + """Node features representing internal coordinates. + + Args: + include_ideality (Boolean): Whether or not to include ideality features + along with direct geometry. + + Attributes: + dim_out (int): Number of dimensions of the output features. + + Inputs: + X (torch.Tensor): Backbone coordinates with shape + `(num_batch, num_residues, num_atom_types, 3)`. + edge_idx (torch.LongTensor): Graph indices for expansion with shape + `(num_batch, num_residues, num_neighbors)`. + C (torch.LongTensor): Chain map with shape + `(num_batch, num_residues)`. + + Outputs: + node_h (torch.Tensor): Edge distance matrix features with shape + `(num_batch, num_residues, 20)` + """ + + def __init__( + self, + include_ideality: bool = False, + distance_eps: float = 0.01, + log_lengths: bool = False, + ): + super(NodeInternalCoords, self).__init__() + self.internal_coords = geometry.InternalCoords() + self.distance_eps = distance_eps + self.include_ideality = include_ideality + self.dim_out = 28 if self.include_ideality else 20 + self.log_lengths = log_lengths + + # Engh and Huber Ideal Geometry + ideal_lengths = [1.459, 1.525, 1.336, 1.229] + ideal_angles = [111.0, 117.2, 121.7, 120.0] + # Angles are output as complement in radians + ideal_angles = [np.pi - degrees * np.pi / 180.0 for degrees in ideal_angles] + + if self.include_ideality: + ideal_lengths = torch.as_tensor(ideal_lengths).view([1, 1, -1]) + self.register_buffer("ideal_lengths", ideal_lengths) + + ideal_angles = torch.as_tensor(ideal_angles).view([1, 1, -1]) + self.register_buffer("ideal_angles", ideal_angles) + + def forward( + self, + X: torch.Tensor, + edge_idx: Optional[torch.LongTensor] = None, + C: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: + outs = self.internal_coords(X, C=C, return_masks=True) + dihedrals, angles, lengths = outs[:3] + mask_dihedrals, mask_angles, mask_lengths = outs[3:] + angle_stack = torch.cat([dihedrals, angles], dim=-1) + mask = chain_map_to_mask(C).unsqueeze(-1) + + if self.log_lengths: + lengths = torch.log(lengths + self.distance_eps) + + feature_list = [torch.cos(angle_stack), torch.sin(angle_stack), lengths] + + # Ideality scores + if self.include_ideality: + # Mask angle features + mask_stack = torch.cat([mask_dihedrals, mask_angles], dim=-1) + feature_list[0] = mask_stack * feature_list[0] + feature_list[1] = mask_stack * feature_list[1] + + _D_fun = lambda D: torch.log(D + self.distance_eps) + length_scores = (_D_fun(lengths) - _D_fun(self.ideal_lengths)) ** 2 + angle_scores = torch.cos(angles - self.ideal_angles) + length_scores = mask_lengths * length_scores + angle_scores = mask_angles * angle_scores + feature_list = feature_list + [length_scores, angle_scores] + node_h = mask * torch.cat(feature_list, dim=-1) + return node_h + + +class NodeRadii(nn.Module): + """Node features representing radii in the larger complex. + + Args: + length_scale (float): Typical length scale for normalizing distances. + + Attributes: + dim_out (int): Number of dimensions of the output features. (4) + + Inputs: + X (torch.Tensor): Backbone coordinates with shape + `(num_batch, num_residues, num_atom_types, 3)`. + edge_idx (torch.LongTensor): Graph indices for expansion with shape + `(num_batch, num_residues, num_neighbors)`. + C (torch.LongTensor): Chain map with shape + `(num_batch, num_residues)`. + + Outputs: + node_h (torch.Tensor): Node radii features with shape + `(num_batch, num_residues, 4)` + """ + + def __init__(self, length_scale: float = 100.0): + super(NodeRadii, self).__init__() + self.dim_out = 4 + self.length_scale = length_scale + + def forward( + self, + X: torch.Tensor, + edge_idx: Optional[torch.LongTensor] = None, + C: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: + num_batch, num_residues = list(C.shape) + mask_i = (C > 0).float() + mask_i = mask_i.reshape([num_batch, num_residues, 1, 1]).expand(X.shape) + X_center = (mask_i * X).sum([1, 2], keepdim=True) / mask_i.sum( + [1, 2], keepdim=True + ) + + node_h = (mask_i * ((X - X_center) / self.length_scale) ** 2).sum(-1) + return node_h + + +class Edge6mers(nn.Module): + """Build concatenation of 3mer coordinates on graph edges. + + This layer assembles the pairwise concatenations of the coordinates + `{X_a for a in {i-1,i,i+1,j-1,j,j+1}}` along every edge in a graph. This can + be used for stitching of '6mer PairTERMs'. + + Args: + require_contiguous (boolean, optional): Whether to enforce that + `{i-1,i,i+1}` and`{j-1,j,j+1}` are each made up of contiguous + residues from the same protein chain. Default is `True`. + + Inputs: + X (torch.Tensor): Backbone coordinates with shape + `(num_batch, num_residues, num_atom_types, 3)`. + edge_idx (torch.LongTensor): Graph indices for expansion with shape + `(num_batch, num_residues, num_neighbors)`. + mask (Tensor, optional): Mask tensor with shape + `(num_batch, num_residues)`. + + Outputs: + X_ij (torch.Tensor): Pairwise-concatenated 3mers with shape + `(num_batch, num_residues, num_neighbors, 2*num_atom_types, 3)`. + mask_ij (Tensor, if mask): Propagated mask tensor for edges with shape + `(num_batch, num_residues, num_neighbors)`. + """ + + def __init__(self, require_contiguous: bool = True): + super(Edge6mers, self).__init__() + self.require_contiguous = require_contiguous + + def forward( + self, + X: torch.Tensor, + edge_idx: torch.LongTensor, + C: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + def _pair_expand(h, collate_fun): + # Build local neighborhoods [i-1, i, i+1] + h_left = F.pad(h[:, :-1, :], (0, 0, 1, 0), "constant", 0) + h_middle = h[:, :, :] + h_right = F.pad(h[:, 1:, :], (0, 0, 0, 1), "constant", 0) + h_i = collate_fun((h_left, h_middle, h_right)) + + # Concatenate [j-1, j, j+1] of neighbors + h_j = graph.collect_neighbors(h_i, edge_idx) + h_i_tile = h_i.unsqueeze(-2).expand(h_j.size()) + h_ij = collate_fun((h_i_tile, h_j)) + return h_ij + + # Concatenation collation function for stitching + _cat = lambda hs: torch.cat(hs, dim=-1) + + # Cumulative product collation function for mask propagation + def _mul(hs): + result = hs[0] + for h_i in hs[1:]: + result = result * h_i + return result + + # Element-wise enforce values are greater than 0 and equal + def _nonzero_and_equal(hs): + entry_0 = hs[0] + result = (hs[0] > 0.0).type(torch.float32) + for h_i in hs[1:]: + result = result * (entry_0 == h_i).type(torch.float32) + return result + + # Build local neighborhoods [i-1, i, i+1] + # X [batch, position, atom, xyz] + X_flat = X.reshape(X.size(0), X.size(1), -1) + X_ij = _pair_expand(X_flat, collate_fun=_cat) + X_ij = X_ij.view(list(X_ij.size())[:-1] + [-1, 3]) + + if C is not None: + if self.require_contiguous: + mask_ij = _pair_expand(C.unsqueeze(-1), collate_fun=_nonzero_and_equal) + else: + mask = chain_map_to_mask(C) + mask_ij = _pair_expand(mask.unsqueeze(-1), collate_fun=_mul) + + return X_ij, mask_ij + else: + return X_ij + + +class Edge2mers(nn.Module): + """Build concatenation of 1mer coordinates on graph edges. + + Inputs: + X (torch.Tensor): Backbone coordinates with shape + `(num_batch, num_residues, num_atom_types, 3)`. + edge_idx (torch.LongTensor): Graph indices for expansion with shape + `(num_batch, num_residues, num_neighbors)`. + C (torch.LongTensor): Chain map with shape + `(num_batch, num_residues)`. + + Outputs: + X_ij (torch.Tensor): Pairwise-concatenated 3mers with shape + `(num_batch, num_residues, num_neighbors, 2*num_atom_types, 3)`. + mask_ij (Tensor, if mask): Propagated mask tensor for edges with shape + `(num_batch, num_residues, num_neighbors)`. + """ + + def __init__(self): + super(Edge2mers, self).__init__() + + def forward( + self, + X: torch.Tensor, + edge_idx: torch.LongTensor, + C: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + num_batch = edge_idx.shape[0] + num_residues = edge_idx.shape[1] + num_neighbors = edge_idx.shape[2] + num_atom_types = X.shape[2] + shape_X = [num_batch, num_residues, num_neighbors, num_atom_types * 3] + X_flat = X.reshape(num_batch, num_residues, -1) + X_i = X_flat.unsqueeze(2).expand(shape_X) + X_j = graph.collect_neighbors(X_flat, edge_idx).expand(shape_X) + X_ij = torch.cat([X_i, X_j], -1) + X_ij = X_ij.reshape( + num_batch, num_residues, num_neighbors, 2 * num_atom_types, 3 + ) + if C is not None: + mask_i = chain_map_to_mask(C).unsqueeze(-1) + mask_j = graph.collect_neighbors(mask_i, edge_idx) + mask_ij = mask_i.unsqueeze(2) * mask_j + return X_ij, mask_ij + else: + return X_ij + + +class EdgeDistance6mer(nn.Module): + """Edge features based on chain distance matrices along each i,j 6mer. + + Args: + feature (str, optional): Option string in {'log', 'inverse', 'raw'} + specifying how to process the raw distance features. + Defaults to 'log'. + distance_eps (float, optional): Smoothing parameter to prevent feature + explosion at small distances. Can be thought of as a 'minimum length + scale'. Defaults to 0.01. + require_contiguous (boolean, optional): Whether to enforce that each + 3mer, `{i-1,i,i+1}` and`{j-1,j,j+1}`, is made up of contiguous + residues from the same protein chain. Default is `False` for + backwards compatibility, but `True` is recommended as best practice. + + Attributes: + dim_out (int): Number of dimensions of the output features. + + Inputs: + X (torch.Tensor): Backbone coordinates with shape + `(num_batch, num_residues, num_atom_types, 3)`. + edge_idx (torch.LongTensor): Graph indices for expansion with shape + `(num_batch, num_residues, num_neighbors)`. + C (torch.LongTensor): Chain map with shape + `(num_batch, num_residues)`. + + Outputs: + edge_h (torch.Tensor): Edge distance matrix features with shape + `(num_batch, num_residues, num_neighbors, (6 * num_atom_types)**2)` + """ + + def __init__( + self, + feature: str = "log", + distance_eps: float = 0.01, + num_atom_types: int = 4, + require_contiguous: bool = False, + ): + super(EdgeDistance6mer, self).__init__() + self.feature = feature + self.distance_eps = distance_eps + self.num_atom_types = num_atom_types + self.layer_6mers = Edge6mers(require_contiguous=require_contiguous) + self.layer_distance = geometry.Distances() + + # Public attribute + self.dim_out = (6 * num_atom_types) ** 2 + + self.feature = feature + feature_functions = { + "log": self.log_func, + "inverse": self.inverse_func, + "raw": self.raw_func, + } + self.feature_function = feature_functions[feature] + + def forward( + self, + X: torch.Tensor, + edge_idx: torch.LongTensor, + C: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: + X_ij, mask_ij = self.layer_6mers(X, edge_idx, C=C) + D_ij = self.layer_distance(X_ij, dim=-2) + feature_ij = self.feature_function(D_ij) + feature_ij_flat = feature_ij.reshape(list(D_ij.shape[:3]) + [-1]) + edge_h = mask_ij * feature_ij_flat + # debug_plot_edge6merdist(edge_h, feature=self.feature) + return edge_h + + def log_func(self, D): + return torch.log(D + self.distance_eps) + + def inverse_func(self, D): + return 1.0 / (D + self.distance_eps) + + def raw_func(self, D): + return D + + +class EdgeDistance2mer(nn.Module): + """Edge features based on chain distance matrices along each i,j 2mer. + + Args: + feature (str, optional): Option string in {'log', 'inverse', 'raw'} + specifying how to process the raw distance features. + Defaults to 'log'. + distance_eps (float, optional): Smoothing parameter to prevent feature + explosion at small distances. Can be thought of as a 'minimum length + scale'. Defaults to 0.01. + + Attributes: + dim_out (int): Number of dimensions of the output features. + + Inputs: + X (torch.Tensor): Backbone coordinates with shape + `(num_batch, num_residues, num_atom_types, 3)`. + edge_idx (torch.LongTensor): Graph indices for expansion with shape + `(num_batch, num_residues, num_neighbors)`. + C (torch.LongTensor): Chain map with shape + `(num_batch, num_residues)`. + + Outputs: + edge_h (torch.Tensor): Edge distance matrix features with shape + `(num_batch, num_residues, num_neighbors, (6 * num_atom_types)**2)` + """ + + def __init__( + self, + features: str = "rbf+log", + distance_eps: float = 0.01, + num_atom_types: int = 4, + rbf_min: float = 0.0, + rbf_max: float = 20.0, + rbf_count: int = 20, + ): + super(EdgeDistance2mer, self).__init__() + self.distance_eps = distance_eps + self.num_atom_types = num_atom_types + self.layer_2mers = Edge2mers() + self.layer_distance = geometry.Distances() + + features = features.split("+") + if not isinstance(features, list): + features = [features] + self.features = features + if "rbf" in self.features: + self.rbf_function = RBFExpansion(rbf_min, rbf_max, rbf_count) + dim_base = (2 * num_atom_types) ** 2 + feature_dims = { + "log": dim_base, + "inverse": dim_base, + "raw": dim_base, + "rbf": dim_base * rbf_count, + } + + # Public attribute + self.dim_out = sum([feature_dims[d] for d in features]) + + self.feature_funcs = { + "log": lambda D: torch.log(D + self.distance_eps), + "inverse": lambda D: 1.0 / (D + self.distance_eps), + "raw": lambda D: D, + "rbf": lambda D: self.rbf_function(D), + } + + def featurize(self, D): + h_list = [] + for feature in self.features: + h = self.feature_funcs[feature](D) + h_list.append(h) + h = torch.cat(h_list, -1) + return h + + def forward( + self, + X: torch.Tensor, + edge_idx: torch.LongTensor, + C: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: + X_ij, mask_ij = self.layer_2mers(X, edge_idx, C=C) + D_ij = self.layer_distance(X_ij, dim=-2) + shape_flat = list(D_ij.shape[:3]) + [-1] + D_ij = D_ij.reshape(shape_flat) + feature_ij = self.featurize(D_ij) + + # DEBGUG + # _debug_plot_edges(edge_idx, feature_ij, unravel=True) + # exit(0) + edge_h = mask_ij * feature_ij + return edge_h + + +class EdgeOrientation2mer(nn.Module): + """Edge features based on chain distance matrices along each i,j 2mer. + + Args: + feature (str, optional): Option string in {'log', 'inverse', 'raw'} + specifying how to process the raw distance features. + Defaults to 'log'. + distance_eps (float, optional): Smoothing parameter to prevent feature + explosion at small distances. Can be thought of as a 'minimum length + scale'. Defaults to 0.01. + + Attributes: + dim_out (int): Number of dimensions of the output features. + + Inputs: + X (torch.Tensor): Backbone coordinates with shape + `(num_batch, num_residues, num_atom_types, 3)`. + edge_idx (torch.LongTensor): Graph indices for expansion with shape + `(num_batch, num_residues, num_neighbors)`. + C (torch.LongTensor): Chain map with shape + `(num_batch, num_residues)`. + + Outputs: + edge_h (torch.Tensor): Edge distance matrix features with shape + `(num_batch, num_residues, num_neighbors, (6 * num_atom_types)**2)` + """ + + def __init__(self, distance_eps: float = 0.1, num_atom_types: int = 4): + super(EdgeOrientation2mer, self).__init__() + self.distance_eps = distance_eps + self.num_atom_types = num_atom_types + self.layer_2mers = Edge2mers() + + # Public attribute + self.dim_out = 3 * (2 * num_atom_types) ** 2 + + def _normed_vec(self, V): + # Unit vector from i to j + mag_sq = (V**2).sum(dim=-1, keepdim=True) + mag = torch.sqrt(mag_sq + self.distance_eps) + V_norm = V / mag + return V_norm + + def forward( + self, + X: torch.Tensor, + edge_idx: torch.LongTensor, + C: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: + X_ij, mask_ij = self.layer_2mers(X, edge_idx, C=C) + + # Build direction vectors + U_ij = self._normed_vec(X_ij.unsqueeze(3) - X_ij.unsqueeze(4)) + + # Build reference frame + X_N, X_CA, X_C, X_O = X.unbind(2) + _normed_cross = lambda U_a, U_b: self._normed_vec(torch.cross(U_a, U_b, dim=-1)) + u_CA_N = self._normed_vec(X_N - X_CA) + u_CA_C = self._normed_vec(X_C - X_CA) + n_1 = u_CA_N + n_2 = _normed_cross(n_1, u_CA_C) + n_3 = _normed_cross(n_1, n_2) + R = torch.stack([n_1, n_2, n_3], -1) + + U_ij = torch.einsum("nijabx,nixy->nijaby", U_ij, R) + + # DEBUG: + # _debug_plot_edges(edge_idx, U_ij[:,:,:,1,5,:]) + + feature_ij = U_ij.view(list(edge_idx.shape)[:3] + [-1]) + edge_h = mask_ij * feature_ij + return edge_h + + +class EdgeOrientationChain(nn.Module): + """Edge features encoding the relative orientations of chains and chain atoms. + + Args: + feature (str, optional): Option string in {'log', 'inverse', 'raw'} + specifying how to process the raw distance features. + Defaults to 'log'. + distance_eps (float, optional): Smoothing parameter to prevent feature + explosion at small distances. Can be thought of as a 'minimum length + scale'. Defaults to 0.1. + distance_eps (float, optional): Like `distance_eps`, but for orientation + calculations. Can be thought of as a 'minimum length scale' + Defaults to 1E-5. + + Attributes: + dim_out (int): Number of dimensions of the output features. + + Inputs: + X (torch.Tensor): Backbone coordinates with shape + `(num_batch, num_residues, num_atom_types, 3)`. + edge_idx (torch.LongTensor): Graph indices for expansion with shape + `(num_batch, num_residues, num_neighbors)`. + C (torch.LongTensor): Chain map with shape + `(num_batch, num_residues)`. + + Outputs: + edge_h (torch.Tensor): Edge distance matrix features with shape + `(num_batch, num_residues, num_neighbors, 24)` + """ + + def __init__( + self, feature: str = "log", distance_eps: float = 0.1, norm_eps: float = 1e-1 + ): + super(EdgeOrientationChain, self).__init__() + self.distance_eps = distance_eps + self.norm_eps = norm_eps + + self.feature = feature + feature_functions = { + "log": lambda D: torch.log(D + self.distance_eps), + "inverse": lambda D: 1.0 / (D + self.distance_eps), + "raw": lambda D: D, + } + self.feature_function = feature_functions[feature] + + # Public attribute + self.dim_out = 24 + + def _normed_vec(self, V): + # Unit vector from i to j + mag_sq = (V**2).sum(dim=-1, keepdim=True) + mag = torch.sqrt(mag_sq + self.norm_eps) + V_norm = V / mag + return V_norm + + def _reference_frames(self, X): + # Build reference frames at each i + X_N, X_CA, X_C, X_O = X.unbind(2) + _normed_cross = lambda U_a, U_b: self._normed_vec(torch.cross(U_a, U_b, dim=-1)) + u_CA_N = self._normed_vec(X_N - X_CA) + u_CA_C = self._normed_vec(X_C - X_CA) + n_1 = u_CA_N + n_2 = _normed_cross(n_1, u_CA_C) + n_3 = _normed_cross(n_1, n_2) + R = torch.stack([n_1, n_2, n_3], -1) + return R + + def _reference_frames_chain(self, X, C): + # Build reference frames at each i + X_N, X_CA, X_C, X_O = X.unbind(2) + _normed_cross = lambda U_a, U_b: self._normed_vec(torch.cross(U_a, U_b, dim=-1)) + u_CA_N = self._normed_vec(X_N - X_CA) + u_CA_C = self._normed_vec(X_C - X_CA) + + u_CA_N_avg = self._chain_average(u_CA_N, C) + u_CA_C_avg = self._chain_average(u_CA_C, C) + + n_1 = self._normed_vec(u_CA_N_avg) + n_2 = _normed_cross(n_1, self._normed_vec(u_CA_C_avg)) + n_3 = _normed_cross(n_1, n_2) + R = torch.stack([n_1, n_2, n_3], -1) + return R + + def _chain_average(self, node_h, C, eps=1e-5): + # Compute the per-chain averages of each feature within a chain, in place + num_batch, num_residues = list(C.shape) + num_chains = int(torch.max(C).item()) + + # Build a position == chain expanded mask (B,L,C) + C_expand = C.unsqueeze(-1).expand(-1, -1, num_chains) + idx = torch.arange(num_chains, device=C.device) + 1 + idx_expand = idx.view(1, 1, -1) + mask_expand = (idx_expand == C_expand).type(torch.float32) + mask_expand = mask_expand.unsqueeze(-1) + + # Masked reduction + node_h_expand = node_h.unsqueeze(2).expand(-1, -1, num_chains, -1) + node_h_chain_average = (mask_expand * node_h_expand).sum(1, keepdim=True) / ( + (mask_expand).sum(1, keepdim=True) + eps + ) + + # Back-expand (B,C,K) => (B,L,3) + node_h_chain_average = (mask_expand * node_h_chain_average).sum(2) + return node_h_chain_average + + def _R_neighbors(self, R_i, edge_idx): + num_batch, num_residues, num_k = list(edge_idx.shape) + R_flat_i = R_i.reshape(num_batch, num_residues, 9) + R_flat_j = graph.collect_neighbors(R_flat_i, edge_idx) + R_j = R_flat_j.reshape(num_batch, num_residues, num_k, 3, 3) + return R_j + + def _transformation_features(self, X_i, X_j, R_i, R_j, edge_idx, edges=True): + # Distance and direction + dX = X_j - X_i.unsqueeze(2).contiguous() + L = torch.sqrt((dX**2).sum(-1, keepdim=True) + self.distance_eps) + u_ij = torch.einsum("niab,nija->nijb", R_i, dX / L) + + # Relative orientation + R_relative_ij = torch.einsum("niab,nijac->nijbc", R_i, R_j) + q_ij = geometry.quaternions_from_rotations(R_relative_ij) + + h = torch.cat((self.feature_function(L), u_ij, q_ij), dim=-1) + return h + + def forward( + self, + X: torch.Tensor, + edge_idx: torch.LongTensor, + C: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: + num_batch, num_residues, num_k = list(edge_idx.shape) + + # Compute local positions (C-alpha) and frames (B, L, 4) + R_i = self._reference_frames(X) + R_chain_i = self._reference_frames_chain(X, C) + + # X chain + X_i = X[:, :, 1, :] + X_j = graph.collect_neighbors(X_i, edge_idx) + X_chain_i = self._chain_average(X_i, C) + X_chain_j = graph.collect_neighbors(X_chain_i, edge_idx) + + # Relative chain features + R_chain_j = self._R_neighbors(R_chain_i, edge_idx) + R_j = self._R_neighbors(R_i, edge_idx) + + h_chain_to_chain = self._transformation_features( + X_chain_i, X_chain_j, R_chain_i, R_chain_j, edge_idx + ) + h_chain_to_node = self._transformation_features( + X_chain_i, X_j, R_chain_i, R_j, edge_idx + ) + h_node_to_node = self._transformation_features(X_i, X_j, R_i, R_j, edge_idx) + edge_h = torch.cat((h_chain_to_chain, h_chain_to_node, h_node_to_node), -1) + + # DEBUG: + # h = h_node_to_node + # _debug_plot_edges(edge_idx, h[:,:,:,0].unsqueeze(-1)) + # _debug_plot_edges(edge_idx, h[:,:,:,1:4]) + # _debug_plot_edges(edge_idx, h[:,:,:,5:9]) + + mask_i = chain_map_to_mask(C).unsqueeze(-1) + mask_j = graph.collect_neighbors(mask_i, edge_idx) + mask_ij = mask_i.unsqueeze(2) * mask_j + edge_h = mask_ij * edge_h + return edge_h + + +class EdgeDistanceChain(nn.Module): + """Edge features based on distance matrices along each i,j 6mer. + + These feature capture (signed) intra-chain distances as well as distinguish + between same vs. different chain. + + Args: + + Attributes: + dim_out (int): Number of dimensions of the output features. + + Inputs: + X (torch.Tensor): Backbone coordinates with shape + `(num_batch, num_residues, num_atom_types, 3)`. + edge_idx (torch.LongTensor): Graph indices for expansion with shape + `(num_batch, num_residues, num_neighbors)`. + C (torch.LongTensor): Chain map with shape + `(num_batch, num_residues)`. + + Outputs: + edge_h (torch.Tensor): Edge chain distance features with shape + `(num_batch, num_residues, num_neighbors, 2)` + """ + + def __init__(self): + super(EdgeDistanceChain, self).__init__() + + # Public attribute + self.dim_out = 3 + + def forward( + self, + X: torch.Tensor, + edge_idx: torch.LongTensor, + C: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: + # Is the edge intra-chain or inter-chain? + chain_i = C.unsqueeze(-1) + chain_j = graph.collect_neighbors(chain_i, edge_idx).squeeze(-1) + is_interface = torch.ne(chain_i, chain_j).type(torch.float32) + + # If it is intra-chain, what is the chain distance? + residue_i = torch.arange(edge_idx.shape[1], device=X.device).view((1, -1, 1)) + residue_j = edge_idx + D_signed = (residue_j - residue_i).type(torch.float32) + D_residue = torch.abs(D_signed) + D_intra = (1.0 - is_interface) * torch.log(D_residue + 1.0) + D_intra_sign = (1.0 - is_interface) * torch.sign(D_signed) + + edge_h = torch.stack([is_interface, D_intra, D_intra_sign], dim=-1) + return edge_h + + +class EdgePositionalEncodings(nn.Module): + """Edge features based on positional encodings of chain distance |i-j|. + + Args: + dim_embeddings (int): Embedding dimension. + + Attributes: + dim_out (int): Number of dimensions of the output features. + + Inputs: + X (torch.Tensor): Backbone coordinates with shape + `(num_batch, num_residues, num_atom_types, 3)`. + edge_idx (torch.LongTensor): Graph indices for expansion with shape + `(num_batch, num_residues, num_neighbors)`. + C (torch.LongTensor): Chain map with shape + `(num_batch, num_residues)`. + + Outputs: + edge_h (torch.Tensor): Edge chain distance features with shape + `(num_batch, num_residues, num_neighbors, 2)` + """ + + def __init__(self, dim_embedding: int = 128, period_range: tuple = (1.0, 1000.0)): + super(EdgePositionalEncodings, self).__init__() + + # Public attribute + self.dim_out = dim_embedding + self.encoding = PositionalEncoding( + d_model=dim_embedding, period_range=period_range + ) + + def forward( + self, + X: torch.Tensor, + edge_idx: torch.LongTensor, + C: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: + # Is the edge intra-chain or inter-chain? + chain_i = C.unsqueeze(-1) + chain_j = graph.collect_neighbors(chain_i, edge_idx).squeeze(-1) + mask_intrachain = torch.eq(chain_i, chain_j).float() + + # If it is intra-chain, what is the chain distance? + residue_i = torch.arange(edge_idx.shape[1], device=X.device).view((1, -1, 1)) + residue_j = edge_idx + D_signed = (residue_j - residue_i).float() + edge_h = mask_intrachain[..., None] * self.encoding(D_signed[..., None]) + return edge_h + + +class EdgeRandomFourierFeatures2mer(nn.Module): + """For edge-ij computes a random fourier projection of the SE3-invariant feature t_ji + pointing from i to j in the local frame of residue i, optionally including the projection + of the associated quaternion representation of R_ji the rotation from taking you from frame i to frame j + Features are decayed exponentially at rate alpha. + Args: + dim_embedding (int): dimension of embedding + trainable (bool): Whether to train the weight matrix of the fourier features + scale (float): The scale (standard deviation) to sample random weights from + use_quaternion (bool): Whether to embed the quaternion representation as well + + Inputs: + X (torch.tensor): of size (batch, length, (4 or 14), 3) + edge_idx (torch.LongTensor): of size (batch, length, num_neighbors) + C (torch.tensor): of size (batch, length) + + Outputs: + edge_h (torch.tensor): of size (batch, length, num_neighbors, dim_embedding) + """ + + def __init__( + self, + dim_embedding: int = 128, + trainable: bool = False, + scale: float = 1.0, + use_quaternion: bool = False, + seed: int = 10, + ): + super().__init__() + + self._seed = seed + with torch.random.fork_rng(): + torch.random.manual_seed(self._seed) + + self.vector_f = FourierFeaturization( + 3, dim_embedding, trainable=trainable, scale=scale + ) + self.distance_f = FourierFeaturization( + 64, dim_embedding, trainable=trainable, scale=scale + ) + + self.use_quaternion = use_quaternion + if self.use_quaternion: + self.quat_f = FourierFeaturization( + 4, dim_embedding, trainable=trainable, scale=scale + ) + + self.layer_2mers = Edge2mers() + self.layer_distance = geometry.Distances() + self.frame_builder = backbone.FrameBuilder() + self.dim_out = dim_embedding + + def forward( + self, + X: torch.Tensor, + edge_idx: torch.LongTensor, + C: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: + X_ij, mask_ij = self.layer_2mers(X, edge_idx, C=C) + D_ij = self.layer_distance(X_ij, dim=-2) + D_ij = D_ij.reshape(*D_ij.size()[:-2], -1) + + R_i, t_i, _ = self.frame_builder.inverse(X, C) + R_j, t_j = transforms.collect_neighbor_transforms(R_i, t_i, edge_idx) + R_ji, t_ji = transforms.compose_inner_transforms( + R_j, t_j, R_i.unsqueeze(-3), t_i.unsqueeze(-2) + ) + + edge_h = self.vector_f(t_ji) + self.distance_f(D_ij) + + if self.use_quaternion: + Q_ji = geometry.quaternions_from_rotations(R_ji) + edge_h = edge_h + self.quat_f(Q_ji) + + return edge_h + + +class RBFExpansion(nn.Module): + def __init__( + self, + value_min: float, + value_max: float, + num_rbf: int, + std: Optional[float] = None, + ): + super(RBFExpansion, self).__init__() + rbf_centers = torch.linspace(value_min, value_max, num_rbf) + self.register_buffer("rbf_centers", rbf_centers) + if std is None: + std = (rbf_centers[1] - rbf_centers[0]).item() + self.std = std + + def forward(self, h: torch.Tensor) -> torch.Tensor: + shape = list(h.shape) + shape_ones = [1 for _ in range(len(shape))] + [-1] + rbf_centers = self.rbf_centers.view(shape_ones) + h = torch.exp(-(((h.unsqueeze(-1) - rbf_centers) / self.std) ** 2)) + h = h.view(shape[:-1] + [-1]) + return h + + +class NodeCartesianCoords(nn.Module): + """Node features containing raw relative coordinates. + + Warning: these features are not rotationally invariant. + + Args: + scale_factor (float, optional): Scale factor to rescale raw coordinates + for neural network processing. Default is 0.3. + num_atom_types (int, optional): Number of atom types. Default is 4. + + Attributes: + dim_out (int): Number of dimensions of the output features. + + Inputs: + X (torch.Tensor): Backbone coordinates with shape + `(num_batch, num_residues, num_atom_types, 3)`. + edge_idx (torch.LongTensor): Graph indices for expansion with shape + `(num_batch, num_residues, num_neighbors)`. + C (torch.LongTensor): Chain map with shape + `(num_batch, num_residues)`. + + Outputs: + edge_h (torch.Tensor): Node relative coordinates features with shape + `(num_batch, num_residues, 3 * (num_atom_types)**2)` + """ + + def __init__(self, scale_factor: float = 0.3, num_atom_types: int = 4): + super(NodeCartesianCoords, self).__init__() + self.scale_factor = scale_factor + self.num_atom_types = num_atom_types + + # Public attribute + self.dim_out = 3 * (num_atom_types**2) + + def forward( + self, + X: torch.Tensor, + edge_idx: torch.LongTensor, + C: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: + num_batch, num_residues, num_neighbors = list(edge_idx.shape) + + dX = X.unsqueeze(-2) - X.unsqueeze(-3) + node_h = self.scale_factor * dX.reshape([num_batch, num_residues, -1]) + + if C is not None: + mask_i = chain_map_to_mask(C) + node_h = mask_i.unsqueeze(-1) * node_h + return node_h + + +class EdgeCartesianCoords(nn.Module): + """Edge features containing raw relative coordinates. + + Warning: these features are not rotationally invariant. + + Args: + scale_factor (float, optional): Scale factor to rescale raw coordinates + for neural network processing. Default is 0.1. + num_atom_types (int, optional): Number of atom types. Default is 4. + + Attributes: + dim_out (int): Number of dimensions of the output features. + + Inputs: + X (torch.Tensor): Backbone coordinates with shape + `(num_batch, num_residues, num_atom_types, 3)`. + edge_idx (torch.LongTensor): Graph indices for expansion with shape + `(num_batch, num_residues, num_neighbors)`. + C (torch.LongTensor): Chain map with shape + `(num_batch, num_residues)`. + + Outputs: + edge_h (torch.Tensor): Edge relative coordinates features with shape + `(num_batch, num_residues, num_neighbors, 3 * (num_atom_types)**2)` + """ + + def __init__(self, scale_factor: float = 0.1, num_atom_types: int = 4): + super(EdgeCartesianCoords, self).__init__() + self.scale_factor = scale_factor + self.num_atom_types = num_atom_types + + # Public attribute + self.dim_out = 3 * (num_atom_types**2) + + def forward( + self, + X: torch.Tensor, + edge_idx: torch.LongTensor, + C: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: + num_batch, num_residues, num_neighbors = list(edge_idx.shape) + + # Collect coordiates and j + X_flat = X.reshape([num_batch, num_residues, -1]) + X_j_flat = graph.collect_neighbors(X_flat, edge_idx) + X_j = X_j_flat.reshape( + [num_batch, num_residues, num_neighbors, 1, self.num_atom_types, 3] + ) + + X_i = X.reshape([num_batch, num_residues, 1, self.num_atom_types, 1, 3]) + dX = X_j - X_i + edge_h = self.scale_factor * dX.reshape( + [num_batch, num_residues, num_neighbors, -1] + ) + if C is not None: + mask_i = chain_map_to_mask(C) + mask_i_expand = mask_i.unsqueeze(-1) + mask_j = graph.collect_neighbors(mask_i_expand, edge_idx) + mask_ij = mask_j * mask_i_expand.unsqueeze(-1) + edge_h = mask_ij * edge_h + return edge_h + + +def chain_map_to_mask(C: torch.LongTensor) -> torch.Tensor: + """Convert chain map into a mask. + + Args: + C (torch.LongTensor): Chain map with shape + `(num_batch, num_residues)`. + + Returns: + mask (Tensor, optional): Mask tensor with shape + `(num_batch, num_residues)`. + """ + return (C > 0).type(torch.float32) + + +def _cgo_cylinder(X1, X2, radius=0.5, rgb=(0.0, 0.0, 1.0)): + x1, y1, z1 = X1.data.numpy().flatten().tolist() + x2, y2, z2 = X2.data.numpy().flatten().tolist() + r1, g1, b1 = rgb + r2, g2, b2 = rgb + cgo_str = ( + f"[ 9.0, {x1}, {y1}, {z1}, {x2}, {y2}, {z2}, {radius}, {r1}, {g1}, {b1}, {r2}," + f" {g2}, {b2} ]" + ) + return cgo_str + + +def _cgo_sphere(X1, radius=1.0): + x1, y1, z1 = X1.data.numpy().flatten().tolist() + cgo_str = f"[ 7.0, {x1}, {y1}, {z1}, {radius}]" + return cgo_str + + +def _cgo_color(rgb=(0.0, 0.0, 1.0)): + r, g, b = rgb + cgo_str = f"[ 6.0, {r}, {g}, {b}]" + return cgo_str + + +if __name__ == "__main__": + _debug_plot_random_graphs(num_neighbors=60) diff --git a/prtm/models/chroma/structure/protein_graph_allatom.py b/prtm/models/chroma/structure/protein_graph_allatom.py new file mode 100644 index 00000000..693654ea --- /dev/null +++ b/prtm/models/chroma/structure/protein_graph_allatom.py @@ -0,0 +1,246 @@ +# Copyright Generate Biomedicines, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Layers for building graph representations of protein structure, all-atom. + +This module contains pytorch layers for representing protein structure as a +graph with node and edge features based on geometric information. The graph +features are differentiable with respect to input coordinates and can be used +for building protein scoring functions and optimizing protein geometries +natively in pytorch. +""" + + +import numpy as np +import torch +import torch.nn as nn +from prtm.models.chroma import graph +from prtm.models.chroma.structure import geometry, sidechain + + +class NodeChiRBF(nn.Module): + """Layers for featurizing chi angles with a smooth binning + + Args: + num_chi_bins (int): Number of bins for discretizing chi angles. + num_chi (int): Number of chi angles. + dim_out (int): Number of output feature dimensions. + bin_scale (float, optional): Scaling parameter that sets bin smoothing. + + Input: + chi (Tensor): Chi angles with shape `(num_batch, num_residues, num_chi)`. + + Output: + h_chi (Tensor): Chi angle features with shape + `(num_batch, num_residues, num_chi * num_chi_bins)`. + """ + + def __init__(self, dim_out, num_chi, num_chi_bins, bin_scale=2.0): + super(NodeChiRBF, self).__init__() + self.dim_out = dim_out + self.num_chi = num_chi + self.num_chi_bins = num_chi_bins + self.bin_scale = bin_scale + + self.embed = nn.Linear(self.num_chi * self.num_chi_bins, dim_out) + + def _featurize(self, chi, mask_chi=None): + num_batch, num_residues, _ = chi.shape + + chi_bin_center = ( + torch.arange(0, self.num_chi_bins, device=chi.device) + * 2.0 + * np.pi + / self.num_chi_bins + ) + chi_bin_center = chi_bin_center.reshape([1, 1, 1, -1]) + + # Set smoothing length scale based on ratio beteen adjacent bin centers + # bin_i / bin_i+1 = 1 / scale + delta_adjacent = np.cos(0.0) - np.cos(2.0 * np.pi / self.num_chi_bins) + cosine = torch.cos(chi.unsqueeze(-1) - chi_bin_center) + chi_features = torch.exp((cosine - 1.0) * self.bin_scale / delta_adjacent) + if mask_chi is not None: + chi_features = mask_chi.unsqueeze(-1) * chi_features + chi_features = chi_features.reshape( + [num_batch, num_residues, self.num_chi * self.num_chi_bins] + ) + return chi_features + + def forward(self, chi, mask_chi=None): + chi_features = self._featurize(chi, mask_chi=mask_chi) + h_chi = self.embed(chi_features) + return h_chi + + +class EdgeSidechainsDirect(nn.Module): + """Layers for direct encoding of side chain geometries. + + Args: + dim_out (int): Number of output hidden dimensions. + max_D (float, optional): Maximum distance cutoff for encoding + of edges. + + Input: + X (Tensor): All atom coordinates with shape + `(num_batch, num_residues, 14, 3)`. + C (LongTensor): Chain map with shape `(num_batch, num_residues)`. + S (LongTensor): Sequence tensor with shape + `(num_batch, num_residues)`. + edge_idx (Tensor): Graph indices for expansion with shape + `(num_batch, num_residues_out, num_neighbors)`. The dimension + of output variables `num_residues_out` must either equal + `num_residues` or 1, the latter of which can be useful for sequential + decoding. + + Output: + h (Tensor): Features with shape + `(num_batch, num_residues_out, num_neighbors, num_hidden)`. + """ + + def __init__( + self, + dim_out, + length_scale=7.5, + distance_eps=0.1, + num_fourier=30, + fourier_order=2, + basis_type="rff", + ): + super(EdgeSidechainsDirect, self).__init__() + self.dim_out = dim_out + self.length_scale = length_scale + self.distance_eps = distance_eps + + # self.embed = nn.Linear(14 * 3 , dim_out) + self.num_fourier = num_fourier + self.rff = torch.nn.Parameter( + 2.0 * np.pi / self.length_scale * torch.randn((3, self.num_fourier)) + ) + self.basis_type = basis_type + if self.basis_type == "rff": + self.embed = nn.Linear(14 * self.num_fourier * 2, dim_out) + elif self.basis_type == "spherical": + self.fourier_order = fourier_order + self.embed = nn.Linear(14 * (self.fourier_order * 2) ** 3, dim_out) + + def _local_coordinates(self, X, C, S, edge_idx): + num_batch, num_residues, num_neighbors = edge_idx.shape + + # Mask and transform into features + mask_atoms = sidechain.atom_mask(C, S) + mask_atoms_j = graph.collect_neighbors(mask_atoms, edge_idx) + mask_i = (C > 0).float().reshape([num_batch, num_residues, 1, 1]) + mask_atoms_ij = mask_i * mask_atoms_j + + # Build conditioning mask + R_i, CA = geometry.frames_from_backbone(X[:, :, :4, :]) + + # Transform neighbor X coordinates into local frames + X_flat = X.reshape([num_batch, num_residues, -1]) + X_j_flat = graph.collect_neighbors(X_flat, edge_idx) + X_j = X_j_flat.reshape([num_batch, num_residues, num_neighbors, 14, 3]) + dX_ij = X_j - CA.reshape([num_batch, num_residues, 1, 1, 3]) + U_ij = torch.einsum("niab,nijma->nijmb", R_i, dX_ij) + return U_ij, mask_atoms_ij + + def _local_coordinates_t(self, t, X, C, S, edge_idx_t): + num_batch, _, num_neighbors = edge_idx_t.shape + num_residues = X.shape[1] + + # Make a mask that + C_i = C[:, t].unsqueeze(1) + # S_i = S[:,t].unsqueeze(1) + # mask_atoms_i = sidechain.atom_mask(C_i, S_i) + C_j = graph.collect_neighbors(C.unsqueeze(-1), edge_idx_t).reshape( + [num_batch, num_neighbors] + ) + S_j = graph.collect_neighbors(S.unsqueeze(-1), edge_idx_t).reshape( + [num_batch, num_neighbors] + ) + mask_atoms_j = sidechain.atom_mask(C_j, S_j).unsqueeze(1) + mask_i = (C_i > 0).float().reshape([num_batch, 1, 1, 1]) + mask_atoms_ij = mask_i * mask_atoms_j + + # Build conditioning mask + X_bb_i = X[:, t, :4, :].unsqueeze(1) + R_i, CA = geometry.frames_from_backbone(X_bb_i) + + # Transform neighbor X coordinates into local frames + X_flat = X.reshape([num_batch, num_residues, -1]) + X_j_flat = graph.collect_neighbors(X_flat, edge_idx_t) + X_j = X_j_flat.reshape([num_batch, 1, num_neighbors, 14, 3]) + dX_ij = X_j - CA.reshape([num_batch, 1, 1, 1, 3]) + U_ij = torch.einsum("niab,nijma->nijmb", R_i, dX_ij) + return U_ij, mask_atoms_ij + + def _fourier_expand(self, h, order): + k = torch.arange(order, device=h.device) + k = k.reshape([1 for i in h.shape] + [-1]) + return torch.cat( + [torch.sin(h.unsqueeze(-1) * (k + 1)), torch.cos(h.unsqueeze(-1) * k)], + dim=-1, + ) + + def _featurize(self, U_ij, mask_atoms_ij): + if self.basis_type == "rff": + # Random fourier features + U_ij = mask_atoms_ij.unsqueeze(-1) * U_ij + U_ff = torch.einsum("nijax,xy->nijay", U_ij, self.rff) + U_ff = torch.concat([torch.cos(U_ff), torch.sin(U_ff)], -1) + + # Gaussian RBF envelope + D_ij = torch.sqrt((U_ij**2).sum(-1) + self.distance_eps) + magnitude = torch.exp(-D_ij * D_ij / (2 * self.length_scale**2)) + U_ff = magnitude.unsqueeze(-1) * U_ff + + U_ff = U_ff.reshape(list(D_ij.shape)[:3] + [-1]) + h = mask_atoms_ij[:, :, :, 0].unsqueeze(-1) * self.embed(U_ff) + + elif self.basis_type == "spherical": + # Convert to spherical coordinates + r_ij = torch.sqrt((U_ij**2).sum(-1) + self.distance_eps) + r_ij_scale = r_ij * 2.0 * np.pi / self.length_scale + x, y, z = U_ij.unbind(-1) + theta_ij = torch.acos(z / r_ij) + phi_ij = torch.atan2(y, x) + + # Build Fourier expansions of each coordinate + r_ff, theta_ff, phi_ff = [ + self._fourier_expand(h, self.fourier_order) + for h in [r_ij_scale, theta_ij, phi_ij] + ] + # Radial envelope function + r_envelope = mask_atoms_ij * torch.exp( + -r_ij * r_ij / (2 * self.length_scale**2) + ) + + # Tensor outer product + bf_ij = torch.einsum( + "bika,bikar,bikat,bikap->bikartp", r_envelope, r_ff, theta_ff, phi_ff + ).reshape(list(r_ij.shape)[:3] + [-1]) + + h = mask_atoms_ij[:, :, :, 0].unsqueeze(-1) * self.embed(bf_ij) + + return h + + def forward(self, X, C, S, edge_idx): + U_ij, mask_atoms_ij = self._local_coordinates(X, C, S, edge_idx) + h = self._featurize(U_ij, mask_atoms_ij) + return h + + def step(self, t, X, C, S, edge_idx_t): + U_ij, mask_atoms_ij = self._local_coordinates_t(t, X, C, S, edge_idx_t) + h = self._featurize(U_ij, mask_atoms_ij) + return h diff --git a/prtm/models/chroma/structure/rmsd.py b/prtm/models/chroma/structure/rmsd.py new file mode 100644 index 00000000..76acae38 --- /dev/null +++ b/prtm/models/chroma/structure/rmsd.py @@ -0,0 +1,530 @@ +# Copyright Generate Biomedicines, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import torch +import torch.linalg +import torch.nn as nn +from prtm.models.chroma import graph +from prtm.models.chroma.linalg import eig_leading +from prtm.models.chroma.structure import geometry, protein_graph + + +class CrossRMSD(nn.Module): + """Compute optimal RMSDs between two sets of structures. + + This module uses the quaternion-based approach for calculating RMSDs as + described in `Using Quaternions to Calculate RMSD`, 2004, by Coutsias, + Seok, and Dill. The minimal RMSD and associated rotation are computed in + terms of the most positive eigenvalue and associated eigvector of a special + 4x4 matrix. + + Args: + method (str, optional): Method for calculating the most postive + eigenvalue. Can be `power` or `symeig`. If `symeig`, this will use + `torch.symeig`, which is the most accurate method but tends to be + very slow on GPU for large batches of RMSDs. If `power`, then use + power iteration to estimate leading eigenvalues. Default is `power`. + method_iter (int, optional): When the method is `power`, this argument + sets the number of power iterations used for approximation. + The default is 50, which has tended to produce estimates of optimal + RMSD with sub-angstrom accuracy on test problems. Note: Convergence + rates of power iteration can be highly variable dependening on the + system. If accuracy is important, it is recommended to compare + outputs with `symeig`-based RMSDs. + + Inputs: + X_mobile (Tensor): Mobile coordinates, i.e. the "mobile" coordinates, + with shape `(num_source, num_atoms, 3)`. + X_target (Tensor): Target coordinates with shape + `(num_target, num_atoms, 3)`. + + Outputs: + RMSD (Tensors): RMSDs after optimal superposition for all pairs of + source and target structures with shape `(num_source, num_target)`. + While `forward` returns the Cartesian product of all possible + alignments, i.e. (`num_source * num_target` alignments), the + `pairedRMSD` will do the same calculation for zipped batches, i.e. + `num_source` total alignments. + """ + + def __init__(self, method="power", method_iter=50, dither=True): + super(CrossRMSD, self).__init__() + + self.method = method + self.method_iter = method_iter + self._eps = 1e-5 + self.dither = dither + + # R_to_F converts xyz cross-covariance matrices (3x3) to the (4x4) F + # matrix of Coutsias et al. This F matrix encodes the optimal RMSD in + # its spectra; namely, the eigenvector associated with the most + # positive eigenvalue of F is the quaternion encoding the optimal + # 3D rotation for superposition. + # fmt: off + R_to_F = np.zeros((9, 16)).astype("f") + F_nonzero = [ + [(0,0,1.),(1,1,1.),(2,2,1.)], [(1,2,1.),(2,1,-1.)], [(2,0,1.),(0,2,-1.)], [(0,1,1.),(1,0,-1.)], + [(1,2,1.),(2,1,-1.)], [(0,0,1.),(1,1,-1.),(2,2,-1.)], [(0,1,1.),(1,0,1.)], [(0,2,1.),(2,0,1.)], + [(2,0,1.),(0,2,-1.)], [(0,1,1.),(1,0,1.)], [(0,0,-1.),(1,1,1.),(2,2,-1.)], [(1,2,1.),(2,1,1.)], + [(0,1,1.),(1,0,-1.)], [(0,2,1.),(2,0,1.)], [(1,2,1.),(2,1,1.)], [(0,0,-1.),(1,1,-1.),(2,2,1.)] + ] + # fmt: on + + for F_ij, nonzero in enumerate(F_nonzero): + for R_i, R_j, sign in nonzero: + R_to_F[R_i * 3 + R_j, F_ij] = sign + self.register_buffer("R_to_F", torch.tensor(R_to_F)) + + def forward(self, X_mobile, X_target): + num_source = X_mobile.size(0) + num_target = X_target.size(0) + num_atoms = X_mobile.size(1) + + # Center coordinates + X_mobile = X_mobile - X_mobile.mean(dim=1, keepdim=True) + X_target = X_target - X_target.mean(dim=1, keepdim=True) + + # CrossCov matrices contract over atoms + R = torch.einsum("sai,taj->stij", [X_mobile, X_target]) + + # F Matrix has leading eigenvector as optimal quaternion + R_flat = R.reshape(num_source, num_target, 9) + F = torch.matmul(R_flat, self.R_to_F).reshape(num_source, num_target, 4, 4) + + # Compute optimal quaternion by extracting leading eigenvector + if self.method == "symeig": + top_eig = torch.linalg.eigvalsh(F)[:, :, 3] + elif self.method == "power": + top_eig, vec = eig_leading(F, num_iterations=self.method_iter) + else: + raise NotImplementedError + + # Compute RMSD in terms of RMSD using the scheme of Coutsias et al + norms = (X_mobile**2).sum(dim=[-1, -2]).unsqueeze(1) + (X_target**2).sum( + dim=[-1, -2] + ).unsqueeze(0) + sqRMSD = torch.relu((norms - 2 * top_eig) / (num_atoms + self._eps)) + RMSD = torch.sqrt(sqRMSD) + return RMSD + + def pairedRMSD( + self, + X_mobile, + X_target, + mask=None, + compute_alignment=False, + align_unmasked=False, + ): + """Compute optimal RMSDs between each corresponding batch members. + + Args: + X_mobile (Tensor): Mobile coordinates with shape + `(..., num_atoms, 3)`. + X_target (Tensor): Target coordinates with shape + `(..., num_atoms, 3)`. + mask (Tensor, optional): Binary mask tensor for missing atoms with + shape `(..., num_atoms)`. + compute_alignment (boolean, optional): If True, also return the + superposed coordinates. + + Returns: + RMSD (Tensors): Optimal RMSDs after superposition for all pairs of + input structures with shape `(...)`. + X_mobile_transform (Tensor, optional): Superposed coordinates with + shape `(..., num_atoms, 3)`. Requires + `compute_alignment` = True`. + """ + # Collapse all leading batch dimensions + num_atoms = X_mobile.size(-2) + batch_dims = list(X_mobile.shape)[:-2] + X_mobile = X_mobile.reshape([-1, num_atoms, 3]) + X_target = X_target.reshape([-1, num_atoms, 3]) + num_batch = X_mobile.size(0) + if mask is not None: + mask = mask.reshape([-1, num_atoms]) + + # Center coordinates + if mask is None: + X_mobile_mean = X_mobile.mean(dim=1, keepdim=True) + X_target_mean = X_target.mean(dim=1, keepdim=True) + else: + mask_expand = mask.unsqueeze(-1) + X_mobile_mean = torch.sum(mask_expand * X_mobile, 1, keepdim=True) / ( + torch.sum(mask_expand, 1, keepdim=True) + self._eps + ) + X_target_mean = torch.sum(mask_expand * X_target, 1, keepdim=True) / ( + torch.sum(mask_expand, 1, keepdim=True) + self._eps + ) + + X_mobile_center = X_mobile - X_mobile_mean + X_target_center = X_target - X_target_mean + + if mask is not None: + X_mobile_center = mask_expand * X_mobile_center + X_target_center = mask_expand * X_target_center + + # Cross-covariance matrices contract over atoms + R = torch.einsum("sai,saj->sij", [X_mobile_center, X_target_center]) + + # F Matrix has leading eigenvector as optimal quaternion + R_flat = R.reshape(num_batch, 9) + R_to_F = self.R_to_F.type(R_flat.dtype) + F = torch.matmul(R_flat, R_to_F).reshape(num_batch, 4, 4) + if self.dither: + F = F + 1e-5 * torch.randn_like(F) + + # Compute optimal quaternion by extracting leading eigenvector + if self.method == "symeig": + L, V = torch.linalg.eigh(F) + top_eig = L[:, 3] + vec = V[:, :, 3] + elif self.method == "power": + top_eig, vec = eig_leading(F, num_iterations=self.method_iter) + else: + raise NotImplementedError + + # Compute RMSD using top eigenvalue + norms = (X_mobile_center**2).sum(dim=[-1, -2]) + (X_target_center**2).sum( + dim=[-1, -2] + ) + sqRMSD = torch.relu((norms - 2 * top_eig) / (num_atoms + self._eps)) + rmsd = torch.sqrt(sqRMSD) + + if not compute_alignment: + # Unpack leading batch dimensions + rmsd = rmsd.reshape(batch_dims) + return rmsd + else: + R = geometry.rotations_from_quaternions(vec, normalize=False) + + X_mobile_transform = torch.einsum("bxr,bir->bix", R, X_mobile_center) + X_mobile_transform = X_mobile_transform + X_target_mean + + if mask is not None: + X_mobile_transform = mask_expand * X_mobile_transform + + # Return the RMSD of the transformed coordinates + rmsd_direct = rmsd_unaligned(X_mobile_transform, X_target, mask) + + # Unpack leading batch dimensions + rmsd_direct = rmsd_direct.reshape(batch_dims) + X_mobile_transform = X_mobile_transform.reshape(batch_dims + [num_atoms, 3]) + if align_unmasked: + X_mobile_transform = X_mobile - X_mobile_mean + X_mobile_transform = torch.einsum( + "bxr, bir -> bix", + R, + X_mobile_transform.view(X_mobile.size(0), -1, 3), + ) + X_mobile_transform = X_mobile_transform + X_target_mean + + return rmsd_direct, X_mobile_transform + + +class BackboneRMSD(nn.Module): + """Compute optimal RMSDs between two sets of backbones. + + This wraps `CrossRMSD` for use with XCS-formatted protein data. + + Args: + method (str, optional): Method for calculating the most postive + eigenvalue. Can be `power` or `symeig`. Default is `power`. + method_iter (int, optional): Number of power iterations for eigenvalue + approximation. Requires `method=power`. Default is 50. + + Inputs: + X_mobile (Tensor): Mobile coordinates with shape + `(num_source, num_atoms, 4, 3)`. + X_target (Tensor): Target coordinates with shape + `(num_target, num_atoms, 4, 3)`. + C (Tensor): Chain map with shape `(num_batch, num_residues)`. + + Outputs: + X_aligned (Tensor, optional): Superposed `X_mobile` with shape + `(num_batch, num_atoms, 3)`. + rmsd (Tensors): Optimal RMSDs after superposition with shape + `(num_batch)`. + """ + + def __init__(self, method="symeig"): + super(BackboneRMSD, self).__init__() + self.rmsd = CrossRMSD(method=method) + + def align(self, X_mobile, X_target, C, align_unmasked=False): + mask = (C > 0).type(torch.float32) + mask_flat = mask.unsqueeze(-1).expand(-1, -1, 4).reshape(mask.shape[0], -1) + + X_mobile_flat = X_mobile.reshape(X_mobile.size(0), -1, 3) + X_target_flat = X_target.reshape(X_target.size(0), -1, 3) + rmsd, X_aligned = self.rmsd.pairedRMSD( + X_mobile_flat, + X_target_flat, + mask=mask_flat, + compute_alignment=True, + align_unmasked=align_unmasked, + ) + X_aligned = X_aligned.reshape(X_mobile.size()).contiguous() + return X_aligned, rmsd + + +class LossFragmentRMSD(nn.Module): + """Compute optimal fragment-pair RMSDs between two sets of backbones. + + Args: + fragment_k (int, option): Fram + method (str, optional): Method for calculating the most postive + eigenvalue. Can be `power` or `symeig`. Default is `power`. + method_iter (int, optional): Number of power iterations for eigenvalue + approximation. Requires `method=power`. Default is 50. + + Inputs: + X_mobile (Tensor): Mobile coordinates with shape + `(num_source, num_atoms, 4, 3)`. + X_target (Tensor): Target coordinates with shape + `(num_target, num_atoms, 4, 3)`. + edge_idx + C (Tensor): Chain map with shape `(num_batch, num_residues)`. + + Outputs: + rmsd (Tensor, optional): Per-site fragment RMSDs with shape + `(num_batch)`. + """ + + def __init__(self, k=7, method="symeig", method_iter=50): + super(LossFragmentRMSD, self).__init__() + self.k = k + self.rmsd = CrossRMSD(method=method, method_iter=method_iter) + + def forward(self, X_mobile, X_target, C, return_coords=False): + # Discard potential sidechain coordinates + X_mobile = X_mobile[:, :, :4, :] + X_target = X_target[:, :, :4, :] + + # Build graph and pair fragments + + X_fragment_mobile, C_fragment_mobile = _collect_X_fragments(X_mobile, C, self.k) + X_fragment_target, C_fragment_target = _collect_X_fragments(X_target, C, self.k) + shape = list(C.shape) + [-1, 3] + X_fragment_mobile = X_fragment_mobile.reshape(shape) + X_fragment_target = X_fragment_target.reshape(shape) + + mask = (C_fragment_mobile > 0).float() + rmsd, X_fragment_mobile_align = self.rmsd.pairedRMSD( + X_fragment_mobile, X_fragment_target, mask, compute_alignment=True + ) + if return_coords: + return rmsd, X_fragment_target, X_fragment_mobile, X_fragment_mobile_align + else: + return rmsd + + +class LossFragmentPairRMSD(nn.Module): + """Compute optimal fragment-pair RMSDs between two sets of backbones. + + Args: + fragment_k (int, option): Fram + method (str, optional): Method for calculating the most postive + eigenvalue. Can be `power` or `symeig`. Default is `power`. + method_iter (int, optional): Number of power iterations for eigenvalue + approximation. Requires `method=power`. Default is 50. + + Inputs: + X_mobile (Tensor): Mobile coordinates with shape + `(num_source, num_atoms, 4, 3)`. + X_target (Tensor): Target coordinates with shape + `(num_target, num_atoms, 4, 3)`. + edge_idx + C (Tensor): Chain map with shape `(num_batch, num_residues)`. + + Outputs: + rmsd (Tensor, optional): Per-site fragment RMSDs with shape + `(num_batch)`. + """ + + def __init__(self, k=7, method="symeig", method_iter=50, graph_num_neighbors=30): + super(LossFragmentPairRMSD, self).__init__() + self.k = k + self.rmsd = CrossRMSD(method=method, method_iter=method_iter) + self.graph_builder = protein_graph.ProteinGraph( + num_neighbors=graph_num_neighbors + ) + + def _stack_neighbor(self, node_h, edge_idx): + neighbor_h = graph.collect_neighbors(node_h, edge_idx) + node_h = node_h[:, :, None, :].expand(neighbor_h.shape) + edge_h = torch.cat([neighbor_h, node_h], dim=-1) + return edge_h + + def _collect_X_fragment_pairs(self, X, C, edge_idx): + X_kmer, C_kmer = _collect_X_fragments(X, C, self.k) + X_pair = self._stack_neighbor(X_kmer, edge_idx) + C_pair = self._stack_neighbor(C_kmer, edge_idx) + X_pair = X_pair.reshape(list(X_pair.shape)[:-1] + [-1, 3]) + return X_pair, C_pair + + def forward(self, X_mobile, X_target, C, return_coords=False): + # Discard potential sidechain coordinates + X_mobile = X_mobile[:, :, :4, :] + X_target = X_target[:, :, :4, :] + + # Build graph and pair fragments + edge_idx, mask_ij = self.graph_builder(X_target, C) + X_pair_mobile, C_pair_mobile = self._collect_X_fragment_pairs( + X_mobile, C, edge_idx + ) + X_pair_target, C_pair_target = self._collect_X_fragment_pairs( + X_target, C, edge_idx + ) + + mask = (C_pair_mobile > 0).float() + + rmsd, X_pair_mobile_align = self.rmsd.pairedRMSD( + X_pair_mobile, X_pair_target, mask, compute_alignment=True + ) + if return_coords: + return rmsd, mask_ij, X_pair_target, X_pair_mobile, X_pair_mobile_align + else: + return rmsd, mask_ij + + +class LossNeighborhoodRMSD(nn.Module): + """Compute optimal fragment-pair RMSDs between two sets of backbones. + + Args: + fragment_k (int, option): Fram + method (str, optional): Method for calculating the most postive + eigenvalue. Can be `power` or `symeig`. Default is `power`. + method_iter (int, optional): Number of power iterations for eigenvalue + approximation. Requires `method=power`. Default is 50. + + Inputs: + X_mobile (Tensor): Mobile coordinates with shape + `(num_source, num_atoms, 4, 3)`. + X_target (Tensor): Target coordinates with shape + `(num_target, num_atoms, 4, 3)`. + edge_idx + C (Tensor): Chain map with shape `(num_batch, num_residues)`. + + Outputs: + rmsd (Tensor, optional): Per-site fragment RMSDs with shape + `(num_batch)`. + """ + + def __init__(self, method="symeig", method_iter=50, graph_num_neighbors=30): + super(LossNeighborhoodRMSD, self).__init__() + self.rmsd = CrossRMSD(method=method, method_iter=method_iter) + self.graph_builder = protein_graph.ProteinGraph( + num_neighbors=graph_num_neighbors + ) + + def _collect_X_neighborhood(self, X, C, edge_idx): + num_batch, num_nodes, num_atoms, _ = X.shape + shape_flat = [num_batch, num_nodes, -1] + X_flat = X.reshape(shape_flat) + C_flat = C[..., None].expand([-1, -1, num_atoms]) + X_neighborhood = graph.collect_neighbors(X_flat, edge_idx).reshape( + [num_batch, num_nodes, -1, 3] + ) + C_neighborhood = graph.collect_neighbors(C_flat, edge_idx).reshape( + [num_batch, num_nodes, -1] + ) + return X_neighborhood, C_neighborhood + + def forward(self, X_mobile, X_target, C, return_coords=False): + # Discard potential sidechain coordinates + X_mobile = X_mobile[:, :, :4, :] + X_target = X_target[:, :, :4, :] + + # Build graph and pair fragments + edge_idx, mask_ij = self.graph_builder(X_target, C) + X_neighborhood_mobile, C_neighborhood_mobile = self._collect_X_neighborhood( + X_mobile, C, edge_idx + ) + X_neighborhood_target, C_neighborhood_target = self._collect_X_neighborhood( + X_target, C, edge_idx + ) + mask = (C_neighborhood_mobile > 0).float() + + rmsd, X_neighborhood_mobile_align = self.rmsd.pairedRMSD( + X_neighborhood_mobile, X_neighborhood_target, mask, compute_alignment=True + ) + mask = (mask.sum(-1) > 0).float() + if return_coords: + return ( + rmsd, + mask, + X_neighborhood_target, + X_neighborhood_mobile, + X_neighborhood_mobile_align, + ) + else: + return rmsd, mask + + +def rmsd_unaligned(X_a, X_b, mask=None, eps=1e-5, _min_rmsd=1e-8): + """Compute RMSD between two coordinate sets without alignment. + + Args: + X_a (Tensor): Coordinate set 1 with shape `(..., num_points, 3)`. + X_b (Tensor): Coordinate set 2 with shape `(..., num_points, 3)`. + mask (Tensor, optional): Mask with shape `(..., num_points)`. + eps (float, optional): Small number to prevent division by zero. + default is 1E-5. + + Returns: + rmsd (Tensor): Root mean squared deviations (raw) with shape `(...)`. + """ + squared_dev = ((X_a - X_b) ** 2).sum(-1) + if mask is None: + rmsd = torch.sqrt(squared_dev.mean(-1).clamp(min=_min_rmsd)) + else: + rmsd = torch.sqrt( + (mask * squared_dev).sum(-1).clamp(min=_min_rmsd) / (mask.sum(-1) + eps) + ) + return rmsd + + +def _collect_X_fragments(X, C, k): + num_batch, num_nodes, num_atoms, _ = X.shape + shape_flat = [num_batch, num_nodes, -1] + X_flat = X.reshape(shape_flat) + C_flat = C[..., None].expand([-1, -1, num_atoms]) + + # Grab local kmers + X_kmer = _collect_kmers(X_flat, k).reshape(shape_flat) + C_kmer = _collect_kmers(C_flat, k).reshape(shape_flat) + + # Treat noncontiguous atoms as missing + C_kmer = torch.where(C[..., None].eq(C_kmer), C_kmer, -C_kmer.abs()) + return X_kmer, C_kmer + + +def _collect_kmers(node_h, k): + """Gather `(B,I,H) => (B,I,K,H)`""" + device = node_h.device + num_batch, num_nodes, _ = node_h.shape + + # Build indices + k_idx = torch.arange(k, device=device) - (k - 1) // 2 + node_idx = torch.arange(node_h.shape[1], device=device) + kmer_idx = node_idx[None, :, None] - k_idx[None, None, :] + kmer_idx = kmer_idx.clamp(min=0, max=num_nodes - 1).long() + kmer_idx = kmer_idx.expand([num_batch, -1, k]) + + # Collect neighbors + kmer_h = graph.collect_neighbors(node_h, kmer_idx) + return kmer_h diff --git a/prtm/models/chroma/structure/sidechain.py b/prtm/models/chroma/structure/sidechain.py new file mode 100644 index 00000000..c311ce56 --- /dev/null +++ b/prtm/models/chroma/structure/sidechain.py @@ -0,0 +1,810 @@ +# Copyright Generate Biomedicines, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Layers for modeling protein side chain conformations. + +This module contains layers for building, measuring, and scoring protein side +chain conformations in a differentiable way. These can be used for tasks such +as building differentiable all-atom structures from chi-angles, computing chi +angles from existing structures, and scoring or optimizing side chains using +symmetry-aware rmsds. +""" + + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from prtm.constants import residue_constants +from prtm.models.chroma import graph +from prtm.models.chroma.geometry import AA_GEOMETRY, ATOM_SYMMETRIES +from prtm.models.chroma.structure import protein_graph +from prtm.models.chroma.structure.geometry import ( + dihedrals, + extend_atoms, + frames_from_backbone, + quaternions_from_rotations, + rotations_from_quaternions, +) + + +class SideChainBuilder(nn.Module): + """Protein side chain builder from chi angles. + + When only partial information is given such as chi angles, this module + will default to using the ideal geometries given in the CHARMM toppar + topology files. + + `Optimization of the additive CHARMM all-atom protein force + field targeting improved sampling of the backbone phi, + psi and side-chain chi1 and chi2 dihedral angles` + + Inputs: + X (tensor): Backbone coordinates with shape + `(batch_size, num_residues, 4, 3)`. + C (tensor): Chain map with shape `(batch_size, num_residues)`. + S (tensor): Sequence tokens with shape `(batch_size, num_residues)`. + chi (tensor): Backbone chi angles with shape + `(batch_size, num_residues, 4)`. + + Outputs: + X (tensor): All-atom coordinates with shape + `(batch_size, num_residues, 14, 3)`. + mask_X (tensor): Atomic mask with shape + `(batch_size, num_residues, 14, 1)` + """ + + def __init__(self, distance_eps=1e-6): + super(SideChainBuilder, self).__init__() + self.num_atoms = 10 + self.num_chi = 4 + self.num_aa = len(residue_constants.alphabetical_restypes) + self.distance_eps = distance_eps + + self._init_maps() + + def _init_maps(self): + """Build geometry and topology maps in tensor form.""" + + shape = (3, self.num_atoms, self.num_aa) + self.register_buffer("_Z", torch.zeros(shape, dtype=torch.float)) + self.register_buffer("_parents", torch.zeros(shape, dtype=torch.long)) + self.register_buffer( + "_chi_ix", 10 * torch.ones((self.num_chi, self.num_aa), dtype=torch.long) + ) + + for i, aa in enumerate(residue_constants.alphabetical_restype_3): + aa_dict = AA_GEOMETRY[aa] + atoms_parents = residue_constants.backbone4_atoms + aa_dict["atoms"] + for j, atom in enumerate(aa_dict["atoms"]): + # Internal coordinates per atom + self._Z[0, j, i] = aa_dict["z-lengths"][j] + self._Z[1, j, i] = aa_dict["z-angles"][j] + self._Z[2, j, i] = aa_dict["z-dihedrals"][j] + + # Parent indices per atom + parents = [atoms_parents.index(p) for p in aa_dict["parents"][j]] + self._parents[0, j, i] = parents[0] + self._parents[1, j, i] = parents[1] + self._parents[2, j, i] = parents[2] + + # Map which chi angles are flexible + for j, parent_ix in enumerate(aa_dict["chi_indices"]): + self._chi_ix[j, i] = parent_ix + + # Convert angles from degrees to radians + self._Z[1:, :, :] = self._Z[1:, :, :] * np.pi / 180.0 + + # Manually fix Arginine, for which CHARMM places NH1 in trans to CD + self._Z[2, 5, residue_constants.alphabetical_restypes.index("R")] = 0.0 + + def forward(self, X, C, S, chi=None): + num_batch, num_residues = list(S.shape) + + if X.shape[2] > 4: + X = X[:, :, :4, :] + + # Expand sequence indexing tensors for gathering residue-specific info + # (B,L) => (B,L,4) + S_expand3 = S.unsqueeze(-1).expand(-1, -1, 4) + # (B,L) => (B,AA,ATOM,L) + S_expand4 = S.reshape([num_batch, 1, 1, num_residues]).expand( + -1, 3, self.num_atoms, -1 + ) + + def _gather(Z): + Z_expand = Z.unsqueeze(0).expand([num_batch, -1, -1, -1]) + # (B,3,ATOM,AA) @ (B,3,ATOM,L) => (B,3,ATOM,L) => (B,L,3,ATOM) + Z_i = torch.gather(Z_expand, -1, S_expand4).permute([0, 3, 1, 2]) + return Z_i + + # Build ideal geometry length, angle, and dihedral tensors 3x(B,L,10) + B, A, D = _gather(self._Z).unbind(-2) + + if chi is not None: + # Scatter chi angles (B,L,4) onto their corresponding dihedrals (B,L,10) + # (4,AA) => (B,AA,4) + chi_ix_expand = ( + self._chi_ix.unsqueeze(0).expand([num_batch, -1, -1]).transpose(-2, -1) + ) + # (B,AA,4) @ (B,L,4) => (B,L,4) + chi_ix_i = torch.gather(chi_ix_expand, -2, S_expand3) + + # Scatter extra chi angles into an extra pad dimension & re-slice + # (B,L,10) <- (B,L,4),(B,L,4) => (B,L,10) + D_pad = F.pad(D, (0, 1)) + D_pad = torch.scatter(D_pad, -1, chi_ix_i, chi) + D = D_pad[:, :, : self.num_atoms] + + # Build indices of parent atoms (B,L,3,10) + X_full = F.pad(X, (0, 0, 0, self.num_atoms)) + parents = _gather(self._parents) + + # Build atom i given current frame + for i in range(self.num_atoms): + # Gather parents (B,L,A,3) => (B,L,3,3) + parents_expand = parents[:, :, :, i].unsqueeze(-1).expand(-1, -1, -1, 3) + # (B,L,A,3) @ (B,L,3,3) => (B,L,3,3) + X1, X2, X3 = torch.gather(X_full, -2, parents_expand).unbind(-2) + + # Extend atom i + X4 = extend_atoms( + X1, + X2, + X3, + B[:, :, i], + A[:, :, i], + D[:, :, i], + degrees=False, + distance_eps=self.distance_eps, + ) + + # Scatter + # X[:,:,i+4,:] = X4 + # scatter_ix = (i+4) * torch.ones( + # (num_batch,num_residues,1,3), dtype=torch.long + # ) + # print(X_full.shape, X4.shape, scatter_ix.shape, i+4) + # print(scatter_ix) + # X_full.scatter_(-2, scatter_ix, X4.unsqueeze(-2)) + # X_full = torch.scatter(X_full, -2, scatter_ix, X4) + # X_full = X_full + 0.1*X4.mean() + + # For some reason direct scatter causes autograd issues + X4_expand = F.pad(X4.unsqueeze(-2), (0, 0, 4 + i, 9 - i)) + X_full = X_full + X4_expand + + # DEBUG: TEST + if False: + D_reconstruct = dihedrals(X1, X2, X3, X4) + D_error = ( + (torch.cos(D[:, :, i]) - torch.cos(D_reconstruct)) ** 2 + + (torch.sin(D[:, :, i]) - torch.sin(D_reconstruct)) ** 2 + ).mean() + print(D_error) + + mask_X = atom_mask(C, S).unsqueeze(-1) + X_full = mask_X * X_full + return X_full, mask_X + + +class ChiAngles(nn.Module): + """Computes Chi-angles from an all-atom protein structure. + + Inputs: + X (tensor): Atomic coordinates with shape + `(batch_size, num_residues, 14, 3)`. + C (tensor): Chain map with shape `(batch_size, num_residues)`. + S (tensor): Sequence tokens with shape `(batch_size, num_residues)`. + + Outputs: + chi (tensor): Backbone chi angles with shape + `(batch_size, num_residues, 4)`. + mask_chi (tensor): Chi angle mask with shape + `(batch_size, num_residues, 4)`. + """ + + def __init__(self, distance_eps=1e-6): + super(ChiAngles, self).__init__() + self.num_atoms = 10 + self.num_chi = 4 + self.num_aa = len(residue_constants.alphabetical_restypes) + + self.distance_eps = distance_eps + + self._init_maps() + + def _init_maps(self): + """Build geometry and topology maps in tensor form.""" + + self.register_buffer( + "_chi_atom_sets", + torch.zeros((self.num_aa, self.num_chi, 4), dtype=torch.long), + ) + + for i, aa in enumerate(residue_constants.alphabetical_restype_3): + aa_dict = AA_GEOMETRY[aa] + atoms_names = residue_constants.backbone4_atoms + aa_dict["atoms"] + + # Map which chi angles are flexible + for j, parent_ix in enumerate(aa_dict["chi_indices"]): + atom_quartet = aa_dict["parents"][parent_ix] + [ + aa_dict["atoms"][parent_ix] + ] + for k, atom in enumerate(atom_quartet): + self._chi_atom_sets[i, j, k] = atoms_names.index(atom) + + def forward(self, X, C, S): + num_batch, num_residues = list(S.shape) + # (B,L) => (B,L,16) + S_expand = S.unsqueeze(-1).expand([-1, -1, 16]) + + # (AA,CHI,ATOM) => (AA,16) => (B,AA,16) + chi_indices_per_aa = self._chi_atom_sets.reshape([1, self.num_aa, 16]) + chi_indices_per_aa = chi_indices_per_aa.expand([num_batch, -1, -1]) + + # (B,AA,16) @ (B,L,16) => (B,L,16) => (B,L,16) + chi_indices = torch.gather(chi_indices_per_aa, -2, S_expand) + chi_indices = chi_indices.unsqueeze(-1).expand([-1, -1, -1, 3]) + + # (B,L,14,3) @ (B,L,16,3) => (B,L,16,3) => (B,L,4,4,3) => (B,L,4) + X_chi = torch.gather(X, -2, chi_indices) + X_1, X_2, X_3, X_4 = X_chi.reshape([num_batch, num_residues, 4, 4, 3]).unbind( + -2 + ) + + chi = dihedrals(X_1, X_2, X_3, X_4, distance_eps=self.distance_eps) + + mask_chi = chi_mask(C, S) + chi = chi * mask_chi + return chi, mask_chi + + +class SideChainSymmetryRenamer(nn.Module): + """Rename atom to their 180-degree symmetry alternatives via permutation. + + Inputs: + X (tensor): Atomic coordinates with shape + `(batch_size, num_residues, 14, 3)`. + S (tensor): Sequence tokens with shape `(batch_size, num_residues)`. + + Outputs: + X_renamed (tensor): Renamed atomic coordinates with shape + `(batch_size, num_residues, 14, 3)`. + """ + + def __init__(self): + super(SideChainSymmetryRenamer, self).__init__() + self.num_atoms = 10 + self.num_aa = len(residue_constants.alphabetical_restypes) + + # Build symmetry indices give alternative atom labelings + self.register_buffer( + "_symmetry_indices", + torch.arange(self.num_atoms).unsqueeze(0).repeat(self.num_aa, 1), + ) + for i, aa in enumerate(residue_constants.alphabetical_restype_3): + if aa in ATOM_SYMMETRIES: + for aa_1, aa_2 in ATOM_SYMMETRIES[aa]: + atom_names = AA_GEOMETRY[aa]["atoms"] + ix_1 = atom_names.index(aa_1) + ix_2 = atom_names.index(aa_2) + self._symmetry_indices[i, ix_1] = ix_2 + self._symmetry_indices[i, ix_2] = ix_1 + + def _gather_per_residue(self, AA_table, S): + num_batch, num_residues = list(S.shape) + + # (B,L) => (B,L,ATOM) + S_expand = S.unsqueeze(-1).expand([-1, -1, self.num_atoms]) + + # (AA,ATOM) => (B,AA,ATOM) + value_per_aa = AA_table.unsqueeze(0).expand([num_batch, -1, -1]) + + # (B,AA,ATOM) @ (B,L,ATOM) => (B,L,ATOM) + value_per_residue = torch.gather(value_per_aa, -2, S_expand) + return value_per_residue + + def forward(self, X, S): + alt_indices = self._gather_per_residue(self._symmetry_indices, S) + alt_indices = alt_indices.unsqueeze(-1).expand(-1, -1, -1, 3) + + X_bb, X_sc = X[:, :, :4, :], X[:, :, 4:, :] + X_sc_alternate = torch.gather(X_sc, -2, alt_indices) + X_alternate = torch.cat([X_bb, X_sc_alternate], dim=-2) + return X_alternate + + +class AllAtomFrameBuilder(nn.Module): + """Build all-atom protein structure from oriented C-alphas and chi angles. + + Inputs: + x (Tensor): C-alpha coordinates with shape `(num_batch, num_residues, 3)`. + q (Tensor): Quaternions representing C-alpha orientiations with shape + with shape `(num_batch, num_residues, 4)`. + chi (tensor): Backbone chi angles with shape + `(num_batch, num_residues, 4)`. + C (tensor): Chain map with shape `(num_batch, num_residues)`. + S (tensor): Sequence tokens with shape `(num_batch, num_residues)`. + + Outputs: + X (Tensor): All-atom protein coordinates with shape + `(num_batch, num_residues, 14, 3)` + """ + + def __init__(self): + super(AllAtomFrameBuilder, self).__init__() + self.sidechain_builder = SideChainBuilder() + self.chi_angles = ChiAngles() + + # Build idealized backbone fragment + # IC +N CA *C O 1.3558 116.8400 180.0000 122.5200 1.2297 + dX = torch.tensor( + [ + [1.459, 0.0, 0.0], # N-C via Engh & Huber is 1.459 + [0.0, 0.0, 0.0], # CA is origin + [-0.547, 0.0, -1.424], # C is placed 1.525 A @ 111 degrees from N + ], + dtype=torch.float32, + ) + self.register_buffer("_dX_local", dX) + + def forward(self, x, q, chi, C, S): + # Build backbone + R = rotations_from_quaternions(q, normalize=True) + dX = torch.einsum("ay,nixy->niax", self._dX_local, R) + X_chain = x.unsqueeze(-2) + dX + + # Build carboxyl groups + X_N, X_CA, X_C = X_chain.unbind(-2) + + # TODO: fix this behavior for termini + mask_next = (C > 0).float()[:, 1:].unsqueeze(-1) + X_N_next = F.pad( + mask_next + * X_N[ + :, + 1:, + ], + (0, 0, 0, 1), + ) + + num_batch, num_residues = C.shape + ones = torch.ones(list(C.shape), dtype=torch.float32, device=C.device) + X_O = extend_atoms( + X_N_next, + X_CA, + X_C, + 1.2297 * ones, + 122.5200 * ones, + 180 * ones, + degrees=True, + ) + X_bb = torch.stack([X_N, X_CA, X_C, X_O], dim=-2) + + # Build sidechains + X, mask_atoms = self.sidechain_builder(X_bb, C, S, chi) + return X, mask_atoms + + def inverse(self, X, C, S): + X_bb = X[:, :, :4, :] + R, x = frames_from_backbone(X_bb) + q = quaternions_from_rotations(R) + chi, mask_chi = self.chi_angles(X, C, S) + return x, q, chi + + +class LossSideChainRMSD(nn.Module): + """Compute side chain RMSDs per residues from an all-atom protein structure. + + Inputs: + X (tensor): Atomic coordinates with shape + `(batch_size, num_residues, 14, 3)`. + X_target (tensor): Atomic coordinates with shape + `(batch_size, num_residues, 14, 3)`. + S (tensor): Sequence tokens with shape `(batch_size, num_residues)`. + + Outputs: + chi (tensor): Backbone chi angles with shape + `(batch_size, num_residues, 4)`. + """ + + def __init__(self, rmsd_eps=1e-2): + super(LossSideChainRMSD, self).__init__() + self.num_atoms = 10 + self.num_aa = len(residue_constants.alphabetical_restypes) + + self.rmsd_eps = rmsd_eps + self.renamer = SideChainSymmetryRenamer() + + def _rmsd(self, X, X_target, atom_mask): + sd = atom_mask * ((X - X_target) ** 2).sum(-1) + rmsd = torch.sqrt( + sd.sum(-1) / (atom_mask.sum(-1) + self.rmsd_eps) + self.rmsd_eps + ) + return rmsd + + def forward(self, X, X_target, C, S, include_symmetry=True): + mask_atoms = atom_mask(C, S) + + X_alt = self.renamer(X, S)[:, :, 4:, :] + X = X[:, :, 4:, :] + X_target = X_target[:, :, 4:, :] + mask_atoms = mask_atoms[:, :, 4:] + + rmsd = self._rmsd(X, X_target, mask_atoms) + if include_symmetry: + rmsd_alternate = self._rmsd(X_alt, X_target, mask_atoms) + + # rmsd = torch.minimum(rmsd, rmsd_alternate) + rmsd = torch.stack([rmsd, rmsd_alternate], -1).min(-1)[0] + rmsd = (C > 0).float() * rmsd + return rmsd + + +class LossFrameAlignedGraph(nn.Module): + """Compute the frame-aligned loss on a nearest neighbors graph. + + Args: + num_neighbors (int): Number of neighbors to build in the graph. Default + is 30. + + Inputs: + X (tensor): Atomic coordinates with shape + `(batch_size, num_residues, 14, 3)`. + X_target (tensor): Atomic coordinates with shape + `(batch_size, num_residues, 14, 3)`. + C (tensor): Chain map with shape `(batch_size, num_residues)`. + S (tensor): Sequence tokens with shape `(batch_size, num_residues)`. + + Outputs: + D (tensor): Per-residue losses with shape `(batch_size, num_residues)`. + """ + + def __init__( + self, + num_neighbors=30, + distance_eps=1e-2, + distance_scale=10.0, + interface_only=False, + ): + super(LossFrameAlignedGraph, self).__init__() + self.distance_eps = distance_eps + self.distance_scale = distance_scale + + self.renamer = SideChainSymmetryRenamer() + self.graph_builder = protein_graph.ProteinGraph(num_neighbors) + self.interface_only = interface_only + + def _frame_ij(self, X, edge_idx): + # Build local frames + num_batch, num_residues, num_atoms, _ = X.shape + + # Build frames at neighbor j (B,L,K,3,3), (B,L,K,3) + X_bb_flat = X[:, :, :4, :].reshape([num_batch, num_residues, -1]) + X_j_flat = graph.collect_neighbors(X_bb_flat, edge_idx) + X_j = X_j_flat.reshape([num_batch, num_residues, -1, 4, 3]) + R_j, X_j_CA = frames_from_backbone(X_j, distance_eps=self.distance_eps) + + # (B,L,1,A,3) - (B,L,K,1,3) => (B,L,K,A,3) + X_ij = X.unsqueeze(-3) - X_j_CA.unsqueeze(-2) + + # Rotate displacements into local frames + r_ij = torch.einsum("nijax,nijxy->nijay", X_ij, R_j) + return r_ij + + def _dist(self, r_ij_1, r_ij_2): + D_sq = (r_ij_1 - r_ij_2) ** 2 + D = torch.sqrt(D_sq.sum(-1) + self.distance_eps) + return D + + def forward(self, X, X_target, C, S): + if X_target.size(2) == 14: + mask_atoms = atom_mask(C, S) + X_alt = self.renamer(X, S) + elif X_target.size(2) == 4: + mask_atoms = (C > 0).float().unsqueeze(-1).expand([-1, -1, 4]) + X_alt = X + else: + raise Exception( + "Size of atom dimension must be 4 (backbone) or 14 (all-atom)." + ) + + # Build the union graph + custom_mask_2D = None + if self.interface_only: + custom_mask_2D = torch.ne(C.unsqueeze(1), C.unsqueeze(2)).float() + edge_idx_model, _ = self.graph_builder( + X[:, :, :4, :], C, custom_mask_2D=custom_mask_2D + ) + edge_idx_target, _ = self.graph_builder( + X_target[:, :, :4, :], C, custom_mask_2D=custom_mask_2D + ) + edge_idx = torch.cat([edge_idx_model, edge_idx_target], 2) + + # Build frame-aligned displacement vectors (B,N,K,A,3) + r_ij = self._frame_ij(X, edge_idx) + r_ij_alt = self._frame_ij(X_alt, edge_idx) + r_ij_target = self._frame_ij(X_target, edge_idx) + + # Build 2D masks (B,N,K,A) + num_batch, num_residues, num_atoms, _ = X.shape + mask_residues = (C > 0).float() + # (B,N,1,A) + mask_i = mask_atoms.reshape([num_batch, num_residues, 1, num_atoms]) + # (B,N,K,1) + mask_j = graph.collect_neighbors(mask_residues.unsqueeze(-1), edge_idx) + mask_ij = mask_i * mask_j + + # Build frame-aligned displacement vectors (B,N,N,A) + D = mask_ij * self._dist(r_ij, r_ij_target) + D_alt = mask_ij * self._dist(r_ij_alt, r_ij_target) + + # Which definition of atom j gives a better score? (B,N) + mask_reduce = mask_ij.sum([-2, -1]) + D_j = D.sum([-2, -1]) / (mask_reduce + self.distance_eps) + D_j_alt = D_alt.sum([-2, -1]) / (mask_reduce + self.distance_eps) + D_j_min = torch.stack([D_j, D_j_alt], -1).min(-1)[0] + + # Return as a per-residue loss + return D_j_min + + +class LossAllAtomDistances(nn.Module): + """Compute the interatomic distance loss on a nearest neighbors graph. + + Args: + num_neighbors (int): Number of neighbors to build in the graph. Default + is 30. + + Inputs: + X (tensor): Atomic coordinates with shape + `(batch_size, num_residues, 14, 3)`. + X_target (tensor): Atomic coordinates with shape + `(batch_size, num_residues, 14, 3)`. + C (tensor): Chain map with shape `(batch_size, num_residues)`. + S (tensor): Sequence tokens with shape `(batch_size, num_residues)`. + + Outputs: + D (tensor): Per-residue losses with shape `(batch_size, num_residues)`. + """ + + def __init__(self, num_neighbors=30, distance_eps=1e-2): + super(LossAllAtomDistances, self).__init__() + self.distance_eps = distance_eps + + self.graph_builder = protein_graph.ProteinGraph(num_neighbors) + + def _dist_ij(self, X, edge_idx): + # Build local frames + num_batch, num_residues, num_atoms, _ = X.shape + + # Build frames at neighbor j (B,L,K,), (B,L,K,A,3) + X_flat = X.reshape([num_batch, num_residues, -1]) + X_j_flat = graph.collect_neighbors(X_flat, edge_idx) + X_j = X_j_flat.reshape([num_batch, num_residues, -1, num_atoms, 3]) + X_i = X.unsqueeze(2).expand([-1, -1, X_j.shape[2], -1, -1]) + + X_ij = torch.cat([X_i, X_j], -2) + D_ij = torch.sqrt( + ((X_ij.unsqueeze(-2) - X_ij.unsqueeze(-3)) ** 2).sum(-1) + self.distance_eps + ) + return D_ij + + def _mask_ij(self, C, S, edge_idx): + # (B,L,A) + mask_atoms = atom_mask(C, S) + + mask_j = graph.collect_neighbors(mask_atoms, edge_idx) + mask_i = mask_atoms.unsqueeze(2).expand([-1, -1, edge_idx.shape[2], -1]) + mask_ij = torch.cat([mask_i, mask_j], -1) + + mask_D = mask_ij.unsqueeze(-1) * mask_ij.unsqueeze(-2) + return mask_D + + def forward(self, X, X_target, C, S): + # Build the union graph + edge_idx_model, _ = self.graph_builder(X[:, :, :4, :], C) + edge_idx_target, _ = self.graph_builder(X_target[:, :, :4, :], C) + edge_idx = torch.cat([edge_idx_model, edge_idx_target], 2) + + mask_ij = self._mask_ij(C, S, edge_idx) + D_model = self._dist_ij(X, edge_idx) + D_target = self._dist_ij(X_target, edge_idx) + + loss = torch.sqrt((D_model - D_target) ** 2 + self.distance_eps) + loss_i = (mask_ij * loss).sum([2, 3, 4]) / ( + mask_ij.sum([2, 3, 4]) + self.distance_eps + ) + return loss_i + + +class LossSidechainClashes(nn.Module): + """Count sidechain clashes in a structure using a nearest neighbors graph. + + This uses the Van der Waals radii based definition of bonding + in pymol as described at https://pymolwiki.org/index.php/Connect_cutoff. + + Args: + num_neighbors (int, optional): Number of neighbors to + build in the graph. Default is 30. + connect_cutoff (float, optional): Bonding cutoff used in formula + `D_clash_cutoff = D_vdw / 2. + self.connect_cutoff`. Default is + 0.35. + use_smooth_cutoff (bool, optional): If True, use a differentiable + definition of clashes by replacing `D < cutoff` with + `sigmoid(smooth_cutoff_alpha * (cutoff - D))`. Default is False. + smooth_cutoff_alpha (float, optional): Steepness parameter for + differentiable clashes, as `alpha -> infinity` it will behave as + discrete cutoff. Default is 1.0. + + Inputs: + X (tensor): Atomic coordinates with shape + `(batch_size, num_residues, 14, 3)`. + C (tensor): Chain map with shape `(batch_size, num_residues)`. + S (tensor): Sequence tokens with shape `(batch_size, num_residues)`. + mask_j (tensor, optional): Binary mask encoding which side chains + should be tested for clashing. + + Outputs: + clashes (tensor): Per-residue number of clashes with shape + `(batch_size, num_residues)`. + """ + + def __init__( + self, + num_neighbors=30, + distance_eps=1e-3, + connect_cutoff=0.35, + use_smooth_cutoff=False, + smooth_cutoff_alpha=1.0, + ): + super(LossSidechainClashes, self).__init__() + self.distance_eps = distance_eps + self.graph_builder = protein_graph.ProteinGraph(num_neighbors) + self.connect_cutoff = connect_cutoff + self.use_smooth_cutoff = use_smooth_cutoff + self.smooth_cutoff_alpha = smooth_cutoff_alpha + + def _dist_ij(self, X, edge_idx): + num_batch, num_residues, num_atoms, _ = X.shape + + # Build frames at neighbor j (B,L,K,), (B,L,K,A,3) + X_flat = X.reshape([num_batch, num_residues, -1]) + X_j_flat = graph.collect_neighbors(X_flat, edge_idx) + X_j = X_j_flat.reshape([num_batch, num_residues, -1, num_atoms, 3]) + X_i = X.unsqueeze(2).expand([-1, -1, X_j.shape[2], -1, -1]) + + D_ij = torch.sqrt( + ((X_i.unsqueeze(-2) - X_j.unsqueeze(-3)) ** 2).sum(-1) + self.distance_eps + ) + return D_ij + + def _mask_ij(self, C, S, edge_idx, mask_j=None): + # (B,L,A) + mask_atoms = atom_mask(C, S) + + # Mask only present atoms + mask_atoms_j = mask_atoms + if mask_j is not None: + mask_atoms_j = mask_atoms_j * mask_j.unsqueeze(-1) + mask_j = graph.collect_neighbors(mask_atoms_j, edge_idx) + mask_i = mask_atoms.unsqueeze(2).expand([-1, -1, edge_idx.shape[2], -1]) + mask_D = mask_i.unsqueeze(-1) * mask_j.unsqueeze(-2) + + # Mask self interactions + node_idx = torch.arange(C.shape[1], device=C.device).reshape([1, -1, 1]) + mask_ne = torch.ne(edge_idx, node_idx) + mask_D = mask_D * mask_ne.reshape(list(mask_ne.shape) + [1, 1]) + return mask_D + + def _gather_vdw_radii(self, C, S): + vdw_radii = {"C": 1.7, "N": 1.55, "O": 1.52, "S": 1.8} + + # Van der waal radii per atom per residue [AA,ATOM] + R = torch.zeros([20, 14], device=C.device) + for i, aa in enumerate(residue_constants.alphabetical_restype_3): + atoms = residue_constants.backbone4_atoms + AA_GEOMETRY[aa]["atoms"] + for j, atom_name in enumerate(atoms): + R[i, j] = vdw_radii[atom_name[0]] + + # (B, AA, ATOM) @ (B, L, ATOM) => (B, L, ATOM) + R = R.reshape([1, 20, 14]).expand([C.shape[0], -1, -1]) + S = S.unsqueeze(-1).expand([-1, -1, 14]) + atom_radii = torch.gather(R, 1, S) + return atom_radii + + def _gather_vdw_diameters(self, C, S, edge_idx): + num_batch, num_residues, num_neighbors = edge_idx.shape + + # Gather van der Waals radii + radii_i = self._gather_vdw_radii(C, S) + radii_j = graph.collect_neighbors(radii_i, edge_idx) + radii_i = radii_i.reshape([num_batch, num_residues, 1, -1, 1]) + radii_j = radii_j.reshape([num_batch, num_residues, num_neighbors, 1, -1]) + + D_vdw = radii_i + radii_j + return D_vdw + + def forward(self, X, C, S, edge_idx=None, mask_j=None, mask_ij=None): + # Compute sidechain interatomic distances + if edge_idx is None: + edge_idx, mask_ij = self.graph_builder(X[:, :, :4, :], C) + + # Distance with shape [B,L,K,AI,AJ] + mask_clash_ij = self._mask_ij(C, S, edge_idx, mask_j) + if mask_ij is not None: + mask_clash_ij = mask_clash_ij * mask_ij.reshape( + list(mask_ij.shape) + [1, 1] + ) + D = self._dist_ij(X, edge_idx) + D_vdw = self._gather_vdw_diameters(C, S, edge_idx) + D_clash_cutoff = D_vdw / 2.0 + self.connect_cutoff + + # Optionally use a smooth definition of clashes that is differentiable + if self.use_smooth_cutoff: + bond_clash = mask_clash_ij * torch.sigmoid( + self.smooth_cutoff_alpha * (D_clash_cutoff - D) + ) + else: + bond_clash = mask_clash_ij * (D < D_clash_cutoff).float() + + # Only cound outgoing clashes from sidechain atoms at i + bond_clash = bond_clash[:, :, :, 4:, :] + + clashes = bond_clash.sum([2, 3, 4]) + return clashes + + +def _gather_atom_mask(C, S, atoms_per_aa, num_atoms): + device = S.device + atoms_per_aa = torch.tensor(atoms_per_aa, dtype=torch.long) + atoms_per_aa = atoms_per_aa.to(device).unsqueeze(0).expand(S.shape[0], -1) + + # (B,A) @ (B,L) => (B,L) + atoms_per_residue = torch.gather(atoms_per_aa, -1, S) + atoms_per_residue = (C > 0).float() * atoms_per_residue + + ix_expand = torch.arange(num_atoms, device=device).reshape([1, 1, -1]) + mask_atoms = ix_expand < atoms_per_residue.unsqueeze(-1) + mask_atoms = mask_atoms.float() + return mask_atoms + + +def atom_mask(C, S): + """Constructs a all-atom coordinate mask from a sequence and chain map. + + Inputs: + C (tensor): Chain map with shape `(batch_size, num_residues)`. + S (tensor): Sequence tokens with shape `(batch_size, num_residues)`. + + Outputs: + mask_atoms (tensor): Atomic mask with shape + `(batch_size, num_residues, 14)`. + """ + return _gather_atom_mask(C, S, residue_constants.num_atoms_per_alphabetical_aa, 14) + + +def chi_mask(C, S): + """Constructs a all-atom coordinate mask from a sequence and chain map. + + Inputs: + C (tensor): Chain map with shape `(batch_size, num_residues)`. + S (tensor): Sequence tokens with shape `(batch_size, num_residues)`. + + Outputs: + mask_atoms (tensor): Chi angle mask with shape + `(batch_size, num_residues, 4)`. + """ + return _gather_atom_mask(C, S, residue_constants.num_chi_per_alphabetical_aa, 4) diff --git a/prtm/models/chroma/structure/symmetry.py b/prtm/models/chroma/structure/symmetry.py new file mode 100644 index 00000000..1a0c2f12 --- /dev/null +++ b/prtm/models/chroma/structure/symmetry.py @@ -0,0 +1,258 @@ +# Copyright Generate Biomedicines, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Layers for euclidean symmetry group operations + +This module contains pytorch layers for symmetry operations for point groups (Cyclic, Dihedral, Tetrahedral, Octahedral and Icosahedral) +""" + +import math +from itertools import product +from typing import Optional, Tuple + +import numpy as np +import torch + +TAU = 0.5 * (1 + math.sqrt(5)) + +ROT_DICT = { + "O": [ + [[0.0, 1.0, 0.0], [-1.0, 0.0, 0.0], [0.0, 0.0, 1.0]], + [[0.0, 0.0, 1.0], [-1.0, 0.0, 0.0], [0.0, -1.0, 0.0]], + [[1.0, 0.0, 0.0], [0.0, 0.0, -1.0], [0.0, 1.0, 0.0]], + [[0.0, 0.0, 1.0], [0.0, -1.0, 0.0], [1.0, 0.0, 0.0]], + [[-1.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 1.0, 0.0]], + [[0.0, -1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 1.0]], + [[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]], + [[0.0, 0.0, 1.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], + [[-1.0, 0.0, 0.0], [0.0, -1.0, 0.0], [0.0, 0.0, 1.0]], + [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], + [[1.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, -1.0, 0.0]], + [[0.0, 0.0, 1.0], [0.0, 1.0, 0.0], [-1.0, 0.0, 0.0]], + [[0.0, -1.0, 0.0], [0.0, 0.0, 1.0], [-1.0, 0.0, 0.0]], + [[0.0, 1.0, 0.0], [0.0, 0.0, -1.0], [-1.0, 0.0, 0.0]], + [[0.0, 0.0, -1.0], [0.0, 1.0, 0.0], [1.0, 0.0, 0.0]], + [[0.0, 0.0, -1.0], [1.0, 0.0, 0.0], [0.0, -1.0, 0.0]], + [[0.0, 0.0, -1.0], [-1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], + [[0.0, -1.0, 0.0], [0.0, 0.0, -1.0], [1.0, 0.0, 0.0]], + [[0.0, 1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, -1.0]], + [[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, -1.0]], + [[1.0, 0.0, 0.0], [0.0, -1.0, 0.0], [0.0, 0.0, -1.0]], + [[0.0, -1.0, 0.0], [-1.0, 0.0, 0.0], [0.0, 0.0, -1.0]], + [[0.0, 0.0, -1.0], [0.0, -1.0, 0.0], [-1.0, 0.0, 0.0]], + [[-1.0, 0.0, 0.0], [0.0, 0.0, -1.0], [0.0, -1.0, 0.0]], + ], + "T": [ + [[-1.0, 0.0, 0.0], [0.0, -1.0, 0.0], [0.0, 0.0, 1.0]], + [[1.0, 0.0, 0.0], [0.0, -1.0, 0.0], [0.0, 0.0, -1.0]], + [[0.0, 0.0, -1.0], [1.0, 0.0, 0.0], [0.0, -1.0, 0.0]], + [[0.0, 0.0, -1.0], [-1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], + [[0.0, 0.0, 1.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], + [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], + [[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, -1.0]], + [[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]], + [[0.0, 0.0, 1.0], [-1.0, 0.0, 0.0], [0.0, -1.0, 0.0]], + [[0.0, -1.0, 0.0], [0.0, 0.0, 1.0], [-1.0, 0.0, 0.0]], + [[0.0, 1.0, 0.0], [0.0, 0.0, -1.0], [-1.0, 0.0, 0.0]], + [[0.0, -1.0, 0.0], [0.0, 0.0, -1.0], [1.0, 0.0, 0.0]], + ], +} + + +def get_point_group(group: str) -> torch.Tensor: + """get representation of group elements at torch.Tensor + + Args: + group (str): group names, selecting from {"C_{n}" , "D_{n}", "T", "O", "I" } + + Returns: + torch.Tensor: rotation matrices for queried point groups + """ + if group.startswith("C"): + n = group.split("_")[1] + G = get_Cn_groups(int(n)) + elif group.startswith("D"): + n = group.split("_")[1] + G = get_Dn_groups(int(n)) + elif group == "I": + G = get_I_rotations() + elif group == "O" or group == "T": + G = torch.Tensor(np.array(ROT_DICT[group])) + else: + raise ValueError("{ } not available".format(group)) + + return G + + +def get_Cn_groups(n: int) -> torch.Tensor: + """get rotation matrices for Cyclic groups + + Args: + n (int): symmetry order + + Returns: + torch.Tensor: n x 3 x 3 + """ + G = [] + for ri in range(n): + cos_phi = np.round(np.cos(ri * np.pi * 2 / n), 4) + sin_phi = np.round(np.sin(ri * np.pi * 2 / n), 4) + + g = np.array( + [[cos_phi, -sin_phi, 0.0], [sin_phi, cos_phi, 0.0], [0.0, 0.0, 1.0]] + ) + G.append(np.round(g, 4)) + + return torch.Tensor(np.array(G)) + + +def get_Dn_groups(n: int) -> torch.Tensor: + """get rotation matrices for Dihedral groups + + Args: + n (int): symmetry order + + Returns: + torch.Tensor: 2n x 3 x 3 + """ + cos_phi = np.round( + np.cos(np.pi * 2 / n), 8 + ) # unify the choice of # of decimals to keep + sin_phi = np.round(np.sin(np.pi * 2 / n), 8) + + rot_generator = np.array( + [[cos_phi, -sin_phi, 0.0], [sin_phi, cos_phi, 0.0], [0.0, 0.0, 1.0]] + ) + + b = np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]]) + + G = [] + c = rot_generator + for itr in range(n): + c_new = c @ rot_generator + G.append(c_new) + c = c_new + G.append(b @ c) + return torch.Tensor(np.array(G)) + + +def get_I_rotations(tree_depth: int = 5) -> torch.Tensor: + """get rotation matrices for the Icosahedral group (I) + + Returns: + torch.Tensor: 60 x 3 x 3 + """ + + tree_depth = 5 # tree traverse depth + + g1 = torch.Tensor([[-1.0, 0.0, 0.0], [0.0, -1.0, 0.0], [0.0, 0.0, 1.0]]) + g2 = torch.Tensor([[0.0, 0.0, 1.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]) + g3 = torch.Tensor( + [ + [0.5, -0.5 * TAU, 0.5 / TAU], + [0.5 * TAU, 0.5 / TAU, -0.5], + [0.5 / TAU, 0.5, 0.5 * TAU], + ] + ) + # g4 = torch.Tensor([[-1., 0., 0.], [0., -1., 0.], [0., 0., -1.]]) reflection operation + generators = [g1, g2, g3] + + gen1 = generators + gen2 = generators + + sym_ops = [] + + # todo: there is probably better way to do this + # brute force search by traversing the Caley graph + for itr in range(tree_depth): + mat_prod = [mat_pair[0] @ mat_pair[1] for mat_pair in product(gen1, gen2)] + sym_ops += mat_prod + gen1 = mat_prod + + # find unique rotation matrices + sym_ops = torch.unique(torch.round(torch.stack(sym_ops, dim=0), decimals=6), dim=0) + + return sym_ops + + +def subsample( + X: torch.Tensor, + C: torch.Tensor, + G: torch.Tensor, + knbr: int, + seed_idx: Optional[int] = None, +): + """generate substructures based on distances between subunit COM + + Args: + X (torch.Tensor): structures + C (torch.Tensor): chain map + G (torch.Tensor): rotation matrices + knbr (int): number of nearest neighbors + seed_idx (int, optional): seed idx, this will be randomly selected if set to None. Defaults to None. + + Returns: + tuple: substructure coordinates, chain map, indices associated with all the substructure chains, seed idx + """ + + if knbr > G.shape[0] - 1: + knbr = G.shape[0] - 1 + + G.shape[0] + X_chain_com = X.reshape(1, G.shape[0], -1, 3).mean(-2) + + if seed_idx is None: + seed_idx = torch.randint(0, G.shape[0], (1,)).item() + + Dis_chain = ( + (X_chain_com.unsqueeze(-2) - X_chain_com.unsqueeze(-3)).pow(2).sum(-1).sqrt() + ) + + subdomain_idx = Dis_chain[0, seed_idx].topk(knbr + 1, largest=False)[1] + + X_subdomain = X.reshape(1, G.shape[0], -1, 4, 3)[:, subdomain_idx] + X_subdomain = X_subdomain.reshape(1, -1, 4, 3) + C_subdomain = C.reshape(1, G.shape[0], -1)[:, : knbr + 1, :].reshape(1, -1) + + return X_subdomain, C_subdomain, subdomain_idx, seed_idx + + +def symmetrize_XCS( + X: torch.Tensor, + C: torch.LongTensor, + S: torch.LongTensor, + G: torch.Tensor, + device: str = "cpu", +) -> Tuple[torch.Tensor, torch.LongTensor, torch.LongTensor]: + """Symmetrize a protein structure with a given symmetry group. + + This function takes a data tensor X, a chain tensor C, a Sequence Tensor S, and a symmetry group tensor G. The function returns a symmetrized data tensor X_complex, a modified chain tensor C_complex, and a replicated Sequence Tensor S_complex. The function uses the torch.einsum function to apply the symmetry group G to each chain in X and concatenate them into X_complex. The function also modifies the chain labels in C by multiplying them by the symmetry index and concatenates them into C_complex. The function also replicates the Sequence Tensor S by the number of symmetry elements and concatenates them into S_complex. + + Args: + X (torch.Tensor): Data tensor with shape `(batch_size, num_residues, 4, 3)`. + C (torch.LongTensor): Chain tensor with shape `(batch_size, num_residues)`. + S (torch.LongTensor): Sequence tensor with shape `(batch_size, num_residues)`. + G (torch.Tensor): Symmetry group tensor with shape `(n_sym, 3, 3)`. + device (str, optional): The device to use for computation. Defaults to "cpu". + + Returns: + torch.Tensor: Symmetrized data tensor with shape `(batch_size, num_residues * n_sym, 4, 3)`. + torch.LongTensor: Modified chain tensor with shape `(batch_size, num_residues * n_sym)`. + torch.LongTensor: Modified Sequence tensor with shape `(batch_size * n_sym,)`. + """ + G = G.to(S.device) + X_complex = torch.einsum("gij,bnlj->gnli", G, X).reshape(1, -1, 4, 3).to(device) + C_complex = torch.cat([C * (i + 1) for i in range(G.shape[0])], 1).to(device) + S_complex = torch.cat([S for i in range(G.shape[0])], 1).to(device) + return X_complex, C_complex, S_complex diff --git a/prtm/models/chroma/structure/transforms.py b/prtm/models/chroma/structure/transforms.py new file mode 100644 index 00000000..374a4e95 --- /dev/null +++ b/prtm/models/chroma/structure/transforms.py @@ -0,0 +1,441 @@ +# Copyright Generate Biomedicines, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Layers for batched 3D transformations, such as residue poses. + +This module contains pytorch layers for computing and composing with +3D, 6-degree-of freedom transformations. +""" + + +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F +from prtm.models.chroma import graph +from prtm.models.chroma.structure import geometry + + +def compose_transforms( + R_a: torch.Tensor, t_a: torch.Tensor, R_b: torch.Tensor, t_b: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + """Compose transforms `T_compose = T_a * T_b` (broadcastable). + + Args: + R_a (torch.Tensor): Transform `T_a` rotation matrix with shape `(...,3,3)`. + t_a (torch.Tensor): Transform `T_a` translation with shape `(...,3)`. + R_b (torch.Tensor): Transform `T_b` rotation matrix with shape `(...,3,3)`. + t_b (torch.Tensor): Transform `T_b` translation with shape `(...,3)`. + + Returns: + R_composed (torch.Tensor): Composed transform `a * b` rotation matrix with + shape `(...,3,3)`. + t_composed (torch.Tensor): Composed transform `a * b` translation vector with + shape `(...,3)`. + """ + R_composed = R_a @ R_b + t_composed = t_a + (R_a @ t_b.unsqueeze(-1)).squeeze(-1) + return R_composed, t_composed + + +def compose_translation( + R_a: torch.Tensor, t_a: torch.Tensor, t_b: torch.Tensor +) -> torch.Tensor: + """Compose translation component of `T_compose = T_a * T_b` (broadcastable). + + Args: + R_a (torch.Tensor): Transform `T_a` rotation matrix with shape `(...,3,3)`. + t_a (torch.Tensor): Transform `T_a` translation with shape `(...,3)`. + t_b (torch.Tensor): Transform `T_b` translation with shape `(...,3)`. + + Returns: + t_composed (torch.Tensor): Composed transform `a * b` translation vector with + shape `(...,3)`. + """ + t_composed = t_a + (R_a @ t_b.unsqueeze(-1)).squeeze(-1) + return t_composed + + +def compose_inner_transforms( + R_a: torch.Tensor, t_a: torch.Tensor, R_b: torch.Tensor, t_b: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + """Compose the relative inner transform `T_ab = T_a^{-1} * T_b`. + + Args: + R_a (torch.Tensor): Transform `T_a` rotation matrix with shape `(...,3,3)`. + t_a (torch.Tensor): Transform `T_a` translation with shape `(...,3)`. + R_b (torch.Tensor): Transform `T_b` rotation matrix with shape `(...,3,3)`. + t_b (torch.Tensor): Transform `T_b` translation with shape `(...,3)`. + + Returns: + R_ab (torch.Tensor): Composed transform `T_a * T_b` rotation matrix with + shape `(...,3,3)`. Inner dimensions are broadcastable. + t_ab (torch.Tensor): Composed transform `T_a * T_b` translation vector with + shape `(...,3)`. + """ + R_a_inverse = R_a.transpose(-1, -2) + R_ab = R_a_inverse @ R_b + t_ab = (R_a_inverse @ ((t_b - t_a).unsqueeze(-1))).squeeze(-1) + return R_ab, t_ab + + +def fuse_gaussians_isometric_plus_radial( + x: torch.Tensor, + p_iso: torch.Tensor, + p_rad: torch.Tensor, + direction: torch.Tensor, + dim: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Fuse Gaussians along a dimension ``dim``. This assumes the Gaussian + precision matrices are a sum of an isometric part P_iso together with + a part P_rad that provides information only along one direction. + + Args: + x (torch.Tensor): A (...,3)-shaped tensor of means. + p_iso (torch.Tensor): A (...)-shaped tensor of weights of the isometric part of the + precision matrix. + p_rad (torch.Tensor): A (...)-shaped tensor of weights of the radial part of the + precision matrix. + direction (torch.Tensor): A (...,3)-shaped tensor of directions along which + information is available. + dim (int): The dimension over which to aggregate (fuse). + + Returns: + A tuple ``(x_fused, P_fused)`` of fused mean and precision, with + specified ``dim`` removed. + """ + assert dim >= 0, "dimension must index from the left" + + # P_rad has information only parallel to the edge. + outer = direction.unsqueeze(-1) * direction.unsqueeze(-2) + inner = direction.square().sum(-1).clamp(min=1e-10) + P_rad = (p_rad / inner)[..., None, None] * outer + P_iso = p_iso.unsqueeze(-1).expand(p_iso.shape + (3,)).diag_embed() + P = P_iso + P_rad + + # Compute the Bayesian fusion aka product-of-experts of the Gaussians. + P_fused = P.sum(dim) + Px_fused = (P @ x.unsqueeze(-1)).squeeze(-1).sum(dim) + # There might be a cheaper way to do this, either via Cholesky + # or hand-coding the 3x3 matrix solve operation. + x_fused = torch.linalg.solve(P_fused, Px_fused) + + return x_fused, P_fused + + +def collect_neighbor_transforms( + R_i: torch.Tensor, t_i: torch.Tensor, edge_idx: torch.LongTensor +) -> Tuple[torch.Tensor, torch.Tensor]: + """Collect neighbor transforms. + + Args: + R_i (torch.Tensor): Transform `T` rotation matrices with shape + `(num_batch, num_residues, 3, 3)`. + t_i (torch.Tensor): Transform `T` translations with shape + `(num_batch, num_residues, 3)`. + edge_idx (torch.LongTensor): Edge indices for neighbors with shape + `(num_batch, num_nodes, num_neighbors)`. + + Returns: + R_j (torch.Tensor): Rotation matrices of neighbor transforms, with shape + `(num_batch, num_residues, num_neighbors, 3, 3)`. + t_j (torch.Tensor): Translations of neighbor transforms, with shape + `(num_batch, num_residues, num_neighbors, 3)`. + """ + num_batch, num_residues, num_neighbors = edge_idx.shape + R_i_flat = R_i.reshape([num_batch, num_residues, 9]) + R_j = graph.collect_neighbors(R_i_flat, edge_idx).reshape( + [num_batch, num_residues, num_neighbors, 3, 3] + ) + t_j = graph.collect_neighbors(t_i, edge_idx) + return R_j, t_j + + +def collect_neighbor_inner_transforms( + R_i: torch.Tensor, t_i: torch.Tensor, edge_idx: torch.LongTensor +) -> Tuple[torch.Tensor, torch.Tensor]: + """Collect inner transforms between neighbors. + + Args: + R_i (torch.Tensor): Transform `T` rotation matrices with shape + `(num_batch, num_residues, 3, 3)`. + t_i (torch.Tensor): Transform `T` translations with shape + `(num_batch, num_residues, 3)`. + edge_idx (torch.LongTensor): Edge indices for neighbors with shape + `(num_batch, num_nodes, num_neighbors)`. + + Returns: + R_ji (torch.Tensor): Rotation matrices of neighbor transforms, with shape + `(num_batch, num_residues, num_neighbors, 3, 3)`. + t_ji (torch.Tensor): Translations of neighbor transforms, with shape + `(num_batch, num_residues, num_neighbors, 3)`. + """ + R_j, t_j = collect_neighbor_transforms(R_i, t_i, edge_idx) + R_ji, t_ji = compose_inner_transforms( + R_j, t_j, R_i.unsqueeze(-3), t_i.unsqueeze(-2) + ) + return R_ji, t_ji + + +def equilibrate_transforms( + R_i: torch.Tensor, + t_i: torch.Tensor, + R_ji: torch.Tensor, + t_ji: torch.Tensor, + logit_ij: torch.Tensor, + mask_ij: torch.Tensor, + edge_idx: torch.LongTensor, + iterations: int = 1, + R_global: Optional[torch.Tensor] = None, + t_global: Optional[torch.Tensor] = None, + R_global_i: Optional[torch.Tensor] = None, + t_global_i: Optional[torch.Tensor] = None, + logit_global_i: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Equilibrate neighbor transforms. + + Args: + R_i (torch.Tensor): Transform `T` rotation matrices with shape + `(num_batch, num_residues, 3, 3)`. + t_i (torch.Tensor): Transform `T` translations with shape + `(num_batch, num_residues, 3)`. + R_ji (torch.Tensor): Rotation matrices to go between frames for nodes i and j + with shape `(num_batch, num_residues, num_neighbors, 3, 3)`. + t_ji (torch.Tensor): Translations to go between frames for nodes i and j with + shape `(num_batch, num_residues, num_neighbors, 3)`. + logit_ij (torch.Tensor): Logits for averaging neighbor transforms with shape + `(num_batch, num_residues, num_neighbors, num_weights)`. Note that + `num_weights` must be 1, 2, or 3; see the documentation for + `generate.layers.structure.transforms.average_transforms` for an + explanation of the interpretations with different `num_weights`. + mask_ij (torch.Tensor): Mask for averaging neighbor transforms with shape + `(num_batch, num_residues, num_neighbors)`. + edge_idx (torch.LongTensor): Edge indices for neighbors with shape + `(num_batch, num_nodes, num_neighbors)`. + iterations (int): Number of iterations to equilibrate for. + R_global (torch.Tensor): Optional global frame rotation matrix with shape + `(num_batch, 3, 3)`. + t_global (torch.Tensor): Optional global frame translation with shape + `(num_batch, 3)`. + R_global_i (torch.Tensor): Optional rotation matrix for global frame from + nodes with shape `(num_batch, num_residues, 3, 3)`. + t_global_i (torch.Tensor): Optional translation for global frame from nodes + with shape `(num_batch, num_residues, 3)`. + logit_global_i (torch.Tensor): Logits for averaging global frame transform + with shape `(num_batch, num_residues, num_weights)`. `num_weights` + should match that of `logit_ij`. + + Returns: + R_i (torch.Tensor): Rotation matrices of equilibrated transforms, with shape + `(num_batch, num_residues, 3, 3)`. + t_i (torch.Tensor): Translations of equilibrated transforms, with shape + `(num_batch, num_residues, 3)`. + """ + + # Optional global frames are treated as additional neighbor + update_global = False + if None not in [R_global, t_global, R_global_i, t_global_i, logit_global_i]: + update_global = True + num_batch, num_residues, num_neighbors = list(mask_ij.shape) + R_global_i = R_global_i.unsqueeze(2) + t_global_i = t_global_i.unsqueeze(2) + R_ji = torch.cat((R_ji, R_global_i), dim=2) + t_ji = torch.cat((t_ji, t_global_i), dim=2) + logit_ij = torch.cat((logit_ij, logit_global_i.unsqueeze(2)), dim=2) + R_global = R_global.reshape([num_batch, 1, 1, 3, 3]).expand(R_global_i.shape) + t_global = t_global.reshape([num_batch, 1, 1, 3]).expand(t_global_i.shape) + mask_i = (mask_ij.sum(2, keepdims=True) > 0).float() + mask_ij = torch.cat((mask_ij, mask_i), dim=2) + + t_edge = None + for i in range(iterations): + R_j, t_j = collect_neighbor_transforms(R_i, t_i, edge_idx) + if update_global: + R_j = torch.cat((R_j, R_global), dim=2) + t_j = torch.cat((t_j, t_global), dim=2) + R_i_pred, t_i_pred = compose_transforms(R_j, t_j, R_ji, t_ji) + + if logit_ij.size(-1) == 3: + # Compute i-j displacement in the same coordinate system as + # t_i_pred, i.e. in global coords. Sign does not matter. + t_edge = t_j - t_i_pred + + R_i, t_i = average_transforms( + R_i_pred, t_i_pred, logit_ij, mask_ij, t_edge=t_edge, dim=2 + ) + + return R_i, t_i + + +def average_transforms( + R: torch.Tensor, + t: torch.Tensor, + w: torch.Tensor, + mask: torch.Tensor, + dim: int, + t_edge: Optional[torch.Tensor] = None, + dither: Optional[bool] = True, + dither_eps: float = 1e-4, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Average transforms with optional dithering. + + Args: + R (torch.Tensor): Transform `T` rotation matrix with shape `(...,3,3)`. + t (torch.Tensor): Transform `T` translation with shape `(...,3)`. + w (torch.Tensor): Logits for averaging weights with shape + `(...,num_weights)`. `num_weights` can be 1 (single scalar + weight per transform), 2 (separate weights for each rotation + and translation), or 3 (one weight for rotation, two weights + for translation corresponding to precision in all directions / + along t_edge). + mask (torch.Tensor): Mask for averaging weights with shape `(...)`. + dim (int): Dimension to average along. + t_edge (torch.Tensor, optional): Translation `T` of shape `(..., 3)` + indicating the displacement between source and target nodes. + dither (bool): Whether to noise final rotations. + dither_eps (float): Fractional amount by which to noise rotations. + + Returns: + R_avg (torch.Tensor): Average transform `T_avg` rotation matrix with + shape `(...{reduced}...,3,3)`. + t_avg (torch.Tensor): Average transform `T_avg` translation with + shape `(...{reduced}...,3)`. + """ + assert dim >= 0, "dimension must index from the left" + w = torch.where( + mask[..., None].bool(), w, torch.full_like(w, torch.finfo(w.dtype).min) + ) + + # We use different averaging models based on the number of weights + num_transform_weights = w.size(-1) + if num_transform_weights == 1: + # Share a single scalar weight between t and R. + probs = w.softmax(dim) + t_probs = probs + R_probs = probs[..., None] + + # Average translation. + t_avg = (t * t_probs).sum(dim) + elif num_transform_weights == 2: + # Use separate scalar weights for each of t and R. + probs = w.softmax(dim) + t_probs, R_probs = probs.unbind(-1) + t_probs = t_probs[..., None] + R_probs = R_probs[..., None, None] + + # Average translation. + t_avg = (t * t_probs).sum(dim) + elif num_transform_weights == 3: + # For R use a signed scalar weight. + R_probs = w[..., 2].softmax(dim)[..., None, None] + + # For t use a two-parameter precision matrix P = P_isometric + P_radial. + # We need to hand compute softmax over the shared dim x 2 elements. + w_t = w[..., :2] + w_t_total = w_t.logsumexp([dim, -1], True) + p_iso, p_rad = (w_t - w_t_total).exp().unbind(-1) + + # Use Gaussian fusion for translation. + t_edge = t_edge * mask.to(t_edge.dtype)[..., None] + t_avg, _ = fuse_gaussians_isometric_plus_radial(t, p_iso, p_rad, t_edge, dim) + else: + raise NotImplementedError + + # Average rotation via SVD + R_avg_unc = (R * R_probs).sum(dim) + R_avg_unc = R_avg_unc + dither_eps * torch.randn_like(R_avg_unc) + U, S, Vh = torch.linalg.svd(R_avg_unc, full_matrices=True) + R_avg = U @ Vh + + # Enforce that matrix is rotation matrix + d = torch.linalg.det(R_avg) + d_expand = F.pad(d[..., None, None], (2, 0), value=1.0) + Vh = Vh * d_expand + R_avg = U @ Vh + return R_avg, t_avg + + +def _debug_plot_transforms( + R_ij: torch.Tensor, + t_ij: torch.Tensor, + logits_ij: torch.Tensor, + edge_idx: torch.LongTensor, + mask_ij: torch.Tensor, + dist_eps: float = 1e-3, +): + """Visualize 6dof frame transformations""" + from matplotlib import pyplot as plt + + num_batch = R_ij.shape[0] + R_ij.shape[1] + + # Masked softmax on logits + # logits_ij = torch.where( + # mask_ij.bool(), logits_ij, + # torch.full_like(logits_ij, torch.finfo(logits_ij.dtype).min) + # ) + p_ij = torch.softmax(logits_ij, 2) + p_ij = torch.log_softmax(logits_ij, 2) + # p_ij = torch.softmax(logits_ij, 2) + P_ij = graph.scatter_edges(p_ij[..., None], edge_idx)[..., 0] + + q_ij = geometry.quaternions_from_rotations(R_ij) + q_ij = graph.scatter_edges(q_ij, edge_idx) + t_ij = graph.scatter_edges(t_ij, edge_idx) + + # Converte to distance, direction, orientation + D = torch.sqrt(t_ij.square().sum(-1)) + U = t_ij / (D[..., None] + dist_eps) + D_max = D.max().item() + t_ij = t_ij / D_max + q_axis = q_ij[..., 1:] + + # Distance features + D_img = D + D_img.min().item() + D_img.max().item() + + def _format(T): + T = T.cpu().data.numpy() + # RGB on (0,1)^3 + if len(T.shape) == 3: + T = (T + 1) / 2 + return T + + base_width = 4 + num_cols = 4 + plt.figure(figsize=(base_width * 4, base_width * num_batch), dpi=300) + ix = 1 + for i in range(num_batch): + plt.subplot(num_batch, num_cols, ix) + plt.imshow(_format(D_img[i, :, :]), cmap="inferno") + # plt.clim([hD_min, hD_max]) + plt.axis("off") + + plt.subplot(num_batch, num_cols, ix + 1) + plt.imshow(_format(U[i, :, :, :])) + plt.axis("off") + plt.subplot(num_batch, num_cols, ix + 2) + plt.imshow(_format(q_axis[i, :, :, :])) + plt.axis("off") + + # Confidence plots + plt.subplot(num_batch, num_cols, ix + 3) + plt.imshow(_format(P_ij[i, :, :]), cmap="inferno") + # plt.clim([0, P_ij[i,:,:].max().item()]) + plt.axis("off") + ix = ix + num_cols + + plt.tight_layout() + return diff --git a/prtm/models/chroma/system.py b/prtm/models/chroma/system.py new file mode 100644 index 00000000..7dd2b966 --- /dev/null +++ b/prtm/models/chroma/system.py @@ -0,0 +1,4535 @@ +# Copyright Generate Biomedicines, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import copy +import logging +import re +import warnings +from dataclasses import dataclass +from functools import partial +from typing import Dict, List, Optional, Tuple + +import numpy as np +import torch + +from prtm.constants.residue_constants import ( + PDB_CHAIN_IDS, + alphabetical_restypes, + backbone4_atoms, + restype_1to3, +) +from prtm.models.chroma import polyseq +from prtm.models.chroma import starparser as sp +from prtm.models.chroma.geometry import AA_GEOMETRY + + +@dataclass +class SystemAssemblyInfo: + """A class for representing the assembly information for System objects. + + assemblies (dict): a dictionary of assemblies with keys being assembly IDs + and values being dictionaries with of the following structure: + { + "details": "complete icosahedral assembly", + "instructions": [ + { + "oper_expression": "(1-60)", + "chains": [0, 1, 2], + + # Each assembly instruction has information for generating + # one or more images, with image `i` generated by applying + # the sequence of operations with IDs in `operations[i]` to the + # list of chains in `chains`. The corresponding operations + # are described under `assembly_info["operations"][ID]`. + "operations": [["X0", "1", "2", "3"], ["X0", "4", "5", "6"]]], + }, + ... + ], + } + + operations (dict): a dictionary with symmetry operations. Keys are operation IDs + and values being dictionaries with the following structure: + { + "type": "identity operation", + "name": "1_555", + "matrix": np.array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]), + "vector": np.array([0., 0., 0.]), + }, + ... + """ + + assemblies: dict + operations: dict + + def __init__(self, assemblies: dict = dict(), operations: dict = dict()): + self.assemblies = assemblies + self.operations = operations + + @staticmethod + def make_operation(type: str, name: str, matrix: list, vector: list): + op = { + "type": type, + "name": name, + "matrix": np.zeros([3, 3]), + "vector": np.zeros([3, 1]), + } + assert len(matrix) == 9, "expected 9 elements in rotation matrix" + assert len(vector) == 3, "expected 3 elements in translation vector" + for i in range(3): + op["vector"][i] = float(vector[i]) + for j in range(3): + op["matrix"][i][j] = float(matrix[i * 3 + j]) + return op + + def delete_chain(self, cid: str): + """Deletes the mention of the chain from assembly information. + + Args: + cid (str): Chain ID to delete. + """ + for ass_id, assembly in self.assemblies.items(): + for ins in assembly["instructions"]: + ins["chains"] = [_id for _id in ins["chains"] if _id != cid] + + def rename_chain(self, old_cid: str, new_cid: str): + """Renames all mentions of a chain to its new chain ID. + + Args: + old_cid (str): Chain ID to rename. + new_cid (str): Newly assigned Chain ID. + """ + for ass_id, assembly in self.assemblies.items(): + for ins in assembly["instructions"]: + ins["chains"] = [ + new_cid if cid == old_cid else cid for cid in ins["chains"] + ] + + +class StringList: + """A class for representing and accessing a list of strings in a highly memory-efficient + manner. Access is constant time, but modification is linear time in length of list. + """ + + def __init__(self, init_list: List[str] = []): + self.string = "" + self.rng = ArrayList(2, dtype=int) + for i in range(len(init_list)): + self.append(init_list[i]) + + def __getitem__(self, i: int): + beg, length = self.rng[i] + return self.string[beg : beg + length] + + def __setitem__(self, i: int, new_string: str): + beg, length = self.rng[i] + self.string = self.string[:beg] + new_string + self.string[beg + length :] + if len(new_string) != length: + self.rng[i, 1] = len(new_string) + self.rng[i + 1 :, 0] = self.rng[i + 1 :, 0] + len(new_string) - length + + def __str__(self): + return self.string + + def __len__(self): + return len(self.rng) + + def copy(self): + new_list = StringList() + new_list.string = self.string + new_list.rng = self.rng.copy() + return new_list + + def append(self, new_string: str): + self.rng.append([len(self.string), len(new_string)]) + self.string = self.string + new_string + + def insert(self, i: int, new_string: str): + if i < len(self): + ix, _ = self.rng[i] + elif i == len(self): + if len(self) == 0: + ix = 0 + else: + ix = self.rng[i - 1].sum() + else: + raise Exception( + f"cannot insert in position {i} for stringList of length {len(self)}" + ) + self.string = self.string[0:ix] + new_string + self.string[ix:] + self.rng.insert(i, [ix, len(new_string)]) + if len(new_string) > 0: + self.rng[i + 1 :, 0] = self.rng[i + 1 :, 0] + len(new_string) + + def pop(self, i: int): + beg, length = self.rng[i] + val = self.string[beg : beg + length] + self.string = self.string[0:beg] + self.string[beg + length :] + self.rng[i + 1 :, 0] = self.rng[i + 1 :, 0] - len(val) + self.rng.pop(i) + return val + + def delete_range(self, rng: range): + rng = sorted(rng) + [i, j] = [rng[0], rng[-1]] + beg, _ = self.rng[i] + end = self.rng[j].sum() + self.string = self.string[0:beg] + self.string[end:] + self.rng[j + 1 :, 0] = self.rng[j + 1 :, 0] - (end - beg + 1) + self.rng.delete_range(rng) + + +class NameList: + """A class for representing and accessing a list of "names"--i.e., strings that tend to + have generic values, such that many repeat values are expected in a given list.""" + + def __init__(self, init_list: List[str] = []): + self._reindex(init_list) + + def _reindex(self, init_list: List[str]): + self.unique_names = [] + self.name_indicies = dict() + self.index_use = dict() + self.indices = ArrayList(1, dtype=int) + for name in init_list: + self.append(name) + + def copy(self): + new_list = NameList() + new_list.unique_names = self.unique_names.copy() + new_list.name_indicies = self.name_indicies.copy() + new_list.index_use = self.index_use.copy() + new_list.indices = self.indices.copy() + return new_list + + def _check_index(self): + L = len(self.unique_names) + I = len(self.index_use) + if (L > 2 * I) and (L - I > 10): + self._reindex([self[i] for i in range(len(self))]) + + def __getitem__(self, i: int): + try: + idx = self.indices[i].item() + except IndexError as e: + raise IndexError(f"index {i} out of range for nameList\n" + str(e)) + return self.unique_names[idx] + + def __setitem__(self, i: int, new_name: str): + try: + idx = self.indices[i] + except IndexError as e: + raise IndexError(f"index {i} out of range for nameList\n" + str(e)) + self.index_use[idx] = self.index_use[idx] - 1 + if self.index_use[idx] == 0: + del self.index_use[idx] + if new_name not in self.name_indicies: + idx = len(self.name_indicies) + self.name_indicies[new_name] = idx + self.unique_names.append(new_name) + else: + idx = self.name_indicies[new_name] + self.indices[i] = idx + self._update_use(idx, 1) + self._check_index() + + def __str__(self): + return str([self[i] for i in range(len(self))]) + + def __len__(self): + return len(self.indices) + + def _update_use(self, idx, delta): + self.index_use[idx] = self.index_use.get(idx, 0) + delta + if self.index_use[idx] <= 0: + del self.index_use[idx] + + def _get_name_index(self, name: str): + if name not in self.name_indicies: + idx = len(self.name_indicies) + self.name_indicies[name] = idx + self.unique_names.append(name) + else: + idx = self.name_indicies[name] + return idx + + def append(self, name: str): + idx = self._get_name_index(name) + self.indices.append(idx) + self.index_use[idx] = self.index_use.get(idx, 0) + 1 + + def insert(self, i: int, new_string: str): + idx = self._get_name_index(new_string) + self.indices.insert(i, idx) + self.index_use[idx] = self.index_use.get(idx, 0) + 1 + + def pop(self, i: int): + idx = self.indices.pop(i).item() + val = self.unique_names[idx] + self._update_use(idx, -1) + self._check_index() + return val + + def delete_range(self, rng: range): + for i in reversed(sorted(rng)): + self.pop(i) + + +class ArrayList: + def __init__(self, ndims: int, dtype: type, length: int = 0, val=0): + if ndims == 1: + self._array = np.ndarray(shape=(max(length, 2)), dtype=dtype) + else: + self._array = np.ndarray(shape=(max(length, 2), ndims), dtype=dtype) + self.ndims = ndims + self._array[:] = val + self.length = length + # view of just the data without the extra allocated stuff + self.array = self._array[: self.length] + + def convert_negative_slice(self, slice_obj): + start = slice_obj.start if slice_obj.start is not None else 0 + stop = slice_obj.stop if slice_obj.stop is not None else self.length + + if start < 0: + start = self.length + start + if stop < 0: + stop = self.length + stop + + return slice(start, stop, slice_obj.step) + + def copy(self): + new_list = ArrayList(ndims=self.ndims, dtype=self.array.dtype, length=len(self)) + new_list[:] = self[:] + return new_list + + def __len__(self): + return self.length + + def capacity(self): + return self._array.shape[0] + + def __getitem__(self, i: int): + return self.array[i] + + def __setitem__(self, i: int, row: list): + self.array[i] = row + + def resize(self, delta): + # for speed, hard-code instead of calling len() and capacity() + new_length = self.length + delta + cap = self._array.shape[0] + if (new_length > cap) or (new_length < cap / 3): + new_capacity = 2 * new_length + self._resize(new_capacity) + self.length = new_length + self.array = self._array[: self.length] + + def _resize(self, new_size): + if self.ndims == 1: + self._array.resize((new_size), refcheck=False) + else: + self._array.resize((new_size, self.ndims), refcheck=False) + + def items(self): + for i in range(self.length): + yield self.array[i, :] + + def append(self, row: list): + self.resize(1) + self.array[-1] = row + + def insert(self, i: int, row: list): + """Insert the row such that it ends up being at index ``i`` in the new arrayList""" + # resize by +1 + self.resize(1) + + # everything in range [i:end-1) moves over by +1 + self.array[i + 1 :] = self.array[i:-1] + + # set the value at index i + self.array[i] = row + + def pop(self, i: int): + """Remove and return element at index i""" + + # get the element at index i + row = self.array[i].copy() + + # everything from [i+1; end) moves over by -1 + self.array[i:-1] = self.array[i + 1 :] + + # resize by -1 + self.resize(-1) + + return row + + def delete_range(self, rng: range): + i, j = min(rng), max(rng) + + # move over to the left to account for the removed part + cut_length = j - i + 1 + new_length = len(self) - cut_length + self.array[i:new_length] = self.array[j + 1 :] + + # resize by -1 + self.resize(-cut_length) + + def __str__(self): + return str([self[i] for i in range(len(self))]) + + +@dataclass +class HierarchicList: + """A utility class that represents a hierarchy of lists. Each level represents + a list of elements, each element having a set of properties (each property being + stored as an array-like object over elements). Further, each element has a number + of children corresponding to a range of elements in a lower-hierarhy list.""" + + _properties: dict + _parent_list: HierarchicList + _child_list: HierarchicList + _num_children: ArrayList # (1, n) + _child_offset: ArrayList # (1, n) + + def __init__( + self, + properties: dict, + parent_list: HierarchicList = None, + num_children: ArrayList = ArrayList(1, dtype=int), + ): + self._properties = dict() + for key in properties: + self._properties[key] = properties[key].copy() + self._parent_list = parent_list + if self._parent_list is not None: + self._parent_list._child_list = self + self._child_list = None + self._num_children = num_children.copy() if num_children is not None else None + # start off with lazy offsets, self.reindex() creates them + self._child_offset = None + + def copy(self): + new_list = HierarchicList( + self._properties, self._parent_list, self._num_children + ) + new_list._child_list = self._child_list + if self._child_offset is None: + new_list._child_offset = None + else: + new_list._child_offset = self._child_offset.copy() + return new_list + + def set_parent(self, parent_list: HierarchicList): + self._parent_list = parent_list + + def child_index(self, i: int, at: int): + if self._child_offset is not None: + return self._child_offset[i] + at + return self._num_children[0:i].sum() + at + + def reindex(self): + if self._num_children is not None: + self._child_offset = ArrayList( + 1, dtype=int, length=len(self._num_children), val=0 + ) + for i in range(1, len(self)): + self._child_offset[i] = ( + self._child_offset[i - 1] + self._num_children[i - 1] + ) + + def append_child(self, properties): + self._num_children[len(self._num_children) - 1] += 1 + self._child_list.append(properties) + + def insert_child(self, i: int, at: int, properties): + idx = self.child_index(i, at) + self._num_children[i] += 1 + self._child_offset = None + self._child_list.insert(idx, properties) + return idx + + def delete_child(self, i: int, at: int): + idx = self.child_index(i, at) + self._num_children[i] -= 1 + self._child_offset = None + self._child_list.delete(idx) + + def append(self, properties): + if set(properties.keys()) != set(self._properties.keys()): + raise Exception(f"unexpected set of attributes '{list(properties.keys())}") + for key, value in properties.items(): + self._properties[key].append(value) + if self._child_offset is not None: + self._child_offset.append( + self._child_offset[-1:].sum() + self._num_children[-1:].sum() + ) + if self._num_children is not None: + self._num_children.append(0) + + def insert(self, i: int, properties): + if set(properties.keys()) != set(self._properties.keys()): + raise Exception(f"unexpected set of attributes '{list(properties.keys())}") + for key, value in properties.items(): + self._properties[key].insert(i, value) + if self._child_offset is not None: + if i >= len(self._child_offset): + off = self._child_offset[-1:].sum() + self._num_children[-1:].sum() + else: + off = self._child_offset[i] + self._child_offset.insert(i, off) + if self._num_children is not None: + self._num_children.insert(i, 0) + + def delete(self, i: int): + for key in self._properties: + self._properties[key].pop(i) + if self._num_children is not None and self._num_children[i] != 0: + for at in range(self._num_children[i] - 1, -1, -1): + self.delete_child(i, at) + self._num_children.pop(i) + self._child_offset = None + + def delete_range(self, rng: range): + for key in self._properties: + self._properties[key].delete_range(rng) + # iterating in descending order so that child offsets remain valid for subsequent elements + for i in reversed(sorted(rng)): + if self._num_children is not None and self._num_children[i] != 0: + idx = self.child_index(i, 0) + self._child_list.delete_range( + self, range(idx, idx + self._num_children[i]) + ) + self._num_children[i] = 0 + self._child_offset = None + + def __len__(self): + for key in self._properties: + return len(self._properties[key]) + return None + + def __getitem__(self, i: str): + return self._properties[i] + + # def __setitem__(self, i: tuple, val): + # self._properties[i[0]][i[1]] = val + + def num_children(self, i: int): + return self._num_children[i] + + def has_children(self, i: int): + return self._num_children is not None and self._num_children[i] + + def __str__(self): + string = "Properties:\n" + for key in self._properties: + string += f"{key}: {str(self._properties[key])}\n" + string += f"num_children: {str(self._num_children)}\n" + string += f"child_offset: {str(self._child_offset)}\n" + string += "----\n" + string += str(self._child_list) + return string + + +@dataclass +class System: + """A class for storing, accessing, managing, and manipulating a molecular + system's structure, sequence, and topological information. The class is + organized as a hierarchy of objects: + + System: top-level class containing all information about a molecular system + -> Chain: a sub-portion of the System; for polymers this is generally a + chemically connected molecular graph belong to a System (e.g., for + protein complexes, this would be one of the proteins). + -> Residue: a generally chemically-connected molecular unit (for polymers, + the repeating unit), belonging to a Chain. + -> Atom: an atom belonging to a Residue with zero, one, or more locations. + -> AtomLocation: the location of an Atom (3D coordinates and other information). + + Attributes: + name (str): given name for System + _chains (list): a list of Chain objects + _entities (dict): a dictionary of SystemEntity objects, with keys being entity IDs + _chain_entities (list): `chain_entities[ci]` stores entity IDs (i.e., keys into + `entities`) corresponding to the entity for chain `ci` + _extra_models (list): a list of hierarchicList object, representing locations + for alternative models + _labels (dict): a dictionary of residue labels. A label is a string value, + under some category (also a string), associated with a residue. E.g., + the category could be "SSE" and the value could be "H" or "S". If entry + `labels[category][gti]` exists and is equal to `value`, this means that + residue with global template index `gti` has the label `category:value`. + _selections (dict): a dictionary of selections. Keys are selection names and + values are lists of corresponding gti indices. + _assembly_info (SystemAssemblyInfo): information on symmetric assemblies that can + be constructed from components of the molecular system. See ``SystemAssemblyInfo``. + """ + + name: str + _chains: HierarchicList + _residues: HierarchicList + _atoms: HierarchicList + _locations: HierarchicList + _entities: Dict[int, SystemEntity] + _chain_entities: List[int] + _extra_models: List[HierarchicList] + _labels: Dict[str, Dict[int, str]] + _selections: Dict[str, List[int]] + _assembly_info: SystemAssemblyInfo + + def __init__(self, name: str = "system"): + self.name = name + self._chains = HierarchicList( + properties={ + "cid": StringList(), + "segid": StringList(), + "authid": StringList(), + } + ) + self._residues = HierarchicList( + properties={ + "name": NameList(), + "resnum": ArrayList(1, dtype=int), + "authresid": StringList(), + "icode": ArrayList(1, dtype="U1"), + }, + parent_list=self._chains, + ) + self._atoms = HierarchicList( + properties={"name": NameList(), "het": ArrayList(1, dtype=bool)}, + parent_list=self._residues, + ) + self._locations = HierarchicList( + properties={ + "coor": ArrayList(5, dtype=float), + "alt": ArrayList(1, dtype="U1"), + }, + parent_list=self._atoms, + num_children=None, + ) + self._entities = dict() + self._chain_entities = [] + self._extra_models = [] + self._labels = dict() + self._selections = dict() + self._assembly_info = SystemAssemblyInfo() + + def _reindex(self): + self._chains.reindex() + self._residues.reindex() + self._atoms.reindex() + self._locations.reindex() + + def _print_indexing(self): + for chain in self.chains(): + off = self._chains.child_index(chain._ix, 0) + num = self._chains.num_children(chain._ix) + print(f"chain {chain._ix}, {chain}: [{off} - {off + num})") + for residue in chain.residues(): + off = self._residues.child_index(residue._ix, 0) + num = self._residues.num_children(residue._ix) + print(f"\tresidue {residue._ix}, {residue}: [{off} - {off + num})") + for atom in residue.atoms(): + off = self._atoms.child_index(atom._ix, 0) + num = self._atoms.num_children(atom._ix) + print(f"\t\tatom {atom._ix}, {atom}: [{off} - {off + num})") + for loc in atom.locations(): + has_children = self._locations.has_children(loc._ix) + print( + f"\t\t\tlocation {loc._ix}, {loc}: has children? {has_children}" + ) + + @classmethod + def from_XCS( + cls, + X: torch.Tensor, + C: torch.Tensor, + S: torch.Tensor, + alternate_alphabet: str = None, + ) -> System: + """Convert an XCS set of pytorch tensors to a new System object. + + B is batch size (Function only handles batch size of one now) + N is the number of residues + + Args: + X (torch.Tensor): Coordinates with shape `(1, num_residues, num_atoms, 3)`. + `num_atoms` will be 14 if `all_atom=True` or 4 otherwise. + C (torch.LongTensor): Chain map with shape `(1, num_residues)`. It codes + positions as 0 when masked, positive integers for chain indices, + and negative integers to represent missing residues of the + corresponding positive integers. + S (torch.LongTensor): Sequence with shape `(1, num_residues)`. + alternate_alphabet (str, optional): Optional alternative alphabet for + sequence encoding. Otherwise the default alphabet is set in + `alphabetical_restypes`.Amino acid alphabet for embedding. + Returns: + System: A System object with the new XCS data. + + """ + alphabet = ( + alphabetical_restypes if alternate_alphabet is None else alternate_alphabet + ) + all_atom = X.shape[2] == 14 + + assert X.shape[0] == 1 + assert C.shape[0] == 1 + assert S.shape[0] == 1 + assert X.shape[1] == S.shape[1] + assert C.shape[1] == C.shape[1] + + X, C, S = [T.squeeze(0).cpu().data.numpy() for T in [X, C, S]] + + chain_ids = np.abs(C) + + atom_count = 0 + new_system = cls("system") + + for i, chain_id in enumerate(np.unique(chain_ids)): + if chain_id == 0: + continue + + chain_bool = chain_ids == chain_id + X_chain = X[chain_bool, :, :].tolist() + C_chain = C[chain_bool].tolist() + S_chain = S[chain_bool].tolist() + + # Build chain + chain = new_system.add_chain("A") + for chain_ix, (X_i, C_i, S_i) in enumerate(zip(X_chain, C_chain, S_chain)): + resname = polyseq.to_triple(alphabet[int(S_i)]) + + # Build residue + residue = chain.add_residue( + resname, chain_ix + 1, str(chain_ix + 1), " " + ) + + if C_i > 0: + atom_names = backbone4_atoms + + if all_atom and resname in AA_GEOMETRY: + atom_names = atom_names + AA_GEOMETRY[resname]["atoms"] + + for atom_ix, atom_name in enumerate(atom_names): + x, y, z = X_i[atom_ix] + atom_count += 1 + residue.add_atom(atom_name, False, x, y, z, 1.0, 0.0, " ") + + # add an entity for each chain (copy from chain information) + for ci, chain in enumerate(new_system.chains()): + seq = [None] * chain.num_residues() + het = [None] * chain.num_residues() + for ri, res in enumerate(chain.residues()): + seq[ri] = res.name + het[ri] = all(a.het for a in res.atoms()) + entity_type, polymer_type = SystemEntity.guess_entity_and_polymer_type(seq) + entity = SystemEntity( + entity_type, f"chain {chain.cid}", polymer_type, seq, het + ) + new_system.add_new_entity(entity, [ci]) + + return new_system + + def to_XCS( + self, + all_atom: bool = False, + batch_dimension: bool = True, + mask_unknown: bool = True, + unknown_token: int = 0, + reorder_chain: bool = True, + alternate_alphabet=None, + alternate_atoms=None, + get_indices=False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Convert System object to XCS format. + + `C` tensor has shape [num_residues], where it codes positions as 0 + when masked, positive integers for chain indices, and negative integers + to represent missing residues of the corresponding positive integers. + + `S` tensor has shape [num_residues], it will map residue amino acid to alphabet integers. + If it is not found in `alphabet`, it will default to `unknown_token`. Set `mask_unknown` to true if + also want to mask `unk residue` in `chain_map` + + This function takes into account missing residues and updates chain_map + accordingly. + + Args: + system (type): generate System object to convert. + all_atom (bool): Include side chain atoms. Default is `False`. + batch_dimension (bool): Include a batch dimension. Default is `True`. + mask_unknown (bool): Mask residues not found in the alphabet. Default is + `True`. + unknown_token (int): Default token index if a residue is not found in + the alphabet. Default is `0`. + reorder_chain (bool): If set to true will start indexing chain at 1, + else will use the alphabet index (Default: True) + altenate_alphabet (str): Alternative alphabet if not `None`. + alternate_atoms (list): Alternate atom name subset for `X` if not `None`. + get_indices (bool): Also return the location indices corresponding to the + returned `X` tensor. + + Returns: + X (torch.Tensor): Coordinates with shape `(1, num_residues, num_atoms, 3)`. + `num_atoms` will be 14 if `all_atom=True` or 4 otherwise. + C (torch.LongTensor): Chain map with shape `(1, num_residues)`. It codes + positions as 0 when masked, positive integers for chain indices, + and negative integers to represent missing residues of the + corresponding positive integers. + S (torch.LongTensor): Sequence with shape `(1, num_residues)`. + location_indices (np.ndaray, optional): location indices corresponding to + the coordinates in `X`. + + """ + alphabet = ( + alphabetical_restypes if alternate_alphabet is None else alternate_alphabet + ) + + # Get chain map grabbing each chain in system and look at length + C = [] + # Unclear to me why the underscore is needed, maybe it isn't? + valid_chain_ids = "_" + PDB_CHAIN_IDS + for ch_id, chain in enumerate(self.chains()): + ch_str = chain.cid + if ch_str in list(valid_chain_ids): + map_ch_id = list(valid_chain_ids).index(ch_str) + else: + # fmt: off + map_ch_id = np.setdiff1d(np.arange(1, len(valid_chain_ids)), np.unique(C))[0] + # fmt: on + if reorder_chain: + map_ch_id = ch_id + 1 + C += [map_ch_id] * chain.num_residues() + + # Grab full sequence + oneLetterSeq = self.sequence(format="one-letter-string") + + if len(oneLetterSeq) != len(C): + logging.warning("Warning, System and chain_map length don't agree") + + # Initialize recipient arrays + atom_names = None + if all_atom: + num_atoms = 14 if all_atom else 4 + else: + if alternate_atoms is not None: + atom_names = alternate_atoms + else: + atom_names = backbone4_atoms + num_atoms = len(atom_names) + atom_names = {a: i for (i, a) in enumerate(atom_names)} + num_residues = self.num_residues() + X = np.zeros([num_residues, num_atoms, 3]) + location_indices = ( + np.zeros([num_residues * num_atoms], dtype=int) if get_indices else None + ) + + S = [] # Array will contain sequence indices + for i in range(num_residues): + # If residue should be mask or not + is_mask = False + + # Add sequence + if oneLetterSeq[i] in list(alphabet): + S.append(alphabet.index(oneLetterSeq[i])) + else: + S.append(unknown_token) + if mask_unknown: + is_mask = True + + # Get residue from system + res = self.get_residue(i) + if res is None or not res.has_structure(): + is_mask = True + + # If residue is mask because no structure or not found in alphabet + if is_mask: + # Set chain map to -x + C[i] = -abs(C[i]) + else: + # Loop through atoms + if all_atom: + code3 = restype_1to3[oneLetterSeq[i]] + atom_names = backbone4_atoms + AA_GEOMETRY[code3]["atoms"] + atom_names = {a: i for (i, a) in enumerate(atom_names)} + + X[ + i, : + ] = np.nan # so we can tell whether some atom was previously found + num_rem = len(atom_names) + for atom in res.atoms(): + name = System.protein_backbone_atom_type(atom.name, False, True) + if name is None: + name = atom.name + ix = atom_names.get(name, None) + if ix is None or not np.isnan(X[i, ix, 0]): + continue + for loc in atom.locations(): + X[i, ix] = loc.coors + if location_indices is not None: + location_indices[i * num_atoms + ix] = loc.get_index() + num_rem -= 1 + break + if num_rem == 0: + break + if num_rem != 0: + C[i] = -abs(C[i]) + X[i, :] = 0 + np.nan_to_num(X[i, :], copy=False, nan=0) + + # Tensor everything + X = torch.tensor(X).float() + C = torch.tensor(C).type(torch.long) + S = torch.tensor(S).type(torch.long) + + # Unsqueeze all the thing + if batch_dimension: + X = X.unsqueeze(0) + C = C.unsqueeze(0) + S = S.unsqueeze(0) + + if location_indices is not None: + return X, C, S, location_indices + + return X, C, S + + def update_with_XCS(self, X, C=None, S=None, alternate_alphabet=None): + """Update the System with XCS coordinates. NOTE: if the System has + more than one model, and if the shape of the System changes (i.e., + atoms are added or deleted), the additional models will be wiped. + + Args: + X (Tensor): Coordinates with shape `(1, num_residues, num_atoms, 3)`. + `num_atoms` will be 14 if `all_atom=True` or 4 otherwise. + C (LongTensor): Chain map with shape `(1, num_residues)`. It codes + positions as 0 when masked, positive integers for chain indices, + and negative integers to represent missing residues of the + corresponding positive integers. Defaults to the current System's + chain map. + S (LongTensor): Sequence with shape `(1, num_residues)`. Defaults to + the current System's sequence. + """ + if C is None or S is None: + _, _C, _S = self.to_XCS() + if C is None: + C = _C + if S is None: + S = _S + + # check to make sure sizes agree + if not ( + (X.shape[1] == self.num_residues()) + and (X.shape[1] == C.shape[1]) + and (X.shape[1] == S.shape[1]) + ): + raise Exception( + f"input tensor sizes {X.shape}, {C.shape}, and {S.shape}, disagree with System size {self.num_residues()}" + ) + + def _process_inputs(T): + if T is not None: + if len(T.shape) == 2 or len(T.shape) == 4: + T = T.squeeze(0) + T = T.to("cpu").detach().numpy() + return T + + X, C, S = map(_process_inputs, [X, C, S]) + + shape_changed = False + alphabet = ( + alphabetical_restypes if alternate_alphabet is None else alternate_alphabet + ) + for i, res in enumerate(self.residues()): + # atoms to update must have structure and are present in the chain map + if not res.has_structure() or C[i] <= 0: + continue + + # First, determine if the sequence has changed + resname = "UNK" + if S is not None and S[i] < len(alphabet): + resname = polyseq.to_triple(alphabet[S[i]]) + # If the identity changes, rename and delete side chain atoms + if res.name != resname: + res.rename(resname) + + # Second, delete all atoms that are missing in XCS or have duplicate names + atoms_sys = [atom.name for atom in res.atoms()] + atoms_XCS = backbone4_atoms + if resname in AA_GEOMETRY: + atoms_XCS = atoms_XCS + AA_GEOMETRY[resname]["atoms"] + atoms_XCS = atoms_XCS[: X.shape[1]] + to_delete = [] + for ix_a, atom in enumerate(res.atoms()): + name = atom.name + if name not in atoms_XCS or name in atoms_sys[:ix_a]: + to_delete.append(atom) + if len(to_delete) > 0: + shape_changed = True + res.delete_atoms(to_delete) + + # Finally, update all atom coordinates and manufacture any missing atoms + for x_id, atom_name in enumerate(atoms_XCS): + atom = res.find_atom(atom_name) + x, y, z = [X[i][x_id][k].item() for k in range(3)] + if atom is not None and atom.num_locations() > 0: + atom.x = x + atom.y = y + atom.z = z + else: + shape_changed = True + if atom is not None: + atom.add_location(x, y, z) + else: + res.add_atom(atom_name, False, x, y, z, 1.0, 0.0) + + # wipe extra models if the shape of the System changed + if shape_changed: + self._extra_models = [] + + def __str__(self): + return "system " + self.name + + def chains(self): + """Chain iterator (generator function).""" + for ci in range(len(self._chains)): + yield ChainView(ci, self) + + def get_chain(self, ci: int): + """Returns the chain by index. + + Args: + ci (int): Chain index (from 0) + + Returns: + ChainView object corresponding to the chain in question. + """ + return ChainView(ci, self) + + def get_chain_by_id(self, cid: str, segid=False): + """Returns the chain by its string ID. + + Args: + cid (str): Chain ID. + segid (bool, optional): If set to True (default is False) will + return the chain with the matching segment ID and not chain ID. + + Returns: + ChainView object corresponding to the chain in question. + """ + for ci, chain in enumerate(self.chains()): + if (not segid and cid == chain.cid) or (segid and cid == chain.segid): + return ChainView(ci, self) + return None + + def get_chains(self): + """Returns the list of all chains.""" + return [ChainView(ci, self) for ci in range(len(self._chains))] + + def get_chains_of_entity(self, entity_id: int, by=None): + """Returns the list of chains that correspond to the given entity ID. + + Args: + entity_id (int): Entity ID. + by (str, optional): If specified as "index", will return a + list of chain indices instead of ChainView objects. + + Returns: + List of ChainView objects or chain indices. + """ + cixs = [ci for (ci, eid) in enumerate(self._chain_entities) if entity_id == eid] + if by == "index": + return cixs + return [ChainView(ci, self) for ci in cixs] + + def residues(self): + """Residue iterator (generator function).""" + for chain in self.chains(): + for residue in chain.residues(): + yield residue + + def get_residue(self, gti: int): + """Returns the residue at the given global index. + + Args: + gti (int): Global residue index. + + Returns: + ResidueView object corresponding to the index. + """ + if gti < 0: + raise Exception(f"negative residue index: {gti}") + off = 0 + for chain in self.chains(): + nr = chain.num_residues() + if gti < off + nr: + return chain.get_residue(gti - off) + off = off + nr + raise Exception( + f"residue index {gti} out of range for System, which has {self.num_residues()} residues" + ) + + def atoms(self): + """Iterator of atoms in this System (generator function).""" + for chain in self.chains(): + for residue in chain.residues(): + for atom in residue.atoms(): + yield atom + + def get_atom(self, aidx: int): + """Returns the atom at the given global atom index. + + Args: + gti (int): Global atom index. + + Returns: + AtomView object corresponding to the index. + """ + if aidx < 0: + raise Exception(f"negative atom index: {aidx}") + off = 0 + for chain in self.chains(): + na = chain.num_atoms() + if aidx < off + na: + return chain.get_atom(aidx - off) + off = off + na + raise Exception( + f"atom index {aidx} out of range for System, which has {self.num_atoms()} atoms" + ) + + def locations(self): + """Iterator of atoms in this System (generator function).""" + for chain in self.chains(): + for residue in chain.residues(): + for atom in residue.atoms(): + for loc in atom.locations(): + yield loc + + def _new_locations(self): + new_locs = self._locations.copy() + for li in range(len(new_locs)): + new_locs["coor"][li] = [np.nan] * 5 + return new_locs + + def select(self, expression: str, left_associativity: bool = True): + """Evalates the given selection expression and returns all atoms + involved in the result as a list of AtomView's. + + Args: + expression (str): selection expression. + left_associativity (bool, optional): determines whether operators + in the expression are left-associative. + + Returns: + List of AtomView's. + """ + val, selex_info = self._select( + expression, left_associativity=left_associativity + ) + + # make a list of AtomView + result = [selex_info["all_atoms"][i].atom for i in sorted(val)] + + return result + + def select_residues( + self, + expression: str, + gti: bool = False, + allow_unstructured=False, + left_associativity: bool = True, + ): + """Evalates the given selection expression and returns all residues with any + atoms involved in the result as a list of ResidueView's or list of gti's. + + Args: + expression (str): selection expression. + gti (bool): if True (default is False), will return a list of gti + instead of a list of ResidueView's. + allow_unstructured (bool): If True (default is False), will allow + unstructured residues to be selected. + left_associativity (bool, optional): determines whether operators + in the expression are left-associative. + + Returns: + List of ResidueView's or gti's (ints). + """ + val, selex_info = self._select( + expression, + unstructured=allow_unstructured, + left_associativity=left_associativity, + ) + + # make a list of ResidueView or gti's + if gti: + result = sorted(set([selex_info["all_atoms"][i].rix for i in val])) + else: + residues = dict() + for i in val: + a = selex_info["all_atoms"][i] + residues[a.rix] = a.atom.residue + result = [residues[rix] for rix in sorted(residues.keys())] + + return result + + def select_chains( + self, expression: str, allow_unstructured=False, left_associativity: bool = True + ): + """Evalates the given selection expression and returns all chains with any + atoms involved in the result as a list of ChainView's. + + Args: + expression (str): selection expression. + allow_unstructured (bool): If True (default is False), will allow + unstructured chains to be selected. + left_associativity (bool, optional): determines whether operators + in the expression are left-associative. + + Returns: + List of ResidueView's or gti's (ints). + """ + val, selex_info = self._select( + expression, + unstructured=allow_unstructured, + left_associativity=left_associativity, + ) + + # make a list of ResidueView or gti's + chains = dict() + for i in val: + a = selex_info["all_atoms"][i] + chains[a.cix] = a.atom.chain + result = [chains[rix] for rix in sorted(chains.keys())] + + return result + + def _select( + self, + expression: str, + unstructured: bool = False, + left_associativity: bool = True, + ): + # Build some helpful data structures to support _selex_eval + @dataclass(frozen=True) + class MappableAtom: + atom: AtomView + aix: int + rix: int + cix: int + + def __hash__(self) -> int: + return self.aix + + # first collect all real atoms + all_atoms = [None] * self.num_atoms() + cix, rix, aix = 0, 0, 0 + for chain in self.chains(): + for residue in chain.residues(): + for atom in residue.atoms(): + all_atoms[aix] = MappableAtom(atom, aix, rix, cix) + aix = aix + 1 + + # for residues that do not have atoms, add a dummy atom with no location + # or name; that way, we can still select the residue even though selection + # algebra fundamentally works on atoms + if residue.num_atoms() == 0: + view = DummyAtomView(residue) + view.dummy = True + # make more room at the end of the list since as this is an "extra" atom + all_atoms.append(None) + all_atoms[aix] = MappableAtom(view, aix, rix, cix) + aix = aix + 1 + rix = rix + 1 + cix = cix + 1 + + _selex_info = {"all_atoms": all_atoms} + _selex_info["all_indices_set"] = set([a.aix for a in all_atoms]) + + # fmt: off + # make an expression tree object + tree = ExpressionTreeEvaluator( + ["hyd", "all", "none"], + ["not", "byres", "bychain", "first", "last", + "chain", "authchain", "segid", "namesel", "gti", "resix", "resid", + "authresid", "resname", "re", "x", "y", "z", "b", "icode", "name"], + ["and", "or", "around", "saround"], + eval_function=partial(self._selex_eval, _selex_info), + left_associativity=left_associativity, + debug=False, + ) + # fmt: on + + # evaluate the expression + val = tree.evaluate(expression) + + # if we are not looking to select unstructured residues, remove any dummy + # atoms. NOTE: making dummy atoms can still impact what structured atoms + # are selected (e.g., consider `saround` relative to an unstructured residue) + if not unstructured: + val = { + i for i in val if not hasattr(_selex_info["all_atoms"][i].atom, "dummy") + } + + return val, _selex_info + + def save_selection( + self, + expression: Optional[str] = None, + gti: Optional[List[int]] = None, + selname: str = "_default", + allow_unstructured=False, + left_associativity: bool = True, + ): + """Performs a selection on the System according to the given + selection string and saves the indices of residues involved in + the result (global template indices) under the given name. + + Args: + expression (str): (optional) selection expression. + gti (list of int): (optional) list of gti to define selection expression + selname (str): selection name. + allow_unstructured (bool): If True (default is False), will allow + unstructured residues to be selected. + left_associativity (bool, optional): determines whether operators + in the expression are left-associative. + """ + if gti is not None: + if expression is not None: + warnings.warn( + f"Expression and gti are both not null, expression will be ignored" + f" and gti will be used!" + ) + result = sorted(gti) + else: + result = self.select_residues( + expression, + allow_unstructured=allow_unstructured, + left_associativity=left_associativity, + gti=True, + ) + + # save the list of gti's + self._selections[selname] = result + + def get_selected(self, selname: str = "_default"): + """Returns the list of gti saved under the specified name. + + Args: + selname (str): selection name. + + Returns: + List of global template indices. + """ + if selname not in self._selections: + raise Exception( + f"selection by name '{selname}' does not exist in the System" + ) + return self._selections[selname] + + def has_selection(self, selname: str = "_default"): + """Returns whether the given named selection exists. + + Args: + selname (str): selection name. + + Returns: + Whether the selection exists in the System. + """ + return selname in self._selections + + def get_selection_names(self): + """Returns the list of all currently stored named selections.""" + return list(self._selections.keys()) + + def remove_selection(self, selname: str = "_default"): + """Deletes the selection under the specified name. + + Args: + selname (str): selection name. + """ + if selname not in self._selections: + raise Exception( + f"selection by name '{selname}' does not exist in the System" + ) + del self._selections[selname] + + def _selex_eval(self, _selex_info, op: str, left, right): + def _is_numeric(string: str) -> bool: + try: + float(string) + return True + except ValueError: + return False + + def _is_int(string: str) -> bool: + try: + int(string) + return True + except ValueError: + return False + + def _unpack_operands(operands, dests): + assert len(operands) == len(dests) + unpacked = [None] * len(operands) + succ = True + for i, (operand, dest) in enumerate(zip(operands, dests)): + if dest is None: + if operand is not None: + succ = False + break + elif dest == "result": + if not (isinstance(operand, dict) and "result" in operand): + succ = False + break + unpacked[i] = operand["result"] + elif dest == "string": + if not (len(operand) == 1 and isinstance(operand[0], str)): + succ = False + break + unpacked[i] = operand[0] + elif dest == "strings": + if not ( + isinstance(operand, list) + and all([isinstance(val, str) for val in operands]) + ): + succ = False + break + unpacked[i] = operands + elif dest == "float": + if not (len(operand) == 1 and _is_numeric(operand[0])): + succ = False + break + unpacked[i] = float(operand[0]) + elif dest == "floats": + if not ( + isinstance(operand, list) + and all([_is_numeric(val) for val in operands]) + ): + succ = False + break + unpacked[i] = [float(val) for val in operands] + elif dest == "range": + test = _parse_range(operand) + if test is None: + succ = False + break + unpacked[i] = test + elif dest == "int": + if not (len(operand) == 1 and _is_int(operand[0])): + succ = False + break + unpacked[i] = int(operand[0]) + elif dest == "ints": + if not ( + isinstance(operand, list) + and all([_is_int(val) for val in operands]) + ): + succ = False + break + unpacked[i] = [int(val) for val in operands] + elif dest == "int_range": + test = _parse_int_range(operand) + if test is None: + succ = False + break + unpacked[i] = test + elif dest == "int_range_string": + test = _parse_int_range(operand, string=True) + if test is None: + succ = False + break + unpacked[i] = test + return unpacked, succ + + def _parse_range(operands: list): + """Parses range information given a list of operands that were originally separated + by spaces. Allowed range expressiosn are of the form: `< n`, `> n`, `n:m` with + optional spaces allowed between operands.""" + if not ( + isinstance(operands, list) + and all([isinstance(opr, str) for opr in operands]) + ): + return None + operand = "".join(operands) + if operand.startswith(">") or operand.startswith("<"): + if not _is_numeric(operand[1:]): + return None + num = float(operand[1:]) + if operand.startswith(">"): + test = lambda x, cut=num: x > cut + else: + test = lambda x, cut=num: x < cut + elif ":" in operand: + parts = operand.split(":") + if (len(parts) != 2) or not all([_is_numeric(p) for p in parts]): + return None + parts = [float(p) for p in parts] + test = lambda x, lims=parts: lims[0] < x < lims[1] + elif _is_numeric(operand): + target = float(operand) + test = lambda x, t=target: x == t + else: + return None + return test + + def _parse_int_range(operands: list, string: bool = False): + """Parses range of integers information given a list of operands that were + originally separated by spaces. Allowed range expressiosn are of the form: + `n`, `n-m`, `n+m`, with optional spaces allowed anywhere and combinations + also allowed (e.g., "n+m+s+r-p+a").""" + if not ( + isinstance(operands, list) + and all([isinstance(opr, str) for opr in operands]) + ): + return None + operand = "".join(operands) + operands = operand.split("+") + ranges = [] + for operand in operands: + m = re.fullmatch("(.*\d)-(.+)", operand) + if m: + if not all([_is_int(g) for g in m.groups()]): + return None + r = range(int(m.group(1)), int(m.group(2)) + 1) + ranges.append(r) + else: + if not _is_int(operand): + return None + if string: + ranges.append(set([operand])) + else: + ranges.append(set([int(operand)])) + if string: + ranges = [[str(x) for x in r] for r in ranges] + test = lambda x, ranges=ranges: any([x in r for r in ranges]) + return test + + # evaluate expression and store result in list `result` + result = set() + if op in ("and", "or"): + (Si, Sj), succ = _unpack_operands([left, right], ["result", "result"]) + if not succ: + return None + if op == "and": + result = set(Si).intersection(set(Sj)) + else: + result = set(Si).union(set(Sj)) + elif op == "not": + (_, S), succ = _unpack_operands([left, right], [None, "result"]) + if not succ: + return None + result = _selex_info["all_indices_set"].difference(S) + elif op == "all": + (_, _), succ = _unpack_operands([left, right], [None, None]) + if not succ: + return None + result = _selex_info["all_indices_set"] + elif op == "none": + (_, _), succ = _unpack_operands([left, right], [None, None]) + if not succ: + return None + elif op == "around": + (S, rad), succ = _unpack_operands([left, right], ["result", "float"]) + if not succ: + return None + + # Convert to numpy for distance calculation + atom_indices = np.asarray( + [ + ai.aix + for ai in _selex_info["all_atoms"] + for xi in ai.atom.locations() + ] + ) + X_i = np.asarray( + [ + [xi.x, xi.y, xi.z] + for ai in _selex_info["all_atoms"] + for xi in ai.atom.locations() + ] + ) + X_j = np.asarray( + [ + [xi.x, xi.y, xi.z] + for j in S + for xi in _selex_info["all_atoms"][j].atom.locations() + ] + ) + D = np.sqrt(((X_j[np.newaxis, :, :] - X_i[:, np.newaxis, :]) ** 2).sum(-1)) + ix_match = (D <= rad).sum(1) > 0 + match_hits = atom_indices[ix_match] + result = set(match_hits.tolist()) + elif op == "saround": + (S, srad), succ = _unpack_operands([left, right], ["result", "int"]) + if not succ: + return None + for j in S: + aj = _selex_info["all_atoms"][j] + rj = aj.rix + for ai in _selex_info["all_atoms"]: + if aj.atom.residue.chain != ai.atom.residue.chain: + continue + ri = ai.rix + if abs(ri - rj) <= srad: + result.add(ai.aix) + elif op == "byres": + (_, S), succ = _unpack_operands([left, right], [None, "result"]) + if not succ: + return None + gtis = set() + for j in S: + gtis.add(_selex_info["all_atoms"][j].rix) + for a in _selex_info["all_atoms"]: + if a.rix in gtis: + result.add(a.aix) + elif op == "bychain": + (_, S), succ = _unpack_operands([left, right], [None, "result"]) + if not succ: + return None + cixs = set() + for j in S: + cixs.add(_selex_info["all_atoms"][j].cix) + for a in _selex_info["all_atoms"]: + if a.cix in cixs: + result.add(a.aix) + elif op in ("first", "last"): + (_, S), succ = _unpack_operands([left, right], [None, "result"]) + if not succ: + return None + if op == "first": + mi = min([_selex_info["all_atoms"][i].aix for i in S]) + else: + mi = max([_selex_info["all_atoms"][i].aix for i in S]) + result.add(mi) + elif op == "name": + (_, name), succ = _unpack_operands([left, right], [None, "string"]) + if not succ: + return None + for a in _selex_info["all_atoms"]: + if a.atom.name == name: + result.add(a.aix) + elif op in ("re", "hyd"): + if op == "re": + (_, regex), succ = _unpack_operands([left, right], [None, "string"]) + else: + (_, _), succ = _unpack_operands([left, right], [None, None]) + regex = "[0123456789]?H.*" + if not succ: + return None + ex = re.compile(regex) + for a in _selex_info["all_atoms"]: + if a.atom.name is not None and ex.fullmatch(a.atom.name): + result.add(a.aix) + elif op in ("chain", "authchain", "segid"): + (_, match_id), succ = _unpack_operands([left, right], [None, "string"]) + if not succ: + return None + if op == "chain": + prop = "cid" + elif op == "authchain": + prop = "authid" + elif op == "segid": + prop = "segid" + for a in _selex_info["all_atoms"]: + if getattr(a.atom.residue.chain, prop) == match_id: + result.add(a.aix) + elif op == "resid": + (_, test), succ = _unpack_operands([left, right], [None, "int_range"]) + if not succ: + return None + for a in _selex_info["all_atoms"]: + if test(a.atom.residue.num): + result.add(a.aix) + elif op in ("resname", "icode"): + (_, match_id), succ = _unpack_operands([left, right], [None, "string"]) + if not succ: + return None + if op == "resname": + prop = "name" + elif op == "icode": + prop = "icode" + for a in _selex_info["all_atoms"]: + if getattr(a.atom.residue, prop) == match_id: + result.add(a.aix) + elif op == "authresid": + (_, test), succ = _unpack_operands( + [left, right], [None, "int_range_string"] + ) + if not succ: + return None + for a in _selex_info["all_atoms"]: + if test(a.atom.residue.authid): + result.add(a.aix) + elif op == "gti": + (_, test), succ = _unpack_operands([left, right], [None, "int_range"]) + if not succ: + return None + for a in _selex_info["all_atoms"]: + if test(a.rix): + result.add(a.aix) + elif op in ("x", "y", "z", "b", "occ"): + (_, test), succ = _unpack_operands([left, right], [None, "range"]) + if not succ: + return None + prop = op + if op == "b": + prop = "B" + for a in _selex_info["all_atoms"]: + for loc in a.atom.locations(): + if test(getattr(loc, prop)): + result.add(a.aix) + break + elif op == "namesel": + (_, selname), succ = _unpack_operands([left, right], [None, "string"]) + if not succ: + return None + if selname not in self._selections: + return None + gtis = set(self._selections[selname]) + for a in _selex_info["all_atoms"]: + if a.rix in gtis: + result.add(a.aix) + else: + return None + + return {"result": result} + + def __getitem__(self, chain_idx: int): + """Returns the chain at the given index.""" + return self.get_chain(chain_idx) + + def add_chain( + self, + cid: str, + segid: str = None, + authid: str = None, + entity_id: int = None, + auto_rename: bool = True, + at: int = None, + ): + """Adds a new chain to the System and returns a reference to it. + + Args: + cid (str): Chain ID. + segid (str): Segment ID. + authid (str): Author chain ID. + entity_id (int, optional): Entity ID of the entity corresponding to this chain. + auto_rename (bool, optional): If True, will pick a unique chain ID if the specified + one clashes with an already existing chain. + + Returns: + AtomView object corresponding to the index. + """ + if auto_rename: + cid = self._pick_unique_chain_name(cid) + if segid is None: + segid = cid + if authid is None: + authid = cid + if at is None: + at = self.num_chains() + self._chains.append({"cid": cid, "segid": segid, "authid": authid}) + self._chain_entities.append(entity_id) + else: + self._chains.insert(at, {"cid": cid, "segid": segid, "authid": authid}) + self._chain_entities.insert(at, entity_id) + return ChainView(at, self) + + def _append_residue(self, name: str, num: int, authid: str, icode: str): + """Add a new residue to the end this System. Internal method, do not use. + + Args: + name (str): Residue name. + num (int): Residue number (i.e., residue ID). + authid (str): Author residue ID. + icode (str): Insertion code. + + Returns: + Global index to the newly added residue. + """ + self._chains.append_child( + {"name": name, "resnum": num, "authresid": authid, "icode": icode} + ) + return len(self._residues) - 1 + + def _append_atom( + self, + name: str, + het: bool, + x: float = None, + y: float = None, + z: float = None, + occ: float = None, + B: float = None, + alt: str = None, + ): + """Adds a new atom to the end of this System. Internal method, do not use. + + Args: + name (str): Atom name. + het (bool): Whether it is a hetero-atom. + x, y, z (float): Atom location coordinates. + occ (float): Occupancy. + B (float): B-factor. + alt (str): Alternative position character. + + Returns: + Global index to the newly added atom. + """ + self._residues.append_child({"name": name, "het": het}) + return len(self._atoms) - 1 + + def _append_location(self, x, y, z, occ, B, alt): + """Adds a location to the end of this System. Internal method, do not use. + + Args: + x, y, z (float): coordinates of the location. + occ (float): occupancy for the location. + B (float): B-factor for the location. + alt (str): alternative location character. + + Returns: + Global index to the newly added location. + """ + self._atoms.append_child({"coor": [x, y, z, occ, B], "alt": alt}) + return len(self._locations) - 1 + + def add_new_entity(self, entity: SystemEntity, chain_indices: list): + """Adds a new entity to the list contained within the System and + assigns chains with provided indices to this entity. + + Args: + entity (SystemEntity): The new entity to add to the System. + chain_indices (list): a list of Chain indices for chains to + assign to this entity. + + Returns: + The entity ID of the newly added entity. + """ + new_entity_id = len(self._entities) + while new_entity_id in self._entities: + new_entity_id = new_entity_id + 1 + self._entities[new_entity_id] = entity + for ci in chain_indices: + self._chain_entities[ci] = new_entity_id + return new_entity_id + + def delete_entity(self, entity_id: int): + """Deletes the entity with the specified ID. Takes care to unlink + any chains belonging to this entity from it. + + Args: + entity_id (int): Entity ID. + """ + chain_indices = self.get_chains_of_entity(entity_id) + for ci in chain_indices: + self._chain_entities[ci] = None + del self._entities[entity_id] + + def _pick_unique_chain_name(self, hint: str, verbose=False): + goodNames = list( + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" + ) + taken = set([chain.cid for chain in self.chains()]) + + # first try to pick a conforming chain name (single alpha-numeric character) + for cid in [hint] + goodNames: + if cid not in taken: + return cid + if verbose: + warnings.warn( + "ran out of reasonable single-letter chain names, will use more than one character (PDB sctructure may be repeating chain IDs upon writing, but should still have unique segment IDs)!" + ) + + # if that does not work, get a longer chain name + for i in range(-1, len(goodNames)): + # first try to expand the original chain ID + base = hint if i < 0 else goodNames[i : i + 1] + if base == "": + continue + for k in range(1000): + longName = f"{base}{k}" + if longName not in taken: + return longName + raise Exception( + "ran out of even multi-character chain names; PDB structure appears to have an enormous number of chains" + ) + + def _ensure_unique_entity(self, ci: int): + """Any time we need to update some piece of information about a Chain that + relates to its entity (e.g., sequence info or hetero info), we cannot just + update it directly because other Chains may be pointing to the same entity. + This function checks for any other chains pointing to the same entity as the + specified chain, and (if so) assigns the given chain to a new (duplicate) + entity and returns its new ID. This clears the way updates of this Chain's entity. + + Args: + ci (int): Index of the Chain for which we are trying to update + entity information. + + Returns: + entity ID for either a newly created entity mapped to the Chain or its + initial entity ID if no other chains point to the same entity. + """ + chain = self.get_chain(ci) + entity_id = chain.get_entity_id() + if entity_id is None: + return entity_id + + # see if any other chains point to the same entity + unique = True + for other in self.chains(): + if (other != chain) and (entity_id == other.get_entity_id()): + unique = False + break + if unique: + return entity_id + + # if so, we need to make a new entity and point the chain to it + new_entity = copy.deepcopy(self._entities[entity_id]) + new_entity_id = self.add_new_entity(new_entity, [ci]) + return new_entity_id + + def num_chains(self): + """Returns the number of chains in the System.""" + return len(self._chains) + + def num_chains_of_entity(self, entity_id: int): + """Returns the number of chains of a given entity. + + Args: + entity_id (int): Entity ID. + + Returns: + number of chains mapping to the entity. + """ + + return sum([entity_id == eid for eid in self._chain_entities]) + + def num_molecules_of_entity(self, entity_id: int): + if self._entities[entity_id].is_polymer(): + return self.num_chains_of_entity(entity_id) + cixs = [ci for (ci, id) in enumerate(self._chain_entities) if id == entity_id] + return sum([self[ci].num_residues() for ci in cixs]) + + def num_entities(self): + """Returns the number of entities in the System.""" + return len(self._entities) + + def num_residues(self): + """Returns the number of residues in the System.""" + return len(self._residues) + + def num_structured_residues(self): + """Returns the number of residues with any structure information.""" + return sum([chain.num_structured_residues() for chain in self.chains()]) + + def num_atoms(self): + """Returns the number of atoms in the System.""" + return len(self._atoms) + + def num_structured_atoms(self): + """Returns the number of atoms with any location information.""" + num = 0 + for chain in self.chains(): + for residue in chain.residues(): + for atom in residue.atoms(): + num = num + (atom.num_locations() > 0) + return num + + def num_atom_locations(self): + """Returns the number of atom locations. Note that an atom can have + multiple (alternative) locations and this functions counts all. + """ + return len(self._locations) + + def num_models(self): + """Returns the number of models in the System. A model is effectively + a conformation of the molecular system and each System object can have + an arbitrary number of different conformations. + """ + return len(self._extra_models) + 1 + + def swap_model(self, i: int): + """Swaps the model at index `i` with the current model (i.e., the + model at index 0). + + Args: + i (int): Model index + """ + if i == 0: + return + if i < 0 or i >= self.num_models(): + raise Exception(f"model index {i} out of range") + tmp = self._locations + self._locations = self._extra_models[i - 1] + self._extra_models[i - 1] = tmp + + def add_model(self, other: System): + """Adds a new model to the System by taking the current model from the + specified System `other`. Note that `other` and the present System + must have the same number of atom locations of matching atom and + residue names. + + Args: + other (System): The System to take the model from. + """ + if len(self._locations) != len(other._locations): + raise Exception( + f"System has {len(self._locations)} atom locations while {len(other._locations)} were provided" + ) + self._extra_models.append(other._locations.copy()) + self._extra_models[-1].set_parent(self._atoms) + + def add_model_from_X(self, X: torch.Tensor): + """Adds a new model to the System with given coordinates. + + Args: + X (torch.Tensor): Coordinate tensor of shape + `(residues, atoms (4 or 14), coordinates (3))` + """ + if len(self._locations) != X.numel() / 3: + raise Exception( + f"System has {len(self._locations)} atom locations while provided tensor shape is {X.shape}" + ) + X = X.detach().cpu() + self._extra_models.append(self._locations.copy()) + self._extra_models[-1]["coor"][:, 0:3] = X.flatten(0, 1) + return None + + def num_assemblies(self): + """Returns the number of biological assemblies defined in this System.""" + return len(self._assembly_info.assemblies) + + @staticmethod + def from_CIF_string(cif_string: str): + """Initializes and returns a System object from a CIF string.""" + import io + + f = io.StringIO(cif_string) + return System._read_cif(f)[0] + + @staticmethod + def from_CIF(input_file: str): + """Initializes and returns a System object from a CIF file.""" + f = open(input_file, "r") + return System._read_cif(f)[0] + + @staticmethod + def _read_cif(f, strict=False): + def _warn_or_error(strict: bool, msg: str): + if strict: + raise Exception(msg) + else: + warnings.warn(msg) + + is_read = { + part: False for part in ["coors", "entities", "sequence", "entity_poly"] + } + category = "" + (in_loop, success) = (False, True) + peeked = sp.PeekedLine("", 0) + # number of molecules per entity prescribed in the CIF file + num_of_mols = dict() + + system = System("system") + while sp.peek_line(f, peeked): + if peeked.line.startswith("#"): + # nothing to do, skip comments + sp.advance(f, peeked) + elif peeked.line.startswith("data_"): + # nothing to do, this is the beginning of the file + sp.advance(f, peeked) + elif peeked.line.startswith("loop_"): + in_loop = True + category = "" + sp.advance(f, peeked) + else: + (cat, name, val) = ("", "", "") + if peeked.line.startswith("_"): + (cat, name, val) = sp.star_item_parse(peeked.line) + if cat != category: + if category != "": + in_loop = False + category = cat + + if (cat == "_entry") and (name == "id"): + if val != "": + system.name = val + sp.advance(f, peeked) + elif cat == "_entity_poly": + if is_read["entity_poly"]: + raise Exception("entity_poly block encountered multiple times") + tab = sp.star_read_data(f, ["entity_id", "type"], in_loop) + for row in tab: + ent_id = int(row[0]) - 1 + if ent_id not in system._entities: + system._entities[ent_id] = SystemEntity( + None, None, row[1], None, None + ) + else: + system._entities[ent_id]._polymer_type = row[1] + is_read["entity_poly"] = True + elif cat == "_entity": + if is_read["entities"]: + raise Exception( + f"entities block encountered multiple times: {peeked.line}" + ) + tab = sp.star_read_data( + f, + ["id", "type", "pdbx_description", "pdbx_number_of_molecules"], + in_loop, + ) + for row in tab: + ent_id = int(row[0]) - 1 + if ent_id not in system._entities: + system._entities[ent_id] = SystemEntity( + row[1], row[2], None, None, None + ) + else: + system._entities[ent_id]._type = row[1] + system._entities[ent_id]._desc = row[2] + if row[3].isnumeric(): + num_of_mols[ent_id] = int(row[3]) + is_read["entities"] = True + elif cat == "_entity_poly_seq": + if is_read["sequence"]: + raise Exception(f"sequence block encountered multiple times") + tab = sp.star_read_data( + f, ["entity_id", "num", "mon_id", "hetero"], in_loop + ) + (seq, het) = ([], []) + for i in range(len(tab)): + # accumulate sequence information until we reach the end or a new entity ID + seq.append(tab[i][2]) + het.append(tab[i][3].startswith("y")) + if (i == len(tab) - 1) or (tab[i][0] != tab[i + 1][0]): + ent_id = int(tab[i][0]) - 1 + system._entities[ent_id]._seq = seq + system._entities[ent_id]._het = het + (seq, het) = ([], []) + is_read["sequence"] = True + elif cat == "_pdbx_struct_assembly": + tab = sp.star_read_data(f, ["id", "details"], in_loop) + for row in tab: + system._assembly_info.assemblies[row[0]] = {"details": row[1]} + elif cat == "_pdbx_struct_assembly_gen": + tab = sp.star_read_data( + f, ["assembly_id", "oper_expression", "asym_id_list"], in_loop + ) + for row in tab: + assembly = system._assembly_info.assemblies[row[0]] + if "instructions" not in assembly: + assembly["instructions"] = [] + chain_ids = [cid.strip() for cid in row[2].strip().split(",")] + assembly["instructions"].append( + {"oper_expression": row[1], "chains": chain_ids} + ) + elif cat == "_pdbx_struct_oper_list": + tab = sp.star_read_data( + f, + [ + "id", + "type", + "name", + "matrix[1][1]", + "matrix[1][2]", + "matrix[1][3]", + "matrix[2][1]", + "matrix[2][2]", + "matrix[2][3]", + "matrix[3][1]", + "matrix[3][2]", + "matrix[3][3]", + "vector[1]", + "vector[2]", + "vector[3]", + ], + in_loop, + ) + for row in tab: + system._assembly_info.operations[ + row[0] + ] = SystemAssemblyInfo.make_operation( + row[1], row[2], row[3:12], row[12:15] + ) + elif cat == "_generate_selections": + tab = sp.star_read_data(f, ["name", "indices"], in_loop) + for row in tab: + system._selections[row[0]] = [ + int(gti.strip()) for gti in row[1].strip().split() + ] + elif cat == "_generate_labels": + tab = sp.star_read_data(f, ["name", "index", "value"], in_loop) + for row in tab: + if row[0] not in system._labels: + system._labels[row[0]] = dict() + idx = int(row[1]) + system._labels[row[0]][int(row[1])] = row[2] + elif cat == "_atom_site": + if is_read["coors"]: + raise Exception(f"ATOM_SITE block encountered multiple times") + # this section is special as it cannot have quoted blocks (because some atom names have the single quote character in them) + tab = sp.star_read_data( + f, + [ + "group_PDB", + "id", + "label_atom_id", + "label_alt_id", + "label_comp_id", + "label_asym_id", + "label_entity_id", + "label_seq_id", + "pdbx_PDB_ins_code", + "Cartn_x", + "Cartn_y", + "Cartn_z", + "occupancy", + "B_iso_or_equiv", + "pdbx_PDB_model_num", + "auth_seq_id", + "auth_asym_id", + ], + in_loop, + cols=False, + has_blocks=False, + ) + + groupCol = 0 + idxCol = 1 + atomNameCol = 2 + altIdCol = 3 + resNameCol = 4 + chainNameCol = 5 + entityIdCol = 6 + seqIdCol = 7 + insCodeCol = 8 + xCol = 9 + yCol = 10 + zCol = 11 + occCol = 12 + bCol = 13 + modelCol = 14 + authSeqIdCol = 15 + authChainNameCol = 16 + + ( + atom, + residue, + chain, + prev_chain, + prev_residue, + prev_atom, + prev_entity_id, + prev_seq_id, + prev_auth_seq_id, + ) = (None, None, None, None, None, None, None, None, None) + loc = None # first model location + aIdx = 0 + for i in range(len(tab)): + if i == 0: + first_model = tab[i][modelCol] + prev_model = first_model + elif (tab[i][modelCol] != prev_model) or ( + tab[i][modelCol] != first_model + ): + if tab[i][modelCol] != prev_model: + aIdx = 0 + num_loc = system.num_atom_locations() + # setting the default value to None allows us to tell when the + # same coordinate in a subsequent model was not specified (e.g., + # when an alternative coordinate is not specified) + system._extra_models.append(system._new_locations()) + prev_model = tab[i][modelCol] + locations_generator = (l for l in system.locations()) + + loc = next(locations_generator, None) + if aIdx >= num_loc: + _warn_or_error( + strict, + f"at atom id: {tab[i][idxCol]} -- too many atoms in model {tab[i][modelCol]} relative to first model {first_model}", + ) + success = False + system._extra_models.clear() + break + + # check that the atoms correspond + same = ( + (loc is not None) + and (tab[i][chainNameCol] == loc.atom.residue.chain.cid) + and (tab[i][resNameCol] == loc.atom.residue.name) + and ( + int( + sp.star_value( + tab[i][seqIdCol], loc.atom.residue.num + ) + ) + == loc.atom.residue.num + ) + and (tab[i][atomNameCol] == loc.atom.name) + ) + if not same: + _warn_or_error( + strict, + f"at atom id: {tab[i][idxCol]} -- atoms in model {tab[i][modelCol]} do not correspond exactly to atoms in first model", + ) + success = False + system._extra_models.clear() + break + + coor = [ + float(tab[i][c]) + for c in [xCol, yCol, zCol, occCol, bCol] + ] + system._extra_models[-1]["coor"][aIdx] = coor + system._extra_models[-1]["alt"][aIdx] = sp.star_value( + tab[i][altIdCol], " " + )[0] + aIdx = aIdx + 1 + continue + + # new chain? + if ( + (chain is None) + or (prev_entity_id != tab[i][entityIdCol]) + or (tab[i][chainNameCol] != chain.cid) + ): + authid = ( + tab[i][authChainNameCol] + if (tab[i][authChainNameCol] != "") + else tab[i][chainNameCol] + ) + chain = system.add_chain( + tab[i][chainNameCol], + tab[i][chainNameCol], + authid, + int(tab[i][entityIdCol]) - 1, + ) + + # new residue + if ( + (residue is None) + or (chain != prev_chain) + or (prev_seq_id != tab[i][seqIdCol]) + or (prev_auth_seq_id != tab[i][authSeqIdCol]) + ): + resnum = ( + int(tab[i][seqIdCol]) + if sp.star_value_defined(tab[i][seqIdCol]) + else chain.num_residues() + 1 + ) + ri = system._append_residue( + tab[i][resNameCol], + resnum, + tab[i][authSeqIdCol], + sp.star_value(tab[i][insCodeCol], " ")[0], + ) + residue = ResidueView(ri, chain) + + # usually will be a new atom, but may be an alternative coordinate + # TODO: this only covers cases where alternative atom coordinates are listed next to each other, + # but sometimes they are not -- need to actively use the altIdCol information + x, y, z, occ, B = [ + float(tab[i][col]) + for col in [xCol, yCol, zCol, occCol, bCol] + ] + alt = sp.star_value(tab[i][altIdCol], " ")[0] + if ( + (atom is None) + or (residue != prev_residue) + or (tab[i][atomNameCol] != atom.name) + ): + ai = system._append_atom( + tab[i][atomNameCol], (tab[i][groupCol] == "HETATM") + ) + atom = AtomView(ai, residue) + system._append_location(x, y, z, occ, B, alt) + + prev_chain = chain + prev_residue = residue + prev_entity_id = tab[i][entityIdCol] + prev_seq_id = tab[i][seqIdCol] + prev_auth_seq_id = tab[i][authSeqIdCol] + is_read["coors"] = True + else: + sp.advance(f, peeked) + + # fill in any "missing" polymer chains (e.g., chains with no visible density + # or known structure, but which are nevertheless present) + for entity_id in num_of_mols: + if system._entities[entity_id].is_polymer(): + rem = num_of_mols[entity_id] - system.num_chains_of_entity(entity_id) + for _ in range(rem): + # the chain will get renamed to avoid clashes + system.add_chain("A", None, None, entity_id, auto_rename=True) + + # fill in missing residues (i.e., those that exist in the entity but not + # the atomistic section) + for chain in system.chains(): + entity = chain.get_entity() + if not entity.is_polymer() or entity._seq is None: + continue + k = 0 + for ri in range(len(entity._seq)): + cur_res = chain.get_residue(k) if k < chain.num_residues() else None + if (cur_res is None) or (cur_res.num > ri + 1): + # insert new residue to correspond to entity monomer with index ri + chain.add_residue(entity._seq[ri], ri + 1, str(ri + 1), " ", at=k) + elif cur_res.num < ri + 1: + _warn_or_error( + strict, f"inconsistent numbering in chain {chain.cid}" + ) + break + k = k + 1 + + # do an entity-to-structure sequence check for all chains + for chain in system.chains(): + if not chain.check_sequence(): + _warn_or_error( + strict, + f"chain {chain.cid} did not pass sequence check against corresponding entity", + ) + + system._reindex() + return system, success + + @staticmethod + def from_PDB_string(cif_string: str, options=""): + """Initializes and returns a System object from a PDB string.""" + import io + + f = io.StringIO(cif_string) + sys = System._read_pdb(f, options) + sys.name = "from_string" + return sys + + @staticmethod + def from_PDB(input_file: str, options=""): + """Initializes and returns a System object from a PDB file.""" + f = open(input_file, "r") + sys = System._read_pdb(f, options) + sys.name = input_file + return sys + + @staticmethod + def _read_pdb(f, strict=False, options=""): + def _to_float(strval, default): + v = default + try: + v = float(strval) + except: + pass + return v + + last_resnum = None + last_resname = None + last_icode = None + last_chain_id = None + last_alt = None + chain = None + residue = None + + # flag to indicate that chain terminus was reached. Initialize to true so as to create a new chain upon reading the first atom. + ter = True + + # various parsing options (the wonders of dealing with the good-old PDB format) + # and any user-specified overrides + options = options.upper() + # use segment IDs to name chains instead of chain IDs? (useful when the latter + # are absent OR when too many chains, so need multi-letter names) + usese_gid = True if ("USESEGID" in options) else False + + # the PDB file was written by CHARMM (slightly different format) + charmm_format = True if ("CHARMM" in options) else False + + # upon reading, convert from all-hydrogen topology (param22 and higher) to + # the CHARMM19 united-atom topology (matters for HIS protonation states) + charmm19_format = True if ("CHARMM19" in options) else False + + # make sure chain IDs are unique, even if they are not unique in the read file + uniq_chain_ids = False if ("ALLOW DUPLICATE CIDS" in options) else True + + # rename CD in ILE to CD1 (as is standard in PDB, but not some MM packages) + fix_Ile_CD = False if ("ALLOW ILE CD" in options) else True + + # consequtive residues that differ only in their insertion code will be treated + # as separate residues + icodes_as_sep_res = True + + # if true, will not pay attention to TER lines in deciding when chains end/begin + ignore_ter = True if ("IGNORE-TER" in options) else False + + # report various warnings when weird things are found and fixed? + verbose = False if ("QUIET" in options) else True + + chains_to_rename = [] + + # read line by line and build the System + system = System("system") + all_system = system + model_index = 0 + for line in f: + line = line.strip() + if line.startswith("ENDMDL"): + # merge the last read model with the overall System + if model_index: + try: + all_system.add_model(system) + except Exception as e: + warnings.warn( + f"error when adding model {model_index + 1}: {str(e)}, skipping model..." + ) + system = System("system") + model_index = model_index + 1 + last_resnum = None + last_resname = None + last_icode = None + last_chain_id = None + last_alt = None + chain = None + residue = None + continue + if line.startswith("END"): + break + if line.startswith("MODEL"): + # new model + continue + if line.startswith("TER") and not ignore_ter: + ter = True + continue + if not (line.startswith("ATOM") or line.startswith("HETATM")): + continue + + """ Now read atom record. Sometimes PDB lines are too short (if they do not contain some + of the last optional columns). We don't want to read past the end of the string!""" + line += " " * 100 + atominx = int(line[6:11]) + atomname = line[12:16].strip() + alt = line[16:17] + resname = line[17:21].strip() + chain_id = line[21:22].strip() + resnum = int(line[23:27]) if charmm_format else int(line[22:26]) + icode = " " if charmm_format else line[26:27] + x = float(line[30:38]) + y = float(line[38:46]) + z = float(line[46:54]) + seg_id = line[72:76].strip() + B = _to_float(line[60:66], 0.0) + occ = _to_float(line[54:60], 0.0) + het = line.startswith("HETATM") + + # use segment ID's instead of chain ID's? + if usese_gid: + chain_id = seg_id + elif (chain_id == "") and (len(seg_id) > 0) and seg_id[0].isalnum(): + # use first character of segment name if no chain name is specified, a segment ID + # is specified, and the latter starts with an alphanumeric character + chain_id = seg_id[0:1] + + # create a new chain object, if necessary + if (chain_id != last_chain_id) or ter: + cid_used = system.get_chain_by_id(chain_id) is not None + chain = system.add_chain(chain_id, seg_id, chain_id, auto_rename=False) + # non-unique chains will be automatically renamed (unless the user specified not to rename chains), BUT we need to + # remember the name that was actually read, since this name is what will be used to determine when the next chain comes + if uniq_chain_ids and cid_used: + chain.cid = chain.cid + f"|to rename {len(chains_to_rename)}" + if model_index == 0: + chains_to_rename.append(chain) + if verbose: + warnings.warn( + "chain name '" + + chain_id + + "' was repeated while reading, will rename at the end..." + ) + + # start to count residue numbers in this chain + last_resnum = None + last_resname = None + ter = False + + if charmm19_format: + if resname == "HSE": + resname = "HSD" # neutral HIS, proton on ND1 + if resname == "HSD": + resname = "HIS" # neutral HIS, proton on NE2 + if resname == "HSC": + resname = "HSP" # doubley-protonated +1 HIS + + # many PDB files in the Protein Data Bank call the delta carbon of isoleucine CD1, but + # the convention in basically all MM packages is to call it CD, since there is only one + if fix_Ile_CD and (resname == "ILE") and (atomname == "CD"): + atomname = "CD1" + + # if necessary, make a new residue + really_new_atom = True # is this a truely new atom, as opposed to an alternative position? + if ( + (resnum != last_resnum) + or (resname != last_resname) + or (icodes_as_sep_res and (icode != last_icode)) + ): + # this corresponds to a case, where the alternative location flag is being used to + # designate two (or more) different possible amino acids at a particular position + # (e.g., where the density is not clear to assign one). In this case, we shall keep + # only the first option, because we don't know any better. But we need to separate + # this from the case, where we end up here because we are trying to separate residues + # by insertion code. + if ( + (resnum == last_resnum) + and (resname != last_resname) + and (alt != last_alt) + and (not icodes_as_sep_res or (icode == last_icode)) + ): + continue + + residue = chain.add_residue( + resname, chain.num_residues() + 1, str(resnum), icode[0] + ) + elif alt != " ": + # if this is not a new residue AND the alternative location flag is specified, + # figure out if another location for this atom has already been given. If not, + # then treat this as the "primary" location, and whatever other locations + # are specified will be treated as alternatives. + a = residue.find_atom(atomname) + if a is not None: + really_new_atom = False + a.add_location(x, y, z, occ, B, alt[0]) + + # if necessary, make a new atom + if really_new_atom: + a = residue.add_atom(atomname, het, x, y, z, occ, B, alt[0]) + + # remember previous values for determining whether something interesting happens next + last_resnum = resnum + last_icode = icode + last_resname = resname + last_chain_id = chain_id + last_alt = alt + + # take care of renaming any chains that had duplicate IDs + for chain in chains_to_rename: + parts = chain.cid.split("|") + assert ( + len(parts) > 1 + ), "something went wrong when renaming a chain at the end of reading" + name = all_system._pick_unique_chain_name(parts[0], verbose) + chain.cid = name + if len(name): + chain.segid = name + + # add an entity for each chain (copy from chain information) + for ci, chain in enumerate(all_system.chains()): + seq = [None] * chain.num_residues() + het = [None] * chain.num_residues() + for ri, res in enumerate(chain.residues()): + seq[ri] = res.name + het[ri] = all(a.het for a in res.atoms()) + entity_type, polymer_type = SystemEntity.guess_entity_and_polymer_type(seq) + entity = SystemEntity( + entity_type, f"chain {chain.cid}", polymer_type, seq, het + ) + all_system.add_new_entity(entity, [ci]) + + return all_system + + def to_CIF(self, output_file: str): + """Writes the System to a CIF file.""" + f = open(output_file, "w") + self._write_cif(f) + + def to_CIF_string(self): + """Returns a CIF string representing the System.""" + import io + + f = io.StringIO("") + self._write_cif(f) + cif_str = f.getvalue() + f.close() + return cif_str + + def _write_cif(self, f): + # fmt: off + _specials_atom_names = [ + "MG", "CL", "FE", "ZN", "MN", "NI", "SE", "CU", "BR", "CO", "AS", + "BE", "RU", "RB", "ZR", "OS", "SR", "GD", "MO", "AU", "AG", "PT", + "AL", "XE", "BE", "CS", "EU", "IR", "AM", "TE", "BA", "SB" + ] + # fmt: on + _ambiguous_atom_names = ["CA", "CD", "NA", "HG", "PB"] + + def _guess_type(atom_name, res_name): + if len(atom_name) > 0 and atom_name[0] == '"': + atom_name = atom_name.replace('"', "") + if atom_name[:2] in _specials_atom_names: + return atom_name[:2] + else: + if atom_name in _ambiguous_atom_names and res_name == atom_name: + return atom_name + elif atom_name == "UNK": + return "X" + return atom_name[:1] + + entry_id = self.name.strip() + if entry_id == "": + entry_id = "system" + f.write( + "data_GNR8\n#\n" + + "_entry.id " + + sp.star_string_escape(entry_id) + + "\n#\n" + ) + + # write entities table + sp.star_loop_header_write( + f, "_entity", ["id", "type", "pdbx_description", "pdbx_number_of_molecules"] + ) + for id, entity in self._entities.items(): + num_mol = self.num_molecules_of_entity(id) + f.write( + f"{id + 1} {sp.star_string_escape(entity._type)} {sp.star_string_escape(entity._desc)} {num_mol}\n" + ) + f.write("#\n") + + # write entity polymer sequences + sp.star_loop_header_write( + f, "_entity_poly_seq", ["entity_id", "num", "mon_id", "hetero"] + ) + for id, entity in self._entities.items(): + if entity._seq is not None: + for i, (res, het) in enumerate(zip(entity._seq, entity._het)): + f.write(f"{id + 1} {i + 1} {res} {'y' if het else 'n'}\n") + f.write("#\n") + + # write entity polymer types + sp.star_loop_header_write(f, "_entity_poly", ["entity_id", "type"]) + for id, entity in self._entities.items(): + if entity.is_polymer(): + f.write(f"{id + 1} {sp.star_string_escape(entity._polymer_type)}\n") + f.write("#\n") + + if self.num_assemblies(): + assemblies = self._assembly_info.assemblies + ops = self._assembly_info.operations + # assembly info table + sp.star_loop_header_write(f, "_pdbx_struct_assembly", ["id", "details"]) + for assembly_id, assembly in assemblies.items(): + f.write(f"{assembly_id} {sp.star_string_escape(assembly['details'])}\n") + f.write("#\n") + + # assembly generation instructions table + sp.star_loop_header_write( + f, + "_pdbx_struct_assembly_gen", + ["assembly_id", "oper_expression", "asym_id_list"], + ) + for assembly_id, assembly in assemblies.items(): + for instruction in assembly["instructions"]: + chain_list = ",".join([str(ci) for ci in instruction["chains"]]) + f.write( + f"{assembly_id} {sp.star_string_escape(instruction['oper_expression'])} {chain_list}\n" + ) + f.write("#\n") + + # symmetry operations table + sp.star_loop_header_write( + f, + "_pdbx_struct_oper_list", + [ + "id", + "type", + "name", + "matrix[1][1]", + "matrix[1][2]", + "matrix[1][3]", + "matrix[2][1]", + "matrix[2][2]", + "matrix[2][3]", + "matrix[3][1]", + "matrix[3][2]", + "matrix[3][3]", + "vector[1]", + "vector[2]", + "vector[3]", + ], + ) + for op_id, op in ops.items(): + f.write( + f"{op_id} {sp.star_string_escape(op['type'])} {sp.star_string_escape(op['name'])} " + ) + f.write( + f"{float(op['matrix'][0][0]):g} {float(op['matrix'][0][1]):g} {float(op['matrix'][0][2]):g} " + ) + f.write( + f"{float(op['matrix'][1][0]):g} {float(op['matrix'][1][1]):g} {float(op['matrix'][1][2]):g} " + ) + f.write( + f"{float(op['matrix'][2][0]):g} {float(op['matrix'][2][1]):g} {float(op['matrix'][2][2]):g} " + ) + f.write( + f"{float(op['vector'][0]):g} {float(op['vector'][1]):g} {float(op['vector'][2]):g}\n" + ) + f.write("#\n") + + sp.star_loop_header_write( + f, + "_atom_site", + [ + "group_PDB", + "id", + "label_atom_id", + "label_alt_id", + "label_comp_id", + "label_asym_id", + "label_entity_id", + "label_seq_id", + "pdbx_PDB_ins_code", + "Cartn_x", + "Cartn_y", + "Cartn_z", + "occupancy", + "B_iso_or_equiv", + "pdbx_PDB_model_num", + "auth_seq_id", + "auth_asym_id", + "type_symbol", + ], + ) + idx = -1 + for model_index in range(self.num_models()): + self.swap_model(model_index) + for chain, entity_id in zip(self.chains(), self._chain_entities): + authchainid = ( + chain.authid if sp.star_value_defined(chain.authid) else chain.cid + ) + for residue in chain.residues(): + authresid = ( + residue.authid + if sp.star_value_defined(residue.authid) + else residue.num + ) + for atom in residue.atoms(): + idx = idx + 1 + for location in atom.locations(): + # this means this coordinate was not specified for this model + if not location.defined(): + continue + + coor = location.coor_info + f.write("HETATM " if atom.het else "ATOM ") + f.write( + f"{idx + 1} {atom.name} {sp.atom_site_token(location.alt)} " + ) + entity_id_str = ( + f"{entity_id + 1}" if entity_id is not None else "?" + ) + f.write( + f"{residue.name} {chain.cid} {entity_id_str} {residue.num} " + ) + f.write( + f"{sp.atom_site_token(residue.icode)} {coor[0]:g} {coor[1]:g} {coor[2]:g} " + ) + f.write(f"{coor[3]:g} {coor[4]:g} {model_index} ") + f.write(f"{authresid} {authchainid}\n") + f.write( + f"{authresid} {authchainid} {_guess_type(atom.name, residue.name)}\n" + ) + self.swap_model(model_index) + f.write("#\n") + + # write out selections + if len(self._selections): + sp.star_loop_header_write(f, "_generate_selections", ["name", "indices"]) + for name, indices in self._selections.items(): + f.write( + f"{sp.star_string_escape(name)} \"{' '.join([str(i) for i in indices])}\"\n" + ) + f.write("#\n") + + # write out labels + if len(self._labels): + sp.star_loop_header_write(f, "_generate_labels", ["name", "index", "value"]) + for category, label_dict in self._labels.items(): + for gti, label in label_dict.items(): + f.write( + f"{sp.star_string_escape(category)} {gti} {sp.star_string_escape(label)}\n" + ) + f.write("#\n") + + def to_PDB(self, output_file: str, options: str = ""): + """Writes the System to a PDB file. + + Args: + output_file (str): output PDB file name. + options (str, optional): a string specifying various options for + the writing process. The presence of certain sub-strings will + trigger specific behaviors. Currently recognized sub-strings + include "CHARMM", "CHARMM19", "CHARMM22", "RENUMBER", "NOEND", + "NOTER", and "NOALT". This option is case-insensitive. + """ + f = open(output_file, "w") + self._write_pdb(f, options) + + def to_PDB_string(self, options=""): + """Writes the System to a PDB string. The options string has the same + interpretation as with System::toPDB. + """ + import io + + f = io.StringIO("") + self._write_pdb(f, options) + cif_str = f.getvalue() + f.close() + return cif_str + + def _write_pdb(self, f, options=""): + def _pdb_line(loc: AtomLocationView, ai: int, ri=None, rn=None, an=None): + if rn is None: + rn = loc.atom.residue.name + if ri is None: + ri = loc.atom.residue.num + if an is None: + an = loc.atom.name + icode = loc.atom.residue.icode + cid = loc.atom.residue.chain.cid + if len(cid) > 1: + cid = cid[0] + segid = loc.atom.residue.chain.segid + if len(segid) > 4: + segid = segid[0:4] + + # atom name placement is different when it is 4 characters long + if len(an) < 4: + an_str = " %-.3s" % an + else: + an_str = "%.4s" % an + + # moduli are used to make sure numbers do not go over prescribe field widths + # (this is not enforced by sprintf like with strings) + line = ( + "%6s%5d %-4s%c%-4s%.1s%4d%c %8.3f%8.3f%8.3f%6.2f%6.2f %.4s" + % ( + "HETATM" if loc.atom.het else "ATOM ", + ai % 100000, + an_str, + loc.alt, + rn, + cid, + ri % 10000, + icode, + loc.x, + loc.y, + loc.z, + loc.occ, + loc.B, + segid, + ) + ) + + return line + + # various formating options (the wonders of dealing with the good-old PDB format) + # and user-defined overrides + options = options.upper() + # the PDB file is intended for use in CHARMM or some other MM package + charmmFormat = True if "CHARMM" in options else False + + # upon writing, convert from all-hydrogen topology (param 22 and higher) + # to CHARMM19 united-atom topology (matters for HIS protonation states) + charmm19Format = True if "CHARMM19" in options else False + + # upon writing, convert from CHARMM19 united-atom topology to all-hydrogen + # param 22 topology (matters for HIS protonation states). Also works for + # converting generic PDB files downloaded from the PDB. + charmm22Format = True if "CHARMM22" in options else False + + # upon writing, renumber residue and atom names to start from 1 and go in order + renumber = True if "RENUMBER" in options else False + + # do not write END at the end of the PDB file (e.g., useful for + # concatenating chains from several structures) + noend = True if "NOEND" in options else False + + # do not demark the end of each chain with TER (this is not _really_ + # necessary, assuming chain names are unique, and it is sometimes nice + # not to have extra lines other than atoms) + noter = True if "NOTER" in options else False + + # write alternative locations by default + writeAlt = True if "NOALT" in options else False + + # upon writing, convert to a generic PDB naming convention (no + # protonation state specified for HIS) + genericFormat = False + + if charmm19Format and charmm22Format: + raise Exception( + "CHARMM 19 and 22 formatting options cannot be specified together" + ) + + atomIndex = 1 + for ci, chain in enumerate(self.chains()): + for ri, residue in enumerate(chain.residues()): + for ai, atom in enumerate(residue.atoms()): + # dirty details of formating for MM purposes converting + atomname = atom.name + resname = residue.name + if charmmFormat: + if (residue.name == "ILE") and (atom.name == "CD1"): + atomname = "CD" + if (atom.name == "O") and (ri == chain.num_residues() - 1): + atomname = "OT1" + if (atom.name == "OXT") and (ri == chain.num_residues() - 1): + atomname = "OT2" + if residue.name == "HOH": + resname = "TIP3" + + if charmm19Format: + if residue.name == "HSD": # neutral HIS, proton on ND1 + resname = "HIS" + if residue.name == "HSE": # neutral HIS, proton on NE2 + resname = "HSD" + if residue.name == "HSC": # doubley-protonated +1 HIS + resname = "HSP" + elif charmm22Format: + """This will convert from CHARMM19 to CHARMM22 as well as from a generic downlodaded + * PDB file to one ready for use in CHARMM22. The latter is because in the all-hydrogen + * topology, HIS protonation state must be explicitly specified, so there is no HIS per se. + * Whereas in typical downloaded PDB files HIS is used for all histidines (usually, one + * does not even really know the protonation state). Whether sometimes people do specify it + * nevertheless, and what naming format they use to do so, I am not sure (welcome to the + * PDB file format). But certainly almost always it is just HIS. Below HIS is renamed to + * HSD, the neutral form with proton on ND1. This is an assumption; not a perfect one, but + * something needs to be assumed. Doing this renaming will make the PDB file work in MM + * packages with the all-hydrogen model.""" + if residue.name == "HSD": # neutral HIS, proton on NE2 + resname = "HSE" + if residue.name == "HIS": # neutral HIS, proton on ND1 + resname = "HSD" + if residue.name == "HSP": # doubley-protonated +1 HIS + resname = "HSC" + elif genericFormat: + if residue.name in ["HSD", "HSP", "HSE", "HSC"]: + resname = "HIS" + if (residue.name == "ILE") and (atom.name == "CD"): + atomname = "CD1" + + # write the atom line + for li in range(atom.num_locations()): + if renumber: + f.write( + _pdb_line( + atom.get_location(li), + atomIndex, + ri=ri + 1, + rn=resname, + an=atomname, + ) + + "\n" + ) + else: + f.write( + _pdb_line( + atom.get_location(li), + atomIndex, + rn=resname, + an=atomname, + ) + + "\n" + ) + atomIndex = atomIndex + 1 + + if not noter and (ri == chain.num_residues() - 1): + f.write("TER\n") + if not noend and (ci == self.num_chains() - 1): + f.write("END\n") + + def canonicalize_protein( + self, + level=2, + drop_coors_unknowns=False, + drop_coors_missing_backbone=False, + filter_by_entity=False, + ): + """Canonicalize the calling System object (in place) by assuming that it represents + a protein molecular system. Different canonicalization rigor and options + can be specified but are all optional. + + Args: + level (int): Canonicalization level that determines which nonstandard-to-standard + residue mappings are performed. Possible values are 1, 2 or 3, with 2 being + the default and higher values meaning more rigorous (and less conservative) + canonicalization. With level 1, only truly equivalent mappings are performed + (e.g., different His protonation states are mapped to the canonical residue + name HIS that does not specify protonation). Level 2 adds to this some less + exact but still quite close mappings--i.e., seleno-methionine (MSE) and seleno- + cystine (SEC) to methionine (MET) and cystine (CYS). Level 3 further adds + even less equivalent but still reasonable mappings--i.e., phosphorylated SER, + THR, TYR, and HIS to their unphosphorylated counterparts as well as S-oxy Cys + to Cys. + drop_coors_unknowns (bool, optional): if True, will discard structural information + for all residues that are not natural or mappable under the current level. + NOTE: any sequence record for these residues (i.e., if they are part of a + polymer entity) will be preserved. + drop_coors_missing_backbone (bool, optional): if True, will discard structural + information for residues that do not have at least the N, CA, C, and O + backbone atoms. Same note applies regarding the full sequence record as in + the above. + filter_by_entity (bool, optional): if True, will remove any chains that do not + represent polymer/polypeptide entities. This is convenient for cases where a + System object has both protein and non-protein components. However, depending + on how the System object was generated, entity metadata may not have been filled, + so applying this canonicalization approach will remove the entire structure. + For this reason, the option is False by default. + """ + + def _mod_to_standard_aa_mappings( + less_standard: bool, almost_standard: bool, standard: bool + ): + # Perfectly corresponding to standard residues + standard_map = {"HSD": "HIS", "HSE": "HIS", "HSC": "HIS", "HSP": "HIS"} + + # Almost perfectly corresponding to standard residues: + # * MSE -- selenomethyonine; SEC -- selenocysteine + almost_standard_map = {"MSE": "MET", "SEC": "CYS"} + + # A little less perfectly corresponding pairings, but can be acceptable (depends): + # * HIP -- ND1-phosphohistidine; SEP -- phosphoserine; TPO -- phosphothreonine; + # * PTR -- o-phosphotyrosine. + less_standard_map = { + "HIP": "HIS", + "CSX": "CYS", + "SEP": "SER", + "TPO": "THR", + "PTR": "TYR", + } + + ret = dict() + if standard: + ret.update(standard_map) + if almost_standard: + ret.update(almost_standard_map) + if less_standard: + ret.update(less_standard_map) + return ret + + def _to_standard_aa_mappings( + less_standard: bool, almost_standard: bool, standard: bool + ): + # get the mapping between modifications and their corresponding standard forms + mapping = _mod_to_standard_aa_mappings( + less_standard, almost_standard, standard + ) + + # add mapping between standard names and themselves + from prtm.models.chroma import polyseq + + for aa in polyseq.canonical_amino_acids(): + mapping[aa] = aa + + return mapping + + less_standard, almost_standard, standard = False, False, False + if level == 3: + less_standard, almost_standard, standard = True, True, True + elif level == 2: + less_standard, almost_standard, standard = False, True, True + elif level == 1: + less_standard, almost_standard, standard = False, False, True + else: + raise Exception(f"unknown canonicalization level {level}") + + to_standard = _to_standard_aa_mappings(less_standard, almost_standard, standard) + + # NOTE: need to re-implement the canonicalization procedure such that it: + # 1. checks to make sure entity sequence and structure sequence agree (error if not) + # 2. goes over entities and looks for residues to rename, does the renaming on the entities + # and all chains simultaneously (so that no new entities are created) + # 3. then goes over the structured part and fixes atoms + + # For residue renamings, we will first record all edits and will perform them + # afterwards in one go, so we can judge whether any new entities have to be + # created. The dictionary `esidues_to_rename`` will be as follows: + # entity_id: { + # chain_index: [list of (residue index, rew name) tuples] + # } + chains_to_delete = [] + residues_to_rename = dict() + for ci, chain in enumerate(self.chains()): + entity = chain.get_entity() + if filter_by_entity: + if ( + (entity is None) + or (entity._type != "polymer") + or ("polypeptide" not in entity.polymer_time) + ): + chains_to_delete.append(chain) + continue + + # iterate in reverse order so we can safely delete any residues we find necessary + cleared_residues = 0 + for residue in reversed(list(chain.residues())): + aa = residue.name + delete_atoms = False + # canonicalize amino acid (delete structure if unknown, provided this was asked for) + if aa in to_standard: + aa_new = to_standard[aa] + if aa != aa_new: + # edit any atoms to reflect the mutation + if ( + (aa == "HSD") + or (aa == "HSE") + or (aa == "HSC") + or (aa == "HSP") + ) and (aa_new == "HIS"): + pass + elif ((aa == "MSE") and (aa_new == "MET")) or ( + (aa == "SEC") and (aa_new == "CYS") + ): + SE = residue.find_atom("SE") + if SE is not None: + if aa == "MSE": + SE.residue.rename("SD") + else: + SE.residue.rename("SG") + elif ( + ((aa == "HIP") and (aa_new == "HIS")) + or ((aa == "SEP") and (aa_new == "SER")) + or ((aa == "TPO") and (aa_new == "THR")) + or ((aa == "PTR") and (aa_new == "TYR")) + ): + # delete the phosphate group + for atomname in ["P", "O1P", "O2P", "O3P", "HOP2", "HOP3"]: + a = residue.find_atom(atomname) + if a is not None: + a.delete() + elif (aa == "CSX") and (aa_new == "CYS"): + a = residue.find_atom("OD") + if a is not None: + a.delete() + + # record residue renaming operation to be done later + entity_id = chain.get_entity_id() + if entity_id is None: + residue.rename(aa_new) + else: + if entity_id not in residues_to_rename: + residues_to_rename[entity_id] = dict() + if ci not in residues_to_rename[entity_id]: + residues_to_rename[entity_id][ci] = list() + residues_to_rename[entity_id][ci].append( + (residue.get_index_in_chain(), aa_new) + ) + else: + if aa == "ARG": + A = {an: None for an in ["CD", "NE", "CZ", "NH1", "NH2"]} + for an in A: + atom = residue.find_atom(an) + if atom is not None and atom.num_locations(): + A[an] = atom.get_location(0) + if all([a is not None for n, a in A.items()]): + dihe1 = System.dihedral( + A["CD"], A["NE"], A["CZ"], A["NH1"] + ) + dihe2 = System.dihedral( + A["CD"], A["NE"], A["CZ"], A["NH2"] + ) + if abs(dihe1) > abs(dihe2): + A["NH1"].name = "NH2" + A["NH2"].name = "NH1" + elif drop_coors_unknowns: + delete_atoms = True + + if not drop_coors_missing_backbone: + if not delete_atoms and not residue.has_full_backbone(): + delete_atoms = True + + if delete_atoms: + residue.delete_atoms() + cleared_residues += 1 + + # If we have deleted all residues in this chain, then this is probably not + # a protein chain, so get rid of it. Unless we are asked to pay attention to + # the entity type (i.e., whether it is peptidic), in which case the decision + # of whether to keep the chain would have been made previously. + if ( + not filter_by_entity + and (cleared_residues != 0) + and (cleared_residues == chain.num_residues()) + ): + chains_to_delete.append(chain) + + # rename residues differently depending on whether all chains of a given entity + # have the same set of renamings + for entity_id, ops in residues_to_rename.items(): + chain_indices = set(ops.keys()) + entity_chains = set(self.get_chains_of_entity(entity_id, by="index")) + unique_renames = set([tuple(v) for v in ops.values()]) + fork = True + if (chain_indices == entity_chains) and (len(unique_renames) == 1): + # we can rename without updating entities, because all entity chains are updated the same way + fork = False + for ci, renames in ops.items(): + chain = self.get_chain(ci) + for ri, new_name in renames: + chain.get_residue(ri).rename(new_name, fork_entity=fork) + + # now delete any chains + for chain in reversed(chains_to_delete): + chain.delete() + + self._reindex() + + def sequence(self, format="three-letter-list"): + """Returns the full sequence of this System, concatenated over all + chains in their order within the System. + + Args: + format (str): sequence format. Possible options are either + "three-letter-list" (default) or "one-letter-string". + + Returns: + List (default) or string. + """ + if format == "three-letter-list": + seq = [] + else: + seq = "" + + for chain in self.chains(): + seq = seq + chain.sequence(format) + return seq + + @staticmethod + def distance(a1: AtomLocationView, a2: AtomLocationView): + """Computes the distance between atom locations `a1` and `a2`.""" + v21 = a1.coors - a2.coors + return np.linalg.norm(v21) + + @staticmethod + def angle( + a1: AtomLocationView, a2: AtomLocationView, a3: AtomLocationView, radians=False + ): + """Computes the angle formed by three 3D points represented by AtomLocationView objects. + + Args: + a1, a2, a3 (AtomLocationView): three 3D points. + radian (bool, optional): if True (default False), will return the angle in radians. + Otherwise, in degrees. + + Returns: + Angle `a1`-`a2`-`a3`. + """ + v21 = a1.coors - a2.coors + v23 = a3.coors - a2.coors + v21 = v21 / np.linalg.norm(v21) + v23 = v23 / np.linalg.norm(v23) + c = np.dot(v21, v23) + return np.arctan2(np.sqrt(1 - c * c), c) * (1 if radians else 180.0 / np.pi) + + @staticmethod + def dihedral( + a1: AtomLocationView, + a2: AtomLocationView, + a3: AtomLocationView, + a4: AtomLocationView, + radians=False, + ): + """Computes the dihedral angle formed by four 3D points represented by AtomLocationView objects. + + Args: + a1, a2, a3, a4 (AtomLocationView): four 3D points. + radian (bool, optional): if True (default False), will return the angle in radians. + Otherwise, in degrees. + + Returns: + Dihedral angle `a1`-`a2`-`a3`-`a4`. + """ + AB = a1.coors - a2.coors + CB = a3.coors - a2.coors + DC = a4.coors - a3.coors + + if min([np.linalg.norm(p) for p in [AB, CB, DC]]) == 0.0: + raise Exception("some points coincide in dihedral calculation") + + ABxCB = np.cross(AB, CB) + ABxCB = ABxCB / np.linalg.norm(ABxCB) + DCxCB = np.cross(DC, CB) + DCxCB = DCxCB / np.linalg.norm(DCxCB) + + # the following is necessary for values very close to 1 but just above + dotp = np.dot(ABxCB, DCxCB) + if dotp > 1.0: + dotp = 1.0 + elif dotp < -1.0: + dotp = -1.0 + + angle = np.arccos(dotp) + if np.dot(ABxCB, DC) > 0: + angle *= -1 + if not radians: + angle *= 180.0 / np.pi + + return angle + + @staticmethod + def protein_backbone_atom_type(atom_name: str, no_hyd=True, by_name=True): + """Backbone atoms can be either nitrogens, carbons, oxigens, or hydrogens. + Specifically, possible known names in each category are: + 'N', 'NT' + 'CA', 'C', 'CY', 'CAY' + 'OY', 'O', 'OCT*', 'OXT', 'OT1', 'OT2' + 'H', 'HY*', 'HA*', 'HN', 'HT*', '1H', '2H', '3H' + """ + array = ["N", "CA", "C", "O", "H"] if by_name else [0, 1, 2, 3, 4] + if atom_name in ["N", "NT"]: + return array[0] + if atom_name == "CA": + return array[1] + if (atom_name == "C") or (atom_name == "CY"): + return array[2] + if atom_name in ["O", "OY", "OXT", "OT1", "OT2"] or atom_name.startswith("OCT"): + return array[3] + if not no_hyd: + if atom_name in ["H", "HA", "HN"]: + return array[4] + if atom_name.startswith("HT") or atom_name.startswith("HY"): + return array[4] + # Rosetta's N-terinal amine has hydrogens named 1H, 2H, and 3H + if ( + atom_name.startswith("1H") + or atom_name.startswith("2H") + or atom_name.startswith("3H") + ): + return array[4] + return None + + +@dataclass +class SystemEntity: + """A molecular entity represented in a molecular system.""" + + _type: str + _desc: str + _polymer_type: str + _seq: list + _het: list + + def is_polymer(self): + """Returns whether the entity represents a polymer.""" + return self._type == "polymer" + + @classmethod + def guess_entity_and_polymer_type(cls, seq: List): + is_poly = np.mean([polyseq.is_polymer_residue(res, None) for res in seq]) > 0.8 + polymer_type = None + if is_poly: + entity_type = "polymer" + for ptype in polyseq.polymerType: + if ( + np.mean([polyseq.is_polymer_residue(res, ptype) for res in seq]) + > 0.8 + ): + polymer_type = polyseq.polymer_type_name(ptype) + break + else: + entity_type = "unknown" + + return entity_type, polymer_type + + @property + def type(self): + return self._type + + @property + def description(self): + return self._desc + + @property + def polymer_type(self): + return self._polymer_type + + @property + def sequence(self): + return self._seq + + @property + def hetero(self): + return self._het + + +@dataclass +class BaseView: + """An abstract base "view" class for accessing different parts of System.""" + + _ix: int + _parent: object + + def get_index(self): + """Return the index of this atom location in its System.""" + return self._ix + + def is_valid(self): + return self._ix >= 0 and self._parent is not None + + def _delete(self): + at = self._ix - self.parent._siblings.child_index(self.parent._ix, 0) + self.parent._siblings.delete_child(self.parent._ix, at) + + @property + def parent(self): + return self._parent + + +@dataclass +class ChainView(BaseView): + """A Chain view, allowing hierarchical exploration and editing.""" + + def __init__(self, ix: int, system: System): + self._ix = ix + self._parent = system + self._siblings = system._chains + + def __str__(self): + return f"{self.cid} ({self.segid}/{self.authid}) -> {str(self.system)}" + + def residues(self): + for rn in range(self.num_residues()): + ri = self._siblings.child_index(self._ix, rn) + yield ResidueView(ri, self) + + def num_residues(self): + """Returns the number of residues in the Chain.""" + return self._siblings.num_children(self._ix) + + def num_structured_residues(self): + return sum([res.has_structure() for res in self.residues()]) + + def num_atoms(self): + return sum([res.num_atoms() for res in self.residues()]) + + def num_atom_locations(self): + return sum([res.num_atom_locations() for res in self.residues()]) + + def sequence(self, format="three-letter-list"): + """Returns the sequence of this chain. See `System::sequence()` for + possible formats. + """ + if format == "three-letter-list": + seq = [None] * self.num_residues() + for ri, residue in enumerate(self.residues()): + seq[ri] = residue.name + return seq + elif format == "one-letter-string": + from prtm.models.chroma import polyseq + + seq = [None] * self.num_residues() + for ri, residue in enumerate(self.residues()): + seq[ri] = polyseq.to_single(residue.name) + return "".join(seq) + else: + raise Exception(f"unknown sequence format {format}") + + def get_residue(self, ri: int): + """Get the residue at the specified index within the Chain. + + Args: + ri (int): Residue index within the Chain. + + Returns: + ResidueView object corresponding to the residue in question. + """ + if ri < 0 or ri >= self.num_residues(): + raise Exception( + f"residue index {ri} out of range for Chain, which has {self.num_residues()} residues" + ) + ri = self._siblings.child_index(self._ix, ri) + return ResidueView(ri, self) + + def get_residue_index(self, residue: ResidueView): + """Get the index of the given residue in this Chain.""" + return residue._ix - self._siblings.child_index(self._ix, 0) + + def get_atom(self, aidx: int): + """Get the atom at index `aidx` within this chain.""" + if aidx < 0: + raise Exception(f"negative atom index: {aidx}") + off = 0 + for residue in self.residues(): + na = residue.num_atoms() + if aidx < off + na: + return residue.get_atom(aidx - off) + off = off + na + raise Exception( + f"atom index {aidx} out of range for System, which has {self.num_atoms()} atoms" + ) + + def get_atoms(self): + """Return a list of all atoms in this chain.""" + atoms_views = [] + for residue in self.residues(): + atoms_views.extend(residue.get_atoms()) + return atoms_views + + def __getitem__(self, res_idx: int): + return self.get_residue(res_idx) + + def get_entity_id(self): + """Return the entity ID corresponding to this chain.""" + return self.system._chain_entities[self._ix] + + def get_entity(self): + """Return the entity this chain belongs to.""" + entity_id = self.get_entity_id() + if entity_id is None: + return None + return self.system._entities[entity_id] + + def check_sequence(self): + """Compare the list of residue names of this chain to the corresponding entity sequence record.""" + entity = self.get_entity() + if entity is not None and entity.is_polymer(): + if self.num_residues() != len(entity._seq): + return False + for res, ent_aan in zip(self.residues(), entity._seq): + if res.name != ent_aan: + return False + return True + + def add_residue(self, name: str, num: int, authid: str, icode: str = " ", at=None): + """Add a new residue to this chain. + + Args: + name (str): Residue name. + num (int): Residue number (i.e., residue ID). + authid (str): Author residue ID. + icode (str): Insertion code. + at (int, optional): Index at which to insert the residue. Default + is to append to the end of the chain (i.e., equivalent of ``at` + being equal to the present length of the chain). + """ + if at is None: + at = self.num_residues() + ri = self._siblings.insert_child( + self._ix, + at, + {"name": name, "resnum": num, "authresid": authid, "icode": icode}, + ) + return ResidueView(ri, self) + + def delete(self, keep_entity=False): + """Deletes this Chain from its System. + + Args: + keep_entity (bool, optional): If False (default) and if the chain + being deleted happens to be the last representative of the + entity it belongs to, the entity will be deleted. If True, the + entity will always be kept. + """ + # delete the mention of the chain from assembly information + self.system._assembly_info.delete_chain(self.cid) + + # optionally, delete the corresponding entity if no other chains point to it + if not keep_entity: + eid = self.get_entity_id() + if self.system.num_chains_of_entity(eid) == 0: + self.system.delete_entity(eid) + + self.system._chain_entities.pop(self._ix) + self._siblings.delete(self._ix) + self._ix = -1 # invalidate the view + + @property + def system(self): + return self._parent + + @property + def cid(self): + return self._siblings["cid"][self._ix] + + @property + def segid(self): + return self._siblings["segid"][self._ix] + + @property + def authid(self): + return self._siblings["authid"][self._ix] + + @cid.setter + def cid(self, val): + self._siblings["cid"][self._ix] = val + + @segid.setter + def segid(self, val): + self._siblings["segid"][self._ix] = val + + @authid.setter + def authid(self, val): + self._siblings["authid"][self._ix] = val + + +@dataclass +class ResidueView(BaseView): + """A Residue view, allowing hierarchical exploration and editing.""" + + def __init__(self, ix: int, chain: ChainView): + self._ix = ix + self._parent = chain + self._siblings = chain.system._residues + + def __str__(self): + return f"{self.name} {self.num} ({self.authid}) -> {str(self.chain)}" + + def atoms(self): + off = self._siblings.child_index(self._ix, 0) + for an in range(self.num_atoms()): + yield AtomView(off + an, self) + + def num_atoms(self): + return self._siblings.num_children(self._ix) + + def num_atom_locations(self): + return sum([a.num_locations() for a in self.atoms()]) + + def has_structure(self): + """Returns whether the atom has any structural information (i.e., one or more locations).""" + for a in self.atoms(): + if a.num_locations(): + return True + return False + + def get_atom(self, ai: int): + """Get the atom at the specified index within the Residue. + + Args: + atom_idx (int): Atom index within the Residue. + + Returns: + AtomView object corresponding to the atom in question. + """ + + if ai < 0 or ai >= self.num_atoms(): + raise Exception( + f"atom index {ai} out of range for Residue, which has {self.num_atoms()} atoms" + ) + ai = self._siblings.child_index(self._ix, ai) + return AtomView(ai, self) + + def get_atom_index(self, atom: AtomView): + """Get the index of the given atom in this Residue.""" + return atom._ix - self._siblings.child_index(self._ix, 0) + + def find_atom(self, name): + """Find and return the first atom (as AtomView object) with the given name + within the Residue or None.""" + for atom in self.atoms(): + if atom.name == name: + return atom + return None + + def __getitem__(self, atom_idx: int): + return self.get_atom(atom_idx) + + def get_index_in_chain(self): + """Return the index of the Residue in its parent Chain.""" + return self.chain.get_residue_index(self) + + def rename(self, new_name: str, fork_entity=True): + """Assigns the residue a new name with all proper updates. + + Args: + new_name (str): New residue name. + fork_entity (bool, optional): If True (default) and if parent + chain corresponds to an entity that has other chains + associated with it and there is a real renaming (i.e., + the old name is not the same as the new name), will + make a new (duplicate) entity for to this chain and + will edit the new one, leaving the old one unchanged. + If False, will not perform this regardless. NOTE: + setting this to False can create an inconsistent state + between chain and entity sequence information. + """ + entity_id = self.chain.get_entity_id() + if entity_id is not None: + entity = self.system._entities[entity_id] + ri = self.get_index_in_chain() + if fork_entity and (entity._seq[ri] != new_name): + ci = self.chain.get_index() + entity_id = self.system._ensure_unique_entity(ci) + entity = self.system._entities[entity_id] + entity._seq[ri] = new_name + self._siblings["name"][self._ix] = new_name + + def add_atom( + self, + name: str, + het: bool, + x: float = None, + y: float = None, + z: float = None, + occ: float = 1.0, + B: float = 0.0, + alt: str = " ", + at=None, + ): + """Adds a new atom to the residue (appending it at the end) and + returns an AtomView to it. If atom location information is + specified, will also add a location to the atom. + + Args: + name (str): Atom name. + het (bool): Whether it is a hetero-atom. + x, y, z (float): Atom location coordinates. + occ (float): Occupancy. + B (float): B-factor. + alt (str): Alternative position character. + at (int, optional): Index at which to insert the atom. Default + is to append to the end of the residue (i.e., equivalent of + ``at` being equal to the number of atoms in the residue). + + Returns: + AtomView object corresponding to the newly added atom. + """ + if at is None: + at = self.num_atoms() + ai = self._siblings.insert_child(self._ix, at, {"name": name, "het": het}) + atom = AtomView(ai, self) + + # now add a location to this atom + if x is not None: + atom.add_location(x, y, z, occ, B, alt) + + return atom + + def delete(self, fork_entity=True): + """Deletes this residue from its Chain/System. + + Args: + fork_entity (bool, optional): If True (default) and if parent + chain corresponds to an entity that has other chains + associated with it, will make a new (duplicate) entity + for to this chain and will edit the new one, leaving the + old one unchanged. If False, will not perform this. + NOTE: setting this to False can create an inconsistent state + between chain and entity sequence information. + """ + # update the entity (duplicating, if necessary) + entity_id = self.chain.get_entity_id() + if entity_id is not None: + entity = self.system._entities[entity_id] + ri = self.get_index_in_chain() + if fork_entity: + ci = self.chain.get_index() + entity_id = self.system._ensure_unique_entity(ci) + entity = self.system._entities[entity_id] + entity._seq.pop(ri) + + # delete the residue + self._delete() + self._ix = -1 # invalidate the view + + def delete_atoms(self, atoms=None): + """Delete either the specified list of atoms or all atoms from the residue. + + Args: + atoms (list, optional): List of AtomView objects corresponding to the + atoms to delete. If not specified, will delete all atoms in the residue. + """ + if atoms is None: + atoms = list(self.atoms()) + for atom in reversed(atoms): + if atom.residue != self: + raise Exception(f"Atom {atom} does not belong to Residue {self}") + atom.delete() + + @property + def chain(self): + return self._parent + + @property + def system(self): + return self.chain.system + + @property + def name(self): + return self._siblings["name"][self._ix] + + @property + def num(self): + return self._siblings["resnum"][self._ix] + + @property + def authid(self): + return self._siblings["authresid"][self._ix] + + @property + def icode(self): + return self._siblings["icode"][self._ix] + + def get_backbone(self, no_hyd=True): + """Assuming that this is a protein residue (i.e., an amino acid), returns the + list of atoms corresponding to the residue's backbone, in the order: + backbone amide (N), alpha carbon (CA), carbonyl carbon (C), carbonyl oxygen (O), + and amide hydrogen (H, optional). + + Args: + no_hyd (bool, optional): If True (default), will exclude the amide hydrogen + and only return four atoms. If False, will include the amide hydrogen. + + Returns: + A list with each entry being an AtomView object corresponding to the backbone + atom in the order above or None if the atom does not exist in the residue. + """ + bb = [None] * (4 if no_hyd else 5) + left = len(bb) + for atom in self.atoms(): + i = System.protein_backbone_atom_type(atom.name, no_hyd) + if i is None or bb[i] is not None: + continue + bb[i] = atom + left = left - 1 + if left == 0: + break + return bb + + def has_full_backbone(self, no_hyd=True): + """Assuming that this is a protein residue (i.e., an amino acid), returns + whether the residue harbors a structurally defined backbone (i.e., has + all backbone atoms each of which has location information). + + Args: + no_hyd (bool, optional): If True (default), will ignore whether the amide + hydrogen exists or not (if False will consider it). + + Returns: + Boolean indicating whether there is a full backbone in the residue. + """ + bb = self.get_backbone(no_hyd) + return all([(a is not None) and a.num_locations() for a in bb]) + + def delete_non_backbone(self, no_hyd=True): + """Assuming that this is a protein residue (i.e., an amino acid), deletes + all atoms except backbone atoms. + + Args: + no_hyd (bool, optional): If True (default), will not consider the amide + hydrogen as a backbone atom (if False will consider it). + """ + to_delete = [] + for atom in self.atoms(): + if System.protein_backbone_atom_type(atom.name, no_hyd) is None: + to_delete.append(atom) + self.delete_atoms(to_delete) + + +@dataclass +class AtomView(BaseView): + """An Atom view, allowing hierarchical exploration and editing.""" + + def __init__(self, ix: int, residue: ResidueView): + self._ix = ix + self._parent = residue + self._siblings = residue.system._atoms + + def __str__(self): + string = self.name + (" (HET) " if self.het else " ") + if self.num_locations() > 0: + string = string + str(self.get_location(0)) + string = string + f" ({self.num_locations()})" + return string + " -> " + str(self.residue) + + def locations(self): + off = self._siblings.child_index(self._ix, 0) + for ln in range(self.num_locations()): + yield AtomLocationView(off + ln, self) + + def num_locations(self): + return self._siblings.num_children(self._ix) + + def __getitem__(self, loc_idx: int): + return self.get_location(loc_idx) + + def get_location(self, li: int = 0): + """Returns the (li+1)-th location of the atom.""" + if li < 0 or li >= self.num_locations(): + raise Exception( + f"location index {li} out of range for Atom with {self.num_locations()} locations" + ) + li = self._siblings.child_index(self._ix, li) + return AtomLocationView(li, self) + + def add_location(self, x, y, z, occ=1.0, B=0.0, alt=" ", at=None): + """Adds a location to this atom, append it to the end. + + Args: + x, y, z (float): coordinates of the location. + occ (float): occupancy for the location. + B (float): B-factor for the location. + alt (str): alternative location character. + at (int, optional): Index at which to insert the location. Default + is to append at the end (i.e., equivalent of ``at` being equal + to the current number of locations). + """ + if at is None: + at = self.num_locations() + li = self._siblings.insert_child( + self._ix, at, {"coor": [x, y, z, occ, B], "alt": alt} + ) + return AtomLocationView(li, self) + + def delete(self): + """Deletes this atom from its Residue/Chain/System.""" + self._delete() + self._ix = -1 # invalidate the view + + @property + def residue(self): + return self._parent + + @property + def chain(self): + return self.residue.chain + + @property + def system(self): + return self.chain.system + + @property + def name(self): + return self._siblings["name"][self._ix] + + @property + def het(self): + return self._siblings["het"][self._ix] + + """Location information getters and setters operate on the default (first) + location for this atom and throw an index error if there are no locations.""" + + @property + def x(self): + if self._siblings.num_children(self._ix) == 0: + raise Exception("atom has no locations") + ix = self._siblings.child_index(self._ix, 0) + return self.system._locations["coor"][ix, 0] + + @property + def y(self): + if self._siblings.num_children(self._ix) == 0: + raise Exception("atom has no locations") + ix = self._siblings.child_index(self._ix, 0) + return self.system._locations["coor"][ix, 1] + + @property + def z(self): + if self._siblings.num_children(self._ix) == 0: + raise Exception("atom has no locations") + ix = self._siblings.child_index(self._ix, 0) + return self.system._locations["coor"][ix, 2] + + @property + def coors(self): + if self._siblings.num_children(self._ix) == 0: + raise Exception("atom has no locations") + ix = self._siblings.child_index(self._ix, 0) + return self.system._locations["coor"][ix, 0:3] + + @property + def occ(self): + if self._siblings.num_children(self._ix) == 0: + raise Exception("atom has no locations") + ix = self._siblings.child_index(self._ix, 0) + return self.system._locations["coor"][ix, 3] + + @property + def B(self): + if self._siblings.num_children(self._ix) == 0: + raise Exception("atom has no locations") + ix = self._siblings.child_index(self._ix, 0) + return self.system._locations["coor"][ix, 4] + + @property + def alt(self): + if self._siblings.num_children(self._ix) == 0: + raise Exception("atom has no locations") + ix = self._siblings.child_index(self._ix, 0) + return self.system._locations["alt"][ix] + + @x.setter + def x(self, val): + if self._siblings.num_children(self._ix) == 0: + raise Exception("atom has no locations") + ix = self._siblings.child_index(self._ix, 0) + self.system._locations["coor"][ix, 0] = val + + @y.setter + def y(self, val): + if self._siblings.num_children(self._ix) == 0: + raise Exception("atom has no locations") + ix = self._siblings.child_index(self._ix, 0) + self.system._locations["coor"][ix, 1] = val + + @z.setter + def z(self, val): + if self._siblings.num_children(self._ix) == 0: + raise Exception("atom has no locations") + ix = self._siblings.child_index(self._ix, 0) + self.system._locations["coor"][ix, 2] = val + + @occ.setter + def occ(self, val): + if self._siblings.num_children(self._ix) == 0: + raise Exception("atom has no locations") + ix = self._siblings.child_index(self._ix, 0) + self.system._locations["coor"][ix, 3] = val + + @B.setter + def B(self, val): + if self._siblings.num_children(self._ix) == 0: + raise Exception("atom has no locations") + ix = self._siblings.child_index(self._ix, 0) + self.system._locations["coor"][ix, 4] = val + + @alt.setter + def alt(self, val): + if self._siblings.num_children(self._ix) == 0: + raise Exception("atom has no locations") + ix = self._siblings.child_index(self._ix, 0) + self.system._locations["alt"][ix] = val + + +class DummyAtomView(AtomView): + """An dummy Atom view that can be attached to a residue but that does not + have any locations and with no other information.""" + + def __init__(self, residue: ResidueView): + self._ix = -1 + self._parent = residue + + def __str__(self): + return "DUMMY -> " + str(self.residue) + + def locations(self): + return + yield + + def num_locations(self): + return 0 + + def __getitem__(self, loc_idx: int): + return None + + def get_location(self, li: int = 0): + raise Exception(f"no locations in DUMMY atom") + + def add_location(self, x, y, z, occ, B, alt, at=None): + raise Exception(f"can't add no locations to DUMMY atom") + + @property + def residue(self): + return self._parent + + @property + def chain(self): + return self.residue.chain + + @property + def system(self): + return self.chain.system + + @property + def name(self): + return None + + @property + def het(self): + return None + + @property + def x(self): + raise Exception(f"no coordinates in DUMMY atom") + + @property + def y(self): + raise Exception(f"no coordinates in DUMMY atom") + + @property + def z(self): + raise Exception(f"no coordinates in DUMMY atom") + + @property + def occ(self): + raise Exception(f"no occupancy in DUMMY atom") + + @property + def B(self): + raise Exception(f"no B-factor in DUMMY atom") + + @property + def alt(self): + raise Exception(f"no alt flag in DUMMY atom") + + @x.setter + def x(self, val): + raise Exception(f"can't set coordinate for DUMMY atom") + + @y.setter + def y(self, val): + raise Exception(f"can't set coordinate for DUMMY atom") + + @z.setter + def z(self, val): + raise Exception(f"can't set coordinate for DUMMY atom") + + @occ.setter + def occ(self, val): + raise Exception(f"can't set occupancy for DUMMY atom") + + @B.setter + def B(self, val): + raise Exception(f"can't set B-factor for DUMMY atom") + + @alt.setter + def alt(self, val): + raise Exception(f"can't set alt flag for DUMMY atom") + + +@dataclass +class AtomLocationView(BaseView): + """An AtomLocation view, allowing hierarchical exploration and editing.""" + + def __init__(self, ix: int, atom: AtomView): + self._ix = ix + self._parent = atom + self._siblings = atom.system._locations + + def __str__(self): + return f"{self.x} {self.y} {self.z}" + + def swap(self, other: AtomLocationView): + """Swaps information between itself and the provided atom location. + + Args: + other (AtomLocationView): the other atom location to swap with. + """ + self.x, other.x = other.x, self.x + self.y, other.y = other.y, self.y + self.z, other.z = other.z, self.z + self.occ, other.occ = other.occ, self.occ + self.B, other.B = other.B, self.B + self.alt, other.alt = other.alt, self.alt + + def defined(self): + """Return whether this is a valid location.""" + return (self.x is not None) and (self.y is not None) and (self.z is not None) + + @property + def atom(self): + return self._parent + + @property + def residue(self): + return self.atom.residue + + @property + def chain(self): + return self.residue.chain + + @property + def system(self): + return self.chain.system + + @property + def x(self): + return self.system._locations["coor"][self._ix, 0] + + @property + def y(self): + return self.system._locations["coor"][self._ix, 1] + + @property + def z(self): + return self.system._locations["coor"][self._ix, 2] + + @property + def occ(self): + return self.system._locations["coor"][self._ix, 3] + + @property + def B(self): + return self.system._locations["coor"][self._ix, 4] + + @property + def alt(self): + return self.system._locations["alt"][self._ix] + + @property + def coors(self): + return np.array(self.system._locations["coor"][self._ix, 0:3]) + + @property + def coor_info(self): + return np.array(self.system._locations["coor"][self._ix]) + + @x.setter + def x(self, val): + self.system._locations["coor"][self._ix, 0] = val + + @y.setter + def y(self, val): + self.system._locations["coor"][self._ix, 1] = val + + @z.setter + def z(self, val): + self.system._locations["coor"][self._ix, 2] = val + + @coors.setter + def coors(self, val): + self.system._locations["coor"][self._ix, 0:3] = val + + @coor_info.setter + def coor_info(self, val): + self.system._locations["coor"][self._ix] = val + + @occ.setter + def occ(self, val): + self.system._locations["coor"][self._ix, 3] = val + + @B.setter + def B(self, val): + self.system._locations["coor"][self._ix, 4] = val + + @alt.setter + def alt(self, val): + self.system._locations["alt"][self._ix] = val + + +class ExpressionTreeEvaluator: + """A class for evaluating custom logical parenthetical expressions. The + implementation is very generic, supports nullary, unary, and binary + operators, and does not know anything about what the expressions actually + mean. Instead the class interprets the expression as a tree of sub- + expressions, governed by parentheses and operators, and traverses the + calling upon a user-specified evaluation function to evaluate leaf + nodes as the tree is gradually collapsed into a single node. This + can be used for evaluating set expressions, algebraic expressions, and + others. + + Args: + operators_nullary (list): A list of strings designating nullary operators + (i.e., operators that do not have any operands). E.g., if the language + describes selection algebra, these could be "hyd", "all", or "none"]. + operators_unary (list): A list of strings designating unary operators + (i.e., operators that have one operand, which must comes to the right + of the operator). E.g., if the language describes selection algebra, + these could be "name", "resid", or "chain". + operators_binary (list): A list of strings designating binary operators + (i.e., operators that have two operands, one on each side of the + operator). E.g., if the language describes selection algebra, thse + could be "and", "or", or "around". + eval_function (str): A function that is able to evaluate a leaf node of + the expression tree. It shall accept three parameters: + + operator (str): name of the operator + left: the left operand. Will be None if the left operand is missing or + not relevant. Otherwise, can be either a list of strings, which + should represent an evaluatable sub-expression corresponding to the + left operand, or the result of a prior evaluation of this function. + right: Same as `left` but for the right operand. + + The function should attempt to evaluate the resulting expression and + return None in the case of failing or a dictionary with the result of + the evaluation stored under key "result". + left_associativity (bool): If True (the default), operators are taken to be + left-associative. Meaning something like "A and B or C" is "(A and B) or C". + If False, the operators are taken to be right-associative, such that + the same expression becomes "A and (B or C)". NOTE: MST is right-associative + but often human intiution tends to be left-associative. + debug (bool): If True (default is false), will print a great deal of debugging + messages to help diagnose any evaluation problems. + """ + + def __init__( + self, + operators_nullary: list, + operators_unary: list, + operators_binary: list, + eval_function: function, + left_associativity: bool = True, + debug: bool = False, + ): + self.operators_nullary = operators_nullary + self.operators_unary = operators_unary + self.operators_binary = operators_binary + self.operators = operators_nullary + operators_unary + operators_binary + self.eval_function = eval_function + self.debug = debug + self.left_associativity = left_associativity + + def _traverse_expression_tree(self, E, i=0, eval_all=True, debug=False): + def _collect_operands(E, j): + # collect all operands before hitting an operator + operands = [] + for k in range(len(E[j:])): + if E[j + k] in self.operators: + k = k - 1 + break + operands.append(E[j + k]) + return operands, j + k + 1 + + def _find_matching_close_paren(E, beg: int): + c = 0 + for i in range(beg, len(E)): + if E[i] == "(": + c = c + 1 + elif E[i] == ")": + c = c - 1 + if c == 0: + return i + return None + + def _my_eval(op, left, right, debug=False): + if debug: + print( + f"\t-> evaluating {operand_str(left)} | {op} | {operand_str(right)}" + ) + result = self.eval_function(op, left, right) + if debug: + print(f"\t-> got result {operand_str(result)}") + return result + + def operand_str(operand): + if isinstance(operand, dict): + if "result" in operand and len(operand["result"]) > 15: + vec = list(operand["result"]) + beg = ", ".join([str(i) for i in vec[:5]]) + end = ", ".join([str(i) for i in vec[-5:]]) + return "{'result': " + f"{beg} ... {end} ({len(vec)} long)" + "}" + return str(operand) + return str(operand) + + left, right, op = None, None, None + if debug: + print(f"-> received {E[i:]}") + + while i < len(E): + if all([x is None for x in (left, right, op)]): + # first part can either be a left parenthesis, a left operand, a nullary operator, or a unary operator + if E[i] == "(": + end = _find_matching_close_paren(E, i) + if end is None: + return None, f"parenthesis imbalance starting with {E[i:]}" + # evaluate expression inside the parentheses, and it becomes the left operand + left, rem = self._traverse_expression_tree( + E[i + 1 : end], 0, eval_all=True, debug=debug + ) + if left is None: + return None, rem + i = end + 1 + if not eval_all: + return left, i + elif E[i] in self.operators_nullary: + # evaluate nullary op + left = _my_eval(E[i], None, None, debug) + if left is None: + return None, f"failed to evaluate nullary operator '{E[i]}'" + i = i + 1 + elif E[i] in self.operators_unary: + op = E[i] + i = i + 1 + elif E[i] in self.operators: + # an operator other than a unary operator cannot appear first + return None, f"unexpected binary operator in the context {E[i:]}" + else: + # if not an operator, then we are looking at operand(s) + left, i = _collect_operands(E, i) + elif (left is not None) and (op is None) and (right is None): + # we have a left operand and now looking for a binary operator + if E[i] not in self.operators_binary: + return ( + None, + f"expected end or a binary operator when got '{E[i]}' in expression: {E}", + ) + op = E[i] + i = i + 1 + elif ( + (left is None) and (op in self.operators_unary) and (right is None) + ) or ( + (left is not None) and (op in self.operators_binary) and (right is None) + ): + # we saw a unary operator before and now looking for a right operand, another unary operator, or a nullary operator + # OR + # we have a left operand and a binary operator before, now looking for a right operand, a unary operator, or a nullary operator + if ( + E[i] in (self.operators_nullary + self.operators_unary) + or E[i] == "(" + ): + right, i = self._traverse_expression_tree( + E, i, eval_all=not self.left_associativity, debug=debug + ) + if right is None: + return None, i + else: + right, i = _collect_operands(E, i) + + # We are now ready to evaluate, because: + # we have a unary operator and a right operand + # OR + # we have a left operand, a binary operator, and a right operand + result = _my_eval(op, left, right, debug) + if result is None: + return ( + None, + f"failed to evaluate operator '{op}' (in expression {E}) with operands {operand_str(left)} and {operand_str(right)}", + ) + if not eval_all: + return result, i + left = result + op, right = None, None + + else: + return ( + None, + f"encountered an unexpected condition when evaluating {E}: left is {operand_str(left)}, op is {op}, or right {operand_str(right)}", + ) + + if (op is not None) or (right is not None): + return None, f"expression ended unexpectedly" + if left is None: + return None, f"failed to evaluate expression: {E}" + + return left, i + + def evaluate(self, expression: str): + """Evaluates the expression and returns the result.""" + + def _split_tokens(expr): + # first split by parentheses (preserving the parentheses themselves) + parts = list(re.split("([()])", expr)) + # then split by space (getting rid of space) + return [ + t.strip() + for p in parts + for t in re.split("\s+", p.strip()) + if t.strip() != "" + ] + + # parse expression into tokens + E = _split_tokens(expression) + val, rem = self._traverse_expression_tree(E, debug=self.debug) + if val is None: + raise Exception( + f"failed to evaluate expression: '{expression}', reason: {rem}" + ) + + return val["result"] diff --git a/prtm/models/chroma/xcs.py b/prtm/models/chroma/xcs.py new file mode 100644 index 00000000..02d2c4c0 --- /dev/null +++ b/prtm/models/chroma/xcs.py @@ -0,0 +1,121 @@ +# Copyright Generate Biomedicines, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""XCS represents protein structure as a tuple of PyTorch tensors. + +The tensors in an XCS representation are: + + `X` (FloatTensor), the Cartesian coordinates representing the protein + structure with shape `(num_batch, num_residues, num_atoms, 3)`. The + `num_atoms` dimension can be one of two sizes: `num_atoms=4` for + backbone-only structures or `num_atoms=14` for all-atom structures + (excluding hydrogens). The first four atoms will always be + `N, CA, C, O`, and the meaning of the optional 10 additional atom + positions will vary based on the residue identity at + a given position. Atom orders for each amino acid are defined in + `constants.AA_GEOMETRY[TRIPLET_CODE]["atoms"]`. + + `C` (LongTensor), the chain map encoding per-residue chain assignments with + shape `(num_batch, num_residues)`.The chain map codes positions as `0` + when masked, poitive integers for chain indices, and negative integers + to represent missing residues (of the corresponding positive integers). + + `S` (LongTensor), the sequence of the protein as alphabet indices with + shape `(num_batch, num_residues)`. The standard alphabet is + `ACDEFGHIKLMNPQRSTVWY`, also defined in `residue_constants.alphabetical_restypes`. +""" + + +from functools import partial, wraps +from inspect import getfullargspec + +import torch +from torch.nn import functional as F + +try: + pass +except ImportError: + print("MST not installed!") + + +def validate_XCS(all_atom=None, sequence=True): + """Decorator factory that adds XCS validation to any function. + + Args: + all_atom (bool, optional): If True, requires that input structure + tensors have 14 residues per atom. If False, reduces to 4 residues + per atom. If None, applies no transformation on input structures. + sequence (bool, optional): If True, makes sure that if S and O are both + provided, that they match, i.e. that O is a one-hot version of S. + If only one of S or O is provided, the other is generated, and both + are passed. + """ + + def decorator(func): + @wraps(func) + def new_func(*args, **kwargs): + args = list(args) + arg_list = getfullargspec(func)[0] + tensors = {} + for var in ["X", "C", "S", "O"]: + try: + if var in kwargs: + tensors[var] = kwargs[var] + else: + tensors[var] = args[arg_list.index(var)] + except IndexError: # empty args_list + tensors[var] = None + except ValueError: # variable not an argument of function + if not sequence and var in ["S", "O"]: + pass + else: + raise Exception( + f"Variable {var} is required by validation but not defined!" + ) + if tensors["X"] is not None and tensors["C"] is not None: + if tensors["X"].shape[:2] != tensors["C"].shape[:2]: + raise ValueError( + f"X shape {tensors['X'].shape} does not match C shape" + f" {tensors['C'].shape}" + ) + if all_atom is not None and tensors["X"] is not None: + if all_atom and tensors["X"].shape[2] != 14: + raise ValueError("Side chain atoms missing!") + elif not all_atom: + if "X" in kwargs: + kwargs["X"] = tensors["X"][:, :, :4] + else: + args[arg_list.index("X")] = tensors["X"][:, :, :4] + if sequence and (tensors["S"] is not None or tensors["O"] is not None): + if tensors["O"] is None: + if "O" in kwargs: + kwargs["O"] = F.one_hot(tensors["S"], 20).float() + else: + args[arg_list.index("O")] = F.one_hot(tensors["S"], 20).float() + elif tensors["S"] is None: + if "S" in kwargs: + kwargs["S"] = tensors["O"].argmax(dim=2) + else: + args[arg_list.index("S")] = tensors["O"].argmax(dim=2) + else: + if not torch.allclose(tensors["O"].argmax(dim=2), tensors["S"]): + raise ValueError("S and O are both provided but don't match!") + return func(*args, **kwargs) + + return new_func + + return decorator + + +validate_XC = partial(validate_XCS, sequence=False) diff --git a/prtm/models/igfold/modeling.py b/prtm/models/igfold/modeling.py index 9e19a77f..a5ce8602 100644 --- a/prtm/models/igfold/modeling.py +++ b/prtm/models/igfold/modeling.py @@ -4,11 +4,11 @@ import numpy as np import torch from einops import rearrange -from prtm.constants.residue_constants import restype_order +from prtm.constants.residue_constants import PDB_CHAIN_IDS, restype_order from prtm.models.antiberty.modeling import _AntiBERTyBase from prtm.models.igfold import config from prtm.models.igfold.model import IgFold -from prtm.protein import PDB_CHAIN_IDS, Protein5 +from prtm.protein import Protein5 from prtm.utils import hub_utils __all__ = ["IgFoldForFolding"] diff --git a/prtm/models/pifold/featurizer.py b/prtm/models/pifold/featurizer.py index 0de5c77c..5426db86 100644 --- a/prtm/models/pifold/featurizer.py +++ b/prtm/models/pifold/featurizer.py @@ -3,7 +3,7 @@ import numpy as np import torch from prtm import protein -from prtm.constants.residue_constants import proteinmppn_restypes +from prtm.constants.residue_constants import alphabetical_restypes def featurize_structure( @@ -20,7 +20,7 @@ def featurize_structure( X = structure.atom_positions score = 100 * np.ones([1, length]) - S = np.asarray([proteinmppn_restypes.index(a) for a in sequence], dtype=np.int32) + S = np.asarray([alphabetical_restypes.index(a) for a in sequence], dtype=np.int32) # Add batch dimension X = X[None] diff --git a/prtm/models/pifold/modeling.py b/prtm/models/pifold/modeling.py index 884c684e..9aa20f5a 100644 --- a/prtm/models/pifold/modeling.py +++ b/prtm/models/pifold/modeling.py @@ -5,7 +5,7 @@ import torch import torch.nn.functional as F from prtm import protein -from prtm.constants.residue_constants import proteinmppn_restypes +from prtm.constants.residue_constants import alphabetical_restypes from prtm.models.pifold import config from prtm.models.pifold.featurizer import featurize_structure from prtm.models.pifold.model import PiFoldModel @@ -67,7 +67,7 @@ def __call__( """Design a protein sequence for a given structure.""" # Expects 3 atom protein structure structure = structure.to_protein4() - alphabet = "".join(proteinmppn_restypes) + alphabet = "".join(alphabetical_restypes) X, S, score, mask = featurize_structure(structure, device=self.device) ( diff --git a/prtm/models/pifold/simdesign_utils.py b/prtm/models/pifold/simdesign_utils.py index 1ebe5bcd..9291fc3d 100644 --- a/prtm/models/pifold/simdesign_utils.py +++ b/prtm/models/pifold/simdesign_utils.py @@ -4,6 +4,8 @@ import torch import torch.nn.functional as F +from prtm.constants.residue_constants import backbone4_atoms + # Thanks for StructTrans # https://github.com/jingraham/neurips19-graph-protein-design @@ -64,7 +66,7 @@ def _dihedrals(X, dihedral_type=0, eps=1e-7): def _hbonds(X, E_idx, mask_neighbors, eps=1e-3): - X_atoms = dict(zip(["N", "CA", "C", "O"], torch.unbind(X, 2))) + X_atoms = dict(zip(backbone4_atoms, torch.unbind(X, 2))) X_atoms["C_prev"] = F.pad(X_atoms["C"][:, 1:, :], (0, 0, 0, 1), "constant", 0) X_atoms["H"] = X_atoms["N"] + _normalize( diff --git a/prtm/models/protein_generator/modeling.py b/prtm/models/protein_generator/modeling.py index cdc91122..7f27836c 100644 --- a/prtm/models/protein_generator/modeling.py +++ b/prtm/models/protein_generator/modeling.py @@ -5,7 +5,7 @@ import numpy as np import torch from prtm import protein -from prtm.constants.residue_constants import restypes +from prtm.constants.residue_constants import PDB_CHAIN_IDS, restypes from prtm.models.protein_generator import config from prtm.models.protein_generator.rosettafold_model import RoseTTAFoldModule from prtm.models.protein_generator.sampler import SeqDiffSampler @@ -131,7 +131,7 @@ def __call__( residue_index=np.array([t[1] for t in features["pdb_idx"]]), b_factors=100 * bfactors, chain_index=np.array( - [protein.PDB_CHAIN_IDS.index(t[0]) for t in features["pdb_idx"]] + [PDB_CHAIN_IDS.index(t[0]) for t in features["pdb_idx"]] ), ) diff --git a/prtm/models/proteinmpnn/config.py b/prtm/models/proteinmpnn/config.py index 81d87054..56838a8d 100644 --- a/prtm/models/proteinmpnn/config.py +++ b/prtm/models/proteinmpnn/config.py @@ -4,7 +4,7 @@ import numpy as np import torch from prtm import protein -from prtm.constants.residue_constants import proteinmppn_restypes +from prtm.constants.residue_constants import alphabetical_restypes_x @dataclass @@ -45,12 +45,12 @@ class InferenceConfig: pssm_bias_flag: bool = False # Unbiased - bias_aas: np.ndarray = np.zeros(len(proteinmppn_restypes), dtype=np.float32) + bias_aas: np.ndarray = np.zeros(len(alphabetical_restypes_x), dtype=np.float32) omit_aas_mask: np.ndarray = field(init=False) def __post_init__(self): self.omit_aas_mask: np.ndarray = np.array( - [aa in self.omit_aas for aa in proteinmppn_restypes] + [aa in self.omit_aas for aa in alphabetical_restypes_x] ).astype(np.float32) diff --git a/prtm/models/proteinmpnn/featurizer.py b/prtm/models/proteinmpnn/featurizer.py index a6dc9f59..dfad31a6 100644 --- a/prtm/models/proteinmpnn/featurizer.py +++ b/prtm/models/proteinmpnn/featurizer.py @@ -3,7 +3,7 @@ import numpy as np import torch -from prtm.constants.residue_constants import proteinmppn_restypes +from prtm.constants.residue_constants import alphabetical_restypes_x from prtm.models.proteinmpnn.config import DesignableProtein, TiedFeaturizeOutput @@ -19,7 +19,7 @@ def get_sequence_scores(S, log_probs, mask): def decode_sequence(S, mask): return "".join( - [proteinmppn_restypes[c] for c, m in zip(S.tolist(), mask.tolist()) if m > 0] + [alphabetical_restypes_x[c] for c, m in zip(S.tolist(), mask.tolist()) if m > 0] ) diff --git a/prtm/models/proteinmpnn/model.py b/prtm/models/proteinmpnn/model.py index 8f42c41b..25180bf0 100644 --- a/prtm/models/proteinmpnn/model.py +++ b/prtm/models/proteinmpnn/model.py @@ -5,7 +5,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -from prtm.constants.residue_constants import proteinmppn_restypes from prtm.models.proteinmpnn import config # A number of functions/classes are adopted from: https://github.com/jingraham/neurips19-graph-protein-design diff --git a/prtm/models/proteinmpnn/modeling.py b/prtm/models/proteinmpnn/modeling.py index f3c131e8..c371d625 100644 --- a/prtm/models/proteinmpnn/modeling.py +++ b/prtm/models/proteinmpnn/modeling.py @@ -5,7 +5,7 @@ import numpy as np import torch from prtm import protein -from prtm.constants.residue_constants import proteinmppn_restypes +from prtm.constants.residue_constants import alphabetical_restypes_x from prtm.models.proteinmpnn import config from prtm.models.proteinmpnn.featurizer import ( decode_sequence, @@ -38,7 +38,7 @@ def _get_model_config(model_name: str) -> config.ProteinMPNNConfig: def _get_default_design_params(sequence_length: int) -> config.DesignParams: """Make default design params for a given sequence length.""" - num_aa = len(proteinmppn_restypes) + num_aa = len(alphabetical_restypes_x) design_params = config.DesignParams( design_mask=np.ones(sequence_length), design_aatype_mask=np.zeros([sequence_length, num_aa], np.int32), diff --git a/prtm/models/rfdiffusion/inference_utils.py b/prtm/models/rfdiffusion/inference_utils.py index 5f7dfa7e..8fefa1b8 100644 --- a/prtm/models/rfdiffusion/inference_utils.py +++ b/prtm/models/rfdiffusion/inference_utils.py @@ -6,6 +6,7 @@ import numpy as np import torch from prtm import protein +from prtm.constants.residue_constants import PDB_CHAIN_IDS from prtm.common_modules.rosetta.util import ComputeAllAtomCoords, rigid_from_3_points from prtm.models.rfdiffusion import config from prtm.models.rfdiffusion.diffusion import get_beta_schedule @@ -787,7 +788,7 @@ def __init__( hotspots = [] if hotspots is None else hotspots self.hotspots = np.array( [ - True if f"{protein.PDB_CHAIN_IDS[cix]}{rix}" in hotspots else False + True if f"{PDB_CHAIN_IDS[cix]}{rix}" in hotspots else False for cix, rix in zip( target_struct.chain_index, target_struct.residue_index ) @@ -830,7 +831,7 @@ def contig_crop(self, contig_crop, residue_offset=200) -> None: # add residue offset between chains if multiple chains in receptor file pdb_idx = list( zip( - [protein.PDB_CHAIN_IDS[i] for i in self.target_struct.chain_index], + [PDB_CHAIN_IDS[i] for i in self.target_struct.chain_index], self.target_struct.residue_index, ) ) diff --git a/prtm/models/rfdiffusion/modeling.py b/prtm/models/rfdiffusion/modeling.py index 629af406..64b86de8 100644 --- a/prtm/models/rfdiffusion/modeling.py +++ b/prtm/models/rfdiffusion/modeling.py @@ -5,6 +5,7 @@ import numpy as np import torch from prtm import protein +from prtm.constants.residue_constants import PDB_CHAIN_IDS from prtm.models.rfdiffusion import config from prtm.models.rfdiffusion.rosettafold_model import RoseTTAFoldModule from prtm.models.rfdiffusion.samplers import ( @@ -240,7 +241,7 @@ def __call__( residue_index=np.arange(1, len(final_seq) + 1, dtype=np.int32), b_factors=bfacts[:, None].repeat(14, axis=-1), chain_index=np.array( - [protein.PDB_CHAIN_IDS.index(char) for char in sampler.chain_idx] + [PDB_CHAIN_IDS.index(char) for char in sampler.chain_idx] ), ) return result, {} diff --git a/prtm/models/rfdiffusion/samplers.py b/prtm/models/rfdiffusion/samplers.py index 2b8982f3..ddf2afcc 100644 --- a/prtm/models/rfdiffusion/samplers.py +++ b/prtm/models/rfdiffusion/samplers.py @@ -9,6 +9,7 @@ from prtm.common_modules.rosetta.contigs import ContigMap from prtm.common_modules.rosetta.kinematics import get_init_xyz, xyz_to_t2d from prtm.common_modules.rosetta.util import ComputeAllAtomCoords +from prtm.constants.residue_constants import PDB_CHAIN_IDS from prtm.models.rfdiffusion import config from prtm.models.rfdiffusion import inference_utils as iu from prtm.models.rfdiffusion import symmetry @@ -1096,7 +1097,7 @@ def sample_init(self): pdb_idx = list( zip( [ - protein.PDB_CHAIN_IDS[i] + PDB_CHAIN_IDS[i] for i in self.target_struct.chain_index ], self.target_struct.residue_index, diff --git a/prtm/protein.py b/prtm/protein.py index ff1b523d..02d32022 100644 --- a/prtm/protein.py +++ b/prtm/protein.py @@ -20,7 +20,7 @@ import os import string from collections import namedtuple -from typing import Any, Dict, Mapping, Optional, Sequence, Union +from typing import Any, Dict, List, Mapping, Optional, Sequence, Union import modelcif import modelcif.alignment @@ -49,11 +49,7 @@ FeatureDict = Mapping[str, np.ndarray] ModelOutput = Mapping[str, Any] # Is a nested dict. -PICO_TO_ANGSTROM = 0.01 -PDB_CHAIN_IDS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" -PDB_MAX_CHAINS = len(PDB_CHAIN_IDS) -assert PDB_MAX_CHAINS == 62 """ Some notes about number of atoms: @@ -138,6 +134,10 @@ def parse_pdb_string( if parse_hetatom: hetatom_positions = [] hetatom_names = [] + hetatom_types = [] + hetatom_index = [] + hetatom_chain_ids = [] + hetatom_b_factors = [] for chain in model: if chain_id is not None and chain.id != chain_id: @@ -158,6 +158,10 @@ def parse_pdb_string( for atom in res: hetatom_positions.append(atom.coord) hetatom_names.append(res.id[0].lstrip("H_")) + hetatom_types.append(atom.fullname.strip()) + hetatom_index.append(res.id[1]) + hetatom_chain_ids.append(chain.id.strip()) + hetatom_b_factors.append(atom.bfactor) continue else: if len(res.id[0].strip()) > 0: @@ -207,9 +211,17 @@ def parse_pdb_string( if parse_hetatom: hetatom_positions = np.array(hetatom_positions) hetatom_names = np.array(hetatom_names) + hetatom_types = np.array(hetatom_types) + hetatom_index = np.array(hetatom_index) + hetatom_chain_ids = np.array(hetatom_chain_ids) + hetatom_b_factors = np.array(hetatom_b_factors) else: hetatom_positions = None hetatom_names = None + hetatom_types = None + hetatom_index = None + hetatom_chain_ids = None + hetatom_b_factors = None protein_dict = dict( atom_positions=np.array(atom_positions), @@ -222,6 +234,10 @@ def parse_pdb_string( parents_chain_index=parents_chain_index, hetatom_positions=hetatom_positions, hetatom_names=hetatom_names, + hetatom_types=hetatom_types, + hetatom_index=hetatom_index, + hetatom_chain_ids=hetatom_chain_ids, + hetatom_b_factors=hetatom_b_factors, ) return protein_dict @@ -287,6 +303,10 @@ def __init__( parents_chain_index: Optional[Sequence[int]] = None, hetatom_positions: Optional[Union[np.ndarray, torch.Tensor]] = None, hetatom_names: Optional[Union[np.ndarray, torch.Tensor]] = None, + hetatom_types: Optional[Union[np.ndarray, torch.Tensor]] = None, + hetatom_index: Optional[Union[np.ndarray, torch.Tensor]] = None, + hetatom_chain_ids: Optional[Union[np.ndarray, torch.Tensor]] = None, + hetatom_b_factors: Optional[Union[np.ndarray, torch.Tensor]] = None, remark: Optional[str] = None, ): """ @@ -322,6 +342,14 @@ def __init__( HETATM positions. Defaults to None. Expected shape is [num_hetatoms, 3]. hetatom_names (Optional[Union[np.ndarray, torch.Tensor]], optional): HETATM names. Defaults to None. Expected shape is [num_hetatoms]. + hetatom_types (Optional[Union[np.ndarray, torch.Tensor]], optional): + HETATM types. Defaults to None. Expected shape is [num_hetatoms]. + hetatom_index (Optional[Union[np.ndarray, torch.Tensor]], optional): + HETATM index as used in PDB. Defaults to None. Expected shape is [num_hetatoms]. The lowest index value is 1. + hetatom_chain_ids (Optional[Union[np.ndarray, torch.Tensor]], optional): + Chain IDs for each HETATM. Defaults to None. Expected shape is [num_hetatoms]. + hetatom_b_factors (Optional[Union[np.ndarray, torch.Tensor]], optional): + B-factors, or temperature factors, of each HETATM. Defaults to None. Expected shape is [num_hetatoms]. """ self.atom_positions = atom_positions self.aatype = aatype @@ -333,6 +361,10 @@ def __init__( self.parents_chain_index = parents_chain_index self.hetatom_positions = hetatom_positions self.hetatom_names = hetatom_names + self.hetatom_types = hetatom_types + self.hetatom_index = hetatom_index + self.hetatom_chain_ids = hetatom_chain_ids + self.hetatom_b_factors = hetatom_b_factors self.remark = remark self.fields = [ @@ -346,6 +378,10 @@ def __init__( "parents_chain_index", "hetatom_positions", "hetatom_names", + "hetatom_types", + "hetatom_index", + "hetatom_chain_ids", + "hetatom_b_factors", "remark", ] @@ -375,7 +411,9 @@ def _validate_inputs(self): if arr is not None: assert arr.shape[1] == num_atom_type - assert self.chain_index.max() < PDB_MAX_CHAINS, "Chain index must be < 62" + assert ( + self.chain_index.max() < residue_constants.PDB_MAX_CHAINS + ), "Chain index must be < 62" def to_torch(self) -> ProteinBase: """Converts a `Protein` instance to torch tensors.""" @@ -429,7 +467,7 @@ def get_pdb_headers(self, chain_id: int = 0) -> Sequence[str]: return pdb_headers - def _to_pdb_from_atom37(self) -> str: + def _to_pdb_from_atom37(self, write_hetatoms=True) -> str: """Converts this `Protein` instance to a PDB string. This is a private method because children should have an appropriate to_pdb method that makes sure there are 37 @@ -458,11 +496,11 @@ def _to_pdb_from_atom37(self) -> str: chain_ids = {} unique_fn = np.unique if isinstance(chain_index, np.ndarray) else torch.unique for i in unique_fn(chain_index): - if i >= PDB_MAX_CHAINS: + if i >= residue_constants.PDB_MAX_CHAINS: raise ValueError( - f"The PDB format supports at most {PDB_MAX_CHAINS} chains." + f"The PDB format supports at most {residue_constants.PDB_MAX_CHAINS} chains." ) - chain_ids[i] = PDB_CHAIN_IDS[i] + chain_ids[i] = residue_constants.PDB_CHAIN_IDS[i] headers = self.get_pdb_headers() if len(headers) > 0: @@ -546,13 +584,43 @@ def _to_pdb_from_atom37(self) -> str: pdb_lines.extend(self.get_pdb_headers(prev_chain_index)) pdb_lines.append("ENDMDL") + + # Add all HETATM sites. + if write_hetatoms and self.hetatom_positions is not None: + for name, pos, atom_type, res_index, chain_id, b_factor in zip( + self.hetatom_names, + self.hetatom_positions, + self.hetatom_types, + self.hetatom_index, + self.hetatom_chain_ids, + self.hetatom_b_factors, + ): + atom_index += 1 + record_type = "HETATM" + name = name if len(name) == 4 else f" {name}" + alt_loc = "" + insertion_code = "" + occupancy = 1.00 + element = atom_type[0] + atom_line = ( + f"{record_type:<6}{atom_index:>5} {atom_type:<4}{alt_loc:>1}" + # TODO: check this refactor, chose main branch version + # f"{res_name_3:>3} {chain_ids[chain_index[i]]:>1}" + f"{name:>3} {chain_id:>1}" + f"{res_index:>4}{insertion_code:>1} " + f"{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}" + f"{occupancy:>6.2f}{b_factor:>6.2f} " + f"{element:>2}{charge:>2}" + ) + pdb_lines.append(atom_line) + pdb_lines.append("END") # Pad all lines to 80 characters pdb_lines = [line.ljust(80) for line in pdb_lines] return "\n".join(pdb_lines) + "\n" # Add terminating newline. - def to_pdb(self): + def to_pdb(self, write_hetatoms=True): raise NotImplementedError def _to_modelcif_from_atom37(self) -> str: @@ -784,6 +852,10 @@ def _pad_to_n_atoms(self, n: int) -> Dict[str, np.ndarray]: remark=protein.remark, hetatom_positions=protein.hetatom_positions, hetatom_names=protein.hetatom_names, + hetatom_types=protein.hetatom_types, + hetatom_index=protein.hetatom_index, + hetatom_chain_ids=protein.hetatom_chain_ids, + hetatom_b_factors=protein.hetatom_b_factors, ) def _crop_n_atoms(self, n: int) -> Dict[str, np.ndarray]: @@ -805,6 +877,10 @@ def _crop_n_atoms(self, n: int) -> Dict[str, np.ndarray]: remark=protein.remark, hetatom_positions=protein.hetatom_positions, hetatom_names=protein.hetatom_names, + hetatom_types=protein.hetatom_types, + hetatom_index=protein.hetatom_index, + hetatom_chain_ids=protein.hetatom_chain_ids, + hetatom_b_factors=protein.hetatom_b_factors, ) def to_protein37(self) -> Protein37: @@ -843,11 +919,17 @@ def to_ca_trace(self) -> ProteinCATrace: remark=self.remark, hetatom_positions=self.hetatom_positions, hetatom_names=self.hetatom_names, + hetatom_types=self.hetatom_types, + hetatom_index=self.hetatom_index, + hetatom_chain_ids=self.hetatom_chain_ids, + hetatom_b_factors=self.hetatom_b_factors, ) def get_chain(self, chain_id: str) -> ProteinBase: - assert chain_id in PDB_CHAIN_IDS, f"Invalid chain_id: {chain_id}" - chain_index = PDB_CHAIN_IDS.index(chain_id) + assert ( + chain_id in residue_constants.PDB_CHAIN_IDS + ), f"Invalid chain_id: {chain_id}" + chain_index = residue_constants.PDB_CHAIN_IDS.index(chain_id) # Get the mask for indices in this chain chain_mask = self.chain_index == chain_index # Create a new Protein instance that only includes the @@ -864,13 +946,24 @@ def get_chain(self, chain_id: str) -> ProteinBase: remark=self.remark, hetatom_positions=self.hetatom_positions, hetatom_names=self.hetatom_names, + hetatom_types=self.hetatom_types, + hetatom_index=self.hetatom_index, + hetatom_chain_ids=self.hetatom_chain_ids, + hetatom_b_factors=self.hetatom_b_factors, ) + def get_all_chains(self) -> List[ProteinBase]: + return [self.get_chain(chain_id) for chain_id in self.chains] + def sequence(self, chain_id: Optional[str] = None) -> str: # Decode the aatype sequence to a string if chain_id is not None: - assert chain_id in PDB_CHAIN_IDS, f"Invalid chain_id: {chain_id}" - aatypes = self.aatype[self.chain_index == PDB_CHAIN_IDS.index(chain_id)] + assert ( + chain_id in residue_constants.PDB_CHAIN_IDS + ), f"Invalid chain_id: {chain_id}" + aatypes = self.aatype[ + self.chain_index == residue_constants.PDB_CHAIN_IDS.index(chain_id) + ] else: aatypes = self.aatype @@ -904,7 +997,9 @@ def shape(self): @property def chains(self): """Returns a string with all available chains.""" - return "".join(PDB_CHAIN_IDS[i] for i in np.unique(self.chain_index)) + return "".join( + residue_constants.PDB_CHAIN_IDS[i] for i in np.unique(self.chain_index) + ) def superimpose(self, other: ProteinBase) -> ProteinBase: """Superimposes another protein onto this protein.""" @@ -917,8 +1012,8 @@ class Protein37(ProteinBase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - def to_pdb(self) -> str: - return self._to_pdb_from_atom37() + def to_pdb(self, write_hetatoms=True) -> str: + return self._to_pdb_from_atom37(write_hetatoms) def to_modelcif(self) -> str: return self._to_modelcif_from_atom37() @@ -956,6 +1051,10 @@ def to_protein14(self) -> Protein14: remark=protein37.remark, hetatom_positions=protein37.hetatom_positions, hetatom_names=protein37.hetatom_names, + hetatom_types=protein37.hetatom_types, + hetatom_index=protein37.hetatom_index, + hetatom_chain_ids=protein37.hetatom_chain_ids, + hetatom_b_factors=protein37.hetatom_b_factors, ) def to_protein27(self) -> Protein27: @@ -1026,6 +1125,10 @@ def to_protein37(self) -> Protein37: remark=protein14.remark, hetatom_positions=protein14.hetatom_positions, hetatom_names=protein14.hetatom_names, + hetatom_types=protein14.hetatom_types, + hetatom_index=protein14.hetatom_index, + hetatom_chain_ids=protein14.hetatom_chain_ids, + hetatom_b_factors=protein14.hetatom_b_factors, ) def to_protein27(self) -> Protein27: @@ -1034,9 +1137,9 @@ def to_protein27(self) -> Protein27: def to_protein14(self) -> Protein14: return self - def to_pdb(self) -> str: + def to_pdb(self, write_hetatoms=True) -> str: protein37 = self.to_protein37() - return protein37.to_pdb() + return self._to_pdb_from_atom37(write_hetatoms) def to_modelcif(self) -> str: protein37 = self.to_protein37() @@ -1082,9 +1185,9 @@ def to_protein37(self) -> Protein37: def to_protein27(self) -> Protein27: return self - def to_pdb(self) -> str: + def to_pdb(self, write_hetatoms=True) -> str: protein37 = self.to_protein37() - return protein37.to_pdb() + return self._to_pdb_from_atom37(write_hetatoms) def to_modelcif(self) -> str: protein37 = self.to_protein37() @@ -1134,9 +1237,9 @@ def to_protein37(self) -> Protein37: def to_protein5(self) -> Protein5: return self - def to_pdb(self) -> str: + def to_pdb(self, write_hetatoms=True) -> str: protein37 = self.to_protein37() - return protein37.to_pdb() + return self._to_pdb_from_atom37(write_hetatoms) def to_modelcif(self) -> str: protein37 = self.to_protein37() @@ -1270,11 +1373,15 @@ def to_protein3(self): remark=protein_ca.remark, hetatom_positions=protein_ca.hetatom_positions, hetatom_names=protein_ca.hetatom_names, + hetatom_types=protein_ca.hetatom_types, + hetatom_index=protein_ca.hetatom_index, + hetatom_chain_ids=protein_ca.hetatom_chain_ids, + hetatom_b_factors=protein_ca.hetatom_b_factors, ) - def to_pdb(self) -> str: + def to_pdb(self, write_hetatoms=True) -> str: protein3 = self.to_protein3() - return protein3.to_pdb() + return protein3.to_pdb()(write_hetatoms) def to_modelcif(self) -> str: protein3 = self.to_protein3() diff --git a/prtm/visual.py b/prtm/visual.py index 0a404852..5443b0b8 100644 --- a/prtm/visual.py +++ b/prtm/visual.py @@ -57,7 +57,7 @@ def make_visualization_pdb( color_map = {i: band_colors_hex[i] for i in range(len(band_edges))} # Update the b-factors in the PDB string to be band indices - to_viz_pdb = overwrite_b_factors(structure.to_pdb(), banded_b_factors) + to_viz_pdb = overwrite_b_factors(structure.to_pdb(write_hetatoms=False), banded_b_factors) return to_viz_pdb, color_map