From 67929acb1b8205e546524eeabe9b86b4c90e3943 Mon Sep 17 00:00:00 2001 From: Liza Shchehlik Date: Tue, 25 Feb 2025 03:11:07 -0500 Subject: [PATCH 1/9] Added datashader 2d scatter function with base usability --- src/spac/visualization.py | 138 +++++++++++++++++++++++++++++++++++++- 1 file changed, 137 insertions(+), 1 deletion(-) diff --git a/src/spac/visualization.py b/src/spac/visualization.py index 1d909baf..a8a797fd 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -23,10 +23,147 @@ import copy +import xarray as xr +import datashader as ds +import datashader.colors as dc +import datashader.transfer_functions as tf +import datashader.utils as du +import matplotlib.cm as cm +import matplotlib.colors as mcolors # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +def datashader_2D_scatter(x, y, point_size=None, ax=None, labels=None, theme=None, + x_axis_title="Component 1", y_axis_title="Component 2", + plot_title=None, **kwargs): + """ + Generates a density scatter plot using Datashader and returns a Matplotlib figure and axes. + + Parameters: + - x: array-like, x coordinates. + - y: array-like, y coordinates. + - point_size: int or None, if set then spreads each point by this many pixels (default: None). + - ax: Matplotlib Axes object or None. If provided, the plot will be drawn on this Axes. + - theme: str or None, color theme for the plot. If provided, must be one of the supported themes. + Supported themes: 'Elevation', 'viridis', 'Hot', 'Set1', 'Set2', 'Set3', 'Sets1to3', + 'inferno', 'color_lookup'. Defaults to 'viridis' if not recognized or not provided. + - x_axis_title: str, label for the x-axis (default: "Component 1"). + - y_axis_title: str, label for the y-axis (default: "Component 2"). + - plot_title: str, title of the plot (default: None). + - labels: array-like, optional. Array of labels for the data points. Can be numerical or categorical. + - **kwargs: additional keyword arguments passed to tf.shade and/or used for figure size: + - fig_width: width of the Datashader canvas in pixels (default: 600) + - fig_height: height of the Datashader canvas in pixels (default: 400) + + Returns: + - fig: Matplotlib Figure object. + - ax: Matplotlib Axes object. + """ + # Input validation + if not hasattr(x, "__iter__") or not hasattr(y, "__iter__"): + raise ValueError("x and y must be array-like.") + if len(x) != len(y): + raise ValueError("x and y must have the same length.") + if labels is not None and len(labels) != len(x): + raise ValueError("Labels length should match x and y length.") + + # Define available themes + themes = { + 'Elevation': ds.colors.Elevation, + 'viridis': ds.colors.viridis, + 'Hot': ds.colors.Hot, + 'Set1': ds.colors.Set1, + 'Set2': ds.colors.Set2, + 'Set3': ds.colors.Set3, + 'Sets1to3': ds.colors.Sets1to3, + 'inferno': ds.colors.inferno, + 'color_lookup': ds.colors.color_lookup, + } + + if theme and theme not in themes: + error_msg = f"Theme '{theme}' not recognized. Please use a valid theme." + raise ValueError(error_msg) + cmap = themes.get(theme, ds.colors.viridis) + + # Create a DataFrame with the coordinates + coords = pd.DataFrame({"x": x, "y": y}) + + # If labels are provided, add them to the DataFrame + if labels is not None: + coords["labels"] = labels + + # Determine the ranges for x and y + x_min, x_max = coords["x"].min(), coords["x"].max() + y_min, y_max = coords["y"].min(), coords["y"].max() + + # Get figure dimensions from kwargs, with defaults + plot_width = kwargs.get('fig_width', 600) + plot_height = kwargs.get('fig_height', 400) + + # Create a Datashader canvas + canvas = ds.Canvas( + plot_width=plot_width, + plot_height=plot_height, + x_range=(x_min, x_max), + y_range=(y_min, y_max) + ) + + # Aggregate points into bins + if labels is not None: + # If labels are not categorical, convert them + if not pd.api.types.is_categorical_dtype(coords["labels"]): + coords["labels"] = pd.Categorical(coords["labels"]) + # Aggregate by category + agg = canvas.points(coords, x="x", y="y", agg=ds.count_cat("labels")) + else: + agg = canvas.points(coords, x="x", y="y", agg=ds.count()) + + # Shading + if labels is not None: + # Determine categories from labels + if pd.api.types.is_categorical_dtype(coords["labels"]): + categories = (coords["labels"].categories + if isinstance(coords["labels"], pd.Categorical) + else coords["labels"].cat.categories) + else: + raise TypeError("Expected labels to be of type Series[Categorical] or Categorical.") + + n = len(categories) + # Build a color_key dictionary. + if isinstance(cmap, (list, tuple)): + color_key = {cat: cmap[i % len(cmap)] for i, cat in enumerate(categories)} + else: + color_palette = ds.colors.colormap_select(cmap, start=0, end=1.0, reverse=False) + color_key = {cat: color_palette[i % len(color_palette)] for i, cat in enumerate(categories)} + + img = tf.shade(agg, color_key=color_key, **kwargs) + else: + img = tf.shade(agg, cmap=cmap, **kwargs) + + + # Customize point size by spreading if provided + if point_size is not None: + img = tf.spread(img, px=point_size) + + # Convert the image to a PIL image + img = img.to_pil() + + # Create a Matplotlib figure and axes (or use the provided ax) + if ax is None: + fig, ax = plt.subplots(figsize=(plot_width / 100, plot_height / 100)) + else: + fig = ax.figure + + ax.imshow(img, origin='lower', extent=(x_min, x_max, y_min, y_max)) + ax.set_xlabel(x_axis_title) + ax.set_ylabel(y_axis_title) + + if plot_title is not None: + ax.set_title(plot_title) + + return fig, ax + def visualize_2D_scatter( x, y, labels=None, point_size=None, theme=None, @@ -651,7 +788,6 @@ def cal_bin_num( else: return fig, axs - def heatmap(adata, column, layer=None, **kwargs): """ Plot the heatmap of the mean feature of cells that belong to a `column`. From 636ae5addf70656206d61601ebe8efefd994df4a Mon Sep 17 00:00:00 2001 From: Liza Shchehlik Date: Thu, 27 Mar 2025 16:15:27 -0400 Subject: [PATCH 2/9] Added docstrings and comments to datashader function --- src/spac/visualization.py | 159 ++++++++++++++++---------------------- 1 file changed, 65 insertions(+), 94 deletions(-) diff --git a/src/spac/visualization.py b/src/spac/visualization.py index a8a797fd..46f3e808 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -21,46 +21,34 @@ import warnings import re import copy - - -import xarray as xr import datashader as ds -import datashader.colors as dc import datashader.transfer_functions as tf -import datashader.utils as du -import matplotlib.cm as cm -import matplotlib.colors as mcolors + # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') -def datashader_2D_scatter(x, y, point_size=None, ax=None, labels=None, theme=None, - x_axis_title="Component 1", y_axis_title="Component 2", - plot_title=None, **kwargs): +def heatmap_datashader(x, y, labels=None, theme=None, + x_axis_title="Component 1", y_axis_title="Component 2", + plot_title=None, **kwargs): """ - Generates a density scatter plot using Datashader and returns a Matplotlib figure and axes. + Generates a heatmap using Datashader for large-scale scatter data. Parameters: - - x: array-like, x coordinates. - - y: array-like, y coordinates. - - point_size: int or None, if set then spreads each point by this many pixels (default: None). - - ax: Matplotlib Axes object or None. If provided, the plot will be drawn on this Axes. - - theme: str or None, color theme for the plot. If provided, must be one of the supported themes. - Supported themes: 'Elevation', 'viridis', 'Hot', 'Set1', 'Set2', 'Set3', 'Sets1to3', - 'inferno', 'color_lookup'. Defaults to 'viridis' if not recognized or not provided. - - x_axis_title: str, label for the x-axis (default: "Component 1"). - - y_axis_title: str, label for the y-axis (default: "Component 2"). - - plot_title: str, title of the plot (default: None). - - labels: array-like, optional. Array of labels for the data points. Can be numerical or categorical. - - **kwargs: additional keyword arguments passed to tf.shade and/or used for figure size: - - fig_width: width of the Datashader canvas in pixels (default: 600) - - fig_height: height of the Datashader canvas in pixels (default: 400) + x (iterable): X-axis coordinates. + y (iterable): Y-axis coordinates. + labels (iterable, optional): Categorical labels for subgrouping data. + theme (str, optional): Colormap theme for visualization. Defaults to 'viridis'. + x_axis_title (str, optional): Label for the x-axis. Default is "Component 1". + y_axis_title (str, optional): Label for the y-axis. Default is "Component 2". + plot_title (str, optional): Title of the plot. + **kwargs: Additional keyword arguments (e.g., 'fig_width', 'fig_height'). Returns: - - fig: Matplotlib Figure object. - - ax: Matplotlib Axes object. + matplotlib.figure.Figure: A Matplotlib figure containing the heatmap visualization. """ - # Input validation + + # Ensure x and y are iterable if not hasattr(x, "__iter__") or not hasattr(y, "__iter__"): raise ValueError("x and y must be array-like.") if len(x) != len(y): @@ -68,7 +56,7 @@ def datashader_2D_scatter(x, y, point_size=None, ax=None, labels=None, theme=Non if labels is not None and len(labels) != len(x): raise ValueError("Labels length should match x and y length.") - # Define available themes + # Define available color themes themes = { 'Elevation': ds.colors.Elevation, 'viridis': ds.colors.viridis, @@ -80,89 +68,72 @@ def datashader_2D_scatter(x, y, point_size=None, ax=None, labels=None, theme=Non 'inferno': ds.colors.inferno, 'color_lookup': ds.colors.color_lookup, } + cmap = themes.get(theme, ds.colors.viridis) # Default to 'viridis' if theme is not specified - if theme and theme not in themes: - error_msg = f"Theme '{theme}' not recognized. Please use a valid theme." - raise ValueError(error_msg) - cmap = themes.get(theme, ds.colors.viridis) - - # Create a DataFrame with the coordinates + # Create a DataFrame for processing coords = pd.DataFrame({"x": x, "y": y}) - - # If labels are provided, add them to the DataFrame if labels is not None: coords["labels"] = labels - # Determine the ranges for x and y + # Determine plot boundaries x_min, x_max = coords["x"].min(), coords["x"].max() y_min, y_max = coords["y"].min(), coords["y"].max() - # Get figure dimensions from kwargs, with defaults + # Set figure dimensions plot_width = kwargs.get('fig_width', 600) plot_height = kwargs.get('fig_height', 400) - # Create a Datashader canvas - canvas = ds.Canvas( - plot_width=plot_width, - plot_height=plot_height, - x_range=(x_min, x_max), - y_range=(y_min, y_max) - ) - - # Aggregate points into bins if labels is not None: - # If labels are not categorical, convert them - if not pd.api.types.is_categorical_dtype(coords["labels"]): - coords["labels"] = pd.Categorical(coords["labels"]) - # Aggregate by category - agg = canvas.points(coords, x="x", y="y", agg=ds.count_cat("labels")) + # Create multiple subplots for each category + categories = pd.Categorical(coords["labels"]).categories + num_categories = len(categories) + + # Arrange subplots in rows of 3 + rows = (num_categories // 3) + (1 if num_categories % 3 != 0 else 0) + fig, axes = plt.subplots(rows, 3, figsize=(12, 4 * rows)) + axes = axes.flatten() + + for i, cat in enumerate(categories): + subset = coords[coords["labels"] == cat] + canvas = ds.Canvas( + plot_width=plot_width, + plot_height=plot_height, + x_range=(x_min, x_max), + y_range=(y_min, y_max) + ) + agg = canvas.points(subset, x="x", y="y", agg=ds.count()) + img = tf.shade(agg, cmap=cmap).to_pil() + + ax = axes[i] + ax.imshow(img, origin='lower', extent=(x_min, x_max, y_min, y_max)) + ax.set_title(f"{plot_title} - {cat}" if plot_title else str(cat)) + ax.set_xlabel(x_axis_title) + ax.set_ylabel(y_axis_title) + + # Remove unused subplot axes if they exist + for j in range(i + 1, len(axes)): + fig.delaxes(axes[j]) else: + # Generate single heatmap if no labels are provided + canvas = ds.Canvas( + plot_width=plot_width, + plot_height=plot_height, + x_range=(x_min, x_max), + y_range=(y_min, y_max) + ) agg = canvas.points(coords, x="x", y="y", agg=ds.count()) - - # Shading - if labels is not None: - # Determine categories from labels - if pd.api.types.is_categorical_dtype(coords["labels"]): - categories = (coords["labels"].categories - if isinstance(coords["labels"], pd.Categorical) - else coords["labels"].cat.categories) - else: - raise TypeError("Expected labels to be of type Series[Categorical] or Categorical.") - - n = len(categories) - # Build a color_key dictionary. - if isinstance(cmap, (list, tuple)): - color_key = {cat: cmap[i % len(cmap)] for i, cat in enumerate(categories)} - else: - color_palette = ds.colors.colormap_select(cmap, start=0, end=1.0, reverse=False) - color_key = {cat: color_palette[i % len(color_palette)] for i, cat in enumerate(categories)} - - img = tf.shade(agg, color_key=color_key, **kwargs) - else: - img = tf.shade(agg, cmap=cmap, **kwargs) + img = tf.shade(agg, cmap=cmap).to_pil() - - # Customize point size by spreading if provided - if point_size is not None: - img = tf.spread(img, px=point_size) - - # Convert the image to a PIL image - img = img.to_pil() - - # Create a Matplotlib figure and axes (or use the provided ax) - if ax is None: fig, ax = plt.subplots(figsize=(plot_width / 100, plot_height / 100)) - else: - fig = ax.figure + ax.imshow(img, origin='lower', extent=(x_min, x_max, y_min, y_max)) + ax.set_title(plot_title if plot_title else "Density Plot") + ax.set_xlabel(x_axis_title) + ax.set_ylabel(y_axis_title) + + plt.tight_layout() + return fig - ax.imshow(img, origin='lower', extent=(x_min, x_max, y_min, y_max)) - ax.set_xlabel(x_axis_title) - ax.set_ylabel(y_axis_title) - if plot_title is not None: - ax.set_title(plot_title) - - return fig, ax def visualize_2D_scatter( From 439df299427ce67e72001499dab877be197557e8 Mon Sep 17 00:00:00 2001 From: Liza Shchehlik Date: Tue, 1 Apr 2025 13:05:58 -0400 Subject: [PATCH 3/9] Test Cases: Added test cases for datashaders --- src/spac/visualization.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/spac/visualization.py b/src/spac/visualization.py index 46f3e808..16af8018 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -47,7 +47,7 @@ def heatmap_datashader(x, y, labels=None, theme=None, Returns: matplotlib.figure.Figure: A Matplotlib figure containing the heatmap visualization. """ - + # Ensure x and y are iterable if not hasattr(x, "__iter__") or not hasattr(y, "__iter__"): raise ValueError("x and y must be array-like.") @@ -91,7 +91,7 @@ def heatmap_datashader(x, y, labels=None, theme=None, # Arrange subplots in rows of 3 rows = (num_categories // 3) + (1 if num_categories % 3 != 0 else 0) fig, axes = plt.subplots(rows, 3, figsize=(12, 4 * rows)) - axes = axes.flatten() + axes = axe`s.flatten() for i, cat in enumerate(categories): subset = coords[coords["labels"] == cat] @@ -651,6 +651,7 @@ def histogram(adata, feature=None, annotation=None, layer=None, # Bin calculation section # The default bin calculation used by sns.histo take quite # some time to compute for large number of points, + # DMAP implemented the Rice rule for bin computation def cal_bin_num( From 128b60fa8021cbf71fee3feeefe69bf71e41adc8 Mon Sep 17 00:00:00 2001 From: Liza Shchehlik Date: Thu, 3 Apr 2025 15:17:01 -0400 Subject: [PATCH 4/9] Fixed syntax of unit-tests --- tests/test_visualization/test_datashader.py | 37 +++++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 tests/test_visualization/test_datashader.py diff --git a/tests/test_visualization/test_datashader.py b/tests/test_visualization/test_datashader.py new file mode 100644 index 00000000..6c8c95d6 --- /dev/null +++ b/tests/test_visualization/test_datashader.py @@ -0,0 +1,37 @@ +import unittest +import numpy as np +import pandas as pd +from spac.visualization import heatmap_datashader +import matplotlib + +matplotlib.use('Agg') # Set the backend to 'Agg' to suppress plot window + + +class TestDataShaderHeatMap(unittest.TestCase): + def setUp(self): + """Prepare data for testing.""" + self.x = np.random.rand(10) + self.y = np.random.rand(10) + # Fixed categorical labels to ensure representation of each category + fixed_labels = ['A', 'B', 'C', 'A', 'B', 'C', 'A', 'B', 'C', 'A'] + self.labels_categorical = pd.Series(fixed_labels, dtype="category") + self.labels_continuous = pd.Series(np.random.rand(10)) + + def test_invalid_input_type(self): + """Test handling of invalid input types.""" + with self.assertRaises(ValueError) as context_manager: + heatmap_datashader(1, self.y) + self.assertEqual(str(context_manager.exception), + "x and y must be array-like.") + + def test_labels_length_mismatch(self): + """Test handling of mismatched lengths between data and labels.""" + wrong_labels = pd.Series(['A'] * 9) # Shorter than x and y + with self.assertRaises(ValueError) as context_manager: + heatmap_datashader(self.x, self.y, labels=wrong_labels) + expected_msg = "Labels length should match x and y length." + self.assertEqual(str(context_manager.exception), expected_msg) + + +if __name__ == "__main__": + unittest.main() From 151b52243afd65b8ba7d7945948036afcf3c081e Mon Sep 17 00:00:00 2001 From: Liza Shchehlik Date: Thu, 3 Apr 2025 15:38:11 -0400 Subject: [PATCH 5/9] Minor Error in Code Syntax --- src/spac/visualization.py | 2 +- tests/test_visualization/test_datashader.py | 11 ++++------- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/src/spac/visualization.py b/src/spac/visualization.py index 549847b1..66465517 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -91,7 +91,7 @@ def heatmap_datashader(x, y, labels=None, theme=None, # Arrange subplots in rows of 3 rows = (num_categories // 3) + (1 if num_categories % 3 != 0 else 0) fig, axes = plt.subplots(rows, 3, figsize=(12, 4 * rows)) - axes = axe`s.flatten() + axes = axes.flatten() for i, cat in enumerate(categories): subset = coords[coords["labels"] == cat] diff --git a/tests/test_visualization/test_datashader.py b/tests/test_visualization/test_datashader.py index 6c8c95d6..66c3b14d 100644 --- a/tests/test_visualization/test_datashader.py +++ b/tests/test_visualization/test_datashader.py @@ -6,7 +6,6 @@ matplotlib.use('Agg') # Set the backend to 'Agg' to suppress plot window - class TestDataShaderHeatMap(unittest.TestCase): def setUp(self): """Prepare data for testing.""" @@ -20,18 +19,16 @@ def setUp(self): def test_invalid_input_type(self): """Test handling of invalid input types.""" with self.assertRaises(ValueError) as context_manager: - heatmap_datashader(1, self.y) - self.assertEqual(str(context_manager.exception), - "x and y must be array-like.") + heatmap_datashader(1, self.y, labels=self.labels_categorical) + self.assertIn("x and y must be array-like", str(context_manager.exception)) def test_labels_length_mismatch(self): """Test handling of mismatched lengths between data and labels.""" wrong_labels = pd.Series(['A'] * 9) # Shorter than x and y with self.assertRaises(ValueError) as context_manager: heatmap_datashader(self.x, self.y, labels=wrong_labels) - expected_msg = "Labels length should match x and y length." - self.assertEqual(str(context_manager.exception), expected_msg) + self.assertIn("Labels length should match x and y length", str(context_manager.exception)) if __name__ == "__main__": - unittest.main() + unittest.main() \ No newline at end of file From 61b927c2337c0c33588921d99319ae82cff632bc Mon Sep 17 00:00:00 2001 From: LizaShch Date: Sun, 13 Apr 2025 23:21:57 +0000 Subject: [PATCH 6/9] Minor changes: refined docstrings and added unit test --- src/spac/visualization.py | 37 +++++++++++++-------- tests/test_visualization/test_datashader.py | 10 ++++++ 2 files changed, 34 insertions(+), 13 deletions(-) diff --git a/src/spac/visualization.py b/src/spac/visualization.py index 6b18a885..29f8638b 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -39,19 +39,30 @@ def heatmap_datashader(x, y, labels=None, theme=None, plot_title=None, **kwargs): """ Generates a heatmap using Datashader for large-scale scatter data. - - Parameters: - x (iterable): X-axis coordinates. - y (iterable): Y-axis coordinates. - labels (iterable, optional): Categorical labels for subgrouping data. - theme (str, optional): Colormap theme for visualization. Defaults to 'viridis'. - x_axis_title (str, optional): Label for the x-axis. Default is "Component 1". - y_axis_title (str, optional): Label for the y-axis. Default is "Component 2". - plot_title (str, optional): Title of the plot. - **kwargs: Additional keyword arguments (e.g., 'fig_width', 'fig_height'). - - Returns: - matplotlib.figure.Figure: A Matplotlib figure containing the heatmap visualization. + + Parameters + ---------- + x : iterable + X-axis coordinates. + y : iterable + Y-axis coordinates. + labels : iterable, optional + Categorical labels for subgrouping data. + theme : str, optional, default='viridis' + Colormap theme for visualization. + x_axis_title : str, optional, default='Component 1' + Label for the x-axis. + y_axis_title : str, optional, default='Component 2' + Label for the y-axis. + plot_title : str, optional + Title of the plot. + **kwargs : dict, optional + Additional keyword arguments (e.g., 'fig_width', 'fig_height'). + + Returns + ------- + matplotlib.figure.Figure + A Matplotlib figure containing the heatmap visualization. """ # Ensure x and y are iterable diff --git a/tests/test_visualization/test_datashader.py b/tests/test_visualization/test_datashader.py index 66c3b14d..bc09a6d4 100644 --- a/tests/test_visualization/test_datashader.py +++ b/tests/test_visualization/test_datashader.py @@ -29,6 +29,16 @@ def test_labels_length_mismatch(self): heatmap_datashader(self.x, self.y, labels=wrong_labels) self.assertIn("Labels length should match x and y length", str(context_manager.exception)) + def test_valid_input_returns_figure(self): + """Test that valid input returns a matplotlib figure with expected subplots.""" + fig = heatmap_datashader(self.x, self.y, labels=self.labels_categorical) + self.assertIsInstance(fig, matplotlib.figure.Figure) + + # There should be as many subplots as unique labels (3 in this case) + num_axes = len(fig.axes) + expected_axes = self.labels_categorical.nunique() + self.assertEqual(num_axes, expected_axes) + if __name__ == "__main__": unittest.main() \ No newline at end of file From f75e5e433a62f32928a70ad4040161cc9807dd12 Mon Sep 17 00:00:00 2001 From: LizaShch Date: Thu, 17 Apr 2025 20:38:04 +0000 Subject: [PATCH 7/9] style(lines): fixed accidental deletion of lines as well as removed the import of unneccesary libraries --- src/spac/visualization.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/spac/visualization.py b/src/spac/visualization.py index 29f8638b..ddadac3b 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -21,8 +21,6 @@ from spac.data_utils import select_values import logging import warnings -import re -import copy import datashader as ds import datashader.transfer_functions as tf @@ -653,6 +651,7 @@ def histogram(adata, feature=None, annotation=None, layer=None, # Check for negative values and apply log1p transformation if # x_log_scale is True + if x_log_scale: if (df[data_column] < 0).any(): print( @@ -675,7 +674,6 @@ def histogram(adata, feature=None, annotation=None, layer=None, # Bin calculation section # The default bin calculation used by sns.histo take quite # some time to compute for large number of points, - # DMAP implemented the Rice rule for bin computation def cal_bin_num( @@ -774,9 +772,9 @@ def calculate_histogram(data, bins, bin_edges=None): kwargs.setdefault("multiple", "stack") kwargs.setdefault("element", "bars") + sns.histplot(data=hist_data, x='bin_center', weights='count', hue=group_by, ax=ax, **kwargs) - # If plotting feature specify which layer if feature: ax.set_title(f'Layer: {layer}') @@ -800,7 +798,6 @@ def calculate_histogram(data, bins, bin_edges=None): sns.histplot(data=hist_data, x="bin_center", ax=ax_i, weights='count', **kwargs) - # If plotting feature specify which layer if feature: ax_i.set_title(f'{groups[i]} with Layer: {layer}') From 7f09eb8c184115f5eec6185b02e07d3f96379d7c Mon Sep 17 00:00:00 2001 From: Liza Shchehlik Date: Sun, 27 Apr 2025 15:24:32 -0400 Subject: [PATCH 8/9] refactor(auth): Changed parts of datashader to use partial as well as added unit tests --- src/spac/visualization.py | 38 ++++++++++----------- tests/test_visualization/test_datashader.py | 37 ++++++++++++++++++-- 2 files changed, 52 insertions(+), 23 deletions(-) diff --git a/src/spac/visualization.py b/src/spac/visualization.py index ddadac3b..3bce0692 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -94,28 +94,28 @@ def heatmap_datashader(x, y, labels=None, theme=None, x_min, x_max = coords["x"].min(), coords["x"].max() y_min, y_max = coords["y"].min(), coords["y"].max() - # Set figure dimensions - plot_width = kwargs.get('fig_width', 600) - plot_height = kwargs.get('fig_height', 400) + # Supply default ranges if not in kwargs + canvas_kwargs = { + 'plot_width': 600, + 'plot_height': 400, + 'x_range': (x_min, x_max), + 'y_range': (y_min, y_max) + } + canvas_kwargs.update(kwargs) + + create_canvas = partial(ds.Canvas, **canvas_kwargs) if labels is not None: - # Create multiple subplots for each category categories = pd.Categorical(coords["labels"]).categories num_categories = len(categories) - # Arrange subplots in rows of 3 rows = (num_categories // 3) + (1 if num_categories % 3 != 0 else 0) fig, axes = plt.subplots(rows, 3, figsize=(12, 4 * rows)) axes = axes.flatten() for i, cat in enumerate(categories): subset = coords[coords["labels"] == cat] - canvas = ds.Canvas( - plot_width=plot_width, - plot_height=plot_height, - x_range=(x_min, x_max), - y_range=(y_min, y_max) - ) + canvas = create_canvas() agg = canvas.points(subset, x="x", y="y", agg=ds.count()) img = tf.shade(agg, cmap=cmap).to_pil() @@ -125,21 +125,19 @@ def heatmap_datashader(x, y, labels=None, theme=None, ax.set_xlabel(x_axis_title) ax.set_ylabel(y_axis_title) - # Remove unused subplot axes if they exist for j in range(i + 1, len(axes)): fig.delaxes(axes[j]) else: - # Generate single heatmap if no labels are provided - canvas = ds.Canvas( - plot_width=plot_width, - plot_height=plot_height, - x_range=(x_min, x_max), - y_range=(y_min, y_max) - ) + canvas = create_canvas() agg = canvas.points(coords, x="x", y="y", agg=ds.count()) img = tf.shade(agg, cmap=cmap).to_pil() - fig, ax = plt.subplots(figsize=(plot_width / 100, plot_height / 100)) + fig, ax = plt.subplots( + figsize=( + canvas_kwargs["plot_width"] / 100, + canvas_kwargs["plot_height"] / 100 + ) + ) ax.imshow(img, origin='lower', extent=(x_min, x_max, y_min, y_max)) ax.set_title(plot_title if plot_title else "Density Plot") ax.set_xlabel(x_axis_title) diff --git a/tests/test_visualization/test_datashader.py b/tests/test_visualization/test_datashader.py index bc09a6d4..00f2e5e3 100644 --- a/tests/test_visualization/test_datashader.py +++ b/tests/test_visualization/test_datashader.py @@ -14,7 +14,6 @@ def setUp(self): # Fixed categorical labels to ensure representation of each category fixed_labels = ['A', 'B', 'C', 'A', 'B', 'C', 'A', 'B', 'C', 'A'] self.labels_categorical = pd.Series(fixed_labels, dtype="category") - self.labels_continuous = pd.Series(np.random.rand(10)) def test_invalid_input_type(self): """Test handling of invalid input types.""" @@ -29,15 +28,47 @@ def test_labels_length_mismatch(self): heatmap_datashader(self.x, self.y, labels=wrong_labels) self.assertIn("Labels length should match x and y length", str(context_manager.exception)) - def test_valid_input_returns_figure(self): + def test_valid_input_returns_figure_basic(self): """Test that valid input returns a matplotlib figure with expected subplots.""" fig = heatmap_datashader(self.x, self.y, labels=self.labels_categorical) self.assertIsInstance(fig, matplotlib.figure.Figure) - # There should be as many subplots as unique labels (3 in this case) num_axes = len(fig.axes) expected_axes = self.labels_categorical.nunique() self.assertEqual(num_axes, expected_axes) + + def test_labels_not_multiple_of_three(self): + """Test heatmap generation when the number of labels is not a multiple of 3.""" + x = np.random.rand(7) + y = np.random.rand(7) + labels = pd.Series(['A', 'B', 'C', 'D', 'E', 'F', 'G'], dtype="category") # 7 labels + + fig = heatmap_datashader(x, y, labels=labels) + self.assertIsInstance(fig, matplotlib.figure.Figure) + + num_axes = len(fig.axes) + expected_axes = labels.nunique() + self.assertEqual(num_axes, expected_axes) + + for ax in fig.axes: + images = [child for child in ax.get_children() if isinstance(child, matplotlib.image.AxesImage)] + self.assertGreater(len(images), 0, "Expected at least one image in each subplot.") + + def test_valid_input_returns_figure(self): + """Test that valid input returns a matplotlib figure with expected subplots and images.""" + fig = heatmap_datashader(self.x, self.y, labels=self.labels_categorical) + self.assertIsInstance(fig, matplotlib.figure.Figure) + + # Check number of axes matches number of unique labels + num_axes = len(fig.axes) + expected_axes = self.labels_categorical.nunique() + self.assertEqual(num_axes, expected_axes) + + # Check that each axis has an image plotted + for ax in fig.axes: + images = [child for child in ax.get_children() if isinstance(child, matplotlib.image.AxesImage)] + self.assertGreater(len(images), 0, "Expected at least one image in each subplot.") + if __name__ == "__main__": From 578ad36c523e7e95384ab08b2eb35820d469e7b6 Mon Sep 17 00:00:00 2001 From: Sam Ying Date: Thu, 29 May 2025 15:43:47 -0400 Subject: [PATCH 9/9] additional formatting to datashader --- tests/test_visualization/test_datashader.py | 31 +++++++++++++-------- 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/tests/test_visualization/test_datashader.py b/tests/test_visualization/test_datashader.py index 00f2e5e3..c6c3a7df 100644 --- a/tests/test_visualization/test_datashader.py +++ b/tests/test_visualization/test_datashader.py @@ -6,6 +6,7 @@ matplotlib.use('Agg') # Set the backend to 'Agg' to suppress plot window + class TestDataShaderHeatMap(unittest.TestCase): def setUp(self): """Prepare data for testing.""" @@ -19,24 +20,27 @@ def test_invalid_input_type(self): """Test handling of invalid input types.""" with self.assertRaises(ValueError) as context_manager: heatmap_datashader(1, self.y, labels=self.labels_categorical) - self.assertIn("x and y must be array-like", str(context_manager.exception)) + self.assertIn("x and y must be array-like", + str(context_manager.exception)) def test_labels_length_mismatch(self): """Test handling of mismatched lengths between data and labels.""" wrong_labels = pd.Series(['A'] * 9) # Shorter than x and y with self.assertRaises(ValueError) as context_manager: heatmap_datashader(self.x, self.y, labels=wrong_labels) - self.assertIn("Labels length should match x and y length", str(context_manager.exception)) + self.assertIn("Labels length should match x and y length", + str(context_manager.exception)) def test_valid_input_returns_figure_basic(self): """Test that valid input returns a matplotlib figure with expected subplots.""" - fig = heatmap_datashader(self.x, self.y, labels=self.labels_categorical) + fig = heatmap_datashader(self.x, self.y, + labels=self.labels_categorical) self.assertIsInstance(fig, matplotlib.figure.Figure) num_axes = len(fig.axes) expected_axes = self.labels_categorical.nunique() self.assertEqual(num_axes, expected_axes) - + def test_labels_not_multiple_of_three(self): """Test heatmap generation when the number of labels is not a multiple of 3.""" x = np.random.rand(7) @@ -51,12 +55,15 @@ def test_labels_not_multiple_of_three(self): self.assertEqual(num_axes, expected_axes) for ax in fig.axes: - images = [child for child in ax.get_children() if isinstance(child, matplotlib.image.AxesImage)] - self.assertGreater(len(images), 0, "Expected at least one image in each subplot.") + images = [child for child in ax.get_children() + if isinstance(child, matplotlib.image.AxesImage)] + self.assertGreater(len(images), 0, + "Expected at least one image in each subplot.") def test_valid_input_returns_figure(self): """Test that valid input returns a matplotlib figure with expected subplots and images.""" - fig = heatmap_datashader(self.x, self.y, labels=self.labels_categorical) + fig = heatmap_datashader(self.x, self.y, + labels=self.labels_categorical) self.assertIsInstance(fig, matplotlib.figure.Figure) # Check number of axes matches number of unique labels @@ -66,10 +73,10 @@ def test_valid_input_returns_figure(self): # Check that each axis has an image plotted for ax in fig.axes: - images = [child for child in ax.get_children() if isinstance(child, matplotlib.image.AxesImage)] - self.assertGreater(len(images), 0, "Expected at least one image in each subplot.") - - + images = [child for child in ax.get_children() + if isinstance(child, matplotlib.image.AxesImage)] + self.assertGreater(len(images), 0, + "Expected at least one image in each subplot.") if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main()