Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 149 additions & 75 deletions src/spac/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,46 +538,151 @@ def _select_values_anndata(data, annotation, values, exclude_values):
return filtered_data


def downsample_cells(input_data, annotations, n_samples=None, stratify=False,
rand=False, combined_col_name='_combined_',
min_threshold=5):
def downsample_cells(input_data, annotations, n_samples=None,
stratify=False, rand=False, combined_col_name='_combined_', min_threshold=5):

"""
Custom downsampling of data based on one or more annotations.

This function offers two primary modes of operation:
1. **Grouping (stratify=False)**:
- For a single annotation: The data is grouped by unique values of the
annotation, and 'n_samples' rows are selected from each group.
- For multiple annotations: The data is grouped based on unique
combinations of the annotations, and 'n_samples' rows are selected
from each combined group.

2. **Stratification (stratify=True)**:
- Annotations (single or multiple) are combined into a new column.
- Proportionate stratified sampling is performed based on the unique
combinations in the new column, ensuring that the downsampled dataset
maintains the proportionate representation of each combined group
from the original dataset.

Parameters
----------
input_data : pd.DataFrame or anndata.AnnData
The input data. Must be either a pandas dataframe or an anndata object.
annotations : str or list of str
The column name(s) to downsample on. If multiple column names are
provided, their values are combined using an underscore as a separator.
n_samples : int, default=None
The number of samples to return. Behavior differs based on the
'stratify' parameter:
- stratify=False: Returns 'n_samples' for each unique value (or
combination) of annotations.
- stratify=True: Returns a total of 'n_samples' stratified by the
frequency of every label or combined labels in the annotation(s).
stratify : bool, default=False
If true, perform proportionate stratified sampling based on the unique
combinations of annotations. This ensures that the downsampled dataset
maintains the proportionate representation of each combined group from
the original dataset.
rand : bool, default=False
If true and stratify is True, randomly select the returned cells.
Otherwise, choose the first n cells.
combined_col_name : str, default='_combined_'
Name of the column that will store combined values when multiple
annotation columns are provided.
min_threshold : int, default=5
The minimum number of samples a combined group should have in the
original dataset to be considered in the downsampled dataset. Groups
with fewer samples than this threshold will be excluded from the
stratification process. Adjusting this parameter determines the
minimum presence a combined group should have in the original dataset
to appear in the downsampled version.

Returns
-------
output_data: pd.DataFrame or anndata.Anndata
The proportionately stratified downsampled dataset.
If the input is a dataframe, a dataframe is returned.
If the input is a anndata object, an anndata object is returned.

Notes
-----
This function emphasizes proportionate stratified sampling, ensuring that
the downsampled dataset is a representative subset of the original data
with respect to the combined annotations. Due to this proportionate nature,
not all unique combinations from the original dataset might be present in
the downsampled dataset, especially if a particular combination has very
few samples in the original dataset. The `min_threshold` parameter can be
adjusted to determine the minimum number of samples a combined group
should have in the original dataset to appear in the downsampled version.
"""

# If n_samples is None, return the input data without processing
if n_samples is None:
return input_data

# Convert annotations to list if it's a string
if isinstance(annotations, str):
annotations = [annotations]

# Extract the cell information for downsampling process
if isinstance(input_data, anndata.AnnData):
cell_data = input_data.obs[annotations].copy()

elif isinstance(input_data, pd.DataFrame):
cell_data = input_data[annotations].copy()

else:
raise TypeError("Input data must be a Pandas DataFrame or Anndata Object.")

Comment on lines +629 to +631
Copy link

Copilot AI Jun 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For consistency and clarity, update the error message to refer to 'AnnData' (with proper casing) instead of 'Anndata'.

Suggested change
else:
raise TypeError("Input data must be a Pandas DataFrame or Anndata Object.")
else:
raise TypeError("Input data must be a Pandas DataFrame or AnnData Object.")

Copilot uses AI. Check for mistakes.
# Prepare cell annotation column for downsample process:
# Check if the columns to downsample on exist
missing_columns = [
col for col in annotations if col not in cell_data.columns
]
if missing_columns:
raise ValueError(
f"Columns {missing_columns} do not exist in the dataframe"
)

# Determine which cells to keep or remove based on annotations column
cell_indexes = _get_downsampled_indexes(cell_data, annotations,
n_samples, stratify, rand, combined_col_name, min_threshold)

# Filter cell indexes to be removed
if isinstance(input_data, anndata.AnnData):
downsampled_data = input_data[cell_indexes,:].copy()
return downsampled_data

elif isinstance(input_data, pd.DataFrame):
downsampled_data = input_data.loc[cell_indexes,:].copy()
return downsampled_data

else:
raise TypeError("Input data must be a Pandas DataFrame or Anndata Object.")


