diff --git a/src/spac/transformations.py b/src/spac/transformations.py index 228c55f1..b0f01ea9 100644 --- a/src/spac/transformations.py +++ b/src/spac/transformations.py @@ -16,9 +16,15 @@ import multiprocessing import parmap from spac.utag_functions import utag +<<<<<<< HEAD from anndata import AnnData from spac.utils import compute_summary_qc_stats from typing import List, Optional +======= +import plotly.express as px +from sklearn.metrics.cluster import adjusted_rand_score +from sklearn.metrics.cluster import normalized_mutual_info_score +>>>>>>> ramya # Configure logging logging.basicConfig(level=logging.INFO, @@ -1290,6 +1296,7 @@ def run_utag_clustering( adata.obs[output_annotation] = cluster_list.copy() adata.uns["utag_features"] = features +<<<<<<< HEAD # add QC metrics to AnnData object def add_qc_metrics(adata, organism="hs", @@ -1464,4 +1471,85 @@ def get_qc_summary_table( summary_table = pd.concat(stat_dfs, ignore_index=True) # Reset index and store in adata.uns summary_table = summary_table.reset_index(drop=True) - adata.uns["qc_summary_table"] = summary_table \ No newline at end of file + adata.uns["qc_summary_table"] = summary_table +======= +def compare_annotations(adata, annotation_list, metric="adjusted_rand_score"): + """ + Create matrix storing metric information with every combination of annotations given + (which can later be used to create a heatmap) + Metric information typically is similarity measure between different clusterings + + The function will add these two attributes to `adata`: + `.uns["compare_annotations"]` + The matrix (numpy array) that stores measure of mutual information between clusters + Each pair of annotations provided will have a value + + `.uns["compare_annotations_list"]` + The annotations used to calculate and create the matrix that stores + + Parameters + ---------- + adata : anndata.AnnData + The AnnData object. + annotation_list : list of str + List of names of (existing) annotations + metric : str + Metric type for calculations to be used + Should be adjusted_rand_score or normalized_mutual_info_score + Default = "adjusted_rand_score" + + Returns + ---------- + fig : plt.figure + it returns a heatmap of the return matrix + + """ + + #Input Validation - there should be more than 1 annotation in order to compare annotations + if len(annotation_list) < 2: + raise ValueError("annotation_list must contain at least 2 annotations") + + #Input Validation - make sure every annotation listed exists + for ann in annotation_list: + check_annotation(adata, annotations=ann) + + #2D array that will contain each computed score + matrix = [] + + #If metric = "adjusted_rand_score", compute the metric using adjusted_rand_score + if metric == "adjusted_rand_score": + for annotation1 in annotation_list: + scores = [] + for annotation2 in annotation_list: + scores.append(adjusted_rand_score(adata.obs[annotation1], adata.obs[annotation2])) + matrix.append(scores) + + #If metric = "normalized_mutual_info_score", compute the metric using normalized_mutual_info_score + elif metric == "normalized_mutual_info_score": + for annotation1 in annotation_list: + scores = [] + for annotation2 in annotation_list: + scores.append(normalized_mutual_info_score(adata.obs[annotation1], adata.obs[annotation2])) + matrix.append(scores) + + else: + raise ValueError("metric should be 'adjusted_rand_score' or 'normalized_mutual_info_score'") + + #Convert 2D list containing scores to numpy array + matrix_final = np.array(matrix) + + # creates a heatmap that corresponds to the final correlational matrix + fig = px.imshow(matrix_final, + labels=dict(x="Clusters", y="Clusters", color="Value"), + x=annotation_list, + y=annotation_list) + fig.update_xaxes(side="top") + + #Store output in AnnData + adata.uns["compare_annotations"] = matrix_final + + #Store list of annotations used in Anndata + adata.uns["compare_annotations_list"] = annotation_list + + return fig +>>>>>>> ramya diff --git a/tests/test_transformations/test_compare_annotations.py b/tests/test_transformations/test_compare_annotations.py new file mode 100644 index 00000000..e880f8c9 --- /dev/null +++ b/tests/test_transformations/test_compare_annotations.py @@ -0,0 +1,81 @@ +import unittest +import numpy as np +import pandas as pd +from anndata import AnnData +from spac.transformations import compare_annotations +from sklearn.metrics.cluster import adjusted_rand_score +from sklearn.metrics.cluster import normalized_mutual_info_score +from plotly.graph_objs import Figure + +class TestCompareAnnotations(unittest.TestCase): + + def setUp(self): + self.obs = pd.DataFrame({ + 'cluster1': [0, 1, 0, 1], + 'cluster2': [1, 0, 1, 0], + 'cluster3': [0, 0, 1, 2] + }) + self.adata = AnnData(X=np.random.rand(4, 5)) + self.adata.obs = self.obs.copy() + + def test_annotation_length(self): + # check if there are at least two annotations + with self.assertRaises(ValueError): + compare_annotations(self.adata, annotation_list = ['cluster1']) + + def test_annotation_exists(self): + # check if the annotations exist in the given list + with self.assertRaises(ValueError): + compare_annotations(self.adata, annotation_list = ['cluster1', 'cluster4']) + + def test_adjusted_rand_score(self): + # check the output of adjusted rand score + expected = adjusted_rand_score(self.adata.obs['cluster1'], self.adata.obs['cluster2']) + compare_annotations(self.adata, annotation_list = ['cluster1', 'cluster2']) + + # check adjusted rand score of comparing cluster1 and cluster2 + self.assertEqual(self.adata.uns["compare_annotations"][0,1], expected) + self.assertEqual(self.adata.uns["compare_annotations"][1,0], expected) + + # check adjusted rand score on matrix diagonal where cluster is compared to itself + self.assertEqual(self.adata.uns["compare_annotations"][0,0], 1.0) + self.assertEqual(self.adata.uns["compare_annotations"][1,1], 1.0) + + def test_normalized_mutual_info_score(self): + # check the output of normalized mutual info score + expected = normalized_mutual_info_score(self.adata.obs['cluster1'], self.adata.obs['cluster3']) + compare_annotations(self.adata, annotation_list=['cluster1', 'cluster3'], metric="normalized_mutual_info_score") + + # check normalized mutual info score of comparing cluster1 and cluster3 + self.assertEqual(self.adata.uns["compare_annotations"][0,1], expected) + self.assertEqual(self.adata.uns["compare_annotations"][1,0], expected) + + # check normalized mutual info score on matrix diagonal where cluster is compared to itself + self.assertEqual(self.adata.uns["compare_annotations"][0,0], 1.0) + self.assertEqual(self.adata.uns["compare_annotations"][1,1], 1.0) + + def test_metric_error(self): + # check if the given metric doesn't exist + with self.assertRaises(ValueError): + compare_annotations(self.adata, annotation_list = ['cluster1', 'cluster2'], metric = 'invalid') + + def test_heatmap_creation(self): + # check if a proper figure is created + fig = compare_annotations(self.adata, annotation_list = ['cluster1', 'cluster2', 'cluster3']) + + # Check function returns a Figure object + self.assertIsInstance(fig, Figure) + + # Check that the trace is a heatmap + self.assertEqual(fig.data[0].type, "heatmap") + + # Check the axis labels + x_labels = list(fig.data[0].x) + y_labels = list(fig.data[0].y) + self.assertEqual(x_labels, self.adata.uns["compare_annotations_list"]) + self.assertEqual(y_labels, self.adata.uns["compare_annotations_list"]) + + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file