From 2efc9c3190caa6f1d0efb3fe94df6ef379c75706 Mon Sep 17 00:00:00 2001 From: Liza Shchehlik Date: Thu, 17 Apr 2025 17:30:21 -0400 Subject: [PATCH 1/4] feat(auth): added pin_colors implementation to 2d_scatter function --- src/spac/visualization.py | 31 +++++++++++++++++++++---------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/src/spac/visualization.py b/src/spac/visualization.py index 0ab0ee11..98fc4544 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -35,7 +35,7 @@ def visualize_2D_scatter( x, y, labels=None, point_size=None, theme=None, ax=None, annotate_centers=False, x_axis_title='Component 1', y_axis_title='Component 2', plot_title=None, - color_representation=None, **kwargs + color_representation=None, defined_color_map=None, **kwargs ): """ Visualize 2D data using plt.scatter. @@ -65,6 +65,8 @@ def visualize_2D_scatter( Title for the plot. color_representation : str, optional Description of what the colors represent. + defined_color_map : dictionary, optional + Dictionary containing colors for label annotations. **kwargs Additional keyword arguments passed to plt.scatter. @@ -83,6 +85,10 @@ def visualize_2D_scatter( raise ValueError("x and y must have the same length.") if labels is not None and len(labels) != len(x): raise ValueError("Labels length should match x and y length.") + if defined_color_map is not None: + if not isinstance(defined_color_map, dict): + raise ValueError("`defined_color_map` must be a dict mapping label→color.") + color_dict = defined_color_map # Define color themes themes = { @@ -136,20 +142,25 @@ def visualize_2D_scatter( "Categorical." ) - # Combine colors from multiple colormaps - cmap1 = plt.get_cmap('tab20') - cmap2 = plt.get_cmap('tab20b') - cmap3 = plt.get_cmap('tab20c') - colors = cmap1.colors + cmap2.colors + cmap3.colors - - # Use the number of unique clusters to set the colormap length - cmap = ListedColormap(colors[:len(unique_clusters)]) + if defined_color_map is not None: + cluster_to_color = color_dict + else: + # fall back to your combined tab20 palettes + cmap1 = plt.get_cmap('tab20') + cmap2 = plt.get_cmap('tab20b') + cmap3 = plt.get_cmap('tab20c') + colors = cmap1.colors + cmap2.colors + cmap3.colors + cluster_to_color = { + str(cluster): colors[i % len(colors)] + for i, cluster in enumerate(unique_clusters) + } for idx, cluster in enumerate(unique_clusters): mask = np.array(labels) == cluster + color = cluster_to_color.get(str(cluster), 'gray') ax.scatter( x[mask], y[mask], - color=cmap(idx), + color=color, label=cluster, s=point_size ) From 8262ebc9c75edf5f64945f300e691f40f028287b Mon Sep 17 00:00:00 2001 From: Liza Shchehlik Date: Thu, 17 Apr 2025 18:01:13 -0400 Subject: [PATCH 2/4] style(header): changed parameter name to color_map instead of defined_color_map --- src/spac/visualization.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/spac/visualization.py b/src/spac/visualization.py index 98fc4544..b257beb7 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -35,7 +35,7 @@ def visualize_2D_scatter( x, y, labels=None, point_size=None, theme=None, ax=None, annotate_centers=False, x_axis_title='Component 1', y_axis_title='Component 2', plot_title=None, - color_representation=None, defined_color_map=None, **kwargs + color_representation=None, color_map=None, **kwargs ): """ Visualize 2D data using plt.scatter. @@ -65,7 +65,7 @@ def visualize_2D_scatter( Title for the plot. color_representation : str, optional Description of what the colors represent. - defined_color_map : dictionary, optional + color_map : dictionary, optional Dictionary containing colors for label annotations. **kwargs Additional keyword arguments passed to plt.scatter. @@ -85,10 +85,10 @@ def visualize_2D_scatter( raise ValueError("x and y must have the same length.") if labels is not None and len(labels) != len(x): raise ValueError("Labels length should match x and y length.") - if defined_color_map is not None: - if not isinstance(defined_color_map, dict): - raise ValueError("`defined_color_map` must be a dict mapping label→color.") - color_dict = defined_color_map + if color_map is not None: + if not isinstance(color_map, dict): + raise ValueError("`color_map` must be a dict mapping label→color.") + color_dict = color_map # Define color themes themes = { @@ -142,7 +142,7 @@ def visualize_2D_scatter( "Categorical." ) - if defined_color_map is not None: + if color_map is not None: cluster_to_color = color_dict else: # fall back to your combined tab20 palettes From b31ae3534184d18716549c8d23a61332dc1a2eed Mon Sep 17 00:00:00 2001 From: LizaShch Date: Thu, 24 Apr 2025 20:35:46 +0000 Subject: [PATCH 3/4] refactor(auth): changed format of lines of code --- src/spac/visualization.py | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/src/spac/visualization.py b/src/spac/visualization.py index b257beb7..8a0d7782 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -85,10 +85,8 @@ def visualize_2D_scatter( raise ValueError("x and y must have the same length.") if labels is not None and len(labels) != len(x): raise ValueError("Labels length should match x and y length.") - if color_map is not None: - if not isinstance(color_map, dict): - raise ValueError("`color_map` must be a dict mapping label→color.") - color_dict = color_map + if not isinstance(color_map, dict): + raise ValueError("`color_map` must be a dict mapping label→color.") # Define color themes themes = { @@ -142,18 +140,14 @@ def visualize_2D_scatter( "Categorical." ) - if color_map is not None: - cluster_to_color = color_dict - else: - # fall back to your combined tab20 palettes - cmap1 = plt.get_cmap('tab20') - cmap2 = plt.get_cmap('tab20b') - cmap3 = plt.get_cmap('tab20c') - colors = cmap1.colors + cmap2.colors + cmap3.colors - cluster_to_color = { - str(cluster): colors[i % len(colors)] - for i, cluster in enumerate(unique_clusters) - } + cmap1 = plt.get_cmap('tab20') + cmap2 = plt.get_cmap('tab20b') + cmap3 = plt.get_cmap('tab20c') + colors = cmap1.colors + cmap2.colors + cmap3.colors + cluster_to_color = color_map if color_map is not None else { + str(cluster): colors[i % len(colors)] + for i, cluster in enumerate(unique_clusters) + } for idx, cluster in enumerate(unique_clusters): mask = np.array(labels) == cluster From 1967d3433ce5c10968497779da9efc2bcc623d9f Mon Sep 17 00:00:00 2001 From: Liza Shchehlik Date: Thu, 24 Apr 2025 16:46:17 -0400 Subject: [PATCH 4/4] fix(tests): Edited condition which caused unit_tests to fail --- src/spac/visualization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spac/visualization.py b/src/spac/visualization.py index 8a0d7782..50842cc9 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -85,7 +85,7 @@ def visualize_2D_scatter( raise ValueError("x and y must have the same length.") if labels is not None and len(labels) != len(x): raise ValueError("Labels length should match x and y length.") - if not isinstance(color_map, dict): + if color_map is not None and not isinstance(color_map, dict): raise ValueError("`color_map` must be a dict mapping label→color.") # Define color themes