def _get_downsampled_indexes(cell_data, annotations,
n_samples, stratify, rand, combined_col_name, min_threshold):
"""
Custom downsampling of data based on one or more annotations.

This function offers two primary modes of operation:
1. **Grouping (stratify=False)**:
- For a single annotation: The data is grouped by unique values of the
annotation, and 'n_samples' rows are selected from each group.
- For multiple annotations: The data is grouped based on unique
combinations of the annotations, and 'n_samples' rows are selected
from each combined group.

2. **Stratification (stratify=True)**:
- Annotations (single or multiple) are combined into a new column.
- Proportionate stratified sampling is performed based on the unique
combinations in the new column, ensuring that the downsampled dataset
maintains the proportionate representation of each combined group
from the original dataset.
Helper function to compute cell indexes for downsampling.

This function processes a DataFrame containing only cell annotations
to determine which cell indexes should be retained after downsampling.

Parameters
----------
input_data : pd.DataFrame
The input data frame.
annotations : str or list of str
The column name(s) to downsample on. If multiple column names are
provided, their values are combined using an underscore as a separator.
cell_data : pd.DataFrame
A dataframe containing cell IDs as indexes and the annotation columns
from the original dataset (AnnData.obs or DataFrame[annotations]).
annotations : list of str
The list of column names in 'cell_data' used for grouping or
stratification. If multiple names are provided, they are combined
into a single grouping column.
n_samples : int, default=None
The number of samples to return. Behavior differs based on the
'stratify' parameter:
- stratify=False: Returns 'n_samples' for each unique value (or
combination) of annotations.
combination) of annotations.
- stratify=True: Returns a total of 'n_samples' stratified by the
frequency of every label or combined labels in the annotation(s).
frequency of every label or combined labels in the annotation(s).
stratify : bool, default=False
If true, perform proportionate stratified sampling based on the unique
combinations of annotations. This ensures that the downsampled dataset
maintains the proportionate representation of each combined group from
the original dataset.
If true, maintains the proportionate representation of each
combined group from the original dataset in downsampled version.
rand : bool, default=False
If true and stratify is True, randomly select the returned cells.
Otherwise, choose the first n cells.
Expand All @@ -586,62 +691,32 @@ def downsample_cells(input_data, annotations, n_samples=None, stratify=False,
annotation columns are provided.
Copy link
Collaborator

@fangliu117 fangliu117 Jun 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The "combined_col_name" parameter is documented but never used in the code.
Since grouping_col is a pd.Series and not a new column in the cell_data DataFrame, the combined_col_name parameter isn't strictly necessary. May either remove or assign the name to the Series.

min_threshold : int, default=5
The minimum number of samples a combined group should have in the
original dataset to be considered in the downsampled dataset. Groups
with fewer samples than this threshold will be excluded from the
stratification process. Adjusting this parameter determines the
minimum presence a combined group should have in the original dataset
to appear in the downsampled version.
original dataset to not be excluded in the downsampled dataset.

Returns
-------
output_data: pd.DataFrame
The proportionately stratified downsampled data frame.

Notes
-----
This function emphasizes proportionate stratified sampling, ensuring that
the downsampled dataset is a representative subset of the original data
with respect to the combined annotations. Due to this proportionate nature,
not all unique combinations from the original dataset might be present in
the downsampled dataset, especially if a particular combination has very
few samples in the original dataset. The `min_threshold` parameter can be
adjusted to determine the minimum number of samples a combined group
should have in the original dataset to appear in the downsampled version.
pd.Index
A pandas Index object containing the IDs of the selected cells.
These IDs correspond to the original DataFrame's or AnnData object's
index.
"""

logging.basicConfig(level=logging.WARNING)
# Convert annotations to list if it's a string
if isinstance(annotations, str):
annotations = [annotations]

# Check if the columns to downsample on exist
missing_columns = [
col for col in annotations if col not in input_data.columns
]
if missing_columns:
raise ValueError(
f"Columns {missing_columns} do not exist in the dataframe"
)

# If n_samples is None, return the input data without processing
if n_samples is None:
return input_data.copy()

# Combine annotations into a single column if multiple annotations
if len(annotations) > 1:
Comment on lines 704 to 707
Copy link

Copilot AI Jun 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The repeated call to logging.basicConfig in both downsample_cells and _get_downsampled_indexes may cause configuration conflicts; consider configuring logging once at application startup.

Suggested change
logging.basicConfig(level=logging.WARNING)
# Convert annotations to list if it's a string
if isinstance(annotations, str):
annotations = [annotations]
# Check if the columns to downsample on exist
missing_columns = [
col for col in annotations if col not in input_data.columns
]
if missing_columns:
raise ValueError(
f"Columns {missing_columns} do not exist in the dataframe"
)
# If n_samples is None, return the input data without processing
if n_samples is None:
return input_data.copy()
# Combine annotations into a single column if multiple annotations
if len(annotations) > 1:
# Combine annotations into a single column if multiple annotations
if len(annotations) > 1:

Copilot uses AI. Check for mistakes.
input_data[combined_col_name] = input_data[annotations].apply(
lambda row: '_'.join(row.values.astype(str)), axis=1)
cell_data[combined_col_name] = cell_data[annotations].astype(str).agg('_'.join, axis=1)
grouping_col = combined_col_name
else:
grouping_col = annotations[0]
Copy link
Collaborator

