-
Notifications
You must be signed in to change notification settings - Fork 11
feat(downsample):AnnData input and output for downsample_cells #349
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
feat(downsample):AnnData input and output for downsample_cells #349
Conversation
Modified downsample_cells function to accept anndata.AnnData objects as input. When an AnnData object is provided, .X` and .obs data are combined into a pandas DataFrame, before applying the rest of the downsampling function.
This commit modifies the 'downsample_cells' function and adds a helper function '_get_downsampled_indices' to provide cell downsampling capabilities for both AnnData objects and Pandas DataFrames.
This test now checks anndata objects are accepted, downsampled correctly, and returned as annadata objects.
Reordering the code to convert annotations to a list before extracting annotation information, ensuring it is in DataFrame format as required by subsequent downsample_cells code.
|
Summary: Changes:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR updates the downsample_cells function to support input as an AnnData object, converting its .X and .obs into a DataFrame for downsampling, while maintaining compatibility with pandas DataFrames.
- Added conversion logic for AnnData input
- Updated error handling and documentation for input types
- Introduced a new unit test to validate AnnData processing
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| tests/test_data_utils/test_downsample_cells.py | Added a new unit test to ensure downsample_cells correctly processes AnnData objects |
| src/spac/data_utils.py | Refactored downsample_cells logic to handle both pandas DataFrame and AnnData inputs, updating docstrings and internal variable usage |
| logging.basicConfig(level=logging.WARNING) | ||
| # Convert annotations to list if it's a string | ||
| if isinstance(annotations, str): | ||
| annotations = [annotations] | ||
|
|
||
| # Check if the columns to downsample on exist | ||
| missing_columns = [ | ||
| col for col in annotations if col not in input_data.columns | ||
| ] | ||
| if missing_columns: | ||
| raise ValueError( | ||
| f"Columns {missing_columns} do not exist in the dataframe" | ||
| ) | ||
|
|
||
| # If n_samples is None, return the input data without processing | ||
| if n_samples is None: | ||
| return input_data.copy() | ||
|
|
||
| # Combine annotations into a single column if multiple annotations | ||
| if len(annotations) > 1: |
Copilot
AI
Jun 24, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The repeated call to logging.basicConfig in both downsample_cells and _get_downsampled_indexes may cause configuration conflicts; consider configuring logging once at application startup.
| logging.basicConfig(level=logging.WARNING) | |
| # Convert annotations to list if it's a string | |
| if isinstance(annotations, str): | |
| annotations = [annotations] | |
| # Check if the columns to downsample on exist | |
| missing_columns = [ | |
| col for col in annotations if col not in input_data.columns | |
| ] | |
| if missing_columns: | |
| raise ValueError( | |
| f"Columns {missing_columns} do not exist in the dataframe" | |
| ) | |
| # If n_samples is None, return the input data without processing | |
| if n_samples is None: | |
| return input_data.copy() | |
| # Combine annotations into a single column if multiple annotations | |
| if len(annotations) > 1: | |
| # Combine annotations into a single column if multiple annotations | |
| if len(annotations) > 1: |
| else: | ||
| raise TypeError("Input data must be a Pandas DataFrame or Anndata Object.") | ||
|
|
Copilot
AI
Jun 24, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For consistency and clarity, update the error message to refer to 'AnnData' (with proper casing) instead of 'Anndata'.
| else: | |
| raise TypeError("Input data must be a Pandas DataFrame or Anndata Object.") | |
| else: | |
| raise TypeError("Input data must be a Pandas DataFrame or AnnData Object.") |
| Otherwise, choose the first n cells. | ||
| combined_col_name : str, default='_combined_' | ||
| Name of the column that will store combined values when multiple | ||
| annotation columns are provided. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The "combined_col_name" parameter is documented but never used in the code.
Since grouping_col is a pd.Series and not a new column in the cell_data DataFrame, the combined_col_name parameter isn't strictly necessary. May either remove or assign the name to the Series.
src/spac/data_utils.py
Outdated
| if len(annotations) > 1: | ||
| input_data[combined_col_name] = input_data[annotations].apply( | ||
| grouping_col = cell_data[annotations].apply( | ||
| lambda row: '_'.join(row.values.astype(str)), axis=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The apply method for combining annotations is readable but can be slow on very large datasets. A more performant, vectorized approach is to use str.cat or agg.
grouping_col = cell_data[annotations].astype(str).agg('_'.join, axis=1)
(Ensure all columns are string type first)
| lambda row: '_'.join(row.values.astype(str)), axis=1) | ||
| grouping_col = combined_col_name | ||
| else: | ||
| grouping_col = annotations[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable grouping_col is used inconsistently.
Suggested:
if len(annotations) > 1:
cell_data[combined_col_name] = cell_data[annotations].apply(
lambda row: '_'.join(row.values.astype(str)), axis=1)
grouping_col = combined_col_name
else:
grouping_col = annotations[0] # This is a string, not a Series
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, I will go through & add these suggestions today
| combined_col_name= '_combined_', | ||
| min_threshold= 5 | ||
| ) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test calls the function with stratify=False and min_threshold=5. In the downsample_cells function, the min_threshold parameter is only used when stratify=True.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The changes should be complete now. I have tested the code and it looks like its working. Let me know if any other fixes need to be made.
When multiple annotations are provided, a new temporary column (named by `combined_col_name`) is now explicitly added to the `cell_data` DataFrame, making `grouping_col` consistently a column name (string).
Replaced slow `DataFrame.apply` with the vectorized `DataFrame.astype(str).agg('_'.join, axis=1)` for combining annotations
bb4b7a3 to
1915dca
Compare
Summary:
Modified downsample_cells function to accept anndata.AnnData objects as input. When an AnnData object is provided, .X and .obs data are combined into a pandas DataFrame, before applying the rest of the downsampling function.
Changes: