-
Notifications
You must be signed in to change notification settings - Fork 11
feat(downsample):AnnData input and output for downsample_cells #349
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
520934e
507ef7f
3e84383
c110994
1915dca
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -538,46 +538,151 @@ def _select_values_anndata(data, annotation, values, exclude_values): | |||||||||||||||||||||||||||||||||||||||||||||
| return filtered_data | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| def downsample_cells(input_data, annotations, n_samples=None, stratify=False, | ||||||||||||||||||||||||||||||||||||||||||||||
| rand=False, combined_col_name='_combined_', | ||||||||||||||||||||||||||||||||||||||||||||||
| min_threshold=5): | ||||||||||||||||||||||||||||||||||||||||||||||
| def downsample_cells(input_data, annotations, n_samples=None, | ||||||||||||||||||||||||||||||||||||||||||||||
| stratify=False, rand=False, combined_col_name='_combined_', min_threshold=5): | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||
| Custom downsampling of data based on one or more annotations. | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| This function offers two primary modes of operation: | ||||||||||||||||||||||||||||||||||||||||||||||
| 1. **Grouping (stratify=False)**: | ||||||||||||||||||||||||||||||||||||||||||||||
| - For a single annotation: The data is grouped by unique values of the | ||||||||||||||||||||||||||||||||||||||||||||||
| annotation, and 'n_samples' rows are selected from each group. | ||||||||||||||||||||||||||||||||||||||||||||||
| - For multiple annotations: The data is grouped based on unique | ||||||||||||||||||||||||||||||||||||||||||||||
| combinations of the annotations, and 'n_samples' rows are selected | ||||||||||||||||||||||||||||||||||||||||||||||
| from each combined group. | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| 2. **Stratification (stratify=True)**: | ||||||||||||||||||||||||||||||||||||||||||||||
| - Annotations (single or multiple) are combined into a new column. | ||||||||||||||||||||||||||||||||||||||||||||||
| - Proportionate stratified sampling is performed based on the unique | ||||||||||||||||||||||||||||||||||||||||||||||
| combinations in the new column, ensuring that the downsampled dataset | ||||||||||||||||||||||||||||||||||||||||||||||
| maintains the proportionate representation of each combined group | ||||||||||||||||||||||||||||||||||||||||||||||
| from the original dataset. | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| Parameters | ||||||||||||||||||||||||||||||||||||||||||||||
| ---------- | ||||||||||||||||||||||||||||||||||||||||||||||
| input_data : pd.DataFrame or anndata.AnnData | ||||||||||||||||||||||||||||||||||||||||||||||
| The input data. Must be either a pandas dataframe or an anndata object. | ||||||||||||||||||||||||||||||||||||||||||||||
| annotations : str or list of str | ||||||||||||||||||||||||||||||||||||||||||||||
| The column name(s) to downsample on. If multiple column names are | ||||||||||||||||||||||||||||||||||||||||||||||
| provided, their values are combined using an underscore as a separator. | ||||||||||||||||||||||||||||||||||||||||||||||
| n_samples : int, default=None | ||||||||||||||||||||||||||||||||||||||||||||||
| The number of samples to return. Behavior differs based on the | ||||||||||||||||||||||||||||||||||||||||||||||
| 'stratify' parameter: | ||||||||||||||||||||||||||||||||||||||||||||||
| - stratify=False: Returns 'n_samples' for each unique value (or | ||||||||||||||||||||||||||||||||||||||||||||||
| combination) of annotations. | ||||||||||||||||||||||||||||||||||||||||||||||
| - stratify=True: Returns a total of 'n_samples' stratified by the | ||||||||||||||||||||||||||||||||||||||||||||||
| frequency of every label or combined labels in the annotation(s). | ||||||||||||||||||||||||||||||||||||||||||||||
| stratify : bool, default=False | ||||||||||||||||||||||||||||||||||||||||||||||
| If true, perform proportionate stratified sampling based on the unique | ||||||||||||||||||||||||||||||||||||||||||||||
| combinations of annotations. This ensures that the downsampled dataset | ||||||||||||||||||||||||||||||||||||||||||||||
| maintains the proportionate representation of each combined group from | ||||||||||||||||||||||||||||||||||||||||||||||
| the original dataset. | ||||||||||||||||||||||||||||||||||||||||||||||
| rand : bool, default=False | ||||||||||||||||||||||||||||||||||||||||||||||
| If true and stratify is True, randomly select the returned cells. | ||||||||||||||||||||||||||||||||||||||||||||||
| Otherwise, choose the first n cells. | ||||||||||||||||||||||||||||||||||||||||||||||
| combined_col_name : str, default='_combined_' | ||||||||||||||||||||||||||||||||||||||||||||||
| Name of the column that will store combined values when multiple | ||||||||||||||||||||||||||||||||||||||||||||||
| annotation columns are provided. | ||||||||||||||||||||||||||||||||||||||||||||||
| min_threshold : int, default=5 | ||||||||||||||||||||||||||||||||||||||||||||||
| The minimum number of samples a combined group should have in the | ||||||||||||||||||||||||||||||||||||||||||||||
| original dataset to be considered in the downsampled dataset. Groups | ||||||||||||||||||||||||||||||||||||||||||||||
| with fewer samples than this threshold will be excluded from the | ||||||||||||||||||||||||||||||||||||||||||||||
| stratification process. Adjusting this parameter determines the | ||||||||||||||||||||||||||||||||||||||||||||||
| minimum presence a combined group should have in the original dataset | ||||||||||||||||||||||||||||||||||||||||||||||
| to appear in the downsampled version. | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| Returns | ||||||||||||||||||||||||||||||||||||||||||||||
| ------- | ||||||||||||||||||||||||||||||||||||||||||||||
| output_data: pd.DataFrame or anndata.Anndata | ||||||||||||||||||||||||||||||||||||||||||||||
| The proportionately stratified downsampled dataset. | ||||||||||||||||||||||||||||||||||||||||||||||
| If the input is a dataframe, a dataframe is returned. | ||||||||||||||||||||||||||||||||||||||||||||||
| If the input is a anndata object, an anndata object is returned. | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| Notes | ||||||||||||||||||||||||||||||||||||||||||||||
| ----- | ||||||||||||||||||||||||||||||||||||||||||||||
| This function emphasizes proportionate stratified sampling, ensuring that | ||||||||||||||||||||||||||||||||||||||||||||||
| the downsampled dataset is a representative subset of the original data | ||||||||||||||||||||||||||||||||||||||||||||||
| with respect to the combined annotations. Due to this proportionate nature, | ||||||||||||||||||||||||||||||||||||||||||||||
| not all unique combinations from the original dataset might be present in | ||||||||||||||||||||||||||||||||||||||||||||||
| the downsampled dataset, especially if a particular combination has very | ||||||||||||||||||||||||||||||||||||||||||||||
| few samples in the original dataset. The `min_threshold` parameter can be | ||||||||||||||||||||||||||||||||||||||||||||||
| adjusted to determine the minimum number of samples a combined group | ||||||||||||||||||||||||||||||||||||||||||||||
| should have in the original dataset to appear in the downsampled version. | ||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| # If n_samples is None, return the input data without processing | ||||||||||||||||||||||||||||||||||||||||||||||
| if n_samples is None: | ||||||||||||||||||||||||||||||||||||||||||||||
| return input_data | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| # Convert annotations to list if it's a string | ||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(annotations, str): | ||||||||||||||||||||||||||||||||||||||||||||||
| annotations = [annotations] | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| # Extract the cell information for downsampling process | ||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(input_data, anndata.AnnData): | ||||||||||||||||||||||||||||||||||||||||||||||
| cell_data = input_data.obs[annotations].copy() | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| elif isinstance(input_data, pd.DataFrame): | ||||||||||||||||||||||||||||||||||||||||||||||
| cell_data = input_data[annotations].copy() | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||
| raise TypeError("Input data must be a Pandas DataFrame or Anndata Object.") | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| # Prepare cell annotation column for downsample process: | ||||||||||||||||||||||||||||||||||||||||||||||
| # Check if the columns to downsample on exist | ||||||||||||||||||||||||||||||||||||||||||||||
| missing_columns = [ | ||||||||||||||||||||||||||||||||||||||||||||||
| col for col in annotations if col not in cell_data.columns | ||||||||||||||||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||||||||||||||||
| if missing_columns: | ||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||||||||||||||||||||||||||
| f"Columns {missing_columns} do not exist in the dataframe" | ||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| # Determine which cells to keep or remove based on annotations column | ||||||||||||||||||||||||||||||||||||||||||||||
| cell_indexes = _get_downsampled_indexes(cell_data, annotations, | ||||||||||||||||||||||||||||||||||||||||||||||
| n_samples, stratify, rand, combined_col_name, min_threshold) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| # Filter cell indexes to be removed | ||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(input_data, anndata.AnnData): | ||||||||||||||||||||||||||||||||||||||||||||||
| downsampled_data = input_data[cell_indexes,:].copy() | ||||||||||||||||||||||||||||||||||||||||||||||
| return downsampled_data | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| elif isinstance(input_data, pd.DataFrame): | ||||||||||||||||||||||||||||||||||||||||||||||
| downsampled_data = input_data.loc[cell_indexes,:].copy() | ||||||||||||||||||||||||||||||||||||||||||||||
| return downsampled_data | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||
| raise TypeError("Input data must be a Pandas DataFrame or Anndata Object.") | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| def _get_downsampled_indexes(cell_data, annotations, | ||||||||||||||||||||||||||||||||||||||||||||||
| n_samples, stratify, rand, combined_col_name, min_threshold): | ||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||
| Custom downsampling of data based on one or more annotations. | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| This function offers two primary modes of operation: | ||||||||||||||||||||||||||||||||||||||||||||||
| 1. **Grouping (stratify=False)**: | ||||||||||||||||||||||||||||||||||||||||||||||
| - For a single annotation: The data is grouped by unique values of the | ||||||||||||||||||||||||||||||||||||||||||||||
| annotation, and 'n_samples' rows are selected from each group. | ||||||||||||||||||||||||||||||||||||||||||||||
| - For multiple annotations: The data is grouped based on unique | ||||||||||||||||||||||||||||||||||||||||||||||
| combinations of the annotations, and 'n_samples' rows are selected | ||||||||||||||||||||||||||||||||||||||||||||||
| from each combined group. | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| 2. **Stratification (stratify=True)**: | ||||||||||||||||||||||||||||||||||||||||||||||
| - Annotations (single or multiple) are combined into a new column. | ||||||||||||||||||||||||||||||||||||||||||||||
| - Proportionate stratified sampling is performed based on the unique | ||||||||||||||||||||||||||||||||||||||||||||||
| combinations in the new column, ensuring that the downsampled dataset | ||||||||||||||||||||||||||||||||||||||||||||||
| maintains the proportionate representation of each combined group | ||||||||||||||||||||||||||||||||||||||||||||||
| from the original dataset. | ||||||||||||||||||||||||||||||||||||||||||||||
| Helper function to compute cell indexes for downsampling. | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| This function processes a DataFrame containing only cell annotations | ||||||||||||||||||||||||||||||||||||||||||||||
| to determine which cell indexes should be retained after downsampling. | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| Parameters | ||||||||||||||||||||||||||||||||||||||||||||||
| ---------- | ||||||||||||||||||||||||||||||||||||||||||||||
| input_data : pd.DataFrame | ||||||||||||||||||||||||||||||||||||||||||||||
| The input data frame. | ||||||||||||||||||||||||||||||||||||||||||||||
| annotations : str or list of str | ||||||||||||||||||||||||||||||||||||||||||||||
| The column name(s) to downsample on. If multiple column names are | ||||||||||||||||||||||||||||||||||||||||||||||
| provided, their values are combined using an underscore as a separator. | ||||||||||||||||||||||||||||||||||||||||||||||
| cell_data : pd.DataFrame | ||||||||||||||||||||||||||||||||||||||||||||||
| A dataframe containing cell IDs as indexes and the annotation columns | ||||||||||||||||||||||||||||||||||||||||||||||
| from the original dataset (AnnData.obs or DataFrame[annotations]). | ||||||||||||||||||||||||||||||||||||||||||||||
| annotations : list of str | ||||||||||||||||||||||||||||||||||||||||||||||
| The list of column names in 'cell_data' used for grouping or | ||||||||||||||||||||||||||||||||||||||||||||||
| stratification. If multiple names are provided, they are combined | ||||||||||||||||||||||||||||||||||||||||||||||
| into a single grouping column. | ||||||||||||||||||||||||||||||||||||||||||||||
| n_samples : int, default=None | ||||||||||||||||||||||||||||||||||||||||||||||
| The number of samples to return. Behavior differs based on the | ||||||||||||||||||||||||||||||||||||||||||||||
| 'stratify' parameter: | ||||||||||||||||||||||||||||||||||||||||||||||
| - stratify=False: Returns 'n_samples' for each unique value (or | ||||||||||||||||||||||||||||||||||||||||||||||
| combination) of annotations. | ||||||||||||||||||||||||||||||||||||||||||||||
| combination) of annotations. | ||||||||||||||||||||||||||||||||||||||||||||||
| - stratify=True: Returns a total of 'n_samples' stratified by the | ||||||||||||||||||||||||||||||||||||||||||||||
| frequency of every label or combined labels in the annotation(s). | ||||||||||||||||||||||||||||||||||||||||||||||
| frequency of every label or combined labels in the annotation(s). | ||||||||||||||||||||||||||||||||||||||||||||||
| stratify : bool, default=False | ||||||||||||||||||||||||||||||||||||||||||||||
| If true, perform proportionate stratified sampling based on the unique | ||||||||||||||||||||||||||||||||||||||||||||||
| combinations of annotations. This ensures that the downsampled dataset | ||||||||||||||||||||||||||||||||||||||||||||||
| maintains the proportionate representation of each combined group from | ||||||||||||||||||||||||||||||||||||||||||||||
| the original dataset. | ||||||||||||||||||||||||||||||||||||||||||||||
| If true, maintains the proportionate representation of each | ||||||||||||||||||||||||||||||||||||||||||||||
| combined group from the original dataset in downsampled version. | ||||||||||||||||||||||||||||||||||||||||||||||
| rand : bool, default=False | ||||||||||||||||||||||||||||||||||||||||||||||
| If true and stratify is True, randomly select the returned cells. | ||||||||||||||||||||||||||||||||||||||||||||||
| Otherwise, choose the first n cells. | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -586,62 +691,32 @@ def downsample_cells(input_data, annotations, n_samples=None, stratify=False, | |||||||||||||||||||||||||||||||||||||||||||||
| annotation columns are provided. | ||||||||||||||||||||||||||||||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The "combined_col_name" parameter is documented but never used in the code. |
||||||||||||||||||||||||||||||||||||||||||||||
| min_threshold : int, default=5 | ||||||||||||||||||||||||||||||||||||||||||||||
| The minimum number of samples a combined group should have in the | ||||||||||||||||||||||||||||||||||||||||||||||
| original dataset to be considered in the downsampled dataset. Groups | ||||||||||||||||||||||||||||||||||||||||||||||
| with fewer samples than this threshold will be excluded from the | ||||||||||||||||||||||||||||||||||||||||||||||
| stratification process. Adjusting this parameter determines the | ||||||||||||||||||||||||||||||||||||||||||||||
| minimum presence a combined group should have in the original dataset | ||||||||||||||||||||||||||||||||||||||||||||||
| to appear in the downsampled version. | ||||||||||||||||||||||||||||||||||||||||||||||
| original dataset to not be excluded in the downsampled dataset. | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| Returns | ||||||||||||||||||||||||||||||||||||||||||||||
| ------- | ||||||||||||||||||||||||||||||||||||||||||||||
| output_data: pd.DataFrame | ||||||||||||||||||||||||||||||||||||||||||||||
| The proportionately stratified downsampled data frame. | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| Notes | ||||||||||||||||||||||||||||||||||||||||||||||
| ----- | ||||||||||||||||||||||||||||||||||||||||||||||
| This function emphasizes proportionate stratified sampling, ensuring that | ||||||||||||||||||||||||||||||||||||||||||||||
| the downsampled dataset is a representative subset of the original data | ||||||||||||||||||||||||||||||||||||||||||||||
| with respect to the combined annotations. Due to this proportionate nature, | ||||||||||||||||||||||||||||||||||||||||||||||
| not all unique combinations from the original dataset might be present in | ||||||||||||||||||||||||||||||||||||||||||||||
| the downsampled dataset, especially if a particular combination has very | ||||||||||||||||||||||||||||||||||||||||||||||
| few samples in the original dataset. The `min_threshold` parameter can be | ||||||||||||||||||||||||||||||||||||||||||||||
| adjusted to determine the minimum number of samples a combined group | ||||||||||||||||||||||||||||||||||||||||||||||
| should have in the original dataset to appear in the downsampled version. | ||||||||||||||||||||||||||||||||||||||||||||||
| pd.Index | ||||||||||||||||||||||||||||||||||||||||||||||
| A pandas Index object containing the IDs of the selected cells. | ||||||||||||||||||||||||||||||||||||||||||||||
| These IDs correspond to the original DataFrame's or AnnData object's | ||||||||||||||||||||||||||||||||||||||||||||||
| index. | ||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| logging.basicConfig(level=logging.WARNING) | ||||||||||||||||||||||||||||||||||||||||||||||
| # Convert annotations to list if it's a string | ||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(annotations, str): | ||||||||||||||||||||||||||||||||||||||||||||||
| annotations = [annotations] | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| # Check if the columns to downsample on exist | ||||||||||||||||||||||||||||||||||||||||||||||
| missing_columns = [ | ||||||||||||||||||||||||||||||||||||||||||||||
| col for col in annotations if col not in input_data.columns | ||||||||||||||||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||||||||||||||||
| if missing_columns: | ||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||||||||||||||||||||||||||
| f"Columns {missing_columns} do not exist in the dataframe" | ||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| # If n_samples is None, return the input data without processing | ||||||||||||||||||||||||||||||||||||||||||||||
| if n_samples is None: | ||||||||||||||||||||||||||||||||||||||||||||||
| return input_data.copy() | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| # Combine annotations into a single column if multiple annotations | ||||||||||||||||||||||||||||||||||||||||||||||
| if len(annotations) > 1: | ||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
704
to
707
|
||||||||||||||||||||||||||||||||||||||||||||||
| logging.basicConfig(level=logging.WARNING) | |
| # Convert annotations to list if it's a string | |
| if isinstance(annotations, str): | |
| annotations = [annotations] | |
| # Check if the columns to downsample on exist | |
| missing_columns = [ | |
| col for col in annotations if col not in input_data.columns | |
| ] | |
| if missing_columns: | |
| raise ValueError( | |
| f"Columns {missing_columns} do not exist in the dataframe" | |
| ) | |
| # If n_samples is None, return the input data without processing | |
| if n_samples is None: | |
| return input_data.copy() | |
| # Combine annotations into a single column if multiple annotations | |
| if len(annotations) > 1: | |
| # Combine annotations into a single column if multiple annotations | |
| if len(annotations) > 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable grouping_col is used inconsistently.
Suggested:
if len(annotations) > 1:
cell_data[combined_col_name] = cell_data[annotations].apply(
lambda row: '_'.join(row.values.astype(str)), axis=1)
grouping_col = combined_col_name
else:
grouping_col = annotations[0] # This is a string, not a Series
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, I will go through & add these suggestions today
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -185,6 +185,62 @@ def test_downsampling_effect_multi_obs(self): | |
| ) | ||
| self.assertDictEqual(expected_counts, actual_counts) | ||
|
|
||
| def test_anndata_input(self): | ||
| """ | ||
| Test that downsample_cells accepts anndata objects as input, | ||
| returns an anndata object and performs downsampling correctly, | ||
| retaining features in .X and annotations in .obs. | ||
| """ | ||
|
|
||
| # create anndata object | ||
| X_data = pd.DataFrame({ | ||
| 'feature1': [1, 3, 5, 7, 9, 12, 14, 16], | ||
| 'feature2': [2, 4, 6, 8, 10, 13, 15, 18], | ||
| 'feature3': [3, 5, 7, 9, 11, 14, 16, 19] | ||
| }) | ||
|
|
||
| obs_data = pd.DataFrame({ | ||
| 'phenotype': [ | ||
| 'phenotype1', | ||
| 'phenotype1', | ||
| 'phenotype2', | ||
| 'phenotype2', | ||
| 'phenotype3', | ||
| 'phenotype3', | ||
| 'phenotype4', | ||
| 'phenotype4' | ||
| ] | ||
| }) | ||
|
|
||
| anndata_obj = anndata.AnnData(X = X_data, obs = obs_data) | ||
|
|
||
| # call downsample on the anndata object | ||
| downsampled_adata = downsample_cells( | ||
| input_data = anndata_obj, | ||
| annotations = 'phenotype', | ||
| n_samples = 1, | ||
| stratify = False, | ||
| rand = True, | ||
| combined_col_name= '_combined_' | ||
| ) | ||
|
|
||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The test calls the function with stratify=False and min_threshold=5. In the downsample_cells function, the min_threshold parameter is only used when stratify=True.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The changes should be complete now. I have tested the code and it looks like its working. Let me know if any other fixes need to be made. |
||
| # confirm the downsampled_df is an anndata object | ||
| self.assertTrue(isinstance(downsampled_adata, anndata.AnnData)) | ||
|
|
||
| # confirm number of samples after downsampling is correct | ||
| # (four groups with one sample each is four rows total) | ||
| self.assertEqual(downsampled_adata.shape[0], 4) | ||
|
|
||
| # confirm the number of groups (phenotypes) is still four | ||
| self.assertEqual(downsampled_adata.obs['phenotype'].nunique(), 4) | ||
|
|
||
| # confirm original annotation column is present | ||
| self.assertIn('phenotype', downsampled_adata.obs.columns) | ||
|
|
||
| # confirm feature columns are present in .var_names | ||
| expected_features = X_data.columns.tolist() | ||
| for feature in expected_features: | ||
| self.assertIn(feature, downsampled_adata.var_names) | ||
|
|
||
| if __name__ == '__main__': | ||
| unittest.main() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For consistency and clarity, update the error message to refer to 'AnnData' (with proper casing) instead of 'Anndata'.