@fangliu117 fangliu117 Jun 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable grouping_col is used inconsistently.
Suggested:
if len(annotations) > 1:
cell_data[combined_col_name] = cell_data[annotations].apply(
lambda row: '_'.join(row.values.astype(str)), axis=1)
grouping_col = combined_col_name
else:
grouping_col = annotations[0] # This is a string, not a Series

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, I will go through & add these suggestions today


# Stratify selection
if stratify:
# Calculate proportions
freqs = input_data[grouping_col].value_counts(normalize=True)
freqs = cell_data[grouping_col].value_counts(normalize=True)

# Exclude groups with fewer samples than the min_threshold
filtered_freqs = freqs[freqs * len(input_data) >= min_threshold]
filtered_freqs = freqs[freqs * len(cell_data) >= min_threshold]

# Log warning for groups that are excluded
excluded_groups = freqs[~freqs.index.isin(filtered_freqs.index)]
Expand All @@ -662,7 +737,7 @@ def downsample_cells(input_data, annotations, n_samples=None, stratify=False,
zero_sample_groups = samples_per_group[samples_per_group == 0]
groups_with_zero_samples = zero_sample_groups.index
group_freqs = freqs[groups_with_zero_samples]
original_counts = group_freqs * len(input_data)
original_counts = group_freqs * len(cell_data)

# Ensure each group has at least one sample if its frequency
# is non-zero
Expand Down Expand Up @@ -701,7 +776,7 @@ def downsample_cells(input_data, annotations, n_samples=None, stratify=False,

# Sample data
sampled_data = []
for group, group_data in input_data.groupby(grouping_col):
for group, group_data in cell_data.groupby(grouping_col):
sample_count = samples_per_group.get(group, 0)
sample_size = min(sample_count, len(group_data))
if rand:
Expand All @@ -713,9 +788,8 @@ def downsample_cells(input_data, annotations, n_samples=None, stratify=False,
output_data = pd.concat(sampled_data)

else:
output_data = input_data.groupby(grouping_col, group_keys=False).apply(
lambda x: x.head(min(n_samples, len(x)))
).reset_index(drop=True)
output_data = cell_data.groupby(grouping_col, group_keys=False).apply(
lambda x: x.head(min(n_samples, len(x))))

# Log the final counts for each label in the downsampled dataset
label_counts = output_data[grouping_col].value_counts()
Expand All @@ -725,8 +799,8 @@ def downsample_cells(input_data, annotations, n_samples=None, stratify=False,
# Log the total number of rows in the resulting data
logging.info(f"Number of rows in the returned data: {len(output_data)}")

return output_data

# Return the cell indexes of the selected downsampled data
return output_data.index

def calculate_centroid(
data,
Expand Down
56 changes: 56 additions & 0 deletions tests/test_data_utils/test_downsample_cells.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,62 @@ def test_downsampling_effect_multi_obs(self):
)
self.assertDictEqual(expected_counts, actual_counts)

def test_anndata_input(self):
"""
Test that downsample_cells accepts anndata objects as input,
returns an anndata object and performs downsampling correctly,
retaining features in .X and annotations in .obs.
"""

# create anndata object
X_data = pd.DataFrame({
'feature1': [1, 3, 5, 7, 9, 12, 14, 16],
'feature2': [2, 4, 6, 8, 10, 13, 15, 18],
'feature3': [3, 5, 7, 9, 11, 14, 16, 19]
})

obs_data = pd.DataFrame({
'phenotype': [
'phenotype1',
'phenotype1',
'phenotype2',
'phenotype2',
'phenotype3',
'phenotype3',
'phenotype4',
'phenotype4'
]
})

anndata_obj = anndata.AnnData(X = X_data, obs = obs_data)

# call downsample on the anndata object
downsampled_adata = downsample_cells(
input_data = anndata_obj,
annotations = 'phenotype',
n_samples = 1,
stratify = False,
rand = True,
combined_col_name= '_combined_'
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test calls the function with stratify=False and min_threshold=5. In the downsample_cells function, the min_threshold parameter is only used when stratify=True.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes should be complete now. I have tested the code and it looks like its working. Let me know if any other fixes need to be made.

# confirm the downsampled_df is an anndata object
self.assertTrue(isinstance(downsampled_adata, anndata.AnnData))

# confirm number of samples after downsampling is correct
# (four groups with one sample each is four rows total)
self.assertEqual(downsampled_adata.shape[0], 4)

# confirm the number of groups (phenotypes) is still four
self.assertEqual(downsampled_adata.obs['phenotype'].nunique(), 4)

# confirm original annotation column is present
self.assertIn('phenotype', downsampled_adata.obs.columns)

# confirm feature columns are present in .var_names
expected_features = X_data.columns.tolist()
for feature in expected_features:
self.assertIn(feature, downsampled_adata.var_names)

if __name__ == '__main__':
unittest.main()