diff --git a/src/spac/visualization.py b/src/spac/visualization.py index 0ab0ee11..4a5984a6 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -13,7 +13,7 @@ from matplotlib.colors import ListedColormap, BoundaryNorm from spac.utils import check_table, check_annotation from spac.utils import check_feature, annotation_category_relations -from spac.utils import check_label +from spac.utils import check_label, check_list_in_list from spac.utils import get_defined_color_map from spac.utils import compute_boxplot_metrics from functools import partial @@ -35,7 +35,7 @@ def visualize_2D_scatter( x, y, labels=None, point_size=None, theme=None, ax=None, annotate_centers=False, x_axis_title='Component 1', y_axis_title='Component 2', plot_title=None, - color_representation=None, **kwargs + color_representation=None, color_map=None, **kwargs ): """ Visualize 2D data using plt.scatter. @@ -65,6 +65,8 @@ def visualize_2D_scatter( Title for the plot. color_representation : str, optional Description of what the colors represent. + color_map : dictionary, optional + Dictionary containing colors for label annotations. **kwargs Additional keyword arguments passed to plt.scatter. @@ -83,6 +85,8 @@ def visualize_2D_scatter( raise ValueError("x and y must have the same length.") if labels is not None and len(labels) != len(x): raise ValueError("Labels length should match x and y length.") + if color_map is not None and not isinstance(color_map, dict): + raise ValueError("`color_map` must be a dict mapping label→color.") # Define color themes themes = { @@ -141,15 +145,17 @@ def visualize_2D_scatter( cmap2 = plt.get_cmap('tab20b') cmap3 = plt.get_cmap('tab20c') colors = cmap1.colors + cmap2.colors + cmap3.colors - - # Use the number of unique clusters to set the colormap length - cmap = ListedColormap(colors[:len(unique_clusters)]) + cluster_to_color = color_map if color_map is not None else { + str(cluster): colors[i % len(colors)] + for i, cluster in enumerate(unique_clusters) + } for idx, cluster in enumerate(unique_clusters): mask = np.array(labels) == cluster + color = cluster_to_color.get(str(cluster), 'gray') ax.scatter( x[mask], y[mask], - color=cmap(idx), + color=color, label=cluster, s=point_size ) @@ -196,6 +202,240 @@ def visualize_2D_scatter( return fig, ax +def embedded_scatter_plot( + adata, + method=None, + annotation=None, + feature=None, + layer=None, + ax=None, + associated_table=None, + spot_size=20, + alpha=0.5, + vmin=-999, + vmax=-999, + color_map=None, + **kwargs): + """ + Visualize scatter plot in PCA, t-SNE, UMAP, spatial or associated table. + + Parameters + ---------- + adata : anndata.AnnData + The AnnData object with coordinates precomputed by the 'tsne' or 'UMAP' + function and stored in 'adata.obsm["X_tsne"]' or 'adata.obsm["X_umap"]' + method : str, optional (default: None) + Visualization method specifying the coordinate system to plot. + Choose from {'tsne', 'umap', 'pca', 'spatial'}. + annotation : str, optional + The name of the column in `adata.obs` to use for coloring + the scatter plot points based on cell annotations. + feature : str, optional + The name of the gene or feature in `adata.var_names` to use + for coloring the scatter plot points based on feature expression. + layer : str, optional + The name of the data layer in `adata.layers` to use for visualization. + If None, the main data matrix `adata.X` is used. + ax : matplotlib.axes.Axes, optional (default: None) + A matplotlib axes object to plot on. + If not provided, a new figure and axes will be created. + associated_table : str, optional (default: None) + Name of the key in `obsm` that contains the numpy array. Takes + precedence over `method` + color_map : str, optional (default: None) + Name of the key in adata.uns that contains color-mapping for + the plot + **kwargs + Parameters passed to visualize_2D_scatter function, + including point_size. + + Returns + ------- + fig : matplotlib.figure.Figure + The created figure for the plot. + ax : matplotlib.axes.Axes + The axes of the plot. + """ + + # Check if both annotation and feature are specified, raise error if so + if annotation and feature: + raise ValueError( + "Please specify either an annotation or a feature for coloring, " + "not both.") + + # Use utility functions for input validation + if layer: + check_table(adata, tables=layer) + if annotation: + check_annotation(adata, annotations=annotation) + if feature: + check_feature(adata, features=[feature]) + color_mapping = None + if color_map is not None: + color_mapping = get_defined_color_map( + adata, + defined_color_map=color_map, + annotations=annotation + ) + + # Validate the method and check if the necessary data exists in adata.obsm + if associated_table is None: + valid_methods = ['tsne', 'umap', 'pca', 'spatial'] + check_list_in_list(input=method, input_name="method", + input_type="method", + target_list=valid_methods, + need_exist=True + ) + if method == "spatial": + key = "spatial" + else: + key = f'X_{method}' + if key not in adata.obsm.keys(): + raise ValueError( + f"{key} coordinates not found in adata.obsm. " + f"Please run {method.upper()} before calling this function." + ) + + else: + check_table( + adata=adata, + tables=associated_table, + should_exist=True, + associated_table=True + ) + + associated_table_shape = adata.obsm[associated_table].shape + if associated_table_shape[1] != 2: + raise ValueError( + f'The associated table:"{associated_table}" does not have' + f' two dimensions. It shape is:"{associated_table_shape}"' + ) + key = associated_table + + err_msg_feat_annotation_non = "Both annotation and feature are None, " + \ + "please provide single input." + err_msg_spot_size = "The 'spot_size' parameter must be an integer, " + \ + f"got {str(type(spot_size))}" + err_msg_alpha_type = "The 'alpha' parameter must be a float," + \ + f"got {str(type(alpha))}" + err_msg_alpha_value = "The 'alpha' parameter must be between " + \ + f"0 and 1 (inclusive), got {str(alpha)}" + err_msg_vmin = "The 'vmin' parameter must be a float or an int, " + \ + f"got {str(type(vmin))}" + err_msg_vmax = "The 'vmax' parameter must be a float or an int, " + \ + f"got {str(type(vmax))}" + err_msg_ax = "The 'ax' parameter must be an instance " + \ + f"of matplotlib.axes.Axes, got {str(type(ax))}" + + if adata is None: + raise ValueError("The input dataset must not be None.") + + if not isinstance(adata, anndata.AnnData): + err_msg_adata = "The 'adata' parameter must be an " + \ + f"instance of anndata.AnnData, got {str(type(adata))}." + raise ValueError(err_msg_adata) + + if key == "spatial": + if annotation is None and feature is None: + raise ValueError(err_msg_feat_annotation_non) + + if 'spatial' not in adata.obsm_keys(): + err_msg = "Spatial coordinates not found in the 'obsm' attribute." + raise ValueError(err_msg) + +# Extract feature name + if not isinstance(spot_size, int): + raise ValueError(err_msg_spot_size) + + if not isinstance(alpha, float): + raise ValueError(err_msg_alpha_type) + + if not (0 <= alpha <= 1): + raise ValueError(err_msg_alpha_value) + + if vmin != -999 and not ( + isinstance(vmin, float) or isinstance(vmin, int) + ): + raise ValueError(err_msg_vmin) + + if vmax != -999 and not ( + isinstance(vmax, float) or isinstance(vmax, int) + ): + raise ValueError(err_msg_vmax) + + if ax is not None and not isinstance(ax, plt.Axes): + raise ValueError(err_msg_ax) + + print(f'Running visualization using the coordinates: "{key}"') + + # Extract the 2D coordinates + x, y = adata.obsm[key].T + + # Determine coloring scheme + if color_mapping is None: + if annotation: + color_values = adata.obs[annotation].astype('category').values + color_representation = annotation + elif feature: + data_src = adata.layers[layer] if layer else adata.X + color_values = data_src[:, adata.var_names == feature].squeeze() + color_representation = feature + else: + color_values = None + color_representation = None + else: + color_values = adata.obs[annotation].astype('category').values + color_representation = annotation + + # Set axis titles based on method and color representation + if method == 'tsne': + x_axis_title = 't-SNE 1' + y_axis_title = 't-SNE 2' + plot_title = f'TSNE-{color_representation}' + elif method == 'pca': + x_axis_title = 'PCA 1' + y_axis_title = 'PCA 2' + plot_title = f'PCA-{color_representation}' + elif method == 'umap': + x_axis_title = 'UMAP 1' + y_axis_title = 'UMAP 2' + plot_title = f'UMAP-{color_representation}' + elif method == 'spatial': + x_axis_title = 'SPATIAL 1' + y_axis_title = 'SPATIAL 2' + plot_title = f'SPATIAL-{color_representation}' + + else: + x_axis_title = f'{associated_table} 1' + y_axis_title = f'{associated_table} 2' + plot_title = f'{associated_table}-{color_representation}' + + # Remove conflicting keys from kwargs + kwargs.pop('x_axis_title', None) + kwargs.pop('y_axis_title', None) + kwargs.pop('plot_title', None) + kwargs.pop('color_representation', None) + + # Set Min and Max in kwargs + kwargs['vmin'] = vmin + kwargs['vmax'] = vmax + + fig, ax = visualize_2D_scatter( + x=x, + y=y, + ax=ax, + labels=color_values, + x_axis_title=x_axis_title, + y_axis_title=y_axis_title, + plot_title=plot_title, + color_representation=color_representation, + color_map=color_mapping, + **kwargs + ) + + return fig, ax + + def dimensionality_reduction_plot( adata, method=None, diff --git a/tests/test_visualization/test_embedded_scatter_plot.py b/tests/test_visualization/test_embedded_scatter_plot.py new file mode 100644 index 00000000..76f30e37 --- /dev/null +++ b/tests/test_visualization/test_embedded_scatter_plot.py @@ -0,0 +1,583 @@ +import unittest +from unittest.mock import patch +import anndata +import numpy as np +import matplotlib +import matplotlib.pyplot as plt +import itertools +from spac.visualization import embedded_scatter_plot +from matplotlib.collections import PathCollection +matplotlib.use('Agg') + + +class TestStaticScatterPlot(unittest.TestCase): + + def setUp(self): + self.adata = anndata.AnnData(X=np.random.rand(10, 10)) + self.adata.obsm['X_tsne'] = np.random.rand(10, 2) + self.adata.obsm['X_umap'] = np.random.rand(10, 2) + self.adata.obsm['X_pca'] = np.random.rand(10, 2) + self.adata.obsm['sumap'] = np.random.rand(10, 2) + self.adata.obsm['3dsumap'] = np.random.rand(10, 3) + self.adata.obs['annotation_column'] = np.random.choice( + ['A', 'B', 'C'], size=10 + ) + self.adata.var_names = ['gene_' + str(i) for i in range(10)] + + def test_missing_umap_coordinates(self): + del self.adata.obsm['X_umap'] + with self.assertRaises(ValueError) as cm: + embedded_scatter_plot(self.adata, 'umap') + expected_msg = ( + "X_umap coordinates not found in adata.obsm. " + "Please run UMAP before calling this function." + ) + self.assertEqual(str(cm.exception), expected_msg) + + def test_missing_tsne_coordinates(self): + del self.adata.obsm['X_tsne'] + with self.assertRaises(ValueError) as cm: + embedded_scatter_plot(self.adata, 'tsne') + expected_msg = ( + "X_tsne coordinates not found in adata.obsm. " + "Please run TSNE before calling this function." + ) + self.assertEqual(str(cm.exception), expected_msg) + + def test_annotation_and_feature(self): + with self.assertRaises(ValueError) as cm: + embedded_scatter_plot( + self.adata, 'tsne', + annotation='annotation_column', + feature='feature_column' + ) + expected_msg = ( + "Please specify either an annotation or a feature for coloring, " + "not both." + ) + self.assertEqual(str(cm.exception), expected_msg) + + def test_annotation_column(self): + fig, ax = embedded_scatter_plot( + self.adata, 'tsne', annotation='annotation_column' + ) + self.assertIsNotNone(fig) + self.assertIsNotNone(ax) + self.assertEqual(ax.get_xlabel(), 't-SNE 1') + self.assertEqual(ax.get_ylabel(), 't-SNE 2') + self.assertEqual(ax.get_title(), 'TSNE-annotation_column') + + def test_associated_table(self): + fig, ax = embedded_scatter_plot( + self.adata, + annotation='annotation_column', + associated_table='sumap' + ) + self.assertIsNotNone(fig) + self.assertIsNotNone(ax) + self.assertEqual(ax.get_xlabel(), 'sumap 1') + self.assertEqual(ax.get_ylabel(), 'sumap 2') + self.assertEqual(ax.get_title(), 'sumap-annotation_column') + + def test_feature_column(self): + fig, ax = embedded_scatter_plot( + self.adata, 'tsne', feature='gene_1' + ) + self.assertIsNotNone(fig) + self.assertIsNotNone(ax) + self.assertEqual(ax.get_xlabel(), 't-SNE 1') + self.assertEqual(ax.get_ylabel(), 't-SNE 2') + self.assertEqual(ax.get_title(), 'TSNE-gene_1') + + def test_ax_provided(self): + fig, ax_provided = plt.subplots() + fig_returned, ax_returned = embedded_scatter_plot( + self.adata, 'tsne', ax=ax_provided + ) + self.assertIs(fig, fig_returned) + self.assertIs(ax_provided, ax_returned) + + def test_real_tsne_plot(self): + fig, ax = embedded_scatter_plot( + self.adata, 'tsne', annotation='annotation_column' + ) + self.assertIsInstance(fig, plt.Figure) + self.assertIsInstance(ax, plt.Axes) + self.assertEqual(ax.get_xlabel(), 't-SNE 1') + self.assertEqual(ax.get_ylabel(), 't-SNE 2') + self.assertEqual(ax.get_title(), 'TSNE-annotation_column') + + def test_real_umap_plot(self): + fig, ax = embedded_scatter_plot( + self.adata, 'umap', feature='gene_1' + ) + self.assertIsInstance(fig, plt.Figure) + self.assertIsInstance(ax, plt.Axes) + self.assertEqual(ax.get_xlabel(), 'UMAP 1') + self.assertEqual(ax.get_ylabel(), 'UMAP 2') + self.assertEqual(ax.get_title(), 'UMAP-gene_1') + + def test_real_pca_plot(self): + fig, ax = embedded_scatter_plot( + self.adata, 'pca', annotation='annotation_column' + ) + self.assertIsInstance(fig, plt.Figure) + self.assertIsInstance(ax, plt.Axes) + self.assertEqual(ax.get_xlabel(), 'PCA 1') + self.assertEqual(ax.get_ylabel(), 'PCA 2') + self.assertEqual(ax.get_title(), 'PCA-annotation_column') + + def test_invalid_method(self): + with self.assertRaises(ValueError) as cm: + embedded_scatter_plot(self.adata, 'invalid_method') + expected_msg = ( + "The method 'invalid_method' does not exist" + " in the provided dataset.\n" + "Existing methods are:\n" + "tsne\n" + "umap\n" + "pca\n" + "spatial" + ) + self.assertEqual(str(cm.exception), expected_msg) + + def test_input_derived_feature_3d(self): + with self.assertRaises(ValueError) as cm: + embedded_scatter_plot( + self.adata, + associated_table='3dsumap') + expected_msg = ('The associated table:"3dsumap" does not have' + ' two dimensions. It shape is:"(10, 3)"') + + self.assertEqual(str(cm.exception), expected_msg) + + def test_conflicting_kwargs(self): + # This test ensures conflicting keys are removed from kwargs + fig, ax = embedded_scatter_plot( + self.adata, + 'tsne', + annotation='annotation_column', + x_axis_title='Conflict X', + y_axis_title='Conflict Y', + plot_title='Conflict Title', + color_representation='Conflict Color' + ) + self.assertIsNotNone(fig) + self.assertIsNotNone(ax) + self.assertEqual(ax.get_xlabel(), 't-SNE 1') + self.assertEqual(ax.get_ylabel(), 't-SNE 2') + self.assertEqual(ax.get_title(), 'TSNE-annotation_column') + + +class SpatialPlotTestCase(unittest.TestCase): + def setUp(self): + # Set up test data + num_rows = 100 + num_cols = 100 + + # Generate data matrix + features = np.random.randint(0, 100, size=(num_rows, num_cols)) + feature_names = [f'Intensity_{i}' for i in range(num_cols)] + + # Generate annotation metadata + annotation_data = { + 'annotation1': np.random.choice(['A', 'B', 'C'], size=num_rows), + 'annotation2': np.random.normal(0, 1, size=num_rows), + 'annotation3': np.random.uniform(0, 1, size=num_rows) + } + + # Create the AnnData object + self.adata = anndata.AnnData( + X=features, + obs=annotation_data + ) + + numpy_array = np.random.uniform(0, 100, size=(num_rows, 2)) + self.adata.obsm["spatial"] = numpy_array + self.adata.var_names = feature_names + + # Generate layer data + layer_data = np.random.randint(0, 100, size=(num_rows, num_cols)) + self.adata.layers['Normalized'] = layer_data + self.adata.layers['Standardized'] = layer_data + + self.spot_size = 10 + self.alpha = 0.5 + + def test_invalid_adata(self): + # Test when adata is not an instance of anndata.AnnData + with self.assertRaises(ValueError): + embedded_scatter_plot(adata=None, + method='spatial', + spot_size=self.spot_size, + alpha=self.alpha + ) + + def test_invalid_layer(self): + # Test when layer is not a string + with self.assertRaises(ValueError): + embedded_scatter_plot(adata=self.adata, + method='spatial', + layer=123, + spot_size=self.spot_size, + alpha=self.alpha + ) + + def test_invalid_feature(self): + # Test when feature is not a string + with self.assertRaises(ValueError): + embedded_scatter_plot(adata=self.adata, + method='spatial', + feature=123, + spot_size=self.spot_size, + alpha=self.alpha + ) + + def test_invalid_annotation(self): + # Test when annotation is not a string + with self.assertRaises(ValueError): + embedded_scatter_plot( + adata=self.adata, + method='spatial', + annotation=123, + spot_size=self.spot_size, + alpha=self.alpha + ) + + def test_invalid_spot_size(self): + # Test when spot_size is not an integer + with self.assertRaises(ValueError): + embedded_scatter_plot(adata=self.adata, method='spatial', + spot_size=10.5, alpha=self.alpha) + + def test_invalid_alpha(self): + # Test when alpha is not a float + with self.assertRaises(ValueError): + embedded_scatter_plot(adata=self.adata, method='spatial', + spot_size=self.spot_size, alpha="0.5") + + def test_invalid_alpha_range(self): + # Test when alpha is outside the range of 0 to 1 + with self.assertRaises(ValueError): + embedded_scatter_plot(adata=self.adata, method='spatial', + spot_size=self.spot_size, alpha=-0.5) + + def test_invalid_theme(self): + # Should raise ValueError for invalid theme + with self.assertRaises(ValueError): + embedded_scatter_plot(self.adata, 'umap', annotation='cat_anno', + theme='not_a_theme') + + def test_missing_annotation(self): + # Test when annotation is None and feature is None + with self.assertRaises(ValueError) as cm: + embedded_scatter_plot(adata=self.adata, method='spatial', + spot_size=self.spot_size, alpha=self.alpha) + error_msg = str(cm.exception) + err_msg_exp = "Both annotation and feature are None, " + \ + "please provide single input." + self.assertEqual(error_msg, err_msg_exp) + + def test_invalid_annotation_name(self): + # Test when annotation name is not found in the dataset + with self.assertRaises(ValueError) as cm: + embedded_scatter_plot( + adata=self.adata, + method='spatial', + annotation='annotation4', + spot_size=self.spot_size, + alpha=self.alpha + ) + error_msg = str(cm.exception) + err_msg_exp = ("The annotation 'annotation4' does not exist " + "in the provided dataset.\n" + "Existing annotations are:\n" + "annotation1\nannotation2\nannotation3" + ) + self.assertEqual(error_msg, err_msg_exp) + + def test_invalid_feature_name(self): + # Test when feature name is not found in the layer + with self.assertRaises(ValueError) as cm: + embedded_scatter_plot( + adata=self.adata, + method='spatial', + feature='feature1', + spot_size=self.spot_size, + alpha=self.alpha + ) + error_msg = str(cm.exception) + target_features = "\n".join(self.adata.var_names) + err_msg_exp = ("The feature 'feature1' does not exist " + "in the provided dataset.\n" + f"Existing features are:\n{target_features}" + ) + self.assertEqual(error_msg, err_msg_exp) + + def test_spatial_plot_annotation(self): + # This test verifies that the spatial_plot function + # correctly interacts with the spatial function, + # and returns the expected Axes object when + # the annotation parameter is used. + + # Mock the spatial function + # This mock the spatial function to replace the + # original implementation.The mock function, mock_spatial, + # checks if the inputs match the expected values. + # The purpose of mocking the spatial function with the + # mock_spatial function is to simulate the behavior of the + # original spatial function during testing and verify if the + # inputs passed to it match the expected values. + def mock_spatial( + adata, + color, + spot_size, + alpha, + ax, + show, + layer, + vmin, + vmax, + **kwargs): + # Assert that the inputs match the expected values + assert layer is None + self.assertEqual(color, 'annotation1') + self.assertEqual(spot_size, self.spot_size) + self.assertEqual(alpha, self.alpha) + # self.assertEqual(vmin, None) + assert vmax is None + assert vmin is None + # self.assertEqual(vmax, None) + self.assertIsInstance(ax, plt.Axes) + self.assertFalse(show) + # Return a list containing the ax object to mimic + # the behavior of the spatial function + return [ax] + + # Mock the spatial function with the mock_spatial function + # spatial_plot.__globals__['sc.pl.spatial'] = mock_spatial + + with patch('scanpy.pl.spatial', new=mock_spatial): + # Create an instance of Axes + ax = plt.Axes( + plt.figure(), + rect=[0, 0, 1, 1] + ) + + # Call the spatial_plot function with the ax object + fig, returned_ax = embedded_scatter_plot( + adata=self.adata, + method='spatial', + annotation='annotation1', + layer=None, + ax=ax, + spot_size=self.spot_size, + alpha=self.alpha + ) + + # Assert that the spatial_plot function returned a list + # containing an Axes object with the same properties + self.assertEqual(returned_ax.get_title(), ax.get_title()) + self.assertEqual(returned_ax.get_xlabel(), ax.get_xlabel()) + self.assertEqual(returned_ax.get_ylabel(), ax.get_ylabel()) + + # Restore the original spatial function + # del spatial_plot.__globals__['sc.pl.spatial'] + + def test_spatial_plot_feature(self): + # Mock the spatial function + def mock_spatial( + adata, + layer, + color, + spot_size, + alpha, + vmin, + vmax, + ax, + show, + **kwargs): + # Assert that the inputs match the expected values + assert layer is None + self.assertEqual(color, 'Intensity_10spatial_plot') + self.assertEqual(spot_size, self.spot_size) + self.assertEqual(alpha, self.alpha) + self.assertEqual(vmin, 0) + self.assertEqual(vmax, 100) + self.assertIsInstance(ax, plt.Axes) + self.assertFalse(show) + # Return a list containing the ax object to mimic + # the behavior of the spatial function + return [ax] + + # Mock the spatial function with the mock_spatial function + # spatial_plot.__globals__['sc.pl.spatial'] = mock_spatial + with patch('scanpy.pl.spatial', new=mock_spatial): + + # Create an instance of Axes + ax = plt.Axes( + plt.figure(), + rect=[0, 0, 1, 1] + ) + + # Call the spatial_plot function with the ax object + fig, returned_ax = embedded_scatter_plot( + adata=self.adata, + method='spatial', + feature='Intensity_10', + ax=ax, + spot_size=self.spot_size, + alpha=self.alpha, + vmin=0, + vmax=100 + ) + + # Assert that the spatial_plot function returned a list + # containing an Axes object with the same properties + self.assertEqual(returned_ax.get_title(), ax.get_title()) + self.assertEqual(returned_ax.get_xlabel(), ax.get_xlabel()) + self.assertEqual(returned_ax.get_ylabel(), ax.get_ylabel()) + + # Restore the original spatial function + # del spatial_plot.__globals__['sc.pl.spatial'] + + def test_spatial_plot_combos_feature(self): + # Define the parameter combinations to test + spot_sizes = [10, 20] + alphas = [0.5, 0.8] + vmins = [-999, 0, 5] + vmaxs = [-999, 10, 20] + features = ['Intensity_20', 'Intensity_80'] + layers = [None, 'Standardized'] + + # Generate all combinations of parameters + # excluding both None values for features and annotations + + parameter_combinations = list(itertools.product( + spot_sizes, alphas, vmins, vmaxs, features, layers + )) + + parameter_combinations = [ + params for params in parameter_combinations if not ( + params[4] is None and params[5] is None + ) + ] + + for params in parameter_combinations: + spot_size, alpha, vmin, vmax, feature, layer = params + print(layer) + # Test the spatial_plot function with the + # given parameter combination + + returned_fig, ax = embedded_scatter_plot( + adata=self.adata, + method='spatial', + feature=feature, + layer=layer, + spot_size=spot_size, + alpha=alpha, + vmin=vmin, + vmax=vmax + ) + + # Perform assertions on the spatial plot + # Check if ax has data plotted + self.assertTrue(ax.has_data()) + + def test_spatial_plot_combos_annotation(self): + # Define the parameter combinations to test + spot_sizes = [10, 20] + alphas = [0.5, 0.8] + vmins = [-999, 0, 5] + vmaxs = [-999, 10, 20] + annotation = ['annotation1', 'annotation2'] + layers = [None, 'Normalized'] + + # Generate all combinations of parameters + # excluding both None values for features and annotations + + parameter_combinations = list(itertools.product( + spot_sizes, alphas, vmins, vmaxs, annotation, layers + )) + + parameter_combinations = [ + params for params in parameter_combinations if not ( + params[4] is None and params[5] is None + ) + ] + + for params in parameter_combinations: + spot_size, alpha, vmin, vmax, annotation, layer = params + # Test the spatial_plot function with the + # given parameter combination + + fig = plt.figure() + ax = fig.add_subplot(1, 1, 1) + + returned_fig, ax = embedded_scatter_plot( + adata=self.adata, + method='spatial', + annotation=annotation, + layer=layer, + ax=ax, + spot_size=spot_size, + alpha=alpha, + vmin=vmin, + vmax=vmax + ) + + plt.close(fig) + # Perform assertions on the spatial plot + # Check if ax has data plotted + self.assertTrue(ax.has_data()) + + def test_color_map_from_uns(self): + # Color conversion formula + def hex_to_rgb(hex_color): + hex_color = hex_color.lstrip('#') + return tuple(int(hex_color[i:i+2], 16) / 255.0 for i in (0, 2, 4)) + + # Should use color map from adata.uns + self.adata.uns['anno_colors'] = {'A': '#ff1111', 'B': '#22ff22', + 'C': '#ee3333'} + annotation = 'annotation1' + returned_fig, ax = embedded_scatter_plot( + adata=self.adata, + method='spatial', + annotation=annotation, + color_map='anno_colors' + ) + expected_rgb = {k: hex_to_rgb(v) for k, v in + self.adata.uns['anno_colors'].items()} + found_labels = set() + for collection in ax.collections: + if isinstance(collection, PathCollection): + label = collection.get_label() + color = collection.get_facecolor()[0][:3] + expected = expected_rgb.get(label) + self.assertIsNotNone(expected, f"Unexpected label" + f"'{label}' in plot.") + np.testing.assert_allclose(color, expected, atol=1e-2, + err_msg=f"Color mismatch" + f"for label '{label}':" + f" got {color}," + f" expected {expected}") + found_labels.add(label) + + # Ensure all expected labels were found in the plot + self.assertEqual(found_labels, set(expected_rgb.keys())) + + def test_color_map_input(self): + # Should raise failure as it accepts colormap as str not dict + anno_colors = {'A': '#111111', 'B': '#222222', + 'C': '#333333'} + with self.assertRaises(TypeError): + embedded_scatter_plot( + adata=self.adata, + method='spatial', + annotation='annotation1', + color_map=anno_colors) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_visualization/test_visualize_2D_scatter.py b/tests/test_visualization/test_visualize_2D_scatter.py index f79b792c..f97c6af2 100644 --- a/tests/test_visualization/test_visualize_2D_scatter.py +++ b/tests/test_visualization/test_visualize_2D_scatter.py @@ -129,6 +129,28 @@ def test_color_representation(self): legend = axis.get_legend() self.assertIn(color_representation, legend.get_title().get_text()) + def test_color_map(self): + color_map = {'A': '#123456', 'B': '#abcdef', 'C': '#654321'} + fig, ax = visualize_2D_scatter( + self.x, self.y, labels=self.labels_categorical, + color_map=color_map) + # Get all PathCollections (one per category) + collections = [c for c in ax.collections if + isinstance(c, matplotlib.collections.PathCollection)] + import matplotlib.colors as mcolors + # Map collections to their labels + collection_map = {c.get_label(): c for c in collections} + for cluster in ['A', 'B', 'C']: + expected_rgba = mcolors.to_rgba(color_map[cluster]) + facecolors = collection_map[cluster].get_facecolors() + # All points in this collection should have the same color + for fc in facecolors: + self.assertTrue( + all(abs(a - b) < 1e-3 for a, b in zip(fc, expected_rgba)), + f"Color mismatch for {cluster}: expected {expected_rgba}," + f"got {fc}" + ) + if __name__ == '__main__': unittest.main()