Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 14 additions & 15 deletions docs/usage/embedding_different_models.ipynb

Large diffs are not rendered by default.

425 changes: 425 additions & 0 deletions model_diagram.json

Large diffs are not rendered by default.

3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ shapely = "^2.0.6"
torch = "^2.4.1"
transformers = "^4.45.2"
scikit-learn = "^1.5.2"
numpy = ">=1.14.5,<2.0"
openai = "^1.52.2"
esm = "^3.1.3"
rdflib = "^6.0.0"
Expand All @@ -42,7 +41,7 @@ pysam = "0.23.0"
types-requests = "2.32.0.20250328"
ipywidgets = "^8.1.7"
sentencepiece = "^0.2.0"
umap = "^0.1.1"
umap-learn = "^0.5.7"

[tool.poetry.group.dev.dependencies]
mkdocstrings = {extras = ["python"], version = "^0.26.2"}
Expand Down
24 changes: 11 additions & 13 deletions src/pyeed/analysis/mutation_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def get_sequence_data(
db: DatabaseConnector,
standard_numbering_tool_name: str,
node_type: str = "Protein",
region_ids_neo4j: Optional[list[int]] = None,
region_ids_neo4j: Optional[list[str]] = None,
) -> tuple[dict[str, str], dict[str, list[str]]]:
"""
Fetch sequence and standard numbering position data for two sequences from the database.
Expand All @@ -39,7 +39,7 @@ def get_sequence_data(
if region_ids_neo4j is not None:
query = f"""
MATCH (p:{node_type})-[rel:HAS_REGION]->(r:Region)
WHERE id(r) IN $region_ids_neo4j
WHERE elementId(r) IN $region_ids_neo4j
MATCH (r)-[rel2:HAS_STANDARD_NUMBERING]->(s:StandardNumbering)
WHERE p.accession_id IN ['{sequence_id1}', '{sequence_id2}']
AND s.name = '{standard_numbering_tool_name}'
Expand Down Expand Up @@ -134,7 +134,7 @@ def save_mutations_to_db(
sequence_id1: str,
sequence_id2: str,
node_type: str = "Protein",
region_ids_neo4j: Optional[list[int]] = None,
region_ids_neo4j: Optional[list[str]] = None,
) -> None:
"""
Save detected mutations to the database as relationships between nodes.
Expand All @@ -155,9 +155,9 @@ def save_mutations_to_db(
if region_ids_neo4j is not None:
query = f"""
MATCH (p1:{node_type} {{accession_id: $sequence_id1}})-[rel:HAS_REGION]->(r1:Region)
WHERE id(r1) IN $region_ids_neo4j
MATCH (r1)-[rel_mutation:MUTATION]->(r2:Region)
WHERE id(r2) IN $region_ids_neo4j
WHERE elementId(r1) IN $region_ids_neo4j
MATCH (r1)-[rel_mutation:MUTATION]-(r2:Region)
WHERE elementId(r2) IN $region_ids_neo4j
MATCH (r2)<-[:HAS_REGION]-(p2:{node_type} {{accession_id: $sequence_id2}})
RETURN rel_mutation
"""
Expand All @@ -172,7 +172,7 @@ def save_mutations_to_db(
else:
existing_mutations = db.execute_read(
f"""
MATCH (p1:{node_type})-[r:MUTATION]->(p2:{node_type})
MATCH (p1:{node_type})-[r:MUTATION]-(p2:{node_type})
WHERE p1.accession_id = $sequence_id1 AND p2.accession_id = $sequence_id2
RETURN r
""",
Expand All @@ -188,10 +188,10 @@ def save_mutations_to_db(
# saving the mutation between the regions
query = f"""
MATCH (r1:Region)
WHERE id(r1) IN $region_ids_neo4j
WHERE elementId(r1) IN $region_ids_neo4j
MATCH (r1)<-[:HAS_REGION]-(p1:{node_type} {{accession_id: $sequence_id1}})
MATCH (r2:Region)
WHERE id(r2) IN $region_ids_neo4j
WHERE elementId(r2) IN $region_ids_neo4j
MATCH (r2)<-[:HAS_REGION]-(p2:{node_type} {{accession_id: $sequence_id2}})
CREATE (r1)-[r:MUTATION]->(r2)
SET r.from_positions = $from_positions,
Expand Down Expand Up @@ -230,7 +230,7 @@ def save_mutations_to_db(
db.execute_write(query, params)

logger.debug(
f"Saved {len(list(params['from_positions']))} mutations to database"
f"Saved {len(list(params['from_positions']))} mutations to database between {sequence_id1} and {sequence_id2}"
)

def get_mutations_between_sequences(
Expand All @@ -241,7 +241,7 @@ def get_mutations_between_sequences(
standard_numbering_tool_name: str,
save_to_db: bool = True,
node_type: str = "Protein",
region_ids_neo4j: Optional[list[int]] = None,
region_ids_neo4j: Optional[list[str]] = None,
) -> dict[str, list[int | str]]:
"""
Get mutations between two sequences using standard numbering and optionally save them to the database.
Expand Down Expand Up @@ -274,8 +274,6 @@ def get_mutations_between_sequences(
region_ids_neo4j,
)

logger.debug(f"Debug mode output: {sequences} and {positions}")

mutations = self.find_mutations(
sequences[sequence_id1],
sequences[sequence_id2],
Expand Down
56 changes: 27 additions & 29 deletions src/pyeed/analysis/ontology_loading.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict
from typing import Any, Dict

from pyeed.dbconnect import DatabaseConnector
from rdflib import OWL, RDF, RDFS, Graph, Namespace, URIRef
Expand Down Expand Up @@ -105,26 +105,28 @@ def _process_relationships(
for s, p, o in g.triples((None, RDFS.subClassOf, None)):
subclass = str(s)

if (o, RDF.type, OWL.Class) in g:
# Handle direct subclass relationships
superclass = str(o)
db.execute_write(
"""
MATCH (sub:OntologyObject {name: $subclass}),
(super:OntologyObject {name: $superclass})
CREATE (sub)-[:SUBCLASS_OF]->(super)
""",
parameters={"subclass": subclass, "superclass": superclass},
)

elif (o, RDF.type, OWL.Restriction) in g:
# Check for the more specific owl:Restriction first.
if (o, RDF.type, OWL.Restriction) in g:
# Handle OWL restrictions (e.g., RO_ in CARD)
self._process_restriction(g, str(o), subclass, db, dicts_labels)
self._process_restriction(g, o, subclass, db, dicts_labels)
# Only if it's not a restriction, check if it's a direct superclass.
elif (o, RDF.type, OWL.Class) in g:
# Ensure we are linking to a named class, not a blank node
if isinstance(o, URIRef):
superclass = str(o)
db.execute_write(
"""
MATCH (sub:OntologyObject {name: $subclass}),
(super:OntologyObject {name: $superclass})
CREATE (sub)-[:SUBCLASS_OF]->(super)
""",
parameters={"subclass": subclass, "superclass": superclass},
)

def _process_restriction(
self,
g: Graph,
restriction_node: str,
restriction_node: Any,
subclass: str,
db: DatabaseConnector,
dicts_labels: Dict[str, str],
Expand All @@ -133,32 +135,28 @@ def _process_restriction(
on_property = None
some_values_from = None

# Convert restriction_node string to RDFLib URIRef
restriction = URIRef(restriction_node)

# Extract onProperty
for _, _, prop in g.triples((restriction, OWL.onProperty, None)):
for _, _, prop in g.triples((restriction_node, OWL.onProperty, None)):
on_property = str(prop)

# Extract someValuesFrom
for _, _, value in g.triples((restriction, OWL.someValuesFrom, None)):
for _, _, value in g.triples((restriction_node, OWL.someValuesFrom, None)):
some_values_from = str(value)

if on_property and some_values_from:
rel_type = dicts_labels.get(on_property, "RELATED_TO")
rel_type = rel_type.replace(" ", "_").replace("-", "_").upper()

query_params = {
"subclass": subclass,
"some_values_from": some_values_from,
"on_property": on_property,
"description": dicts_labels.get(on_property, ""),
}

query = """
MATCH (sub:OntologyObject {name: $subclass}),
(super:OntologyObject {name: $some_values_from})
CREATE (sub)-[:CustomRelationship {
name: $on_property,
description: $description
}]->(super)
query = f"""
MATCH (sub:OntologyObject {{name: $subclass}}),
(super:OntologyObject {{name: $some_values_from}})
CREATE (sub)-[:`{rel_type}` {{uri: $on_property}}]->(super)
"""

db.execute_write(query, parameters=query_params)
31 changes: 23 additions & 8 deletions src/pyeed/analysis/sequence_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from Bio.Align import PairwiseAligner as BioPairwiseAligner
from Bio.Align.substitution_matrices import Array as BioSubstitutionMatrix
from joblib import Parallel, cpu_count, delayed
from loguru import logger
from pyeed.dbconnect import DatabaseConnector
from pyeed.tools.utility import chunks
from rich.progress import Progress
Expand Down Expand Up @@ -157,9 +158,11 @@ def align_multipairwise(
pair for pair in pairs if tuple(sorted(pair)) not in existing_pairs
]

print(f"Number of existing pairs: {len(existing_pairs)}")
print(f"Number of total pairs: {len(pairs)}")
print(f"Number of pairs to align: {len(new_pairs)}")
logger.info(f"Number of existing pairs: {len(existing_pairs)}")
logger.info(f"Number of total pairs: {len(pairs)}")
logger.info(f"Number of pairs to align: {len(new_pairs)}")

logger.info(f"Length of sequences: {len(sequences)}")

with Progress() as progress:
align_task = progress.add_task(
Expand Down Expand Up @@ -226,7 +229,7 @@ def _to_db(
UNWIND $alignments AS alignment
MATCH (p1:{node_type} {{accession_id: alignment.query_id}})-[rel1:HAS_REGION]->(r1:Region)
MATCH (p2:{node_type} {{accession_id: alignment.target_id}})-[rel2:HAS_REGION]->(r2:Region)
WHERE id(r1) IN $region_ids_neo4j AND id(r2) IN $region_ids_neo4j
WHERE elementId(r1) IN $region_ids_neo4j AND elementId(r2) IN $region_ids_neo4j
MERGE (r1)-[r:PAIRWISE_ALIGNED]->(r2)
SET r.similarity = alignment.identity,
r.mismatches = alignment.mismatches,
Expand Down Expand Up @@ -313,14 +316,20 @@ def _get_id_sequence_dict(
if ids != []:
if region_ids_neo4j is not None:
query = f"""
MATCH (p:{node_type})-[e:HAS_REGION]->(r:Region)
WHERE id(r) IN $region_ids_neo4j AND p.accession_id IN $ids
MATCH (p:{node_type})-[e:HAS_REGION]-(r:Region)
WHERE elementId(r) IN $region_ids_neo4j AND p.accession_id IN $ids
RETURN p.accession_id AS accession_id, e.start AS start, e.end AS end, p.sequence AS sequence
"""
nodes = db.execute_read(
query,
parameters={"region_ids_neo4j": region_ids_neo4j, "ids": ids},
)
logger.info(f" Full query: {query}")
logger.info(f"The ids are: {ids}")
logger.info(f"The region ids are: {region_ids_neo4j}")
logger.info(
f"Length of nodes (run query of type both region and ids): {len(nodes)}"
)
else:
query = f"""
MATCH (p:{node_type})
Expand All @@ -329,26 +338,32 @@ def _get_id_sequence_dict(
"""
nodes = db.execute_read(query, parameters={"ids": ids})

logger.info(f"Length of nodes (run query of type ids): {len(nodes)}")

else:
if region_ids_neo4j is not None:
query = f"""
MATCH (p:{node_type})-[e:HAS_REGION]->(r:Region)
WHERE id(r) IN $region_ids_neo4j
WHERE elementId(r) IN $region_ids_neo4j
RETURN p.accession_id AS accession_id, e.start AS start, e.end AS end, p.sequence AS sequence
"""
nodes = db.execute_read(
query,
parameters={
"region_ids_neo4j": region_ids_neo4j,
},
)
) #

logger.info(f"Length of nodes (run query of type region): {len(nodes)}")
else:
query = f"""
MATCH (p:{node_type})
RETURN p.accession_id AS accession_id, p.sequence AS sequence
"""
nodes = db.execute_read(query)

logger.info(f"Length of nodes (run query of type): {len(nodes)}")

if region_ids_neo4j is not None:
return {
node["accession_id"]: node["sequence"][node["start"] : node["end"]]
Expand Down
21 changes: 13 additions & 8 deletions src/pyeed/analysis/standard_numbering.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def get_node_base_sequence(
if region_ids_neo4j:
query = f"""
MATCH (p:{node_type})-[e:HAS_REGION]->(r:Region)
WHERE id(r) IN $region_ids_neo4j
WHERE elementId(r) IN $region_ids_neo4j
WHERE p.accession_id = '{base_sequence_id}'
RETURN p.accession_id AS accession_id, e.start AS start, e.end AS end, p.sequence AS sequence
"""
Expand Down Expand Up @@ -113,7 +113,7 @@ def save_positions(
if region_ids_neo4j:
query = f"""
MATCH (p:{node_type} {{accession_id: '{protein_id}'}})-[e:HAS_REGION]->(r:Region)
WHERE id(r) IN $region_ids_neo4j
WHERE elementId(r) IN $region_ids_neo4j
MATCH (s:StandardNumbering {{name: '{self.name}'}})
MERGE (r)-[rel:HAS_STANDARD_NUMBERING]->(s)
SET rel.positions = {str(positions[protein_id])}
Expand Down Expand Up @@ -402,7 +402,7 @@ def apply_standard_numbering_pairwise(
query = """
MATCH (s:StandardNumbering {name: $name})
MATCH (d:DNA)-[e:HAS_REGION]-(r:Region)-[:HAS_STANDARD_NUMBERING]-(s)
WHERE id(r) IN $region_ids_neo4j
WHERE elementId(r) IN $region_ids_neo4j
AND d.accession_id IN $list_of_seq_ids
RETURN d.accession_id AS accession_id
"""
Expand Down Expand Up @@ -431,29 +431,34 @@ def apply_standard_numbering_pairwise(
for row in results:
if row is not None:
if row.get("accession_id"):
pairs.remove((base_sequence_id, row["accession_id"]))
logger.info(
f"Pair {base_sequence_id} and {row['accession_id']} already exists under the same standard numbering node"
f"Pair {base_sequence_id} and {row['accession_id']} already exists under the same standard numbering node \n Removing x from the list: {(base_sequence_id, row['accession_id'])}"
)
pairs.remove((base_sequence_id, row["accession_id"]))
break

# remove double pairs in the list of pairs
pairs = list(set(pairs))
logger.info(f"Pairs: {pairs}")

# Run the pairwise alignment using the PairwiseAligner.
pairwise_aligner = PairwiseAligner(node_type=node_type)
input = (list_of_seq_ids or []) + [base_sequence_id]
input = list_of_seq_ids + [base_sequence_id]
if not input:
raise ValueError("No input sequences provided")

logger.info(f"Input: {input}")
logger.info(f"Input: {input} with length of {len(input)}")
logger.info(
f"Length of region ids: {len(region_ids_neo4j) if region_ids_neo4j else 0}"
)

results_pairwise = pairwise_aligner.align_multipairwise(
ids=input, # Combine ids for alignment
db=db,
pairs=pairs, # List of sequence pairs to be aligned
node_type=node_type,
region_ids_neo4j=region_ids_neo4j,
num_cores=1,
)

# logger.info(f"Pairwise alignment results: {results_pairwise}")
Expand Down Expand Up @@ -551,7 +556,7 @@ def apply_standard_numbering(
# get the region objects for each of the nodes as well
query = f"""
MATCH (p:{node_type})-[e:HAS_REGION]->(r:Region)
WHERE id(r) IN $region_ids_neo4j
WHERE elementId(r) IN $region_ids_neo4j
WHERE p.accession_id IN $list_of_seq_ids
RETURN p.accession_id AS accession_id, e.start AS start, e.end AS end, p.sequence AS sequence
"""
Expand Down
Loading
Loading