From 520934ef77c76a44f626036991fc6a3967d9fc18 Mon Sep 17 00:00:00 2001 From: Chloe-Thangavelu <83045403+Chloe-Thangavelu@users.noreply.github.com> Date: Wed, 28 May 2025 20:27:59 -0700 Subject: [PATCH 1/5] feat(downsample):anndata.AnnData input for downsample_cells Modified downsample_cells function to accept anndata.AnnData objects as input. When an AnnData object is provided, .X` and .obs data are combined into a pandas DataFrame, before applying the rest of the downsampling function. --- src/spac/data_utils.py | 10 ++++ .../test_data_utils/test_downsample_cells.py | 55 +++++++++++++++++++ 2 files changed, 65 insertions(+) diff --git a/src/spac/data_utils.py b/src/spac/data_utils.py index 736701ee..0e03bde0 100644 --- a/src/spac/data_utils.py +++ b/src/spac/data_utils.py @@ -610,6 +610,16 @@ def downsample_cells(input_data, annotations, n_samples=None, stratify=False, """ logging.basicConfig(level=logging.WARNING) + + # if input is an anndata object convert to pandas dataframe + if isinstance(input_data, anndata.AnnData): + counts_df = input_data.to_df() + input_data = pd.merge(counts_df, input_data.obs, left_index = True, right_index = True, how = 'left') + elif isinstance(input_data, pd.DataFrame): + pass + else: + raise TypeError("Input data must be a Pandas DataFrame or Anndata Object.") + # Convert annotations to list if it's a string if isinstance(annotations, str): annotations = [annotations] diff --git a/tests/test_data_utils/test_downsample_cells.py b/tests/test_data_utils/test_downsample_cells.py index 55907e45..e0ab6fe0 100644 --- a/tests/test_data_utils/test_downsample_cells.py +++ b/tests/test_data_utils/test_downsample_cells.py @@ -185,6 +185,61 @@ def test_downsampling_effect_multi_obs(self): ) self.assertDictEqual(expected_counts, actual_counts) + def test_anndata_input(self): + """ + Test that downsample_cells accepts anndata objects as input. + Check that it merges anndata .X and .obs into a dataframe + and performs downsampling correctly. + """ + + # create anndata object + X_data = pd.DataFrame({ + 'feature1': [1, 3, 5, 7, 9, 12, 14, 16], + 'feature2': [2, 4, 6, 8, 10, 13, 15, 18], + 'feature3': [3, 5, 7, 9, 11, 14, 16, 19] + }) + + obs_data = pd.DataFrame({ + 'phenotype': [ + 'phenotype1', + 'phenotype1', + 'phenotype2', + 'phenotype2', + 'phenotype3', + 'phenotype3', + 'phenotype4', + 'phenotype4' + ] + }) + + anndata_obj = anndata.AnnData(X = X_data, obs = obs_data) + + # call downsample on the anndata object + downsampled_df = downsample_cells( + input_data = anndata_obj, + annotations = 'phenotype', + n_samples = 1, + stratify = False, + rand = True, + combined_col_name= '_combined_', + min_threshold= 5 + ) + + # confirm the downsampled_df is a pandas dataframe + self.assertTrue(isinstance(downsampled_df, pd.DataFrame)) + + # confirm number of samples after downsampling is correct + # (four groups with one sample each is four rows total) + self.assertTrue(len(downsampled_df) == 4) + + # confirm the number of groups (phenotypes) is still four + self.assertTrue(downsampled_df['phenotype'].nunique() == 4) + + # confirm original annotation column & feature columns are present + expected_feature_columns = X_data.columns.tolist() + for col in expected_feature_columns: + self.assertIn(col, downsampled_df.columns) + self.assertIn('phenotype', downsampled_df.columns) if __name__ == '__main__': unittest.main() From 507ef7f06af83a7599b46abfc3669df47cc2d19b Mon Sep 17 00:00:00 2001 From: Chloe-Thangavelu <83045403+Chloe-Thangavelu@users.noreply.github.com> Date: Thu, 12 Jun 2025 20:43:54 -0700 Subject: [PATCH 2/5] feat(downsample):AnnData input and output for downsample_cells This commit modifies the 'downsample_cells' function and adds a helper function '_get_downsampled_indices' to provide cell downsampling capabilities for both AnnData objects and Pandas DataFrames. --- src/spac/data_utils.py | 234 ++++++++++++++++++++++++++--------------- 1 file changed, 149 insertions(+), 85 deletions(-) diff --git a/src/spac/data_utils.py b/src/spac/data_utils.py index 0e03bde0..3376a53d 100644 --- a/src/spac/data_utils.py +++ b/src/spac/data_utils.py @@ -538,46 +538,151 @@ def _select_values_anndata(data, annotation, values, exclude_values): return filtered_data -def downsample_cells(input_data, annotations, n_samples=None, stratify=False, - rand=False, combined_col_name='_combined_', - min_threshold=5): +def downsample_cells(input_data, annotations, n_samples=None, + stratify=False, rand=False, combined_col_name='_combined_', min_threshold=5): + + """ + Custom downsampling of data based on one or more annotations. + + This function offers two primary modes of operation: + 1. **Grouping (stratify=False)**: + - For a single annotation: The data is grouped by unique values of the + annotation, and 'n_samples' rows are selected from each group. + - For multiple annotations: The data is grouped based on unique + combinations of the annotations, and 'n_samples' rows are selected + from each combined group. + + 2. **Stratification (stratify=True)**: + - Annotations (single or multiple) are combined into a new column. + - Proportionate stratified sampling is performed based on the unique + combinations in the new column, ensuring that the downsampled dataset + maintains the proportionate representation of each combined group + from the original dataset. + + Parameters + ---------- + input_data : pd.DataFrame or anndata.AnnData + The input data. Must be either a pandas dataframe or an anndata object. + annotations : str or list of str + The column name(s) to downsample on. If multiple column names are + provided, their values are combined using an underscore as a separator. + n_samples : int, default=None + The number of samples to return. Behavior differs based on the + 'stratify' parameter: + - stratify=False: Returns 'n_samples' for each unique value (or + combination) of annotations. + - stratify=True: Returns a total of 'n_samples' stratified by the + frequency of every label or combined labels in the annotation(s). + stratify : bool, default=False + If true, perform proportionate stratified sampling based on the unique + combinations of annotations. This ensures that the downsampled dataset + maintains the proportionate representation of each combined group from + the original dataset. + rand : bool, default=False + If true and stratify is True, randomly select the returned cells. + Otherwise, choose the first n cells. + combined_col_name : str, default='_combined_' + Name of the column that will store combined values when multiple + annotation columns are provided. + min_threshold : int, default=5 + The minimum number of samples a combined group should have in the + original dataset to be considered in the downsampled dataset. Groups + with fewer samples than this threshold will be excluded from the + stratification process. Adjusting this parameter determines the + minimum presence a combined group should have in the original dataset + to appear in the downsampled version. + + Returns + ------- + output_data: pd.DataFrame or anndata.Anndata + The proportionately stratified downsampled dataset. + If the input is a dataframe, a dataframe is returned. + If the input is a anndata object, an anndata object is returned. + + Notes + ----- + This function emphasizes proportionate stratified sampling, ensuring that + the downsampled dataset is a representative subset of the original data + with respect to the combined annotations. Due to this proportionate nature, + not all unique combinations from the original dataset might be present in + the downsampled dataset, especially if a particular combination has very + few samples in the original dataset. The `min_threshold` parameter can be + adjusted to determine the minimum number of samples a combined group + should have in the original dataset to appear in the downsampled version. + """ + + # If n_samples is None, return the input data without processing + if n_samples is None: + return input_data + + # Extract the cell information for downsampling process + if isinstance(input_data, anndata.AnnData): + cell_data = input_data.obs[annotations].copy() + + elif isinstance(input_data, pd.DataFrame): + cell_data = input_data[annotations].copy() + + else: + raise TypeError("Input data must be a Pandas DataFrame or Anndata Object.") + + # Prepare cell annotation column for downsample process: + # Convert annotations to list if it's a string + if isinstance(annotations, str): + annotations = [annotations] + + # Check if the columns to downsample on exist + missing_columns = [ + col for col in annotations if col not in cell_data.columns + ] + if missing_columns: + raise ValueError( + f"Columns {missing_columns} do not exist in the dataframe" + ) + + # Determine which cells to keep or remove based on annotations column + cell_indexes = _get_downsampled_indexes(cell_data, annotations, + n_samples, stratify, rand, combined_col_name, min_threshold) + + # Filter cell indexes to be removed + if isinstance(input_data, anndata.AnnData): + downsampled_data = input_data[cell_indexes,:].copy() + return downsampled_data + + elif isinstance(input_data, pd.DataFrame): + downsampled_data = input_data.loc[cell_indexes,:].copy() + return downsampled_data + + else: + raise TypeError("Input data must be a Pandas DataFrame or Anndata Object.") + + +def _get_downsampled_indexes(cell_data, annotations, + n_samples, stratify, rand, combined_col_name, min_threshold): """ - Custom downsampling of data based on one or more annotations. - - This function offers two primary modes of operation: - 1. **Grouping (stratify=False)**: - - For a single annotation: The data is grouped by unique values of the - annotation, and 'n_samples' rows are selected from each group. - - For multiple annotations: The data is grouped based on unique - combinations of the annotations, and 'n_samples' rows are selected - from each combined group. - - 2. **Stratification (stratify=True)**: - - Annotations (single or multiple) are combined into a new column. - - Proportionate stratified sampling is performed based on the unique - combinations in the new column, ensuring that the downsampled dataset - maintains the proportionate representation of each combined group - from the original dataset. + Helper function to compute cell indexes for downsampling. + + This function processes a DataFrame containing only cell annotations + to determine which cell indexes should be retained after downsampling. Parameters ---------- - input_data : pd.DataFrame - The input data frame. - annotations : str or list of str - The column name(s) to downsample on. If multiple column names are - provided, their values are combined using an underscore as a separator. + cell_data : pd.DataFrame + A dataframe containing cell IDs as indexes and the annotation columns + from the original dataset (AnnData.obs or DataFrame[annotations]). + annotations : list of str + The list of column names in 'cell_data' used for grouping or + stratification. If multiple names are provided, they are combined + into a single grouping column. n_samples : int, default=None The number of samples to return. Behavior differs based on the 'stratify' parameter: - stratify=False: Returns 'n_samples' for each unique value (or - combination) of annotations. + combination) of annotations. - stratify=True: Returns a total of 'n_samples' stratified by the - frequency of every label or combined labels in the annotation(s). + frequency of every label or combined labels in the annotation(s). stratify : bool, default=False - If true, perform proportionate stratified sampling based on the unique - combinations of annotations. This ensures that the downsampled dataset - maintains the proportionate representation of each combined group from - the original dataset. + If true, maintains the proportionate representation of each + combined group from the original dataset in downsampled version. rand : bool, default=False If true and stratify is True, randomly select the returned cells. Otherwise, choose the first n cells. @@ -586,72 +691,32 @@ def downsample_cells(input_data, annotations, n_samples=None, stratify=False, annotation columns are provided. min_threshold : int, default=5 The minimum number of samples a combined group should have in the - original dataset to be considered in the downsampled dataset. Groups - with fewer samples than this threshold will be excluded from the - stratification process. Adjusting this parameter determines the - minimum presence a combined group should have in the original dataset - to appear in the downsampled version. + original dataset to not be excluded in the downsampled dataset. Returns ------- - output_data: pd.DataFrame - The proportionately stratified downsampled data frame. - - Notes - ----- - This function emphasizes proportionate stratified sampling, ensuring that - the downsampled dataset is a representative subset of the original data - with respect to the combined annotations. Due to this proportionate nature, - not all unique combinations from the original dataset might be present in - the downsampled dataset, especially if a particular combination has very - few samples in the original dataset. The `min_threshold` parameter can be - adjusted to determine the minimum number of samples a combined group - should have in the original dataset to appear in the downsampled version. + pd.Index + A pandas Index object containing the IDs of the selected cells. + These IDs correspond to the original DataFrame's or AnnData object's + index. """ logging.basicConfig(level=logging.WARNING) - - # if input is an anndata object convert to pandas dataframe - if isinstance(input_data, anndata.AnnData): - counts_df = input_data.to_df() - input_data = pd.merge(counts_df, input_data.obs, left_index = True, right_index = True, how = 'left') - elif isinstance(input_data, pd.DataFrame): - pass - else: - raise TypeError("Input data must be a Pandas DataFrame or Anndata Object.") - - # Convert annotations to list if it's a string - if isinstance(annotations, str): - annotations = [annotations] - - # Check if the columns to downsample on exist - missing_columns = [ - col for col in annotations if col not in input_data.columns - ] - if missing_columns: - raise ValueError( - f"Columns {missing_columns} do not exist in the dataframe" - ) - - # If n_samples is None, return the input data without processing - if n_samples is None: - return input_data.copy() # Combine annotations into a single column if multiple annotations if len(annotations) > 1: - input_data[combined_col_name] = input_data[annotations].apply( + grouping_col = cell_data[annotations].apply( lambda row: '_'.join(row.values.astype(str)), axis=1) - grouping_col = combined_col_name else: grouping_col = annotations[0] # Stratify selection if stratify: # Calculate proportions - freqs = input_data[grouping_col].value_counts(normalize=True) + freqs = cell_data[grouping_col].value_counts(normalize=True) # Exclude groups with fewer samples than the min_threshold - filtered_freqs = freqs[freqs * len(input_data) >= min_threshold] + filtered_freqs = freqs[freqs * len(cell_data) >= min_threshold] # Log warning for groups that are excluded excluded_groups = freqs[~freqs.index.isin(filtered_freqs.index)] @@ -672,7 +737,7 @@ def downsample_cells(input_data, annotations, n_samples=None, stratify=False, zero_sample_groups = samples_per_group[samples_per_group == 0] groups_with_zero_samples = zero_sample_groups.index group_freqs = freqs[groups_with_zero_samples] - original_counts = group_freqs * len(input_data) + original_counts = group_freqs * len(cell_data) # Ensure each group has at least one sample if its frequency # is non-zero @@ -711,7 +776,7 @@ def downsample_cells(input_data, annotations, n_samples=None, stratify=False, # Sample data sampled_data = [] - for group, group_data in input_data.groupby(grouping_col): + for group, group_data in cell_data.groupby(grouping_col): sample_count = samples_per_group.get(group, 0) sample_size = min(sample_count, len(group_data)) if rand: @@ -723,9 +788,8 @@ def downsample_cells(input_data, annotations, n_samples=None, stratify=False, output_data = pd.concat(sampled_data) else: - output_data = input_data.groupby(grouping_col, group_keys=False).apply( - lambda x: x.head(min(n_samples, len(x))) - ).reset_index(drop=True) + output_data = cell_data.groupby(grouping_col, group_keys=False).apply( + lambda x: x.head(min(n_samples, len(x)))) # Log the final counts for each label in the downsampled dataset label_counts = output_data[grouping_col].value_counts() @@ -735,8 +799,8 @@ def downsample_cells(input_data, annotations, n_samples=None, stratify=False, # Log the total number of rows in the resulting data logging.info(f"Number of rows in the returned data: {len(output_data)}") - return output_data - + # Return the cell indexes of the selected downsampled data + return output_data.index def calculate_centroid( data, From 3e843831831d8661ece99de0939638af3512076b Mon Sep 17 00:00:00 2001 From: Chloe-Thangavelu <83045403+Chloe-Thangavelu@users.noreply.github.com> Date: Thu, 12 Jun 2025 22:13:31 -0700 Subject: [PATCH 3/5] Update test_downsample_cells.py This test now checks anndata objects are accepted, downsampled correctly, and returned as annadata objects. --- .../test_data_utils/test_downsample_cells.py | 32 ++++++++++--------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/tests/test_data_utils/test_downsample_cells.py b/tests/test_data_utils/test_downsample_cells.py index e0ab6fe0..33186f20 100644 --- a/tests/test_data_utils/test_downsample_cells.py +++ b/tests/test_data_utils/test_downsample_cells.py @@ -187,11 +187,11 @@ def test_downsampling_effect_multi_obs(self): def test_anndata_input(self): """ - Test that downsample_cells accepts anndata objects as input. - Check that it merges anndata .X and .obs into a dataframe - and performs downsampling correctly. + Test that downsample_cells accepts anndata objects as input, + returns an anndata object and performs downsampling correctly, + retaining features in .X and annotations in .obs. """ - + # create anndata object X_data = pd.DataFrame({ 'feature1': [1, 3, 5, 7, 9, 12, 14, 16], @@ -215,7 +215,7 @@ def test_anndata_input(self): anndata_obj = anndata.AnnData(X = X_data, obs = obs_data) # call downsample on the anndata object - downsampled_df = downsample_cells( + downsampled_adata = downsample_cells( input_data = anndata_obj, annotations = 'phenotype', n_samples = 1, @@ -224,22 +224,24 @@ def test_anndata_input(self): combined_col_name= '_combined_', min_threshold= 5 ) - - # confirm the downsampled_df is a pandas dataframe - self.assertTrue(isinstance(downsampled_df, pd.DataFrame)) + + # confirm the downsampled_df is an anndata object + self.assertTrue(isinstance(downsampled_adata, anndata.AnnData)) # confirm number of samples after downsampling is correct # (four groups with one sample each is four rows total) - self.assertTrue(len(downsampled_df) == 4) + self.assertEqual(downsampled_adata.shape[0], 4) # confirm the number of groups (phenotypes) is still four - self.assertTrue(downsampled_df['phenotype'].nunique() == 4) + self.assertEqual(downsampled_adata.obs['phenotype'].nunique(), 4) - # confirm original annotation column & feature columns are present - expected_feature_columns = X_data.columns.tolist() - for col in expected_feature_columns: - self.assertIn(col, downsampled_df.columns) - self.assertIn('phenotype', downsampled_df.columns) + # confirm original annotation column is present + self.assertIn('phenotype', downsampled_adata.obs.columns) + + # confirm feature columns are present in .var_names + expected_features = X_data.columns.tolist() + for feature in expected_features: + self.assertIn(feature, downsampled_adata.var_names) if __name__ == '__main__': unittest.main() From c110994cee46dbd69d29844aa539d7a59da2e212 Mon Sep 17 00:00:00 2001 From: Chloe-Thangavelu <83045403+Chloe-Thangavelu@users.noreply.github.com> Date: Thu, 12 Jun 2025 22:25:39 -0700 Subject: [PATCH 4/5] Fix downsample_cells in data_utils.py Reordering the code to convert annotations to a list before extracting annotation information, ensuring it is in DataFrame format as required by subsequent downsample_cells code. --- src/spac/data_utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/spac/data_utils.py b/src/spac/data_utils.py index 3376a53d..6169cd2e 100644 --- a/src/spac/data_utils.py +++ b/src/spac/data_utils.py @@ -614,7 +614,11 @@ def downsample_cells(input_data, annotations, n_samples=None, # If n_samples is None, return the input data without processing if n_samples is None: return input_data - + + # Convert annotations to list if it's a string + if isinstance(annotations, str): + annotations = [annotations] + # Extract the cell information for downsampling process if isinstance(input_data, anndata.AnnData): cell_data = input_data.obs[annotations].copy() @@ -626,10 +630,6 @@ def downsample_cells(input_data, annotations, n_samples=None, raise TypeError("Input data must be a Pandas DataFrame or Anndata Object.") # Prepare cell annotation column for downsample process: - # Convert annotations to list if it's a string - if isinstance(annotations, str): - annotations = [annotations] - # Check if the columns to downsample on exist missing_columns = [ col for col in annotations if col not in cell_data.columns From 1915dca7b85092658ab50051edab2f5f3f97ecdb Mon Sep 17 00:00:00 2001 From: Chloe-Thangavelu <83045403+Chloe-Thangavelu@users.noreply.github.com> Date: Wed, 25 Jun 2025 17:21:41 -0700 Subject: [PATCH 5/5] Fix: Improve downsample cells grouping When multiple annotations are provided, a new temporary column (named by `combined_col_name`) is now explicitly added to the `cell_data` DataFrame, making `grouping_col` consistently a column name (string). Replaced slow `DataFrame.apply` with the vectorized `DataFrame.astype(str).agg('_'.join, axis=1)` for combining annotations --- src/spac/data_utils.py | 4 ++-- tests/test_data_utils/test_downsample_cells.py | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/spac/data_utils.py b/src/spac/data_utils.py index 6169cd2e..633fd77f 100644 --- a/src/spac/data_utils.py +++ b/src/spac/data_utils.py @@ -705,8 +705,8 @@ def _get_downsampled_indexes(cell_data, annotations, # Combine annotations into a single column if multiple annotations if len(annotations) > 1: - grouping_col = cell_data[annotations].apply( - lambda row: '_'.join(row.values.astype(str)), axis=1) + cell_data[combined_col_name] = cell_data[annotations].astype(str).agg('_'.join, axis=1) + grouping_col = combined_col_name else: grouping_col = annotations[0] diff --git a/tests/test_data_utils/test_downsample_cells.py b/tests/test_data_utils/test_downsample_cells.py index 33186f20..8ba0793c 100644 --- a/tests/test_data_utils/test_downsample_cells.py +++ b/tests/test_data_utils/test_downsample_cells.py @@ -221,8 +221,7 @@ def test_anndata_input(self): n_samples = 1, stratify = False, rand = True, - combined_col_name= '_combined_', - min_threshold= 5 + combined_col_name= '_combined_' ) # confirm the downsampled_df is an anndata object