Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion src/spac/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,7 +704,8 @@ def heatmap(adata, column, layer=None, **kwargs):
def hierarchical_heatmap(adata, annotation, features=None, layer=None,
cluster_feature=False, cluster_annotations=False,
standard_scale=None, z_score="annotation",
swap_axes=False, rotate_label=False, **kwargs):
swap_axes=False, rotate_label=False,
show_counts=False, **kwargs):

"""
Generates a hierarchical clustering heatmap and dendrogram.
Expand Down Expand Up @@ -745,6 +746,9 @@ def hierarchical_heatmap(adata, annotation, features=None, layer=None,
axis (rows) and features are on the horizontal axis (columns).
When set to True, features will be on the vertical axis and
annotations on the horizontal axis. Default is False.
show_counts : bool, optional
If True, adds the amount of cells to each cluster in the heatmap.
Default is False.
rotate_label : bool, optional
If True, rotate x-axis labels by 45 degrees. Default is False.
**kwargs:
Expand Down Expand Up @@ -910,6 +914,31 @@ def hierarchical_heatmap(adata, annotation, features=None, layer=None,
'col_linkage': dendro_col_data
}

if show_counts:

# Add cell counts to labels
# Retrieve the number of cells in each group
cell_counts = labels.value_counts()
cell_counts = dict(cell_counts)

# Retrieve the cluster labels from the heatmap
cluster_labels = clustergrid.ax_heatmap.get_yticklabels()

# Append the cell number to each cluster
numbered_labels = []
for x in cluster_labels:
key = int(x.get_text())
value = cell_counts[key]
numbered_label = f'cluster {key}\n{value} cells'
numbered_labels.append(numbered_label)

# add updated labels with cell counts to the heatmap
clustergrid.ax_heatmap.set_yticklabels(numbered_labels)
plt.setp(clustergrid.ax_heatmap.get_yticklabels(), rotation=0)

# adjust so labels don't get cut off
clustergrid.ax_heatmap.figure.subplots_adjust(right=0.9, left=0.1)

return mean_intensity, clustergrid, dendrogram_data


Expand Down
20 changes: 20 additions & 0 deletions tests/test_visualization/test_hierarchical_heatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,26 @@ def test_axes_switching(self):
mean_intensity_swapped.shape
)

def test_cell_count_labels(self):
'''This test confirms the cell count labels are correct'''
# Set up AnnData object with known cell counts (3 cells in 1 cluster)
X_data = pd.DataFrame(
{'gene1': [1, 2, 3], 'gene2': [1, 2, 3], 'gene3': [1, 2, 3]}
)
obs_data = pd.DataFrame({'cluster': [1, 1, 1]})
self.adata = anndata.AnnData(X=X_data, obs=obs_data)

# Use hierarchical_heatmap function on AnnData object and get y-axis labels
_, clustergrid, _ = hierarchical_heatmap(self.adata, annotation='cluster', show_counts=True)
actual_labels = []
for label in clustergrid.ax_heatmap.get_yticklabels():
actual_labels.append(label.get_text())

# Confirm actual labels are in the expected labels
# dendrogram order varies if multiple clusters
expected_labels = ['cluster 1\n3 cells']
self.assertIn(expected_labels[0], actual_labels)


if __name__ == "__main__":
unittest.main()