diff --git a/src/spac/visualization.py b/src/spac/visualization.py index ef46f7c5..e08bfaa4 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -704,7 +704,8 @@ def heatmap(adata, column, layer=None, **kwargs): def hierarchical_heatmap(adata, annotation, features=None, layer=None, cluster_feature=False, cluster_annotations=False, standard_scale=None, z_score="annotation", - swap_axes=False, rotate_label=False, **kwargs): + swap_axes=False, rotate_label=False, + show_counts=False, **kwargs): """ Generates a hierarchical clustering heatmap and dendrogram. @@ -745,6 +746,9 @@ def hierarchical_heatmap(adata, annotation, features=None, layer=None, axis (rows) and features are on the horizontal axis (columns). When set to True, features will be on the vertical axis and annotations on the horizontal axis. Default is False. + show_counts : bool, optional + If True, adds the amount of cells to each cluster in the heatmap. + Default is False. rotate_label : bool, optional If True, rotate x-axis labels by 45 degrees. Default is False. **kwargs: @@ -910,6 +914,31 @@ def hierarchical_heatmap(adata, annotation, features=None, layer=None, 'col_linkage': dendro_col_data } + if show_counts: + + # Add cell counts to labels + # Retrieve the number of cells in each group + cell_counts = labels.value_counts() + cell_counts = dict(cell_counts) + + # Retrieve the cluster labels from the heatmap + cluster_labels = clustergrid.ax_heatmap.get_yticklabels() + + # Append the cell number to each cluster + numbered_labels = [] + for x in cluster_labels: + key = int(x.get_text()) + value = cell_counts[key] + numbered_label = f'cluster {key}\n{value} cells' + numbered_labels.append(numbered_label) + + # add updated labels with cell counts to the heatmap + clustergrid.ax_heatmap.set_yticklabels(numbered_labels) + plt.setp(clustergrid.ax_heatmap.get_yticklabels(), rotation=0) + + # adjust so labels don't get cut off + clustergrid.ax_heatmap.figure.subplots_adjust(right=0.9, left=0.1) + return mean_intensity, clustergrid, dendrogram_data diff --git a/tests/test_visualization/test_hierarchical_heatmap.py b/tests/test_visualization/test_hierarchical_heatmap.py index 4b2b4863..de320d89 100644 --- a/tests/test_visualization/test_hierarchical_heatmap.py +++ b/tests/test_visualization/test_hierarchical_heatmap.py @@ -118,6 +118,26 @@ def test_axes_switching(self): mean_intensity_swapped.shape ) + def test_cell_count_labels(self): + '''This test confirms the cell count labels are correct''' + # Set up AnnData object with known cell counts (3 cells in 1 cluster) + X_data = pd.DataFrame( + {'gene1': [1, 2, 3], 'gene2': [1, 2, 3], 'gene3': [1, 2, 3]} + ) + obs_data = pd.DataFrame({'cluster': [1, 1, 1]}) + self.adata = anndata.AnnData(X=X_data, obs=obs_data) + + # Use hierarchical_heatmap function on AnnData object and get y-axis labels + _, clustergrid, _ = hierarchical_heatmap(self.adata, annotation='cluster', show_counts=True) + actual_labels = [] + for label in clustergrid.ax_heatmap.get_yticklabels(): + actual_labels.append(label.get_text()) + + # Confirm actual labels are in the expected labels + # dendrogram order varies if multiple clusters + expected_labels = ['cluster 1\n3 cells'] + self.assertIn(expected_labels[0], actual_labels) + if __name__ == "__main__": unittest.main()