diff --git a/src/spac/visualization.py b/src/spac/visualization.py index 9003c163..db44e510 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -22,6 +22,9 @@ from spac.data_utils import select_values import logging import warnings +import datashader as ds +import datashader.transfer_functions as tf + import base64 import time import json @@ -32,11 +35,126 @@ from functools import partial from collections import OrderedDict - # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +def heatmap_datashader(x, y, labels=None, theme=None, + x_axis_title="Component 1", y_axis_title="Component 2", + plot_title=None, **kwargs): + """ + Generates a heatmap using Datashader for large-scale scatter data. + + Parameters + ---------- + x : iterable + X-axis coordinates. + y : iterable + Y-axis coordinates. + labels : iterable, optional + Categorical labels for subgrouping data. + theme : str, optional, default='viridis' + Colormap theme for visualization. + x_axis_title : str, optional, default='Component 1' + Label for the x-axis. + y_axis_title : str, optional, default='Component 2' + Label for the y-axis. + plot_title : str, optional + Title of the plot. + **kwargs : dict, optional + Additional keyword arguments (e.g., 'fig_width', 'fig_height'). + + Returns + ------- + matplotlib.figure.Figure + A Matplotlib figure containing the heatmap visualization. + """ + + # Ensure x and y are iterable + if not hasattr(x, "__iter__") or not hasattr(y, "__iter__"): + raise ValueError("x and y must be array-like.") + if len(x) != len(y): + raise ValueError("x and y must have the same length.") + if labels is not None and len(labels) != len(x): + raise ValueError("Labels length should match x and y length.") + + # Define available color themes + themes = { + 'Elevation': ds.colors.Elevation, + 'viridis': ds.colors.viridis, + 'Hot': ds.colors.Hot, + 'Set1': ds.colors.Set1, + 'Set2': ds.colors.Set2, + 'Set3': ds.colors.Set3, + 'Sets1to3': ds.colors.Sets1to3, + 'inferno': ds.colors.inferno, + 'color_lookup': ds.colors.color_lookup, + } + cmap = themes.get(theme, ds.colors.viridis) # Default to 'viridis' if theme is not specified + + # Create a DataFrame for processing + coords = pd.DataFrame({"x": x, "y": y}) + if labels is not None: + coords["labels"] = labels + + # Determine plot boundaries + x_min, x_max = coords["x"].min(), coords["x"].max() + y_min, y_max = coords["y"].min(), coords["y"].max() + + # Supply default ranges if not in kwargs + canvas_kwargs = { + 'plot_width': 600, + 'plot_height': 400, + 'x_range': (x_min, x_max), + 'y_range': (y_min, y_max) + } + canvas_kwargs.update(kwargs) + + create_canvas = partial(ds.Canvas, **canvas_kwargs) + + if labels is not None: + categories = pd.Categorical(coords["labels"]).categories + num_categories = len(categories) + + rows = (num_categories // 3) + (1 if num_categories % 3 != 0 else 0) + fig, axes = plt.subplots(rows, 3, figsize=(12, 4 * rows)) + axes = axes.flatten() + + for i, cat in enumerate(categories): + subset = coords[coords["labels"] == cat] + canvas = create_canvas() + agg = canvas.points(subset, x="x", y="y", agg=ds.count()) + img = tf.shade(agg, cmap=cmap).to_pil() + + ax = axes[i] + ax.imshow(img, origin='lower', extent=(x_min, x_max, y_min, y_max)) + ax.set_title(f"{plot_title} - {cat}" if plot_title else str(cat)) + ax.set_xlabel(x_axis_title) + ax.set_ylabel(y_axis_title) + + for j in range(i + 1, len(axes)): + fig.delaxes(axes[j]) + else: + canvas = create_canvas() + agg = canvas.points(coords, x="x", y="y", agg=ds.count()) + img = tf.shade(agg, cmap=cmap).to_pil() + + fig, ax = plt.subplots( + figsize=( + canvas_kwargs["plot_width"] / 100, + canvas_kwargs["plot_height"] / 100 + ) + ) + ax.imshow(img, origin='lower', extent=(x_min, x_max, y_min, y_max)) + ax.set_title(plot_title if plot_title else "Density Plot") + ax.set_xlabel(x_axis_title) + ax.set_ylabel(y_axis_title) + + plt.tight_layout() + return fig + + + def visualize_2D_scatter( x, y, labels=None, point_size=None, theme=None, @@ -538,6 +656,7 @@ def histogram(adata, feature=None, annotation=None, layer=None, # Check for negative values and apply log1p transformation if # x_log_scale is True + if x_log_scale: if (df[data_column] < 0).any(): print( diff --git a/tests/test_visualization/test_datashader.py b/tests/test_visualization/test_datashader.py new file mode 100644 index 00000000..c6c3a7df --- /dev/null +++ b/tests/test_visualization/test_datashader.py @@ -0,0 +1,82 @@ +import unittest +import numpy as np +import pandas as pd +from spac.visualization import heatmap_datashader +import matplotlib + +matplotlib.use('Agg') # Set the backend to 'Agg' to suppress plot window + + +class TestDataShaderHeatMap(unittest.TestCase): + def setUp(self): + """Prepare data for testing.""" + self.x = np.random.rand(10) + self.y = np.random.rand(10) + # Fixed categorical labels to ensure representation of each category + fixed_labels = ['A', 'B', 'C', 'A', 'B', 'C', 'A', 'B', 'C', 'A'] + self.labels_categorical = pd.Series(fixed_labels, dtype="category") + + def test_invalid_input_type(self): + """Test handling of invalid input types.""" + with self.assertRaises(ValueError) as context_manager: + heatmap_datashader(1, self.y, labels=self.labels_categorical) + self.assertIn("x and y must be array-like", + str(context_manager.exception)) + + def test_labels_length_mismatch(self): + """Test handling of mismatched lengths between data and labels.""" + wrong_labels = pd.Series(['A'] * 9) # Shorter than x and y + with self.assertRaises(ValueError) as context_manager: + heatmap_datashader(self.x, self.y, labels=wrong_labels) + self.assertIn("Labels length should match x and y length", + str(context_manager.exception)) + + def test_valid_input_returns_figure_basic(self): + """Test that valid input returns a matplotlib figure with expected subplots.""" + fig = heatmap_datashader(self.x, self.y, + labels=self.labels_categorical) + self.assertIsInstance(fig, matplotlib.figure.Figure) + + num_axes = len(fig.axes) + expected_axes = self.labels_categorical.nunique() + self.assertEqual(num_axes, expected_axes) + + def test_labels_not_multiple_of_three(self): + """Test heatmap generation when the number of labels is not a multiple of 3.""" + x = np.random.rand(7) + y = np.random.rand(7) + labels = pd.Series(['A', 'B', 'C', 'D', 'E', 'F', 'G'], dtype="category") # 7 labels + + fig = heatmap_datashader(x, y, labels=labels) + self.assertIsInstance(fig, matplotlib.figure.Figure) + + num_axes = len(fig.axes) + expected_axes = labels.nunique() + self.assertEqual(num_axes, expected_axes) + + for ax in fig.axes: + images = [child for child in ax.get_children() + if isinstance(child, matplotlib.image.AxesImage)] + self.assertGreater(len(images), 0, + "Expected at least one image in each subplot.") + + def test_valid_input_returns_figure(self): + """Test that valid input returns a matplotlib figure with expected subplots and images.""" + fig = heatmap_datashader(self.x, self.y, + labels=self.labels_categorical) + self.assertIsInstance(fig, matplotlib.figure.Figure) + + # Check number of axes matches number of unique labels + num_axes = len(fig.axes) + expected_axes = self.labels_categorical.nunique() + self.assertEqual(num_axes, expected_axes) + + # Check that each axis has an image plotted + for ax in fig.axes: + images = [child for child in ax.get_children() + if isinstance(child, matplotlib.image.AxesImage)] + self.assertGreater(len(images), 0, + "Expected at least one image in each subplot.") + +if __name__ == "__main__": + unittest.main()