diff --git a/docs/source/api/docstrings/pydeseq2.dds.DeseqDataSet.rst b/docs/source/api/docstrings/pydeseq2.dds.DeseqDataSet.rst index e4388bba..376125d2 100644 --- a/docs/source/api/docstrings/pydeseq2.dds.DeseqDataSet.rst +++ b/docs/source/api/docstrings/pydeseq2.dds.DeseqDataSet.rst @@ -18,6 +18,7 @@ ~DeseqDataSet.fit_genewise_dispersions ~DeseqDataSet.fit_size_factors ~DeseqDataSet.plot_dispersions + ~DeseqDataSet.plot_rle ~DeseqDataSet.refit ~DeseqDataSet.vst diff --git a/docs/source/api/docstrings/pydeseq2.ds.DeseqStats.rst b/docs/source/api/docstrings/pydeseq2.ds.DeseqStats.rst index 22275d83..f55afe29 100644 --- a/docs/source/api/docstrings/pydeseq2.ds.DeseqStats.rst +++ b/docs/source/api/docstrings/pydeseq2.ds.DeseqStats.rst @@ -12,6 +12,7 @@ ~DeseqStats.lfc_shrink ~DeseqStats.run_wald_test ~DeseqStats.summary + ~DeseqStats.plot_MA diff --git a/pydeseq2/dds.py b/pydeseq2/dds.py index b421a504..eb759d2c 100644 --- a/pydeseq2/dds.py +++ b/pydeseq2/dds.py @@ -21,6 +21,7 @@ from pydeseq2.preprocessing import deseq2_norm_transform from pydeseq2.utils import build_design_matrix from pydeseq2.utils import dispersion_trend +from pydeseq2.utils import make_rle_plot from pydeseq2.utils import make_scatter from pydeseq2.utils import mean_absolute_deviation from pydeseq2.utils import n_or_more_replicates @@ -1010,6 +1011,36 @@ def plot_dispersions( **kwargs, ) + def plot_rle( + self, + normalize: bool = False, + save_path: str | None = None, + **kwargs, + ): + """Plot ratio of log expressions (RLE) for each sample. + + Useful for visualizing sample to sample variation. + + Parameters + ---------- + normalize : bool, optional + Whether to normalize the counts before plotting. (default: ``False``). + + save_path : str, optional + The path where to save the plot. If left None, the plot won't be saved + (default: ``None``). + + **kwargs + Keyword arguments for the scatter plot. + """ + make_rle_plot( + count_matrix=self.X, + normalize=normalize, + sample_ids=self.obs_names, + save_path=save_path, + **kwargs, + ) + def _fit_parametric_dispersion_trend(self, vst: bool = False): r"""Fit the dispersion curve according to a parametric model. diff --git a/pydeseq2/utils.py b/pydeseq2/utils.py index 95d93ac1..f4e9f80d 100644 --- a/pydeseq2/utils.py +++ b/pydeseq2/utils.py @@ -1608,3 +1608,63 @@ def lowess( delta = (1 - delta**2) ** 2 return yest + + +def make_rle_plot( + count_matrix: np.array, + sample_ids: np.array, + normalize: bool = False, + save_path: str | None = None, + **kwargs, +) -> None: + """ + Create a ratio of log expression plot using matplotlib. + + Parameters + ---------- + count_matrix : ndarray + An mxn matrix of count data, where m is the number of samples (rows), + and n is the number of genes (columns). + + sample_ids : ndarray + An array of sample identifiers. + + normalize : bool + Whether to normalize the count matrix before plotting. (default: ``False``). + + save_path : str, optional + The path where to save the plot. If left None, the plot won't be saved + (default: ``None``). + + **kwargs : + Additional keyword arguments passed to matplotlib's boxplot function. + """ + if normalize: + geometric_mean = np.exp(np.mean(np.log(count_matrix + 1), axis=0)) + size_factors = np.median(count_matrix / geometric_mean, axis=1) + count_matrix = count_matrix / size_factors[:, np.newaxis] + + plt.rcParams.update({"font.size": 10}) + + fig, ax = plt.subplots(figsize=(15, 8), dpi=600) + + # Calculate median expression across samples + gene_medians = np.median(count_matrix, axis=0) + rle_values = np.log2(count_matrix / gene_medians) + + kwargs.setdefault("alpha", 0.5) + boxprops = {"facecolor": "lightgray", "alpha": kwargs.pop("alpha")} + + ax.boxplot(rle_values.T, patch_artist=True, boxprops=boxprops, **kwargs) + + ax.axhline(0, color="red", linestyle="--", linewidth=1, alpha=0.5, zorder=3) + ax.set_xlabel("Sample") + ax.set_ylabel("Relative Log Expression") + ax.set_xticks(np.arange(len(sample_ids))) + ax.set_xticklabels(sample_ids, rotation=90) + plt.tight_layout() + + if save_path: + plt.savefig(save_path, bbox_inches="tight") + else: + plt.show() diff --git a/tests/test_pydeseq2.py b/tests/test_pydeseq2.py index 18e57011..2ba15821 100644 --- a/tests/test_pydeseq2.py +++ b/tests/test_pydeseq2.py @@ -875,3 +875,16 @@ def assert_res_almost_equal(py_res, r_res, tol=0.02): ).max() < tol assert (abs(r_res.pvalue - py_res.pvalue) / r_res.pvalue).max() < tol assert (abs(r_res.padj - py_res.padj) / r_res.padj).max() < tol + + +def test_plot_rle(train_counts, train_metadata): + """Test that the RLE plot is generated without error.""" + + dds = DeseqDataSet( + counts=train_counts, + metadata=train_metadata, + design="~condition", + ) + + dds.plot_rle(normalize=False) + dds.plot_rle(normalize=True)