From 97695292dd05b78e828627b8f95cb6f41380069e Mon Sep 17 00:00:00 2001 From: Arshnoor Randhawa Date: Thu, 24 Apr 2025 17:05:48 -0400 Subject: [PATCH 1/8] feat(core): added compare_annotations function and unit tests --- src/spac/transformations.py | 83 +++++++++++++++++++ .../test_compare_annotations.py | 81 ++++++++++++++++++ 2 files changed, 164 insertions(+) create mode 100644 tests/test_transformations/test_compare_annotations.py diff --git a/src/spac/transformations.py b/src/spac/transformations.py index f497d402..e5c319b4 100644 --- a/src/spac/transformations.py +++ b/src/spac/transformations.py @@ -14,6 +14,9 @@ import multiprocessing import parmap from spac.utag_functions import utag +import plotly.express as px +from sklearn.metrics.cluster import adjusted_rand_score +from sklearn.metrics.cluster import normalized_mutual_info_score # Configure logging logging.basicConfig(level=logging.INFO, @@ -1164,3 +1167,83 @@ def run_utag_clustering( cluster_list = utag_results.obs[curClusterCol].copy() adata.obs[output_annotation] = cluster_list.copy() adata.uns["utag_features"] = features + +def compare_annotations(adata, annotation_list, metric="adjusted_rand_score"): + """ + Create matrix storing metric information with every combination of annotations given + (which can later be used to create a heatmap) + Metric information typically is similarity measure between different clusterings + + The function will add these two attributes to `adata`: + `.uns["compare_annotations"]` + The matrix (numpy array) that stores measure of mutual information between clusters + Each pair of annotations provided will have a value + + `.uns["compare_annotations_list"]` + The annotations used to calculate and create the matrix that stores + + Parameters + ---------- + adata : anndata.AnnData + The AnnData object. + annotation_list : list of str + List of names of (existing) annotations + metric : str + Metric type for calculations to be used + Should be adjusted_rand_score or normalized_mutual_info_score + Default = "adjusted_rand_score" + + Returns + ---------- + fig : plt.figure + it returns a heatmap of the return matrix + + """ + + #Input Validation - there should be more than 1 annotation in order to compare annotations + if len(annotation_list) < 2: + raise ValueError("annotation_list must contain at least 2 annotations") + + #Input Validation - make sure every annotation listed exists + for ann in annotation_list: + check_annotation(adata, annotations=ann) + + #2D array that will contain each computed score + matrix = [] + + #If metric = "adjusted_rand_score", compute the metric using adjusted_rand_score + if metric == "adjusted_rand_score": + for annotation1 in annotation_list: + scores = [] + for annotation2 in annotation_list: + scores.append(adjusted_rand_score(adata.obs[annotation1], adata.obs[annotation2])) + matrix.append(scores) + + #If metric = "normalized_mutual_info_score", compute the metric using normalized_mutual_info_score + elif metric == "normalized_mutual_info_score": + for annotation1 in annotation_list: + scores = [] + for annotation2 in annotation_list: + scores.append(normalized_mutual_info_score(adata.obs[annotation1], adata.obs[annotation2])) + matrix.append(scores) + + else: + raise ValueError("metric should be 'adjusted_rand_score' or 'normalized_mutual_info_score'") + + #Convert 2D list containing scores to numpy array + matrix_final = np.array(matrix) + + # creates a heatmap that corresponds to the final correlational matrix + fig = px.imshow(matrix_final, + labels=dict(x="Clusters", y="Clusters", color="Value"), + x=annotation_list, + y=annotation_list) + fig.update_xaxes(side="top") + + #Store output in AnnData + adata.uns["compare_annotations"] = matrix_final + + #Store list of annotations used in Anndata + adata.uns["compare_annotations_list"] = annotation_list + + return fig diff --git a/tests/test_transformations/test_compare_annotations.py b/tests/test_transformations/test_compare_annotations.py new file mode 100644 index 00000000..e880f8c9 --- /dev/null +++ b/tests/test_transformations/test_compare_annotations.py @@ -0,0 +1,81 @@ +import unittest +import numpy as np +import pandas as pd +from anndata import AnnData +from spac.transformations import compare_annotations +from sklearn.metrics.cluster import adjusted_rand_score +from sklearn.metrics.cluster import normalized_mutual_info_score +from plotly.graph_objs import Figure + +class TestCompareAnnotations(unittest.TestCase): + + def setUp(self): + self.obs = pd.DataFrame({ + 'cluster1': [0, 1, 0, 1], + 'cluster2': [1, 0, 1, 0], + 'cluster3': [0, 0, 1, 2] + }) + self.adata = AnnData(X=np.random.rand(4, 5)) + self.adata.obs = self.obs.copy() + + def test_annotation_length(self): + # check if there are at least two annotations + with self.assertRaises(ValueError): + compare_annotations(self.adata, annotation_list = ['cluster1']) + + def test_annotation_exists(self): + # check if the annotations exist in the given list + with self.assertRaises(ValueError): + compare_annotations(self.adata, annotation_list = ['cluster1', 'cluster4']) + + def test_adjusted_rand_score(self): + # check the output of adjusted rand score + expected = adjusted_rand_score(self.adata.obs['cluster1'], self.adata.obs['cluster2']) + compare_annotations(self.adata, annotation_list = ['cluster1', 'cluster2']) + + # check adjusted rand score of comparing cluster1 and cluster2 + self.assertEqual(self.adata.uns["compare_annotations"][0,1], expected) + self.assertEqual(self.adata.uns["compare_annotations"][1,0], expected) + + # check adjusted rand score on matrix diagonal where cluster is compared to itself + self.assertEqual(self.adata.uns["compare_annotations"][0,0], 1.0) + self.assertEqual(self.adata.uns["compare_annotations"][1,1], 1.0) + + def test_normalized_mutual_info_score(self): + # check the output of normalized mutual info score + expected = normalized_mutual_info_score(self.adata.obs['cluster1'], self.adata.obs['cluster3']) + compare_annotations(self.adata, annotation_list=['cluster1', 'cluster3'], metric="normalized_mutual_info_score") + + # check normalized mutual info score of comparing cluster1 and cluster3 + self.assertEqual(self.adata.uns["compare_annotations"][0,1], expected) + self.assertEqual(self.adata.uns["compare_annotations"][1,0], expected) + + # check normalized mutual info score on matrix diagonal where cluster is compared to itself + self.assertEqual(self.adata.uns["compare_annotations"][0,0], 1.0) + self.assertEqual(self.adata.uns["compare_annotations"][1,1], 1.0) + + def test_metric_error(self): + # check if the given metric doesn't exist + with self.assertRaises(ValueError): + compare_annotations(self.adata, annotation_list = ['cluster1', 'cluster2'], metric = 'invalid') + + def test_heatmap_creation(self): + # check if a proper figure is created + fig = compare_annotations(self.adata, annotation_list = ['cluster1', 'cluster2', 'cluster3']) + + # Check function returns a Figure object + self.assertIsInstance(fig, Figure) + + # Check that the trace is a heatmap + self.assertEqual(fig.data[0].type, "heatmap") + + # Check the axis labels + x_labels = list(fig.data[0].x) + y_labels = list(fig.data[0].y) + self.assertEqual(x_labels, self.adata.uns["compare_annotations_list"]) + self.assertEqual(y_labels, self.adata.uns["compare_annotations_list"]) + + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file From 76f381316acd57ff9fbe3e5bea4e37b0a9ae3dd9 Mon Sep 17 00:00:00 2001 From: Arshnoor Randhawa Date: Thu, 1 May 2025 16:52:10 -0400 Subject: [PATCH 2/8] undo --- src/spac/transformations.py | 83 ------------------- .../test_compare_annotations.py | 81 ------------------ 2 files changed, 164 deletions(-) delete mode 100644 tests/test_transformations/test_compare_annotations.py diff --git a/src/spac/transformations.py b/src/spac/transformations.py index e5c319b4..f497d402 100644 --- a/src/spac/transformations.py +++ b/src/spac/transformations.py @@ -14,9 +14,6 @@ import multiprocessing import parmap from spac.utag_functions import utag -import plotly.express as px -from sklearn.metrics.cluster import adjusted_rand_score -from sklearn.metrics.cluster import normalized_mutual_info_score # Configure logging logging.basicConfig(level=logging.INFO, @@ -1167,83 +1164,3 @@ def run_utag_clustering( cluster_list = utag_results.obs[curClusterCol].copy() adata.obs[output_annotation] = cluster_list.copy() adata.uns["utag_features"] = features - -def compare_annotations(adata, annotation_list, metric="adjusted_rand_score"): - """ - Create matrix storing metric information with every combination of annotations given - (which can later be used to create a heatmap) - Metric information typically is similarity measure between different clusterings - - The function will add these two attributes to `adata`: - `.uns["compare_annotations"]` - The matrix (numpy array) that stores measure of mutual information between clusters - Each pair of annotations provided will have a value - - `.uns["compare_annotations_list"]` - The annotations used to calculate and create the matrix that stores - - Parameters - ---------- - adata : anndata.AnnData - The AnnData object. - annotation_list : list of str - List of names of (existing) annotations - metric : str - Metric type for calculations to be used - Should be adjusted_rand_score or normalized_mutual_info_score - Default = "adjusted_rand_score" - - Returns - ---------- - fig : plt.figure - it returns a heatmap of the return matrix - - """ - - #Input Validation - there should be more than 1 annotation in order to compare annotations - if len(annotation_list) < 2: - raise ValueError("annotation_list must contain at least 2 annotations") - - #Input Validation - make sure every annotation listed exists - for ann in annotation_list: - check_annotation(adata, annotations=ann) - - #2D array that will contain each computed score - matrix = [] - - #If metric = "adjusted_rand_score", compute the metric using adjusted_rand_score - if metric == "adjusted_rand_score": - for annotation1 in annotation_list: - scores = [] - for annotation2 in annotation_list: - scores.append(adjusted_rand_score(adata.obs[annotation1], adata.obs[annotation2])) - matrix.append(scores) - - #If metric = "normalized_mutual_info_score", compute the metric using normalized_mutual_info_score - elif metric == "normalized_mutual_info_score": - for annotation1 in annotation_list: - scores = [] - for annotation2 in annotation_list: - scores.append(normalized_mutual_info_score(adata.obs[annotation1], adata.obs[annotation2])) - matrix.append(scores) - - else: - raise ValueError("metric should be 'adjusted_rand_score' or 'normalized_mutual_info_score'") - - #Convert 2D list containing scores to numpy array - matrix_final = np.array(matrix) - - # creates a heatmap that corresponds to the final correlational matrix - fig = px.imshow(matrix_final, - labels=dict(x="Clusters", y="Clusters", color="Value"), - x=annotation_list, - y=annotation_list) - fig.update_xaxes(side="top") - - #Store output in AnnData - adata.uns["compare_annotations"] = matrix_final - - #Store list of annotations used in Anndata - adata.uns["compare_annotations_list"] = annotation_list - - return fig diff --git a/tests/test_transformations/test_compare_annotations.py b/tests/test_transformations/test_compare_annotations.py deleted file mode 100644 index e880f8c9..00000000 --- a/tests/test_transformations/test_compare_annotations.py +++ /dev/null @@ -1,81 +0,0 @@ -import unittest -import numpy as np -import pandas as pd -from anndata import AnnData -from spac.transformations import compare_annotations -from sklearn.metrics.cluster import adjusted_rand_score -from sklearn.metrics.cluster import normalized_mutual_info_score -from plotly.graph_objs import Figure - -class TestCompareAnnotations(unittest.TestCase): - - def setUp(self): - self.obs = pd.DataFrame({ - 'cluster1': [0, 1, 0, 1], - 'cluster2': [1, 0, 1, 0], - 'cluster3': [0, 0, 1, 2] - }) - self.adata = AnnData(X=np.random.rand(4, 5)) - self.adata.obs = self.obs.copy() - - def test_annotation_length(self): - # check if there are at least two annotations - with self.assertRaises(ValueError): - compare_annotations(self.adata, annotation_list = ['cluster1']) - - def test_annotation_exists(self): - # check if the annotations exist in the given list - with self.assertRaises(ValueError): - compare_annotations(self.adata, annotation_list = ['cluster1', 'cluster4']) - - def test_adjusted_rand_score(self): - # check the output of adjusted rand score - expected = adjusted_rand_score(self.adata.obs['cluster1'], self.adata.obs['cluster2']) - compare_annotations(self.adata, annotation_list = ['cluster1', 'cluster2']) - - # check adjusted rand score of comparing cluster1 and cluster2 - self.assertEqual(self.adata.uns["compare_annotations"][0,1], expected) - self.assertEqual(self.adata.uns["compare_annotations"][1,0], expected) - - # check adjusted rand score on matrix diagonal where cluster is compared to itself - self.assertEqual(self.adata.uns["compare_annotations"][0,0], 1.0) - self.assertEqual(self.adata.uns["compare_annotations"][1,1], 1.0) - - def test_normalized_mutual_info_score(self): - # check the output of normalized mutual info score - expected = normalized_mutual_info_score(self.adata.obs['cluster1'], self.adata.obs['cluster3']) - compare_annotations(self.adata, annotation_list=['cluster1', 'cluster3'], metric="normalized_mutual_info_score") - - # check normalized mutual info score of comparing cluster1 and cluster3 - self.assertEqual(self.adata.uns["compare_annotations"][0,1], expected) - self.assertEqual(self.adata.uns["compare_annotations"][1,0], expected) - - # check normalized mutual info score on matrix diagonal where cluster is compared to itself - self.assertEqual(self.adata.uns["compare_annotations"][0,0], 1.0) - self.assertEqual(self.adata.uns["compare_annotations"][1,1], 1.0) - - def test_metric_error(self): - # check if the given metric doesn't exist - with self.assertRaises(ValueError): - compare_annotations(self.adata, annotation_list = ['cluster1', 'cluster2'], metric = 'invalid') - - def test_heatmap_creation(self): - # check if a proper figure is created - fig = compare_annotations(self.adata, annotation_list = ['cluster1', 'cluster2', 'cluster3']) - - # Check function returns a Figure object - self.assertIsInstance(fig, Figure) - - # Check that the trace is a heatmap - self.assertEqual(fig.data[0].type, "heatmap") - - # Check the axis labels - x_labels = list(fig.data[0].x) - y_labels = list(fig.data[0].y) - self.assertEqual(x_labels, self.adata.uns["compare_annotations_list"]) - self.assertEqual(y_labels, self.adata.uns["compare_annotations_list"]) - - - -if __name__ == "__main__": - unittest.main() \ No newline at end of file From 9fe387bb37ddbaf26601a3f90532d16398a43f20 Mon Sep 17 00:00:00 2001 From: Sam Ying Date: Tue, 3 Jun 2025 09:56:50 -0400 Subject: [PATCH 3/8] addition of pinned colors, must be rebased to pass --- src/spac/visualization.py | 270 +++++++++++++++++++++++++++++++++++++- 1 file changed, 268 insertions(+), 2 deletions(-) diff --git a/src/spac/visualization.py b/src/spac/visualization.py index 0ab0ee11..da0ee387 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -35,7 +35,7 @@ def visualize_2D_scatter( x, y, labels=None, point_size=None, theme=None, ax=None, annotate_centers=False, x_axis_title='Component 1', y_axis_title='Component 2', plot_title=None, - color_representation=None, **kwargs + color_representation=None, color_map=None, **kwargs ): """ Visualize 2D data using plt.scatter. @@ -65,6 +65,8 @@ def visualize_2D_scatter( Title for the plot. color_representation : str, optional Description of what the colors represent. + color_map : dictionary, optional + Dictionary containing colors for label annotations. **kwargs Additional keyword arguments passed to plt.scatter. @@ -83,6 +85,8 @@ def visualize_2D_scatter( raise ValueError("x and y must have the same length.") if labels is not None and len(labels) != len(x): raise ValueError("Labels length should match x and y length.") + if color_map is not None and not isinstance(color_map, dict): + raise ValueError("`color_map` must be a dict mapping label→color.") # Define color themes themes = { @@ -141,15 +145,20 @@ def visualize_2D_scatter( cmap2 = plt.get_cmap('tab20b') cmap3 = plt.get_cmap('tab20c') colors = cmap1.colors + cmap2.colors + cmap3.colors + cluster_to_color = color_map if color_map is not None else { + str(cluster): colors[i % len(colors)] + for i, cluster in enumerate(unique_clusters) + } # Use the number of unique clusters to set the colormap length cmap = ListedColormap(colors[:len(unique_clusters)]) for idx, cluster in enumerate(unique_clusters): mask = np.array(labels) == cluster + color = cluster_to_color.get(str(cluster), 'gray') ax.scatter( x[mask], y[mask], - color=cmap(idx), + color=color, label=cluster, s=point_size ) @@ -196,6 +205,263 @@ def visualize_2D_scatter( return fig, ax +def embedded_scatter_plot( + adata, + method=None, + annotation=None, + feature=None, + layer=None, + ax=None, + associated_table=None, + spot_size=20, + alpha=0.5, + vmin=-999, + vmax=-999, + color_map=None, + **kwargs): + """ + Visualize scatter plot in PCA, t-SNE, UMAP, spatial or associated table. + + Parameters + ---------- + adata : anndata.AnnData + The AnnData object with coordinates precomputed by the 'tsne' or 'UMAP' + function and stored in 'adata.obsm["X_tsne"]' or 'adata.obsm["X_umap"]' + method : str, optional (default: None) + Visualization method specifying the coordinate system to plot. + Choose from {'tsne', 'umap', 'pca', 'spatial'}. + annotation : str, optional + The name of the column in `adata.obs` to use for coloring + the scatter plot points based on cell annotations. + feature : str, optional + The name of the gene or feature in `adata.var_names` to use + for coloring the scatter plot points based on feature expression. + layer : str, optional + The name of the data layer in `adata.layers` to use for visualization. + If None, the main data matrix `adata.X` is used. + ax : matplotlib.axes.Axes, optional (default: None) + A matplotlib axes object to plot on. + If not provided, a new figure and axes will be created. + associated_table : str, optional (default: None) + Name of the key in `obsm` that contains the numpy array. Takes + precedence over `method` + color_map : str, optional (default: None) + Name of the key in adata.uns that contains color-mapping for + the plot + **kwargs + Parameters passed to visualize_2D_scatter function, + including point_size. + + Returns + ------- + fig : matplotlib.figure.Figure + The created figure for the plot. + ax : matplotlib.axes.Axes + The axes of the plot. + """ + + # Check if both annotation and feature are specified, raise error if so + if annotation and feature: + raise ValueError( + "Please specify either an annotation or a feature for coloring, " + "not both.") + + # Use utility functions for input validation + if layer: + check_table(adata, tables=layer) + if annotation: + check_annotation(adata, annotations=annotation) + if feature: + check_feature(adata, features=[feature]) + color_mapping = None + if color_map is not None: + color_mapping = get_defined_color_map( + adata, + defined_color_map=color_map, + annotations=annotation + ) + + # Validate the method and check if the necessary data exists in adata.obsm + if associated_table is None: + valid_methods = ['tsne', 'umap', 'pca', 'spatial'] + if method not in valid_methods: + raise ValueError("Method should be one of {'tsne', 'umap', 'pca', 'spatial'}" + f'. Got:"{method}"') + if method == "spatial": + key = "spatial" + else: + key = f'X_{method}' + if key not in adata.obsm.keys(): + raise ValueError( + f"{key} coordinates not found in adata.obsm. " + f"Please run {method.upper()} before calling this function." + ) + + else: + check_table( + adata=adata, + tables=associated_table, + should_exist=True, + associated_table=True + ) + + associated_table_shape = adata.obsm[associated_table].shape + if associated_table_shape[1] != 2: + raise ValueError( + f'The associated table:"{associated_table}" does not have' + f' two dimensions. It shape is:"{associated_table_shape}"' + ) + key = associated_table + + err_msg_layer = "The 'layer' parameter must be a string, " + \ + f"got {str(type(layer))}" + err_msg_feature = "The 'feature' parameter must be a string, " + \ + f"got {str(type(feature))}" + err_msg_annotation = "The 'annotation' parameter must be a string, " + \ + f"got {str(type(annotation))}" + err_msg_feat_annotation_coe = "Both annotation and feature are passed, " +\ + "please provide sinle input." + err_msg_feat_annotation_non = "Both annotation and feature are None, " + \ + "please provide single input." + err_msg_spot_size = "The 'spot_size' parameter must be an integer, " + \ + f"got {str(type(spot_size))}" + err_msg_alpha_type = "The 'alpha' parameter must be a float," + \ + f"got {str(type(alpha))}" + err_msg_alpha_value = "The 'alpha' parameter must be between " + \ + f"0 and 1 (inclusive), got {str(alpha)}" + err_msg_vmin = "The 'vmin' parameter must be a float or an int, " + \ + f"got {str(type(vmin))}" + err_msg_vmax = "The 'vmax' parameter must be a float or an int, " + \ + f"got {str(type(vmax))}" + err_msg_ax = "The 'ax' parameter must be an instance " + \ + f"of matplotlib.axes.Axes, got {str(type(ax))}" + + if adata is None: + raise ValueError("The input dataset must not be None.") + + if not isinstance(adata, anndata.AnnData): + err_msg_adata = "The 'adata' parameter must be an " + \ + f"instance of anndata.AnnData, got {str(type(adata))}." + raise ValueError(err_msg_adata) + + if layer is not None and not isinstance(layer, str): + raise ValueError(err_msg_layer) + + if layer is not None and layer not in adata.layers.keys(): + err_msg_layer_exist = f"Layer {layer} does not exists, " + \ + f"available layers are {str(adata.layers.keys())}" + raise ValueError(err_msg_layer_exist) + + if feature is not None and not isinstance(feature, str): + raise ValueError(err_msg_feature) + + if annotation is not None and not isinstance(annotation, str): + raise ValueError(err_msg_annotation) + + if annotation is not None and feature is not None: + raise ValueError(err_msg_feat_annotation_coe) + + if key == "spatial": + if annotation is None and feature is None: + raise ValueError(err_msg_feat_annotation_non) + + if 'spatial' not in adata.obsm_keys(): + err_msg = "Spatial coordinates not found in the 'obsm' attribute." + raise ValueError(err_msg) + +# Extract feature name + if not isinstance(spot_size, int): + raise ValueError(err_msg_spot_size) + + if not isinstance(alpha, float): + raise ValueError(err_msg_alpha_type) + + if not (0 <= alpha <= 1): + raise ValueError(err_msg_alpha_value) + + if vmin != -999 and not ( + isinstance(vmin, float) or isinstance(vmin, int) + ): + raise ValueError(err_msg_vmin) + + if vmax != -999 and not ( + isinstance(vmax, float) or isinstance(vmax, int) + ): + raise ValueError(err_msg_vmax) + + if ax is not None and not isinstance(ax, plt.Axes): + raise ValueError(err_msg_ax) + + print(f'Running visualization using the coordinates: "{key}"') + + # Extract the 2D coordinates + x, y = adata.obsm[key].T + + # Determine coloring scheme + if color_mapping is None: + if annotation: + color_values = adata.obs[annotation].astype('category').values + color_representation = annotation + elif feature: + data_src = adata.layers[layer] if layer else adata.X + color_values = data_src[:, adata.var_names == feature].squeeze() + color_representation = feature + else: + color_values = None + color_representation = None + else: + color_values = adata.obs[annotation].astype('category').values + color_representation = annotation + + # Set axis titles based on method and color representation + if method == 'tsne': + x_axis_title = 't-SNE 1' + y_axis_title = 't-SNE 2' + plot_title = f'TSNE-{color_representation}' + elif method == 'pca': + x_axis_title = 'PCA 1' + y_axis_title = 'PCA 2' + plot_title = f'PCA-{color_representation}' + elif method == 'umap': + x_axis_title = 'UMAP 1' + y_axis_title = 'UMAP 2' + plot_title = f'UMAP-{color_representation}' + elif method == 'spatial': + x_axis_title = 'SPATIAL 1' + y_axis_title = 'SPATIAL 2' + plot_title = f'SPATIAL-{color_representation}' + + else: + x_axis_title = f'{associated_table} 1' + y_axis_title = f'{associated_table} 2' + plot_title = f'{associated_table}-{color_representation}' + + # Remove conflicting keys from kwargs + kwargs.pop('x_axis_title', None) + kwargs.pop('y_axis_title', None) + kwargs.pop('plot_title', None) + kwargs.pop('color_representation', None) + + # Set Min and Max in kwargs + kwargs['vmin'] = vmin + kwargs['vmax'] = vmax + + fig, ax = visualize_2D_scatter( + x=x, + y=y, + ax=ax, + labels=color_values, + x_axis_title=x_axis_title, + y_axis_title=y_axis_title, + plot_title=plot_title, + color_representation=color_representation, + color_map=color_mapping, + **kwargs + ) + + return fig, ax + + def dimensionality_reduction_plot( adata, method=None, From 4e59324d9fc456ba2cf38c6b774a4d391fe04ffb Mon Sep 17 00:00:00 2001 From: Sam Ying Date: Tue, 3 Jun 2025 09:56:50 -0400 Subject: [PATCH 4/8] addition of pinned colors, correcting rebase through cherry pick --- src/spac/visualization.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/spac/visualization.py b/src/spac/visualization.py index da0ee387..91e48ff4 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -15,7 +15,10 @@ from spac.utils import check_feature, annotation_category_relations from spac.utils import check_label from spac.utils import get_defined_color_map +<<<<<<< HEAD from spac.utils import compute_boxplot_metrics +======= +>>>>>>> 2a60092 (addition of pinned colors, must be rebased to pass) from functools import partial from spac.utils import color_mapping, spell_out_special_characters from spac.data_utils import select_values From ab1a87e3fc6740c22e9d3e7ded2735ca83c9f1db Mon Sep 17 00:00:00 2001 From: Sam Ying Date: Fri, 30 May 2025 14:04:41 -0400 Subject: [PATCH 5/8] Use of 2D_scatter for spatial --- src/spac/visualization.py | 59 ++ .../test_embedded_scatter_plot.py | 524 ++++++++++++++++++ 2 files changed, 583 insertions(+) create mode 100644 tests/test_visualization/test_embedded_scatter_plot.py diff --git a/src/spac/visualization.py b/src/spac/visualization.py index 91e48ff4..643e4a1e 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -276,6 +276,7 @@ def embedded_scatter_plot( check_annotation(adata, annotations=annotation) if feature: check_feature(adata, features=[feature]) +<<<<<<< HEAD color_mapping = None if color_map is not None: color_mapping = get_defined_color_map( @@ -283,6 +284,8 @@ def embedded_scatter_plot( defined_color_map=color_map, annotations=annotation ) +======= +>>>>>>> 2820b07 (Use of 2D_scatter for spatial) # Validate the method and check if the necessary data exists in adata.obsm if associated_table is None: @@ -315,6 +318,7 @@ def embedded_scatter_plot( f' two dimensions. It shape is:"{associated_table_shape}"' ) key = associated_table +<<<<<<< HEAD err_msg_layer = "The 'layer' parameter must be a string, " + \ f"got {str(type(layer))}" @@ -339,6 +343,32 @@ def embedded_scatter_plot( err_msg_ax = "The 'ax' parameter must be an instance " + \ f"of matplotlib.axes.Axes, got {str(type(ax))}" +======= + + err_msg_layer = "The 'layer' parameter must be a string, " + \ + f"got {str(type(layer))}" + err_msg_feature = "The 'feature' parameter must be a string, " + \ + f"got {str(type(feature))}" + err_msg_annotation = "The 'annotation' parameter must be a string, " + \ + f"got {str(type(annotation))}" + err_msg_feat_annotation_coe = "Both annotation and feature are passed, " +\ + "please provide sinle input." + err_msg_feat_annotation_non = "Both annotation and feature are None, " + \ + "please provide single input." + err_msg_spot_size = "The 'spot_size' parameter must be an integer, " + \ + f"got {str(type(spot_size))}" + err_msg_alpha_type = "The 'alpha' parameter must be a float," + \ + f"got {str(type(alpha))}" + err_msg_alpha_value = "The 'alpha' parameter must be between " + \ + f"0 and 1 (inclusive), got {str(alpha)}" + err_msg_vmin = "The 'vmin' parameter must be a float or an int, " + \ + f"got {str(type(vmin))}" + err_msg_vmax = "The 'vmax' parameter must be a float or an int, " + \ + f"got {str(type(vmax))}" + err_msg_ax = "The 'ax' parameter must be an instance " + \ + f"of matplotlib.axes.Axes, got {str(type(ax))}" + +>>>>>>> 2820b07 (Use of 2D_scatter for spatial) if adata is None: raise ValueError("The input dataset must not be None.") @@ -373,6 +403,11 @@ def embedded_scatter_plot( raise ValueError(err_msg) # Extract feature name +<<<<<<< HEAD +======= + feature_names = adata.var_names.tolist() + +>>>>>>> 2820b07 (Use of 2D_scatter for spatial) if not isinstance(spot_size, int): raise ValueError(err_msg_spot_size) @@ -401,6 +436,7 @@ def embedded_scatter_plot( x, y = adata.obsm[key].T # Determine coloring scheme +<<<<<<< HEAD if color_mapping is None: if annotation: color_values = adata.obs[annotation].astype('category').values @@ -415,6 +451,25 @@ def embedded_scatter_plot( else: color_values = adata.obs[annotation].astype('category').values color_representation = annotation +======= + if annotation: + color_values = adata.obs[annotation].astype('category').values + color_representation = annotation + vmin = None + vmax = None + elif feature: + data_source = adata.layers[layer] if layer else adata.X + color_values = data_source[:, adata.var_names == feature].squeeze() + color_representation = feature + feature_index = feature_names.index(feature) + if vmin == -999: + vmin = np.min(data_source[:, feature_index]) + if vmax == -999: + vmax = np.max(data_source[:, feature_index]) + else: + color_values = None + color_representation = None +>>>>>>> 2820b07 (Use of 2D_scatter for spatial) # Set axis titles based on method and color representation if method == 'tsne': @@ -444,6 +499,10 @@ def embedded_scatter_plot( kwargs.pop('y_axis_title', None) kwargs.pop('plot_title', None) kwargs.pop('color_representation', None) + + # Set Min and Max in kwargs + kwargs['vmin'] = vmin + kwargs['vmax'] = vmax # Set Min and Max in kwargs kwargs['vmin'] = vmin diff --git a/tests/test_visualization/test_embedded_scatter_plot.py b/tests/test_visualization/test_embedded_scatter_plot.py new file mode 100644 index 00000000..fff9aad3 --- /dev/null +++ b/tests/test_visualization/test_embedded_scatter_plot.py @@ -0,0 +1,524 @@ +import unittest +from unittest.mock import patch +import anndata +import numpy as np +import matplotlib +import matplotlib.pyplot as plt +import itertools +from spac.visualization import embedded_scatter_plot +matplotlib.use('Agg') + + +class TestStaticScatterPlot(unittest.TestCase): + + def setUp(self): + self.adata = anndata.AnnData(X=np.random.rand(10, 10)) + self.adata.obsm['X_tsne'] = np.random.rand(10, 2) + self.adata.obsm['X_umap'] = np.random.rand(10, 2) + self.adata.obsm['X_pca'] = np.random.rand(10, 2) + self.adata.obsm['sumap'] = np.random.rand(10, 2) + self.adata.obsm['3dsumap'] = np.random.rand(10, 3) + self.adata.obs['annotation_column'] = np.random.choice( + ['A', 'B', 'C'], size=10 + ) + self.adata.var_names = ['gene_' + str(i) for i in range(10)] + + def test_missing_umap_coordinates(self): + del self.adata.obsm['X_umap'] + with self.assertRaises(ValueError) as cm: + embedded_scatter_plot(self.adata, 'umap') + expected_msg = ( + "X_umap coordinates not found in adata.obsm. " + "Please run UMAP before calling this function." + ) + self.assertEqual(str(cm.exception), expected_msg) + + def test_missing_tsne_coordinates(self): + del self.adata.obsm['X_tsne'] + with self.assertRaises(ValueError) as cm: + embedded_scatter_plot(self.adata, 'tsne') + expected_msg = ( + "X_tsne coordinates not found in adata.obsm. " + "Please run TSNE before calling this function." + ) + self.assertEqual(str(cm.exception), expected_msg) + + def test_annotation_and_feature(self): + with self.assertRaises(ValueError) as cm: + embedded_scatter_plot( + self.adata, 'tsne', + annotation='annotation_column', + feature='feature_column' + ) + expected_msg = ( + "Please specify either an annotation or a feature for coloring, " + "not both." + ) + self.assertEqual(str(cm.exception), expected_msg) + + def test_annotation_column(self): + fig, ax = embedded_scatter_plot( + self.adata, 'tsne', annotation='annotation_column' + ) + self.assertIsNotNone(fig) + self.assertIsNotNone(ax) + self.assertEqual(ax.get_xlabel(), 't-SNE 1') + self.assertEqual(ax.get_ylabel(), 't-SNE 2') + self.assertEqual(ax.get_title(), 'TSNE-annotation_column') + + def test_associated_table(self): + fig, ax = embedded_scatter_plot( + self.adata, + annotation='annotation_column', + associated_table='sumap' + ) + self.assertIsNotNone(fig) + self.assertIsNotNone(ax) + self.assertEqual(ax.get_xlabel(), 'sumap 1') + self.assertEqual(ax.get_ylabel(), 'sumap 2') + self.assertEqual(ax.get_title(), 'sumap-annotation_column') + + def test_feature_column(self): + fig, ax = embedded_scatter_plot( + self.adata, 'tsne', feature='gene_1' + ) + self.assertIsNotNone(fig) + self.assertIsNotNone(ax) + self.assertEqual(ax.get_xlabel(), 't-SNE 1') + self.assertEqual(ax.get_ylabel(), 't-SNE 2') + self.assertEqual(ax.get_title(), 'TSNE-gene_1') + + def test_ax_provided(self): + fig, ax_provided = plt.subplots() + fig_returned, ax_returned = embedded_scatter_plot( + self.adata, 'tsne', ax=ax_provided + ) + self.assertIs(fig, fig_returned) + self.assertIs(ax_provided, ax_returned) + + def test_real_tsne_plot(self): + fig, ax = embedded_scatter_plot( + self.adata, 'tsne', annotation='annotation_column' + ) + self.assertIsInstance(fig, plt.Figure) + self.assertIsInstance(ax, plt.Axes) + self.assertEqual(ax.get_xlabel(), 't-SNE 1') + self.assertEqual(ax.get_ylabel(), 't-SNE 2') + self.assertEqual(ax.get_title(), 'TSNE-annotation_column') + + def test_real_umap_plot(self): + fig, ax = embedded_scatter_plot( + self.adata, 'umap', feature='gene_1' + ) + self.assertIsInstance(fig, plt.Figure) + self.assertIsInstance(ax, plt.Axes) + self.assertEqual(ax.get_xlabel(), 'UMAP 1') + self.assertEqual(ax.get_ylabel(), 'UMAP 2') + self.assertEqual(ax.get_title(), 'UMAP-gene_1') + + def test_real_pca_plot(self): + fig, ax = embedded_scatter_plot( + self.adata, 'pca', annotation='annotation_column' + ) + self.assertIsInstance(fig, plt.Figure) + self.assertIsInstance(ax, plt.Axes) + self.assertEqual(ax.get_xlabel(), 'PCA 1') + self.assertEqual(ax.get_ylabel(), 'PCA 2') + self.assertEqual(ax.get_title(), 'PCA-annotation_column') + + def test_invalid_method(self): + with self.assertRaises(ValueError) as cm: + embedded_scatter_plot(self.adata, 'invalid_method') + expected_msg = ("Method should be one of {'tsne', 'umap', 'pca'," + " 'spatial'}." + ' Got:"invalid_method"' + ) + self.assertEqual(str(cm.exception), expected_msg) + + def test_input_derived_feature_3d(self): + with self.assertRaises(ValueError) as cm: + embedded_scatter_plot( + self.adata, + associated_table='3dsumap') + expected_msg = ('The associated table:"3dsumap" does not have' + ' two dimensions. It shape is:"(10, 3)"') + + self.assertEqual(str(cm.exception), expected_msg) + + def test_conflicting_kwargs(self): + # This test ensures conflicting keys are removed from kwargs + fig, ax = embedded_scatter_plot( + self.adata, + 'tsne', + annotation='annotation_column', + x_axis_title='Conflict X', + y_axis_title='Conflict Y', + plot_title='Conflict Title', + color_representation='Conflict Color' + ) + self.assertIsNotNone(fig) + self.assertIsNotNone(ax) + self.assertEqual(ax.get_xlabel(), 't-SNE 1') + self.assertEqual(ax.get_ylabel(), 't-SNE 2') + self.assertEqual(ax.get_title(), 'TSNE-annotation_column') + + +class SpatialPlotTestCase(unittest.TestCase): + def setUp(self): + # Set up test data + num_rows = 100 + num_cols = 100 + + # Generate data matrix + features = np.random.randint(0, 100, size=(num_rows, num_cols)) + feature_names = [f'Intensity_{i}' for i in range(num_cols)] + + # Generate annotation metadata + annotation_data = { + 'annotation1': np.random.choice(['A', 'B', 'C'], size=num_rows), + 'annotation2': np.random.normal(0, 1, size=num_rows), + 'annotation3': np.random.uniform(0, 1, size=num_rows) + } + + # Create the AnnData object + self.adata = anndata.AnnData( + X=features, + obs=annotation_data + ) + + numpy_array = np.random.uniform(0, 100, size=(num_rows, 2)) + self.adata.obsm["spatial"] = numpy_array + self.adata.var_names = feature_names + + # Generate layer data + layer_data = np.random.randint(0, 100, size=(num_rows, num_cols)) + self.adata.layers['Normalized'] = layer_data + self.adata.layers['Standardized'] = layer_data + + self.spot_size = 10 + self.alpha = 0.5 + + def test_invalid_adata(self): + # Test when adata is not an instance of anndata.AnnData + with self.assertRaises(ValueError): + embedded_scatter_plot(adata=None, + method='spatial', + spot_size=self.spot_size, + alpha=self.alpha + ) + + def test_invalid_layer(self): + # Test when layer is not a string + with self.assertRaises(ValueError): + embedded_scatter_plot(adata=self.adata, + method='spatial', + layer=123, + spot_size=self.spot_size, + alpha=self.alpha + ) + + def test_invalid_feature(self): + # Test when feature is not a string + with self.assertRaises(ValueError): + embedded_scatter_plot(adata=self.adata, + method='spatial', + feature=123, + spot_size=self.spot_size, + alpha=self.alpha + ) + + def test_invalid_annotation(self): + # Test when annotation is not a string + with self.assertRaises(ValueError): + embedded_scatter_plot( + adata=self.adata, + method='spatial', + annotation=123, + spot_size=self.spot_size, + alpha=self.alpha + ) + + def test_invalid_spot_size(self): + # Test when spot_size is not an integer + with self.assertRaises(ValueError): + embedded_scatter_plot(adata=self.adata, method='spatial', + spot_size=10.5, alpha=self.alpha) + + def test_invalid_alpha(self): + # Test when alpha is not a float + with self.assertRaises(ValueError): + embedded_scatter_plot(adata=self.adata, method='spatial', + spot_size=self.spot_size, alpha="0.5") + + def test_invalid_alpha_range(self): + # Test when alpha is outside the range of 0 to 1 + with self.assertRaises(ValueError): + embedded_scatter_plot(adata=self.adata, method='spatial', + spot_size=self.spot_size, alpha=-0.5) + + def test_missing_annotation(self): + # Test when annotation is None and feature is None + with self.assertRaises(ValueError) as cm: + embedded_scatter_plot(adata=self.adata, method='spatial', + spot_size=self.spot_size, alpha=self.alpha) + error_msg = str(cm.exception) + err_msg_exp = "Both annotation and feature are None, " + \ + "please provide single input." + self.assertEqual(error_msg, err_msg_exp) + + def test_invalid_annotation_name(self): + # Test when annotation name is not found in the dataset + with self.assertRaises(ValueError) as cm: + embedded_scatter_plot( + adata=self.adata, + method='spatial', + annotation='annotation4', + spot_size=self.spot_size, + alpha=self.alpha + ) + error_msg = str(cm.exception) + err_msg_exp = ("The annotation 'annotation4' does not exist " + "in the provided dataset.\n" + "Existing annotations are:\n" + "annotation1\nannotation2\nannotation3" + ) + self.assertEqual(error_msg, err_msg_exp) + + def test_invalid_feature_name(self): + # Test when feature name is not found in the layer + with self.assertRaises(ValueError) as cm: + embedded_scatter_plot( + adata=self.adata, + method='spatial', + feature='feature1', + spot_size=self.spot_size, + alpha=self.alpha + ) + error_msg = str(cm.exception) + target_features = "\n".join(self.adata.var_names) + err_msg_exp = ("The feature 'feature1' does not exist " + "in the provided dataset.\n" + f"Existing features are:\n{target_features}" + ) + self.assertEqual(error_msg, err_msg_exp) + + def test_spatial_plot_annotation(self): + # This test verifies that the spatial_plot function + # correctly interacts with the spatial function, + # and returns the expected Axes object when + # the annotation parameter is used. + + # Mock the spatial function + # This mock the spatial function to replace the + # original implementation.The mock function, mock_spatial, + # checks if the inputs match the expected values. + # The purpose of mocking the spatial function with the + # mock_spatial function is to simulate the behavior of the + # original spatial function during testing and verify if the + # inputs passed to it match the expected values. + def mock_spatial( + adata, + color, + spot_size, + alpha, + ax, + show, + layer, + vmin, + vmax, + **kwargs): + # Assert that the inputs match the expected values + assert layer is None + self.assertEqual(color, 'annotation1') + self.assertEqual(spot_size, self.spot_size) + self.assertEqual(alpha, self.alpha) + # self.assertEqual(vmin, None) + assert vmax is None + assert vmin is None + # self.assertEqual(vmax, None) + self.assertIsInstance(ax, plt.Axes) + self.assertFalse(show) + # Return a list containing the ax object to mimic + # the behavior of the spatial function + return [ax] + + # Mock the spatial function with the mock_spatial function + # spatial_plot.__globals__['sc.pl.spatial'] = mock_spatial + + with patch('scanpy.pl.spatial', new=mock_spatial): + # Create an instance of Axes + ax = plt.Axes( + plt.figure(), + rect=[0, 0, 1, 1] + ) + + # Call the spatial_plot function with the ax object + fig, returned_ax = embedded_scatter_plot( + adata=self.adata, + method='spatial', + annotation='annotation1', + layer=None, + ax=ax, + spot_size=self.spot_size, + alpha=self.alpha + ) + + # Assert that the spatial_plot function returned a list + # containing an Axes object with the same properties + self.assertEqual(returned_ax.get_title(), ax.get_title()) + self.assertEqual(returned_ax.get_xlabel(), ax.get_xlabel()) + self.assertEqual(returned_ax.get_ylabel(), ax.get_ylabel()) + + # Restore the original spatial function + # del spatial_plot.__globals__['sc.pl.spatial'] + + def test_spatial_plot_feature(self): + # Mock the spatial function + def mock_spatial( + adata, + layer, + color, + spot_size, + alpha, + vmin, + vmax, + ax, + show, + **kwargs): + # Assert that the inputs match the expected values + assert layer is None + self.assertEqual(color, 'Intensity_10spatial_plot') + self.assertEqual(spot_size, self.spot_size) + self.assertEqual(alpha, self.alpha) + self.assertEqual(vmin, 0) + self.assertEqual(vmax, 100) + self.assertIsInstance(ax, plt.Axes) + self.assertFalse(show) + # Return a list containing the ax object to mimic + # the behavior of the spatial function + return [ax] + + # Mock the spatial function with the mock_spatial function + # spatial_plot.__globals__['sc.pl.spatial'] = mock_spatial + with patch('scanpy.pl.spatial', new=mock_spatial): + + # Create an instance of Axes + ax = plt.Axes( + plt.figure(), + rect=[0, 0, 1, 1] + ) + + # Call the spatial_plot function with the ax object + fig, returned_ax = embedded_scatter_plot( + adata=self.adata, + method='spatial', + feature='Intensity_10', + ax=ax, + spot_size=self.spot_size, + alpha=self.alpha, + vmin=0, + vmax=100 + ) + + # Assert that the spatial_plot function returned a list + # containing an Axes object with the same properties + self.assertEqual(returned_ax.get_title(), ax.get_title()) + self.assertEqual(returned_ax.get_xlabel(), ax.get_xlabel()) + self.assertEqual(returned_ax.get_ylabel(), ax.get_ylabel()) + + # Restore the original spatial function + # del spatial_plot.__globals__['sc.pl.spatial'] + + def test_spatial_plot_combos_feature(self): + # Define the parameter combinations to test + spot_sizes = [10, 20] + alphas = [0.5, 0.8] + vmins = [-999, 0, 5] + vmaxs = [-999, 10, 20] + features = ['Intensity_20', 'Intensity_80'] + layers = [None, 'Standardized'] + + # Generate all combinations of parameters + # excluding both None values for features and annotations + + parameter_combinations = list(itertools.product( + spot_sizes, alphas, vmins, vmaxs, features, layers + )) + + parameter_combinations = [ + params for params in parameter_combinations if not ( + params[4] is None and params[5] is None + ) + ] + + for params in parameter_combinations: + spot_size, alpha, vmin, vmax, feature, layer = params + print(layer) + # Test the spatial_plot function with the + # given parameter combination + + returned_fig, ax = embedded_scatter_plot( + adata=self.adata, + method='spatial', + feature=feature, + layer=layer, + spot_size=spot_size, + alpha=alpha, + vmin=vmin, + vmax=vmax + ) + + # Perform assertions on the spatial plot + # Check if ax has data plotted + self.assertTrue(ax.has_data()) + + def test_spatial_plot_combos_annotation(self): + # Define the parameter combinations to test + spot_sizes = [10, 20] + alphas = [0.5, 0.8] + vmins = [-999, 0, 5] + vmaxs = [-999, 10, 20] + annotation = ['annotation1', 'annotation2'] + layers = [None, 'Normalized'] + + # Generate all combinations of parameters + # excluding both None values for features and annotations + + parameter_combinations = list(itertools.product( + spot_sizes, alphas, vmins, vmaxs, annotation, layers + )) + + parameter_combinations = [ + params for params in parameter_combinations if not ( + params[4] is None and params[5] is None + ) + ] + + for params in parameter_combinations: + spot_size, alpha, vmin, vmax, annotation, layer = params + # Test the spatial_plot function with the + # given parameter combination + + fig = plt.figure() + ax = fig.add_subplot(1, 1, 1) + + returned_fig, ax = embedded_scatter_plot( + adata=self.adata, + method='spatial', + annotation=annotation, + layer=layer, + ax=ax, + spot_size=spot_size, + alpha=alpha, + vmin=vmin, + vmax=vmax + ) + + plt.close(fig) + # Perform assertions on the spatial plot + # Check if ax has data plotted + self.assertTrue(ax.has_data()) + + +if __name__ == '__main__': + unittest.main() From 3ef099e794d17c33098be03c22e9aec97d0c0619 Mon Sep 17 00:00:00 2001 From: Sam Ying Date: Tue, 3 Jun 2025 15:14:38 -0400 Subject: [PATCH 6/8] Combined dimensionality and spatial plot --- src/spac/visualization.py | 62 ------------------- .../test_embedded_scatter_plot.py | 29 +++++++++ .../test_visualize_2D_scatter.py | 20 ++++++ 3 files changed, 49 insertions(+), 62 deletions(-) diff --git a/src/spac/visualization.py b/src/spac/visualization.py index 643e4a1e..da0ee387 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -15,10 +15,7 @@ from spac.utils import check_feature, annotation_category_relations from spac.utils import check_label from spac.utils import get_defined_color_map -<<<<<<< HEAD from spac.utils import compute_boxplot_metrics -======= ->>>>>>> 2a60092 (addition of pinned colors, must be rebased to pass) from functools import partial from spac.utils import color_mapping, spell_out_special_characters from spac.data_utils import select_values @@ -276,7 +273,6 @@ def embedded_scatter_plot( check_annotation(adata, annotations=annotation) if feature: check_feature(adata, features=[feature]) -<<<<<<< HEAD color_mapping = None if color_map is not None: color_mapping = get_defined_color_map( @@ -284,8 +280,6 @@ def embedded_scatter_plot( defined_color_map=color_map, annotations=annotation ) -======= ->>>>>>> 2820b07 (Use of 2D_scatter for spatial) # Validate the method and check if the necessary data exists in adata.obsm if associated_table is None: @@ -318,7 +312,6 @@ def embedded_scatter_plot( f' two dimensions. It shape is:"{associated_table_shape}"' ) key = associated_table -<<<<<<< HEAD err_msg_layer = "The 'layer' parameter must be a string, " + \ f"got {str(type(layer))}" @@ -343,32 +336,6 @@ def embedded_scatter_plot( err_msg_ax = "The 'ax' parameter must be an instance " + \ f"of matplotlib.axes.Axes, got {str(type(ax))}" -======= - - err_msg_layer = "The 'layer' parameter must be a string, " + \ - f"got {str(type(layer))}" - err_msg_feature = "The 'feature' parameter must be a string, " + \ - f"got {str(type(feature))}" - err_msg_annotation = "The 'annotation' parameter must be a string, " + \ - f"got {str(type(annotation))}" - err_msg_feat_annotation_coe = "Both annotation and feature are passed, " +\ - "please provide sinle input." - err_msg_feat_annotation_non = "Both annotation and feature are None, " + \ - "please provide single input." - err_msg_spot_size = "The 'spot_size' parameter must be an integer, " + \ - f"got {str(type(spot_size))}" - err_msg_alpha_type = "The 'alpha' parameter must be a float," + \ - f"got {str(type(alpha))}" - err_msg_alpha_value = "The 'alpha' parameter must be between " + \ - f"0 and 1 (inclusive), got {str(alpha)}" - err_msg_vmin = "The 'vmin' parameter must be a float or an int, " + \ - f"got {str(type(vmin))}" - err_msg_vmax = "The 'vmax' parameter must be a float or an int, " + \ - f"got {str(type(vmax))}" - err_msg_ax = "The 'ax' parameter must be an instance " + \ - f"of matplotlib.axes.Axes, got {str(type(ax))}" - ->>>>>>> 2820b07 (Use of 2D_scatter for spatial) if adata is None: raise ValueError("The input dataset must not be None.") @@ -403,11 +370,6 @@ def embedded_scatter_plot( raise ValueError(err_msg) # Extract feature name -<<<<<<< HEAD -======= - feature_names = adata.var_names.tolist() - ->>>>>>> 2820b07 (Use of 2D_scatter for spatial) if not isinstance(spot_size, int): raise ValueError(err_msg_spot_size) @@ -436,7 +398,6 @@ def embedded_scatter_plot( x, y = adata.obsm[key].T # Determine coloring scheme -<<<<<<< HEAD if color_mapping is None: if annotation: color_values = adata.obs[annotation].astype('category').values @@ -451,25 +412,6 @@ def embedded_scatter_plot( else: color_values = adata.obs[annotation].astype('category').values color_representation = annotation -======= - if annotation: - color_values = adata.obs[annotation].astype('category').values - color_representation = annotation - vmin = None - vmax = None - elif feature: - data_source = adata.layers[layer] if layer else adata.X - color_values = data_source[:, adata.var_names == feature].squeeze() - color_representation = feature - feature_index = feature_names.index(feature) - if vmin == -999: - vmin = np.min(data_source[:, feature_index]) - if vmax == -999: - vmax = np.max(data_source[:, feature_index]) - else: - color_values = None - color_representation = None ->>>>>>> 2820b07 (Use of 2D_scatter for spatial) # Set axis titles based on method and color representation if method == 'tsne': @@ -499,10 +441,6 @@ def embedded_scatter_plot( kwargs.pop('y_axis_title', None) kwargs.pop('plot_title', None) kwargs.pop('color_representation', None) - - # Set Min and Max in kwargs - kwargs['vmin'] = vmin - kwargs['vmax'] = vmax # Set Min and Max in kwargs kwargs['vmin'] = vmin diff --git a/tests/test_visualization/test_embedded_scatter_plot.py b/tests/test_visualization/test_embedded_scatter_plot.py index fff9aad3..202f9429 100644 --- a/tests/test_visualization/test_embedded_scatter_plot.py +++ b/tests/test_visualization/test_embedded_scatter_plot.py @@ -256,6 +256,12 @@ def test_invalid_alpha_range(self): embedded_scatter_plot(adata=self.adata, method='spatial', spot_size=self.spot_size, alpha=-0.5) + def test_invalid_theme(self): + # Should raise ValueError for invalid theme + with self.assertRaises(ValueError): + embedded_scatter_plot(self.adata, 'umap', annotation='cat_anno', + theme='not_a_theme') + def test_missing_annotation(self): # Test when annotation is None and feature is None with self.assertRaises(ValueError) as cm: @@ -519,6 +525,29 @@ def test_spatial_plot_combos_annotation(self): # Check if ax has data plotted self.assertTrue(ax.has_data()) + def test_color_map_from_uns(self): + # Should use color map from adata.uns + self.adata.uns['anno_colors'] = {'A': '#111111', 'B': '#222222', + 'C': '#333333', 'D': '#444444'} + returned_fig, ax = embedded_scatter_plot( + adata=self.adata, + method='spatial', + annotation='annotation1', + color_map='anno_colors' + ) + self.assertTrue(ax.has_data()) + + def test_color_map_input(self): + # Should raise failure as it accepts colormap as str not dict + anno_colors = {'A': '#111111', 'B': '#222222', + 'C': '#333333', 'D': '#444444'} + with self.assertRaises(TypeError): + embedded_scatter_plot( + adata=self.adata, + method='spatial', + annotation='annotation1', + color_map=anno_colors) + if __name__ == '__main__': unittest.main() diff --git a/tests/test_visualization/test_visualize_2D_scatter.py b/tests/test_visualization/test_visualize_2D_scatter.py index f79b792c..39a44cc3 100644 --- a/tests/test_visualization/test_visualize_2D_scatter.py +++ b/tests/test_visualization/test_visualize_2D_scatter.py @@ -129,6 +129,26 @@ def test_color_representation(self): legend = axis.get_legend() self.assertIn(color_representation, legend.get_title().get_text()) + def test_color_map(self): + color_map = {'A': '#123456', 'B': '#abcdef', 'C': '#654321'} + fig, ax = visualize_2D_scatter( + self.x, self.y, labels=self.labels_categorical, + color_map=color_map) + # Get all PathCollections (one per category) + collections = [c for c in ax.collections if + isinstance(c, matplotlib.collections.PathCollection)] + import matplotlib.colors as mcolors + for i, cluster in enumerate(['A', 'B', 'C']): + expected_rgba = mcolors.to_rgba(color_map[cluster]) + facecolors = collections[i].get_facecolors() + # All points in this collection should have the same color + for fc in facecolors: + self.assertTrue( + all(abs(a - b) < 1e-3 for a, b in zip(fc, expected_rgba)), + f"Color mismatch for {cluster}: expected {expected_rgba}," + f"got {fc}" + ) + if __name__ == '__main__': unittest.main() From 95ca56dead6004b3dceb5f6498c955e41b4dec60 Mon Sep 17 00:00:00 2001 From: Sam Ying Date: Tue, 1 Jul 2025 11:31:03 -0400 Subject: [PATCH 7/8] correction of errors and unittests --- src/spac/visualization.py | 42 ++++------------- .../test_embedded_scatter_plot.py | 46 +++++++++++++++---- 2 files changed, 46 insertions(+), 42 deletions(-) diff --git a/src/spac/visualization.py b/src/spac/visualization.py index da0ee387..4a5984a6 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -13,7 +13,7 @@ from matplotlib.colors import ListedColormap, BoundaryNorm from spac.utils import check_table, check_annotation from spac.utils import check_feature, annotation_category_relations -from spac.utils import check_label +from spac.utils import check_label, check_list_in_list from spac.utils import get_defined_color_map from spac.utils import compute_boxplot_metrics from functools import partial @@ -150,9 +150,6 @@ def visualize_2D_scatter( for i, cluster in enumerate(unique_clusters) } - # Use the number of unique clusters to set the colormap length - cmap = ListedColormap(colors[:len(unique_clusters)]) - for idx, cluster in enumerate(unique_clusters): mask = np.array(labels) == cluster color = cluster_to_color.get(str(cluster), 'gray') @@ -205,7 +202,7 @@ def visualize_2D_scatter( return fig, ax -def embedded_scatter_plot( +def embedded_scatter_plot( adata, method=None, annotation=None, @@ -284,9 +281,11 @@ def embedded_scatter_plot( # Validate the method and check if the necessary data exists in adata.obsm if associated_table is None: valid_methods = ['tsne', 'umap', 'pca', 'spatial'] - if method not in valid_methods: - raise ValueError("Method should be one of {'tsne', 'umap', 'pca', 'spatial'}" - f'. Got:"{method}"') + check_list_in_list(input=method, input_name="method", + input_type="method", + target_list=valid_methods, + need_exist=True + ) if method == "spatial": key = "spatial" else: @@ -313,14 +312,6 @@ def embedded_scatter_plot( ) key = associated_table - err_msg_layer = "The 'layer' parameter must be a string, " + \ - f"got {str(type(layer))}" - err_msg_feature = "The 'feature' parameter must be a string, " + \ - f"got {str(type(feature))}" - err_msg_annotation = "The 'annotation' parameter must be a string, " + \ - f"got {str(type(annotation))}" - err_msg_feat_annotation_coe = "Both annotation and feature are passed, " +\ - "please provide sinle input." err_msg_feat_annotation_non = "Both annotation and feature are None, " + \ "please provide single input." err_msg_spot_size = "The 'spot_size' parameter must be an integer, " + \ @@ -344,27 +335,10 @@ def embedded_scatter_plot( f"instance of anndata.AnnData, got {str(type(adata))}." raise ValueError(err_msg_adata) - if layer is not None and not isinstance(layer, str): - raise ValueError(err_msg_layer) - - if layer is not None and layer not in adata.layers.keys(): - err_msg_layer_exist = f"Layer {layer} does not exists, " + \ - f"available layers are {str(adata.layers.keys())}" - raise ValueError(err_msg_layer_exist) - - if feature is not None and not isinstance(feature, str): - raise ValueError(err_msg_feature) - - if annotation is not None and not isinstance(annotation, str): - raise ValueError(err_msg_annotation) - - if annotation is not None and feature is not None: - raise ValueError(err_msg_feat_annotation_coe) - if key == "spatial": if annotation is None and feature is None: raise ValueError(err_msg_feat_annotation_non) - + if 'spatial' not in adata.obsm_keys(): err_msg = "Spatial coordinates not found in the 'obsm' attribute." raise ValueError(err_msg) diff --git a/tests/test_visualization/test_embedded_scatter_plot.py b/tests/test_visualization/test_embedded_scatter_plot.py index 202f9429..76f30e37 100644 --- a/tests/test_visualization/test_embedded_scatter_plot.py +++ b/tests/test_visualization/test_embedded_scatter_plot.py @@ -6,6 +6,7 @@ import matplotlib.pyplot as plt import itertools from spac.visualization import embedded_scatter_plot +from matplotlib.collections import PathCollection matplotlib.use('Agg') @@ -129,9 +130,14 @@ def test_real_pca_plot(self): def test_invalid_method(self): with self.assertRaises(ValueError) as cm: embedded_scatter_plot(self.adata, 'invalid_method') - expected_msg = ("Method should be one of {'tsne', 'umap', 'pca'," - " 'spatial'}." - ' Got:"invalid_method"' + expected_msg = ( + "The method 'invalid_method' does not exist" + " in the provided dataset.\n" + "Existing methods are:\n" + "tsne\n" + "umap\n" + "pca\n" + "spatial" ) self.assertEqual(str(cm.exception), expected_msg) @@ -526,21 +532,45 @@ def test_spatial_plot_combos_annotation(self): self.assertTrue(ax.has_data()) def test_color_map_from_uns(self): + # Color conversion formula + def hex_to_rgb(hex_color): + hex_color = hex_color.lstrip('#') + return tuple(int(hex_color[i:i+2], 16) / 255.0 for i in (0, 2, 4)) + # Should use color map from adata.uns - self.adata.uns['anno_colors'] = {'A': '#111111', 'B': '#222222', - 'C': '#333333', 'D': '#444444'} + self.adata.uns['anno_colors'] = {'A': '#ff1111', 'B': '#22ff22', + 'C': '#ee3333'} + annotation = 'annotation1' returned_fig, ax = embedded_scatter_plot( adata=self.adata, method='spatial', - annotation='annotation1', + annotation=annotation, color_map='anno_colors' ) - self.assertTrue(ax.has_data()) + expected_rgb = {k: hex_to_rgb(v) for k, v in + self.adata.uns['anno_colors'].items()} + found_labels = set() + for collection in ax.collections: + if isinstance(collection, PathCollection): + label = collection.get_label() + color = collection.get_facecolor()[0][:3] + expected = expected_rgb.get(label) + self.assertIsNotNone(expected, f"Unexpected label" + f"'{label}' in plot.") + np.testing.assert_allclose(color, expected, atol=1e-2, + err_msg=f"Color mismatch" + f"for label '{label}':" + f" got {color}," + f" expected {expected}") + found_labels.add(label) + + # Ensure all expected labels were found in the plot + self.assertEqual(found_labels, set(expected_rgb.keys())) def test_color_map_input(self): # Should raise failure as it accepts colormap as str not dict anno_colors = {'A': '#111111', 'B': '#222222', - 'C': '#333333', 'D': '#444444'} + 'C': '#333333'} with self.assertRaises(TypeError): embedded_scatter_plot( adata=self.adata, From 729cfeeac40d1ad9652f2b1782ec2619cc06b4a3 Mon Sep 17 00:00:00 2001 From: Sam Ying Date: Tue, 1 Jul 2025 11:55:06 -0400 Subject: [PATCH 8/8] test_visualize_2d test case update --- tests/test_visualization/test_visualize_2D_scatter.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_visualization/test_visualize_2D_scatter.py b/tests/test_visualization/test_visualize_2D_scatter.py index 39a44cc3..f97c6af2 100644 --- a/tests/test_visualization/test_visualize_2D_scatter.py +++ b/tests/test_visualization/test_visualize_2D_scatter.py @@ -138,9 +138,11 @@ def test_color_map(self): collections = [c for c in ax.collections if isinstance(c, matplotlib.collections.PathCollection)] import matplotlib.colors as mcolors - for i, cluster in enumerate(['A', 'B', 'C']): + # Map collections to their labels + collection_map = {c.get_label(): c for c in collections} + for cluster in ['A', 'B', 'C']: expected_rgba = mcolors.to_rgba(color_map[cluster]) - facecolors = collections[i].get_facecolors() + facecolors = collection_map[cluster].get_facecolors() # All points in this collection should have the same color for fc in facecolors: self.assertTrue(