diff --git a/src/spac/data_utils.py b/src/spac/data_utils.py index 736701ee..633fd77f 100644 --- a/src/spac/data_utils.py +++ b/src/spac/data_utils.py @@ -538,46 +538,151 @@ def _select_values_anndata(data, annotation, values, exclude_values): return filtered_data -def downsample_cells(input_data, annotations, n_samples=None, stratify=False, - rand=False, combined_col_name='_combined_', - min_threshold=5): +def downsample_cells(input_data, annotations, n_samples=None, + stratify=False, rand=False, combined_col_name='_combined_', min_threshold=5): + + """ + Custom downsampling of data based on one or more annotations. + + This function offers two primary modes of operation: + 1. **Grouping (stratify=False)**: + - For a single annotation: The data is grouped by unique values of the + annotation, and 'n_samples' rows are selected from each group. + - For multiple annotations: The data is grouped based on unique + combinations of the annotations, and 'n_samples' rows are selected + from each combined group. + + 2. **Stratification (stratify=True)**: + - Annotations (single or multiple) are combined into a new column. + - Proportionate stratified sampling is performed based on the unique + combinations in the new column, ensuring that the downsampled dataset + maintains the proportionate representation of each combined group + from the original dataset. + + Parameters + ---------- + input_data : pd.DataFrame or anndata.AnnData + The input data. Must be either a pandas dataframe or an anndata object. + annotations : str or list of str + The column name(s) to downsample on. If multiple column names are + provided, their values are combined using an underscore as a separator. + n_samples : int, default=None + The number of samples to return. Behavior differs based on the + 'stratify' parameter: + - stratify=False: Returns 'n_samples' for each unique value (or + combination) of annotations. + - stratify=True: Returns a total of 'n_samples' stratified by the + frequency of every label or combined labels in the annotation(s). + stratify : bool, default=False + If true, perform proportionate stratified sampling based on the unique + combinations of annotations. This ensures that the downsampled dataset + maintains the proportionate representation of each combined group from + the original dataset. + rand : bool, default=False + If true and stratify is True, randomly select the returned cells. + Otherwise, choose the first n cells. + combined_col_name : str, default='_combined_' + Name of the column that will store combined values when multiple + annotation columns are provided. + min_threshold : int, default=5 + The minimum number of samples a combined group should have in the + original dataset to be considered in the downsampled dataset. Groups + with fewer samples than this threshold will be excluded from the + stratification process. Adjusting this parameter determines the + minimum presence a combined group should have in the original dataset + to appear in the downsampled version. + + Returns + ------- + output_data: pd.DataFrame or anndata.Anndata + The proportionately stratified downsampled dataset. + If the input is a dataframe, a dataframe is returned. + If the input is a anndata object, an anndata object is returned. + + Notes + ----- + This function emphasizes proportionate stratified sampling, ensuring that + the downsampled dataset is a representative subset of the original data + with respect to the combined annotations. Due to this proportionate nature, + not all unique combinations from the original dataset might be present in + the downsampled dataset, especially if a particular combination has very + few samples in the original dataset. The `min_threshold` parameter can be + adjusted to determine the minimum number of samples a combined group + should have in the original dataset to appear in the downsampled version. + """ + + # If n_samples is None, return the input data without processing + if n_samples is None: + return input_data + + # Convert annotations to list if it's a string + if isinstance(annotations, str): + annotations = [annotations] + + # Extract the cell information for downsampling process + if isinstance(input_data, anndata.AnnData): + cell_data = input_data.obs[annotations].copy() + + elif isinstance(input_data, pd.DataFrame): + cell_data = input_data[annotations].copy() + + else: + raise TypeError("Input data must be a Pandas DataFrame or Anndata Object.") + + # Prepare cell annotation column for downsample process: + # Check if the columns to downsample on exist + missing_columns = [ + col for col in annotations if col not in cell_data.columns + ] + if missing_columns: + raise ValueError( + f"Columns {missing_columns} do not exist in the dataframe" + ) + + # Determine which cells to keep or remove based on annotations column + cell_indexes = _get_downsampled_indexes(cell_data, annotations, + n_samples, stratify, rand, combined_col_name, min_threshold) + + # Filter cell indexes to be removed + if isinstance(input_data, anndata.AnnData): + downsampled_data = input_data[cell_indexes,:].copy() + return downsampled_data + + elif isinstance(input_data, pd.DataFrame): + downsampled_data = input_data.loc[cell_indexes,:].copy() + return downsampled_data + + else: + raise TypeError("Input data must be a Pandas DataFrame or Anndata Object.") + + +def _get_downsampled_indexes(cell_data, annotations, + n_samples, stratify, rand, combined_col_name, min_threshold): """ - Custom downsampling of data based on one or more annotations. - - This function offers two primary modes of operation: - 1. **Grouping (stratify=False)**: - - For a single annotation: The data is grouped by unique values of the - annotation, and 'n_samples' rows are selected from each group. - - For multiple annotations: The data is grouped based on unique - combinations of the annotations, and 'n_samples' rows are selected - from each combined group. - - 2. **Stratification (stratify=True)**: - - Annotations (single or multiple) are combined into a new column. - - Proportionate stratified sampling is performed based on the unique - combinations in the new column, ensuring that the downsampled dataset - maintains the proportionate representation of each combined group - from the original dataset. + Helper function to compute cell indexes for downsampling. + + This function processes a DataFrame containing only cell annotations + to determine which cell indexes should be retained after downsampling. Parameters ---------- - input_data : pd.DataFrame - The input data frame. - annotations : str or list of str - The column name(s) to downsample on. If multiple column names are - provided, their values are combined using an underscore as a separator. + cell_data : pd.DataFrame + A dataframe containing cell IDs as indexes and the annotation columns + from the original dataset (AnnData.obs or DataFrame[annotations]). + annotations : list of str + The list of column names in 'cell_data' used for grouping or + stratification. If multiple names are provided, they are combined + into a single grouping column. n_samples : int, default=None The number of samples to return. Behavior differs based on the 'stratify' parameter: - stratify=False: Returns 'n_samples' for each unique value (or - combination) of annotations. + combination) of annotations. - stratify=True: Returns a total of 'n_samples' stratified by the - frequency of every label or combined labels in the annotation(s). + frequency of every label or combined labels in the annotation(s). stratify : bool, default=False - If true, perform proportionate stratified sampling based on the unique - combinations of annotations. This ensures that the downsampled dataset - maintains the proportionate representation of each combined group from - the original dataset. + If true, maintains the proportionate representation of each + combined group from the original dataset in downsampled version. rand : bool, default=False If true and stratify is True, randomly select the returned cells. Otherwise, choose the first n cells. @@ -586,51 +691,21 @@ def downsample_cells(input_data, annotations, n_samples=None, stratify=False, annotation columns are provided. min_threshold : int, default=5 The minimum number of samples a combined group should have in the - original dataset to be considered in the downsampled dataset. Groups - with fewer samples than this threshold will be excluded from the - stratification process. Adjusting this parameter determines the - minimum presence a combined group should have in the original dataset - to appear in the downsampled version. + original dataset to not be excluded in the downsampled dataset. Returns ------- - output_data: pd.DataFrame - The proportionately stratified downsampled data frame. - - Notes - ----- - This function emphasizes proportionate stratified sampling, ensuring that - the downsampled dataset is a representative subset of the original data - with respect to the combined annotations. Due to this proportionate nature, - not all unique combinations from the original dataset might be present in - the downsampled dataset, especially if a particular combination has very - few samples in the original dataset. The `min_threshold` parameter can be - adjusted to determine the minimum number of samples a combined group - should have in the original dataset to appear in the downsampled version. + pd.Index + A pandas Index object containing the IDs of the selected cells. + These IDs correspond to the original DataFrame's or AnnData object's + index. """ logging.basicConfig(level=logging.WARNING) - # Convert annotations to list if it's a string - if isinstance(annotations, str): - annotations = [annotations] - - # Check if the columns to downsample on exist - missing_columns = [ - col for col in annotations if col not in input_data.columns - ] - if missing_columns: - raise ValueError( - f"Columns {missing_columns} do not exist in the dataframe" - ) - - # If n_samples is None, return the input data without processing - if n_samples is None: - return input_data.copy() # Combine annotations into a single column if multiple annotations if len(annotations) > 1: - input_data[combined_col_name] = input_data[annotations].apply( - lambda row: '_'.join(row.values.astype(str)), axis=1) + cell_data[combined_col_name] = cell_data[annotations].astype(str).agg('_'.join, axis=1) grouping_col = combined_col_name else: grouping_col = annotations[0] @@ -638,10 +713,10 @@ def downsample_cells(input_data, annotations, n_samples=None, stratify=False, # Stratify selection if stratify: # Calculate proportions - freqs = input_data[grouping_col].value_counts(normalize=True) + freqs = cell_data[grouping_col].value_counts(normalize=True) # Exclude groups with fewer samples than the min_threshold - filtered_freqs = freqs[freqs * len(input_data) >= min_threshold] + filtered_freqs = freqs[freqs * len(cell_data) >= min_threshold] # Log warning for groups that are excluded excluded_groups = freqs[~freqs.index.isin(filtered_freqs.index)] @@ -662,7 +737,7 @@ def downsample_cells(input_data, annotations, n_samples=None, stratify=False, zero_sample_groups = samples_per_group[samples_per_group == 0] groups_with_zero_samples = zero_sample_groups.index group_freqs = freqs[groups_with_zero_samples] - original_counts = group_freqs * len(input_data) + original_counts = group_freqs * len(cell_data) # Ensure each group has at least one sample if its frequency # is non-zero @@ -701,7 +776,7 @@ def downsample_cells(input_data, annotations, n_samples=None, stratify=False, # Sample data sampled_data = [] - for group, group_data in input_data.groupby(grouping_col): + for group, group_data in cell_data.groupby(grouping_col): sample_count = samples_per_group.get(group, 0) sample_size = min(sample_count, len(group_data)) if rand: @@ -713,9 +788,8 @@ def downsample_cells(input_data, annotations, n_samples=None, stratify=False, output_data = pd.concat(sampled_data) else: - output_data = input_data.groupby(grouping_col, group_keys=False).apply( - lambda x: x.head(min(n_samples, len(x))) - ).reset_index(drop=True) + output_data = cell_data.groupby(grouping_col, group_keys=False).apply( + lambda x: x.head(min(n_samples, len(x)))) # Log the final counts for each label in the downsampled dataset label_counts = output_data[grouping_col].value_counts() @@ -725,8 +799,8 @@ def downsample_cells(input_data, annotations, n_samples=None, stratify=False, # Log the total number of rows in the resulting data logging.info(f"Number of rows in the returned data: {len(output_data)}") - return output_data - + # Return the cell indexes of the selected downsampled data + return output_data.index def calculate_centroid( data, diff --git a/tests/test_data_utils/test_downsample_cells.py b/tests/test_data_utils/test_downsample_cells.py index 55907e45..8ba0793c 100644 --- a/tests/test_data_utils/test_downsample_cells.py +++ b/tests/test_data_utils/test_downsample_cells.py @@ -185,6 +185,62 @@ def test_downsampling_effect_multi_obs(self): ) self.assertDictEqual(expected_counts, actual_counts) + def test_anndata_input(self): + """ + Test that downsample_cells accepts anndata objects as input, + returns an anndata object and performs downsampling correctly, + retaining features in .X and annotations in .obs. + """ + + # create anndata object + X_data = pd.DataFrame({ + 'feature1': [1, 3, 5, 7, 9, 12, 14, 16], + 'feature2': [2, 4, 6, 8, 10, 13, 15, 18], + 'feature3': [3, 5, 7, 9, 11, 14, 16, 19] + }) + + obs_data = pd.DataFrame({ + 'phenotype': [ + 'phenotype1', + 'phenotype1', + 'phenotype2', + 'phenotype2', + 'phenotype3', + 'phenotype3', + 'phenotype4', + 'phenotype4' + ] + }) + + anndata_obj = anndata.AnnData(X = X_data, obs = obs_data) + + # call downsample on the anndata object + downsampled_adata = downsample_cells( + input_data = anndata_obj, + annotations = 'phenotype', + n_samples = 1, + stratify = False, + rand = True, + combined_col_name= '_combined_' + ) + + # confirm the downsampled_df is an anndata object + self.assertTrue(isinstance(downsampled_adata, anndata.AnnData)) + + # confirm number of samples after downsampling is correct + # (four groups with one sample each is four rows total) + self.assertEqual(downsampled_adata.shape[0], 4) + + # confirm the number of groups (phenotypes) is still four + self.assertEqual(downsampled_adata.obs['phenotype'].nunique(), 4) + + # confirm original annotation column is present + self.assertIn('phenotype', downsampled_adata.obs.columns) + + # confirm feature columns are present in .var_names + expected_features = X_data.columns.tolist() + for feature in expected_features: + self.assertIn(feature, downsampled_adata.var_names) if __name__ == '__main__': unittest.main()