Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 120 additions & 1 deletion src/spac/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
from spac.data_utils import select_values
import logging
import warnings
import datashader as ds
import datashader.transfer_functions as tf

import base64
import time
import json
Expand All @@ -32,11 +35,126 @@
from functools import partial
from collections import OrderedDict


# Configure logging
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s')

def heatmap_datashader(x, y, labels=None, theme=None,
x_axis_title="Component 1", y_axis_title="Component 2",
plot_title=None, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LizaShch, can we use kwargs to pass optional arguments to ds.Canvas to control the figure? This way, you don't need to explicty control only for plot width and plot height, but for all other input argument that function would tak.

"""
Generates a heatmap using Datashader for large-scale scatter data.

Parameters
----------
x : iterable
X-axis coordinates.
y : iterable
Y-axis coordinates.
labels : iterable, optional
Categorical labels for subgrouping data.
theme : str, optional, default='viridis'
Colormap theme for visualization.
x_axis_title : str, optional, default='Component 1'
Label for the x-axis.
y_axis_title : str, optional, default='Component 2'
Label for the y-axis.
plot_title : str, optional
Title of the plot.
**kwargs : dict, optional
Additional keyword arguments (e.g., 'fig_width', 'fig_height').

Returns
-------
matplotlib.figure.Figure
A Matplotlib figure containing the heatmap visualization.
"""

# Ensure x and y are iterable
if not hasattr(x, "__iter__") or not hasattr(y, "__iter__"):
raise ValueError("x and y must be array-like.")
if len(x) != len(y):
raise ValueError("x and y must have the same length.")
if labels is not None and len(labels) != len(x):
raise ValueError("Labels length should match x and y length.")

# Define available color themes
themes = {
'Elevation': ds.colors.Elevation,
'viridis': ds.colors.viridis,
'Hot': ds.colors.Hot,
'Set1': ds.colors.Set1,
'Set2': ds.colors.Set2,
'Set3': ds.colors.Set3,
'Sets1to3': ds.colors.Sets1to3,
'inferno': ds.colors.inferno,
'color_lookup': ds.colors.color_lookup,
}
cmap = themes.get(theme, ds.colors.viridis) # Default to 'viridis' if theme is not specified

# Create a DataFrame for processing
coords = pd.DataFrame({"x": x, "y": y})
if labels is not None:
coords["labels"] = labels

# Determine plot boundaries
x_min, x_max = coords["x"].min(), coords["x"].max()
y_min, y_max = coords["y"].min(), coords["y"].max()

# Supply default ranges if not in kwargs
canvas_kwargs = {
'plot_width': 600,
'plot_height': 400,
'x_range': (x_min, x_max),
'y_range': (y_min, y_max)
}
canvas_kwargs.update(kwargs)

create_canvas = partial(ds.Canvas, **canvas_kwargs)

if labels is not None:
categories = pd.Categorical(coords["labels"]).categories
num_categories = len(categories)

rows = (num_categories // 3) + (1 if num_categories % 3 != 0 else 0)
fig, axes = plt.subplots(rows, 3, figsize=(12, 4 * rows))
axes = axes.flatten()

for i, cat in enumerate(categories):
subset = coords[coords["labels"] == cat]
canvas = create_canvas()
agg = canvas.points(subset, x="x", y="y", agg=ds.count())
img = tf.shade(agg, cmap=cmap).to_pil()

ax = axes[i]
ax.imshow(img, origin='lower', extent=(x_min, x_max, y_min, y_max))
ax.set_title(f"{plot_title} - {cat}" if plot_title else str(cat))
ax.set_xlabel(x_axis_title)
ax.set_ylabel(y_axis_title)

for j in range(i + 1, len(axes)):
fig.delaxes(axes[j])
else:
canvas = create_canvas()
agg = canvas.points(coords, x="x", y="y", agg=ds.count())
img = tf.shade(agg, cmap=cmap).to_pil()

fig, ax = plt.subplots(
figsize=(
canvas_kwargs["plot_width"] / 100,
canvas_kwargs["plot_height"] / 100
)
)
ax.imshow(img, origin='lower', extent=(x_min, x_max, y_min, y_max))
ax.set_title(plot_title if plot_title else "Density Plot")
ax.set_xlabel(x_axis_title)
ax.set_ylabel(y_axis_title)

plt.tight_layout()
return fig




def visualize_2D_scatter(
x, y, labels=None, point_size=None, theme=None,
Expand Down Expand Up @@ -538,6 +656,7 @@ def histogram(adata, feature=None, annotation=None, layer=None,

# Check for negative values and apply log1p transformation if
# x_log_scale is True

if x_log_scale:
if (df[data_column] < 0).any():
print(
Expand Down
82 changes: 82 additions & 0 deletions tests/test_visualization/test_datashader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import unittest
import numpy as np
import pandas as pd
from spac.visualization import heatmap_datashader
import matplotlib

matplotlib.use('Agg') # Set the backend to 'Agg' to suppress plot window


class TestDataShaderHeatMap(unittest.TestCase):
def setUp(self):
"""Prepare data for testing."""
self.x = np.random.rand(10)
self.y = np.random.rand(10)
# Fixed categorical labels to ensure representation of each category
fixed_labels = ['A', 'B', 'C', 'A', 'B', 'C', 'A', 'B', 'C', 'A']
self.labels_categorical = pd.Series(fixed_labels, dtype="category")

def test_invalid_input_type(self):
"""Test handling of invalid input types."""
with self.assertRaises(ValueError) as context_manager:
heatmap_datashader(1, self.y, labels=self.labels_categorical)
self.assertIn("x and y must be array-like",
str(context_manager.exception))

def test_labels_length_mismatch(self):
"""Test handling of mismatched lengths between data and labels."""
wrong_labels = pd.Series(['A'] * 9) # Shorter than x and y
with self.assertRaises(ValueError) as context_manager:
heatmap_datashader(self.x, self.y, labels=wrong_labels)
self.assertIn("Labels length should match x and y length",
str(context_manager.exception))

def test_valid_input_returns_figure_basic(self):
"""Test that valid input returns a matplotlib figure with expected subplots."""
fig = heatmap_datashader(self.x, self.y,
labels=self.labels_categorical)
self.assertIsInstance(fig, matplotlib.figure.Figure)

num_axes = len(fig.axes)
expected_axes = self.labels_categorical.nunique()
self.assertEqual(num_axes, expected_axes)

def test_labels_not_multiple_of_three(self):
"""Test heatmap generation when the number of labels is not a multiple of 3."""
x = np.random.rand(7)
y = np.random.rand(7)
labels = pd.Series(['A', 'B', 'C', 'D', 'E', 'F', 'G'], dtype="category") # 7 labels

fig = heatmap_datashader(x, y, labels=labels)
self.assertIsInstance(fig, matplotlib.figure.Figure)

num_axes = len(fig.axes)
expected_axes = labels.nunique()
self.assertEqual(num_axes, expected_axes)

for ax in fig.axes:
images = [child for child in ax.get_children()
if isinstance(child, matplotlib.image.AxesImage)]
self.assertGreater(len(images), 0,
"Expected at least one image in each subplot.")

def test_valid_input_returns_figure(self):
"""Test that valid input returns a matplotlib figure with expected subplots and images."""
fig = heatmap_datashader(self.x, self.y,
labels=self.labels_categorical)
self.assertIsInstance(fig, matplotlib.figure.Figure)

# Check number of axes matches number of unique labels
num_axes = len(fig.axes)
expected_axes = self.labels_categorical.nunique()
self.assertEqual(num_axes, expected_axes)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LizaShch, any other assertion you can add to make sure the plot got generated correctly besides the number of axes? For example, the content of the plot?


# Check that each axis has an image plotted
for ax in fig.axes:
images = [child for child in ax.get_children()
if isinstance(child, matplotlib.image.AxesImage)]
self.assertGreater(len(images), 0,
"Expected at least one image in each subplot.")

if __name__ == "__main__":
unittest.main()