Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 89 additions & 1 deletion src/spac/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,15 @@
import multiprocessing
import parmap
from spac.utag_functions import utag
<<<<<<< HEAD
from anndata import AnnData
from spac.utils import compute_summary_qc_stats
from typing import List, Optional
=======
import plotly.express as px
from sklearn.metrics.cluster import adjusted_rand_score
from sklearn.metrics.cluster import normalized_mutual_info_score
>>>>>>> ramya

# Configure logging
logging.basicConfig(level=logging.INFO,
Expand Down Expand Up @@ -1290,6 +1296,7 @@ def run_utag_clustering(
adata.obs[output_annotation] = cluster_list.copy()
adata.uns["utag_features"] = features

<<<<<<< HEAD
# add QC metrics to AnnData object
def add_qc_metrics(adata,
organism="hs",
Expand Down Expand Up @@ -1464,4 +1471,85 @@ def get_qc_summary_table(
summary_table = pd.concat(stat_dfs, ignore_index=True)
# Reset index and store in adata.uns
summary_table = summary_table.reset_index(drop=True)
adata.uns["qc_summary_table"] = summary_table
adata.uns["qc_summary_table"] = summary_table
=======
def compare_annotations(adata, annotation_list, metric="adjusted_rand_score"):
"""
Create matrix storing metric information with every combination of annotations given
(which can later be used to create a heatmap)
Metric information typically is similarity measure between different clusterings

The function will add these two attributes to `adata`:
`.uns["compare_annotations"]`
The matrix (numpy array) that stores measure of mutual information between clusters
Each pair of annotations provided will have a value

`.uns["compare_annotations_list"]`
The annotations used to calculate and create the matrix that stores

Parameters
----------
adata : anndata.AnnData
The AnnData object.
annotation_list : list of str
List of names of (existing) annotations
metric : str
Metric type for calculations to be used
Should be adjusted_rand_score or normalized_mutual_info_score
Default = "adjusted_rand_score"

Returns
----------
fig : plt.figure
it returns a heatmap of the return matrix

"""

#Input Validation - there should be more than 1 annotation in order to compare annotations
if len(annotation_list) < 2:
raise ValueError("annotation_list must contain at least 2 annotations")

#Input Validation - make sure every annotation listed exists
for ann in annotation_list:
check_annotation(adata, annotations=ann)

#2D array that will contain each computed score
matrix = []

#If metric = "adjusted_rand_score", compute the metric using adjusted_rand_score
if metric == "adjusted_rand_score":
for annotation1 in annotation_list:
scores = []
for annotation2 in annotation_list:
scores.append(adjusted_rand_score(adata.obs[annotation1], adata.obs[annotation2]))
matrix.append(scores)

#If metric = "normalized_mutual_info_score", compute the metric using normalized_mutual_info_score
elif metric == "normalized_mutual_info_score":
for annotation1 in annotation_list:
scores = []
for annotation2 in annotation_list:
scores.append(normalized_mutual_info_score(adata.obs[annotation1], adata.obs[annotation2]))
matrix.append(scores)

else:
raise ValueError("metric should be 'adjusted_rand_score' or 'normalized_mutual_info_score'")

#Convert 2D list containing scores to numpy array
matrix_final = np.array(matrix)

# creates a heatmap that corresponds to the final correlational matrix
fig = px.imshow(matrix_final,
labels=dict(x="Clusters", y="Clusters", color="Value"),
x=annotation_list,
y=annotation_list)
fig.update_xaxes(side="top")

#Store output in AnnData
adata.uns["compare_annotations"] = matrix_final

#Store list of annotations used in Anndata
adata.uns["compare_annotations_list"] = annotation_list

return fig
>>>>>>> ramya
81 changes: 81 additions & 0 deletions tests/test_transformations/test_compare_annotations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import unittest
import numpy as np
import pandas as pd
from anndata import AnnData
from spac.transformations import compare_annotations
from sklearn.metrics.cluster import adjusted_rand_score
from sklearn.metrics.cluster import normalized_mutual_info_score
from plotly.graph_objs import Figure

class TestCompareAnnotations(unittest.TestCase):

def setUp(self):
self.obs = pd.DataFrame({
'cluster1': [0, 1, 0, 1],
'cluster2': [1, 0, 1, 0],
'cluster3': [0, 0, 1, 2]
})
self.adata = AnnData(X=np.random.rand(4, 5))
self.adata.obs = self.obs.copy()

def test_annotation_length(self):
# check if there are at least two annotations
with self.assertRaises(ValueError):
compare_annotations(self.adata, annotation_list = ['cluster1'])

def test_annotation_exists(self):
# check if the annotations exist in the given list
with self.assertRaises(ValueError):
compare_annotations(self.adata, annotation_list = ['cluster1', 'cluster4'])

def test_adjusted_rand_score(self):
# check the output of adjusted rand score
expected = adjusted_rand_score(self.adata.obs['cluster1'], self.adata.obs['cluster2'])
compare_annotations(self.adata, annotation_list = ['cluster1', 'cluster2'])

# check adjusted rand score of comparing cluster1 and cluster2
self.assertEqual(self.adata.uns["compare_annotations"][0,1], expected)
self.assertEqual(self.adata.uns["compare_annotations"][1,0], expected)

# check adjusted rand score on matrix diagonal where cluster is compared to itself
self.assertEqual(self.adata.uns["compare_annotations"][0,0], 1.0)
self.assertEqual(self.adata.uns["compare_annotations"][1,1], 1.0)

def test_normalized_mutual_info_score(self):
# check the output of normalized mutual info score
expected = normalized_mutual_info_score(self.adata.obs['cluster1'], self.adata.obs['cluster3'])
compare_annotations(self.adata, annotation_list=['cluster1', 'cluster3'], metric="normalized_mutual_info_score")

# check normalized mutual info score of comparing cluster1 and cluster3
self.assertEqual(self.adata.uns["compare_annotations"][0,1], expected)
self.assertEqual(self.adata.uns["compare_annotations"][1,0], expected)

# check normalized mutual info score on matrix diagonal where cluster is compared to itself
self.assertEqual(self.adata.uns["compare_annotations"][0,0], 1.0)
self.assertEqual(self.adata.uns["compare_annotations"][1,1], 1.0)

def test_metric_error(self):
# check if the given metric doesn't exist
with self.assertRaises(ValueError):
compare_annotations(self.adata, annotation_list = ['cluster1', 'cluster2'], metric = 'invalid')

def test_heatmap_creation(self):
# check if a proper figure is created
fig = compare_annotations(self.adata, annotation_list = ['cluster1', 'cluster2', 'cluster3'])

# Check function returns a Figure object
self.assertIsInstance(fig, Figure)

# Check that the trace is a heatmap
self.assertEqual(fig.data[0].type, "heatmap")

# Check the axis labels
x_labels = list(fig.data[0].x)
y_labels = list(fig.data[0].y)
self.assertEqual(x_labels, self.adata.uns["compare_annotations_list"])
self.assertEqual(y_labels, self.adata.uns["compare_annotations_list"])



if __name__ == "__main__":
unittest.main()