From 3140fc035575a0e3de01c260c1d7238ccc11e493 Mon Sep 17 00:00:00 2001 From: Sricharan Reddy Varra Date: Sun, 2 Mar 2025 22:29:34 -0800 Subject: [PATCH 1/9] added filter_table_by_query --- pyproject.toml | 1 + src/spatialdata/__init__.py | 2 + .../_core/query/relational_query.py | 28 ++++++++ src/spatialdata/_core/spatialdata.py | 71 +++++++++++++++++++ 4 files changed, 102 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 17d33bb16..c37fea41c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ license = {file = "LICENSE"} readme = "README.md" dependencies = [ "anndata>=0.9.1", + "annsel>=0.0.10", "click", "dask-image", "dask>=2024.4.1,<=2024.11.2", diff --git a/src/spatialdata/__init__.py b/src/spatialdata/__init__.py index 9ddfea32d..ca290da37 100644 --- a/src/spatialdata/__init__.py +++ b/src/spatialdata/__init__.py @@ -41,6 +41,7 @@ "match_element_to_table", "match_table_to_element", "match_sdata_to_table", + "filter_by_table_query", "SpatialData", "get_extent", "get_centroids", @@ -68,6 +69,7 @@ from spatialdata._core.operations.vectorize import to_circles, to_polygons from spatialdata._core.query._utils import get_bounding_box_corners from spatialdata._core.query.relational_query import ( + filter_by_table_query, get_element_annotators, get_element_instances, get_values, diff --git a/src/spatialdata/_core/query/relational_query.py b/src/spatialdata/_core/query/relational_query.py index b84d43c1b..eb275e1bf 100644 --- a/src/spatialdata/_core/query/relational_query.py +++ b/src/spatialdata/_core/query/relational_query.py @@ -12,6 +12,7 @@ import numpy as np import pandas as pd from anndata import AnnData +from annsel.core.typing import Predicates from dask.dataframe import DataFrame as DaskDataFrame from geopandas import GeoDataFrame from xarray import DataArray, DataTree @@ -823,6 +824,33 @@ def match_sdata_to_table( return SpatialData.init_from_elements(filtered_elements | {table_name: filtered_table}) +def filter_by_table_query( + sdata: SpatialData, + table_name: str, + filter_tables: bool = True, + include_orphan_tables: bool = False, + elements: list[str] | None = None, + obs_expr: Predicates | None = None, + var_expr: Predicates | None = None, + x_expr: Predicates | None = None, + obs_names_expr: Predicates | None = None, + var_names_expr: Predicates | None = None, + layer: str | None = None, + how: Literal["left", "left_exclusive", "inner", "right", "right_exclusive"] = "right", +) -> SpatialData: + sdata_subset = ( + sdata.subset(element_names=elements, filter_tables=filter_tables, include_orphan_tables=include_orphan_tables) + if elements + else sdata + ) + + filtered_table = sdata_subset.tables[table_name].an.filter( + obs=obs_expr, var=var_expr, x=x_expr, obs_names=obs_names_expr, var_names=var_names_expr, layer=layer + ) + + return match_sdata_to_table(sdata=sdata_subset, table_name=table_name, table=filtered_table, how=how) + + @dataclass class _ValueOrigin: origin: str diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index f011d08f8..8d8932ff2 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -12,6 +12,7 @@ import pandas as pd import zarr from anndata import AnnData +from annsel.core.typing import Predicates from dask.dataframe import DataFrame as DaskDataFrame from dask.dataframe import read_parquet from dask.delayed import Delayed @@ -2455,6 +2456,76 @@ def attrs(self, value: Mapping[Any, Any]) -> None: else: self._attrs = dict(value) + def filter_by_table_query( + self, + table_name: str, + filter_tables: bool = True, + include_orphan_tables: bool = False, + elements: list[str] | None = None, + obs_expr: Predicates | None = None, + var_expr: Predicates | None = None, + x_expr: Predicates | None = None, + obs_names_expr: Predicates | None = None, + var_names_expr: Predicates | None = None, + layer: str | None = None, + how: Literal["left", "left_exclusive", "inner", "right", "right_exclusive"] = "right", + ) -> SpatialData: + """Filter the SpatialData object based on a set of table queries. (:class:`anndata.AnnData`. + + Parameters + ---------- + table_name + The name of the table to filter the SpatialData object by. + filter_tables, optional + If True (default), the table is filtered to only contain rows that are annotating regions + contained within the element_names. + include_orphan_tables, optional + If True (not default), include tables that do not annotate SpatialElement(s). Only has an effect if + `filter_tables` is also set to True. + elements, optional + The names of the elements to filter the SpatialData object by. + obs_expr, optional + A Predicate or an iterable of Predicates to filter :attr:`anndata.AnnData.obs` by. + var_expr, optional + A Predicate or an iterable of Predicates to filter :attr:`anndata.AnnData.var` by. + x_expr, optional + A Predicate or an iterable of Predicates to filter :attr:`anndata.AnnData.X` by. + obs_names_expr, optional + A Predicate or an iterable of Predicates to filter :attr:`anndata.AnnData.obs_names` by. + var_names_expr, optional + A Predicate or an iterable of Predicates to filter :attr:`anndata.AnnData.var_names` by. + layer, optional + The layer of the :class:`anndata.AnnData` to filter the SpatialData object by, only used with `x_expr`. + how, optional + The type of join to perform. See :func:`spatialdata.join_spatialelement_table`. Default is "right". + + Returns + ------- + The filtered SpatialData object. + + Notes + ----- + This function calls :func:`spatialdata.filter_by_table_query` with the convenience that `sdata` is the current + SpatialData object. + + """ + from spatialdata._core.query.relational_query import filter_by_table_query + + return filter_by_table_query( + self, + table_name, + filter_tables, + include_orphan_tables, + elements, + obs_expr, + var_expr, + x_expr, + obs_names_expr, + var_names_expr, + layer, + how, + ) + class QueryManager: """Perform queries on SpatialData objects.""" From 7b9ec9b377c0ec6e877d6ce7ce79d20fb6e5caaa Mon Sep 17 00:00:00 2001 From: Sricharan Reddy Varra Date: Wed, 5 Mar 2025 15:54:59 -0800 Subject: [PATCH 2/9] added tests --- .../_core/query/relational_query.py | 49 +++- src/spatialdata/_core/spatialdata.py | 16 +- tests/conftest.py | 136 ++++++++++++ tests/core/query/test_relational_query.py | 209 ++++++++++++++++++ 4 files changed, 392 insertions(+), 18 deletions(-) diff --git a/src/spatialdata/_core/query/relational_query.py b/src/spatialdata/_core/query/relational_query.py index eb275e1bf..2553ddb4c 100644 --- a/src/spatialdata/_core/query/relational_query.py +++ b/src/spatialdata/_core/query/relational_query.py @@ -828,8 +828,7 @@ def filter_by_table_query( sdata: SpatialData, table_name: str, filter_tables: bool = True, - include_orphan_tables: bool = False, - elements: list[str] | None = None, + element_names: list[str] | None = None, obs_expr: Predicates | None = None, var_expr: Predicates | None = None, x_expr: Predicates | None = None, @@ -838,13 +837,49 @@ def filter_by_table_query( layer: str | None = None, how: Literal["left", "left_exclusive", "inner", "right", "right_exclusive"] = "right", ) -> SpatialData: - sdata_subset = ( - sdata.subset(element_names=elements, filter_tables=filter_tables, include_orphan_tables=include_orphan_tables) - if elements - else sdata + """Filter the SpatialData object based on a set of table queries. (:class:`anndata.AnnData`. + + Parameters + ---------- + sdata: + The SpatialData object to filter. + table_name + The name of the table to filter the SpatialData object by. + filter_tables, optional + If True (default), the table is filtered to only contain rows that are annotating regions + contained within the element_names. + element_names, optional + The names of the elements to filter the SpatialData object by. + obs_expr, optional + A Predicate or an iterable of Predicates to filter :attr:`anndata.AnnData.obs` by. + var_expr, optional + A Predicate or an iterable of Predicates to filter :attr:`anndata.AnnData.var` by. + x_expr, optional + A Predicate or an iterable of Predicates to filter :attr:`anndata.AnnData.X` by. + obs_names_expr, optional + A Predicate or an iterable of Predicates to filter :attr:`anndata.AnnData.obs_names` by. + var_names_expr, optional + A Predicate or an iterable of Predicates to filter :attr:`anndata.AnnData.var_names` by. + layer, optional + The layer of the :class:`anndata.AnnData` to filter the SpatialData object by, only used with `x_expr`. + how, optional + The type of join to perform. See :func:`spatialdata.join_spatialelement_table`. Default is "right". + + Returns + ------- + The filtered SpatialData object. + + Notes + ----- + This function calls :func:`spatialdata.filter_by_table_query` with the convenience that `sdata` is the current + SpatialData object. + + """ + sdata_subset: SpatialData = ( + sdata.subset(element_names=element_names, filter_tables=filter_tables) if element_names else sdata ) - filtered_table = sdata_subset.tables[table_name].an.filter( + filtered_table: AnnData = sdata_subset.tables[table_name].an.filter( obs=obs_expr, var=var_expr, x=x_expr, obs_names=obs_names_expr, var_names=var_names_expr, layer=layer ) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 8d8932ff2..4dc3ea6ba 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -2460,8 +2460,7 @@ def filter_by_table_query( self, table_name: str, filter_tables: bool = True, - include_orphan_tables: bool = False, - elements: list[str] | None = None, + element_names: list[str] | None = None, obs_expr: Predicates | None = None, var_expr: Predicates | None = None, x_expr: Predicates | None = None, @@ -2479,10 +2478,7 @@ def filter_by_table_query( filter_tables, optional If True (default), the table is filtered to only contain rows that are annotating regions contained within the element_names. - include_orphan_tables, optional - If True (not default), include tables that do not annotate SpatialElement(s). Only has an effect if - `filter_tables` is also set to True. - elements, optional + element_names, optional The names of the elements to filter the SpatialData object by. obs_expr, optional A Predicate or an iterable of Predicates to filter :attr:`anndata.AnnData.obs` by. @@ -2505,9 +2501,8 @@ def filter_by_table_query( Notes ----- - This function calls :func:`spatialdata.filter_by_table_query` with the convenience that `sdata` is the current - SpatialData object. - + You can also use :method:`spatialdata.SpatialData.filter_by_table_query` with the convenience that `sdata` is + the current SpatialData object. """ from spatialdata._core.query.relational_query import filter_by_table_query @@ -2515,8 +2510,7 @@ def filter_by_table_query( self, table_name, filter_tables, - include_orphan_tables, - elements, + element_names, obs_expr, var_expr, x_expr, diff --git a/tests/conftest.py b/tests/conftest.py index 5ced646e2..afc30c85b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -495,3 +495,139 @@ def adata_labels() -> AnnData: "tensor_copy": rng.integers(0, blobs.shape[0], size=(n_obs_labels, 2)), } return generate_adata(n_var, obs_labels, obsm_labels, uns_labels) + + +@pytest.fixture() +def complex_sdata() -> SpatialData: + """ + Create a complex SpatialData object with multiple data types for comprehensive testing. + + Contains: + - Images (2D and 3D) + - Labels (2D and 3D) + - Shapes (polygons and circles) + - Points + - Multiple tables with different annotations + - Categorical and numerical values in both obs and var + + Returns + ------- + SpatialData + A complex SpatialData object for testing. + """ + # Get basic components using existing functions + images = _get_images() + labels = _get_labels() + shapes = _get_shapes() + points = _get_points() + + # Create tables with enhanced var data + n_var = 10 + + # Table 1: Basic table annotating labels2d + obs1 = pd.DataFrame( + { + "region": pd.Categorical(["labels2d"] * 50), + "instance_id": range(1, 51), # Skip background (0) + "cell_type": pd.Categorical(RNG.choice(["T cell", "B cell", "Macrophage"], size=50)), + "size": RNG.uniform(10, 100, size=50), + } + ) + + var1 = pd.DataFrame( + { + "feature_type": pd.Categorical(["gene", "protein", "gene", "protein", "gene"] * 2), + "importance": RNG.uniform(0, 10, size=n_var), + "is_marker": RNG.choice([True, False], size=n_var), + }, + index=[f"feature_{i}" for i in range(n_var)], + ) + + X1 = RNG.normal(size=(50, n_var)) + uns1 = { + "spatialdata_attrs": { + "region": "labels2d", + "region_key": "region", + "instance_key": "instance_id", + } + } + + table1 = AnnData(X=X1, obs=obs1, var=var1, uns=uns1) + + # Table 2: Annotating both polygons and circles from shapes + n_polygons = len(shapes["poly"]) + n_circles = len(shapes["circles"]) + total_items = n_polygons + n_circles + + obs2 = pd.DataFrame( + { + "region": pd.Categorical(["poly"] * n_polygons + ["circles"] * n_circles), + "instance_id": np.concatenate([range(n_polygons), range(n_circles)]), + "category": pd.Categorical(RNG.choice(["A", "B", "C"], size=total_items)), + "value": RNG.normal(size=total_items), + "count": RNG.poisson(10, size=total_items), + } + ) + + var2 = pd.DataFrame( + { + "feature_type": pd.Categorical( + ["feature_type1", "feature_type2", "feature_type1", "feature_type2", "feature_type1"] * 2 + ), + "score": RNG.exponential(2, size=n_var), + "detected": RNG.choice([True, False], p=[0.7, 0.3], size=n_var), + }, + index=[f"metric_{i}" for i in range(n_var)], + ) + + X2 = RNG.normal(size=(total_items, n_var)) + uns2 = { + "spatialdata_attrs": { + "region": ["poly", "circles"], + "region_key": "region", + "instance_key": "instance_id", + } + } + + table2 = AnnData(X=X2, obs=obs2, var=var2, uns=uns2) + + # Table 3: Orphan table not annotating any elements + obs3 = pd.DataFrame( + { + "cluster": pd.Categorical(RNG.choice(["cluster_1", "cluster_2", "cluster_3"], size=40)), + "sample": pd.Categorical(["sample_A"] * 20 + ["sample_B"] * 20), + "qc_pass": RNG.choice([True, False], p=[0.8, 0.2], size=40), + } + ) + + var3 = pd.DataFrame( + { + "feature_type": pd.Categorical(["gene", "protein", "gene", "protein", "gene"] * 2), + "mean_expression": RNG.uniform(0, 20, size=n_var), + "variance": RNG.gamma(2, 2, size=n_var), + }, + index=[f"feature_{i}" for i in range(n_var)], + ) + + X3 = RNG.normal(size=(40, n_var)) + table3 = AnnData(X=X3, obs=obs3, var=var3) + + # Create additional coordinate system in one of the shapes for testing + # Modified copy of circles with an additional coordinate system + circles_alt_coords = shapes["circles"].copy() + circles_alt_coords["coordinate_system"] = "alt_system" + + # Add everything to a SpatialData object + sdata = SpatialData( + images=images, + labels=labels, + shapes={**shapes, "circles_alt_coords": circles_alt_coords}, + points=points, + tables={"labels_table": table1, "shapes_table": table2, "orphan_table": table3}, + ) + + # Add layers to tables for testing layer-specific operations + sdata.tables["labels_table"].layers["scaled"] = sdata.tables["labels_table"].X * 2 + sdata.tables["labels_table"].layers["log"] = np.log1p(np.abs(sdata.tables["labels_table"].X)) + + return sdata diff --git a/tests/core/query/test_relational_query.py b/tests/core/query/test_relational_query.py index f0b4da7e0..beaa6a53d 100644 --- a/tests/core/query/test_relational_query.py +++ b/tests/core/query/test_relational_query.py @@ -1,3 +1,4 @@ +import annsel as an import numpy as np import pandas as pd import pytest @@ -7,6 +8,7 @@ from spatialdata._core.query.relational_query import ( _locate_value, _ValueOrigin, + filter_by_table_query, get_element_annotators, join_spatialelement_table, ) @@ -1052,3 +1054,210 @@ def test_get_element_annotators(full_sdata): full_sdata.tables["another_table"] = another_table names = get_element_annotators(full_sdata, "labels2d") assert names == {"another_table", "table"} + + +def test_filter_by_table_query(complex_sdata): + """Test basic filtering functionality of filter_by_table_query.""" + sdata = complex_sdata + + # Test 1: Basic filtering on categorical obs column + result = filter_by_table_query(sdata=sdata, table_name="labels_table", obs_expr=an.col("cell_type") == "T cell") + + # Check that the table was filtered properly + assert all(result["labels_table"].obs["cell_type"] == "T cell") + # Check that result has fewer rows than original + assert result["labels_table"].n_obs < sdata["labels_table"].n_obs + # Check that labels2d element is still present + assert "labels2d" in result.labels + + # Test 2: Filtering on numerical obs column + result = filter_by_table_query(sdata=sdata, table_name="labels_table", obs_expr=an.col("size") > 50) + + # Check that the table was filtered properly + assert all(result["labels_table"].obs["size"] > 50) + # Check that labels2d element is still present + assert "labels2d" in result.labels + + # Test 3: Filtering with var expressions + result = filter_by_table_query( + sdata=sdata, table_name="shapes_table", var_expr=an.col("feature_type") == "feature_type1" + ) + + # Check that the filtered var dataframe only has 'spatial' feature_type + assert all(result["shapes_table"].var["feature_type"] == "feature_type1") + # Check that the filtered var dataframe has fewer rows than the original + assert result["shapes_table"].n_vars < sdata["shapes_table"].n_vars + + # Test 4: Multiple filtering conditions (obs and var) + result = filter_by_table_query( + sdata=sdata, table_name="shapes_table", obs_expr=an.col("category") == "A", var_expr=an.col("score") > 2 + ) + + # Check that both filters were applied + assert all(result["shapes_table"].obs["category"] == "A") + assert all(result["shapes_table"].var["score"] > 2) + + # Test 5: Using X expressions + result = filter_by_table_query(sdata=sdata, table_name="labels_table", x_expr=an.col("feature_1") > 0.5) + + # Check that the filter was applied to X + assert np.all(result["labels_table"][:, "feature_1"].X > 0.5) + + # Test 6: Using different join types + # Test with inner join + result = filter_by_table_query( + sdata=sdata, table_name="shapes_table", obs_expr=an.col("category") == "A", how="inner" + ) + + # The elements should be filtered to only include instances in the table + assert "poly" in result.shapes + assert "circles" in result.shapes + + # Test with left join + result = filter_by_table_query( + sdata=sdata, table_name="shapes_table", obs_expr=an.col("category") == "A", how="left" + ) + + # Elements should be preserved but table should be filtered + assert "poly" in result.shapes + assert "circles" in result.shapes + assert all(result["shapes_table"].obs["category"] == "A") + + # Test 7: Filtering with specific element_names + result = filter_by_table_query( + sdata=sdata, + table_name="shapes_table", + element_names=["poly"], # Only include poly, not circles + obs_expr=an.col("category") == "A", + ) + + # Only specified elements should be in the result + assert "poly" in result.shapes + assert "circles" not in result.shapes + + # Test 8: Testing orphan table handling + # First test with include_orphan_tables=False (default) + result = filter_by_table_query( + sdata=sdata, + table_name="shapes_table", + obs_expr=an.col("category") == "A", + filter_tables=True, + ) + + # Orphan table should not be in the result + assert "orphan_table" not in result.tables + + +def test_filter_by_table_query_with_layers(complex_sdata): + """Test filtering using different layers.""" + sdata = complex_sdata + + # Test filtering using a specific layer + result = filter_by_table_query( + sdata=sdata, + table_name="labels_table", + x_expr=an.col("feature_1") > 1.0, + layer="scaled", # The 'scaled' layer has values 2x the original X + ) + + # Values in the scaled layer's feature_1 column should be > 1.0 + assert np.all(result["labels_table"][:, "feature_1"].layers["scaled"] > 1.0) + + +def test_filter_by_table_query_edge_cases(complex_sdata): + """Test edge cases for filter_by_table_query.""" + sdata = complex_sdata + + # Test 1: Empty result from filtering raises AssertionError + with pytest.raises(AssertionError, match="No valid element to join"): + # This should raise an error because there are no elements matching this category + filter_by_table_query( + sdata=sdata, + table_name="shapes_table", + obs_expr=an.col("category") == "NonExistentCategory", + how="inner", # Inner join requires matching elements + ) + + # Test 2: Filter by obs_names + result = filter_by_table_query( + sdata=sdata, + table_name="shapes_table", + obs_names_expr=an.obs_names.str.starts_with("0"), # Only rows with index starting with '0' + ) + + # Check that filtered table only has obs names starting with '0' + assert all(str(idx).startswith("0") for idx in result["shapes_table"].obs_names) + + # Test 3: Invalid table name raises KeyError + with pytest.raises(KeyError, match="nonexistent_table"): + filter_by_table_query(sdata=sdata, table_name="nonexistent_table", obs_expr=an.col("category") == "A") + + # Test 4: Invalid column name in expression + with pytest.raises(KeyError): # The exact exception type may vary + filter_by_table_query(sdata=sdata, table_name="shapes_table", obs_expr=an.col("nonexistent_column") == "A") + + # Test 5: Using layer that doesn't exist + with pytest.raises(KeyError): + filter_by_table_query( + sdata=sdata, table_name="labels_table", x_expr=an.col("feature_1") > 0.5, layer="nonexistent_layer" + ) + + # Test 6: Filter by var_names + result = filter_by_table_query( + sdata=sdata, + table_name="labels_table", + var_names_expr=an.var_names.str.contains("feature_[0-4]"), # Only features 0-4 + ) + + # Check that filtered table only has var names matching the pattern + for idx in result["labels_table"].var_names: + var_name = str(idx) + assert var_name.startswith("feature_") and int(var_name.split("_")[1]) < 5 + + # Test 7: Invalid element_names (element doesn't exist) + with pytest.raises(AssertionError, match="elements_dict must not be empty"): + filter_by_table_query( + sdata=sdata, + table_name="shapes_table", + element_names=["nonexistent_element"], + obs_expr=an.col("category") == "A", + ) + + # Test 8: Invalid join type raises ValueError + with pytest.raises(TypeError, match="not a valid type of join."): + filter_by_table_query( + sdata=sdata, + table_name="shapes_table", + how="invalid_join_type", # Invalid join type + obs_expr=an.col("category") == "A", + ) + + +def test_filter_by_table_query_complex_combination(complex_sdata): + """Test complex combinations of filters.""" + sdata = complex_sdata + + # Combine multiple filtering criteria + result = filter_by_table_query( + sdata=sdata, + table_name="shapes_table", + obs_expr=(an.col("category") == "A", an.col("value") > 0), + var_expr=an.col("feature_type") == "feature_type1", + how="inner", + ) + + # Validate the combined filtering results + assert all(result["shapes_table"].obs["category"] == "A") + assert all(result["shapes_table"].obs["value"] > 0) + assert all(result["shapes_table"].var["feature_type"] == "feature_type1") + + # Both elements should be present but filtered + assert "circles" in result.shapes + + # The filtered shapes should only contain the instances from the filtered table + table_instance_ids = set( + zip(result["shapes_table"].obs["region"], result["shapes_table"].obs["instance_id"], strict=True) + ) + if "circles" in result.shapes: + for idx in result["circles"].index: + assert ("circles", idx) in table_instance_ids From 559cc597d385dbef08259dd2a0130101a3c9d36a Mon Sep 17 00:00:00 2001 From: Sricharan Reddy Varra Date: Mon, 10 Mar 2025 22:30:59 -0700 Subject: [PATCH 3/9] updated annsel, udjusted test --- pyproject.toml | 2 +- tests/core/query/test_relational_query.py | 24 +++++++---------------- 2 files changed, 8 insertions(+), 18 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c37fea41c..cd95df7f5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ license = {file = "LICENSE"} readme = "README.md" dependencies = [ "anndata>=0.9.1", - "annsel>=0.0.10", + "annsel>=0.1.0", "click", "dask-image", "dask>=2024.4.1,<=2024.11.2", diff --git a/tests/core/query/test_relational_query.py b/tests/core/query/test_relational_query.py index beaa6a53d..2bd79036a 100644 --- a/tests/core/query/test_relational_query.py +++ b/tests/core/query/test_relational_query.py @@ -1168,17 +1168,7 @@ def test_filter_by_table_query_edge_cases(complex_sdata): """Test edge cases for filter_by_table_query.""" sdata = complex_sdata - # Test 1: Empty result from filtering raises AssertionError - with pytest.raises(AssertionError, match="No valid element to join"): - # This should raise an error because there are no elements matching this category - filter_by_table_query( - sdata=sdata, - table_name="shapes_table", - obs_expr=an.col("category") == "NonExistentCategory", - how="inner", # Inner join requires matching elements - ) - - # Test 2: Filter by obs_names + # Test 1: Filter by obs_names result = filter_by_table_query( sdata=sdata, table_name="shapes_table", @@ -1188,21 +1178,21 @@ def test_filter_by_table_query_edge_cases(complex_sdata): # Check that filtered table only has obs names starting with '0' assert all(str(idx).startswith("0") for idx in result["shapes_table"].obs_names) - # Test 3: Invalid table name raises KeyError + # Test 2: Invalid table name raises KeyError with pytest.raises(KeyError, match="nonexistent_table"): filter_by_table_query(sdata=sdata, table_name="nonexistent_table", obs_expr=an.col("category") == "A") - # Test 4: Invalid column name in expression + # Test 3: Invalid column name in expression with pytest.raises(KeyError): # The exact exception type may vary filter_by_table_query(sdata=sdata, table_name="shapes_table", obs_expr=an.col("nonexistent_column") == "A") - # Test 5: Using layer that doesn't exist + # Test 4: Using layer that doesn't exist with pytest.raises(KeyError): filter_by_table_query( sdata=sdata, table_name="labels_table", x_expr=an.col("feature_1") > 0.5, layer="nonexistent_layer" ) - # Test 6: Filter by var_names + # Test 5: Filter by var_names result = filter_by_table_query( sdata=sdata, table_name="labels_table", @@ -1214,7 +1204,7 @@ def test_filter_by_table_query_edge_cases(complex_sdata): var_name = str(idx) assert var_name.startswith("feature_") and int(var_name.split("_")[1]) < 5 - # Test 7: Invalid element_names (element doesn't exist) + # Test 6: Invalid element_names (element doesn't exist) with pytest.raises(AssertionError, match="elements_dict must not be empty"): filter_by_table_query( sdata=sdata, @@ -1223,7 +1213,7 @@ def test_filter_by_table_query_edge_cases(complex_sdata): obs_expr=an.col("category") == "A", ) - # Test 8: Invalid join type raises ValueError + # Test 7: Invalid join type raises ValueError with pytest.raises(TypeError, match="not a valid type of join."): filter_by_table_query( sdata=sdata, From add58808218afd2fd5ab64c5b8a0f44b5ab9ebae Mon Sep 17 00:00:00 2001 From: Sricharan Reddy Varra Date: Mon, 10 Mar 2025 22:38:13 -0700 Subject: [PATCH 4/9] fixed docstring: func instead of method --- src/spatialdata/_core/spatialdata.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 4dc3ea6ba..3cd672f8c 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -2501,7 +2501,7 @@ def filter_by_table_query( Notes ----- - You can also use :method:`spatialdata.SpatialData.filter_by_table_query` with the convenience that `sdata` is + You can also use :func:`spatialdata.SpatialData.filter_by_table_query` with the convenience that `sdata` is the current SpatialData object. """ from spatialdata._core.query.relational_query import filter_by_table_query From d6a9b955fb4165e552d4aa289b5184a856c517cb Mon Sep 17 00:00:00 2001 From: Sricharan Reddy Varra Date: Mon, 10 Mar 2025 22:40:08 -0700 Subject: [PATCH 5/9] using SpatialData method for one test for ci --- tests/core/query/test_relational_query.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/core/query/test_relational_query.py b/tests/core/query/test_relational_query.py index 2bd79036a..6c1b81cab 100644 --- a/tests/core/query/test_relational_query.py +++ b/tests/core/query/test_relational_query.py @@ -1228,8 +1228,7 @@ def test_filter_by_table_query_complex_combination(complex_sdata): sdata = complex_sdata # Combine multiple filtering criteria - result = filter_by_table_query( - sdata=sdata, + result = sdata.filter_by_table_query( table_name="shapes_table", obs_expr=(an.col("category") == "A", an.col("value") > 0), var_expr=an.col("feature_type") == "feature_type1", From 877d85c4c7a2fe0322cf3148cc5f4465ac61b868 Mon Sep 17 00:00:00 2001 From: Sricharan Reddy Varra Date: Mon, 10 Mar 2025 22:46:56 -0700 Subject: [PATCH 6/9] removed explicit optional in docstrings --- .../_core/query/relational_query.py | 18 +++++++++--------- src/spatialdata/_core/spatialdata.py | 18 +++++++++--------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/src/spatialdata/_core/query/relational_query.py b/src/spatialdata/_core/query/relational_query.py index 2553ddb4c..b68539021 100644 --- a/src/spatialdata/_core/query/relational_query.py +++ b/src/spatialdata/_core/query/relational_query.py @@ -845,24 +845,24 @@ def filter_by_table_query( The SpatialData object to filter. table_name The name of the table to filter the SpatialData object by. - filter_tables, optional + filter_tables If True (default), the table is filtered to only contain rows that are annotating regions contained within the element_names. - element_names, optional + element_names The names of the elements to filter the SpatialData object by. - obs_expr, optional + obs_expr A Predicate or an iterable of Predicates to filter :attr:`anndata.AnnData.obs` by. - var_expr, optional + var_expr A Predicate or an iterable of Predicates to filter :attr:`anndata.AnnData.var` by. - x_expr, optional + x_expr A Predicate or an iterable of Predicates to filter :attr:`anndata.AnnData.X` by. - obs_names_expr, optional + obs_names_expr A Predicate or an iterable of Predicates to filter :attr:`anndata.AnnData.obs_names` by. - var_names_expr, optional + var_names_expr A Predicate or an iterable of Predicates to filter :attr:`anndata.AnnData.var_names` by. - layer, optional + layer The layer of the :class:`anndata.AnnData` to filter the SpatialData object by, only used with `x_expr`. - how, optional + how The type of join to perform. See :func:`spatialdata.join_spatialelement_table`. Default is "right". Returns diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 3cd672f8c..19d2e4fcc 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -2475,24 +2475,24 @@ def filter_by_table_query( ---------- table_name The name of the table to filter the SpatialData object by. - filter_tables, optional + filter_tables If True (default), the table is filtered to only contain rows that are annotating regions contained within the element_names. - element_names, optional + element_names The names of the elements to filter the SpatialData object by. - obs_expr, optional + obs_expr A Predicate or an iterable of Predicates to filter :attr:`anndata.AnnData.obs` by. - var_expr, optional + var_expr A Predicate or an iterable of Predicates to filter :attr:`anndata.AnnData.var` by. - x_expr, optional + x_expr A Predicate or an iterable of Predicates to filter :attr:`anndata.AnnData.X` by. - obs_names_expr, optional + obs_names_expr A Predicate or an iterable of Predicates to filter :attr:`anndata.AnnData.obs_names` by. - var_names_expr, optional + var_names_expr A Predicate or an iterable of Predicates to filter :attr:`anndata.AnnData.var_names` by. - layer, optional + layer The layer of the :class:`anndata.AnnData` to filter the SpatialData object by, only used with `x_expr`. - how, optional + how The type of join to perform. See :func:`spatialdata.join_spatialelement_table`. Default is "right". Returns From 79985414bbe3ef208c60280be8df985a9b97492f Mon Sep 17 00:00:00 2001 From: Sricharan Reddy Varra Date: Mon, 10 Mar 2025 22:48:56 -0700 Subject: [PATCH 7/9] updated docstring Notes --- src/spatialdata/_core/query/relational_query.py | 5 ++--- src/spatialdata/_core/spatialdata.py | 3 +-- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/spatialdata/_core/query/relational_query.py b/src/spatialdata/_core/query/relational_query.py index b68539021..68262960d 100644 --- a/src/spatialdata/_core/query/relational_query.py +++ b/src/spatialdata/_core/query/relational_query.py @@ -871,9 +871,8 @@ def filter_by_table_query( Notes ----- - This function calls :func:`spatialdata.filter_by_table_query` with the convenience that `sdata` is the current - SpatialData object. - + You can also use :func:`spatialdata.SpatialData.filter_by_table_query` with the convenience that `sdata` is the + current SpatialData object. """ sdata_subset: SpatialData = ( sdata.subset(element_names=element_names, filter_tables=filter_tables) if element_names else sdata diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 19d2e4fcc..fe238fd2b 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -2501,8 +2501,7 @@ def filter_by_table_query( Notes ----- - You can also use :func:`spatialdata.SpatialData.filter_by_table_query` with the convenience that `sdata` is - the current SpatialData object. + You can also use :func:`query.relational_query.filter_by_table_query`. """ from spatialdata._core.query.relational_query import filter_by_table_query From 3fb202eb3f4a6482d8e20778c30e0790503a2209 Mon Sep 17 00:00:00 2001 From: Sricharan Reddy Varra Date: Mon, 10 Mar 2025 23:02:15 -0700 Subject: [PATCH 8/9] updated api/operations.md --- docs/api/operations.md | 1 + docs/conf.py | 1 + 2 files changed, 2 insertions(+) diff --git a/docs/api/operations.md b/docs/api/operations.md index 937b8dbca..c4c4b373e 100644 --- a/docs/api/operations.md +++ b/docs/api/operations.md @@ -15,6 +15,7 @@ Operations on `SpatialData` objects. .. autofunction:: match_element_to_table .. autofunction:: match_table_to_element .. autofunction:: match_sdata_to_table +.. autofunction:: filter_by_table_query .. autofunction:: concatenate .. autofunction:: transform .. autofunction:: rasterize diff --git a/docs/conf.py b/docs/conf.py index 6efe4c54a..4c9ff4dce 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -103,6 +103,7 @@ "datatree": ("https://datatree.readthedocs.io/en/latest/", None), "dask": ("https://docs.dask.org/en/latest/", None), "shapely": ("https://shapely.readthedocs.io/en/stable", None), + "annsel": ("https://annsel.readthedocs.io/en/latest/", None), } From 33423f61a4eac029cb0eba9c0a666bfc3ae18824 Mon Sep 17 00:00:00 2001 From: Sricharan Reddy Varra Date: Thu, 10 Apr 2025 11:39:41 -0700 Subject: [PATCH 9/9] updated annsel version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index cd95df7f5..ee6e6524a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ license = {file = "LICENSE"} readme = "README.md" dependencies = [ "anndata>=0.9.1", - "annsel>=0.1.0", + "annsel>=0.1.1", "click", "dask-image", "dask>=2024.4.1,<=2024.11.2",