From e97cf089e47ba41b660d307b2aaaec690c113ff2 Mon Sep 17 00:00:00 2001 From: Marius Lange Date: Tue, 29 Jul 2025 13:17:23 +0200 Subject: [PATCH 1/6] Enable batched mode for jaccard --- src/cellmapper/constants.py | 4 + src/cellmapper/model/cellmapper.py | 5 + src/cellmapper/model/kernel.py | 34 ++-- src/cellmapper/utils.py | 162 +++++++++++++++++- .../model/test_query_to_reference_mapping.py | 57 ++++++ tests/model/test_self_mapping.py | 56 ++++++ 6 files changed, 306 insertions(+), 12 deletions(-) diff --git a/src/cellmapper/constants.py b/src/cellmapper/constants.py index dcac56a..dc7dfae 100644 --- a/src/cellmapper/constants.py +++ b/src/cellmapper/constants.py @@ -18,3 +18,7 @@ class PackageConstants: # Threshold for recommending spectral method over iterative for matrix powers SPECTRAL_METHOD_THRESHOLD: int = 10 + + # Batch processing thresholds for Jaccard-based methods + JACCARD_BATCH_WARNING_CELLS: int = 100_000 # Warn if >100k cells with Jaccard methods + JACCARD_BATCH_WARNING_NEIGHBORS: int = 20 # Warn if >20 neighbors with Jaccard methods diff --git a/src/cellmapper/model/cellmapper.py b/src/cellmapper/model/cellmapper.py index 076ab6b..a3833ed 100644 --- a/src/cellmapper/model/cellmapper.py +++ b/src/cellmapper/model/cellmapper.py @@ -241,6 +241,7 @@ def compute_mapping_matrix( self_edges: bool | None = None, n_eigenvectors: int = 50, eigen_solver: Literal["partial", "complete"] = "partial", + n_batches: int | None = None, ) -> None: """ Compute the mapping matrix for label transfer. @@ -257,6 +258,9 @@ def compute_mapping_matrix( Eigendecomposition method for spectral approach: - "partial": Uses sparse eigendecomposition, faster (default) - "complete": Uses complete eigendecomposition, exact for testing + n_batches + Number of batches to split Jaccard/HNOCA computation into for memory efficiency. + Only applies to Jaccard and HNOCA methods. If None, computation is done in a single batch. Returns ------- @@ -292,6 +296,7 @@ def compute_mapping_matrix( kernel_method=kernel_method, symmetrize=symmetrize, self_edges=self_edges, + n_batches=n_batches, ) # Validate expected shape before creating mapping operator diff --git a/src/cellmapper/model/kernel.py b/src/cellmapper/model/kernel.py index 07ea090..6a77988 100644 --- a/src/cellmapper/model/kernel.py +++ b/src/cellmapper/model/kernel.py @@ -8,7 +8,7 @@ from cellmapper.logging import logger from cellmapper.model._knn_backend import get_backend from cellmapper.model.neighbors import Neighbors -from cellmapper.utils import extract_neighbors_from_distances +from cellmapper.utils import compute_jaccard_kernel_matrix, extract_neighbors_from_distances class Kernel: @@ -218,6 +218,7 @@ def compute_kernel_matrix( symmetrize: bool = False, symmetrize_method: Literal["max", "mean"] = "max", self_edges: bool = False, + n_batches: int | None = None, **kwargs, ) -> None: """ @@ -232,6 +233,9 @@ def compute_kernel_matrix( - "max": Take element-wise maximum between matrix and transpose (preserves strongest connections) - "mean": Take element-wise average between matrix and transpose (smooths connections) %(self_edges)s + n_batches + Number of batches to use for Jaccard/HNOCA computation. If None (default), + compute in a single batch. Use batch processing to reduce memory usage for large datasets. **kwargs Additional keyword arguments for kernel computation. @@ -262,14 +266,21 @@ def compute_kernel_matrix( assert self.yx is not None, "yx neighbors must be computed" n_neighbors = self.yx.n_neighbors - # Compute kernel matrix - kernel_matrix = (yx @ xx.T) + (yy @ xy.T) - - if kernel_method == "jaccard": - kernel_matrix.data /= 4 * n_neighbors - kernel_matrix.data - elif kernel_method == "hnoca": - kernel_matrix.data /= 2 * n_neighbors - kernel_matrix.data - kernel_matrix.data = kernel_matrix.data**2 + # Check if batch processing might be beneficial and warn user + assert yx is not None, "yx adjacency matrix must be available" + n_query_cells = yx.shape[0] + if ( + n_batches is None + and n_query_cells > PackageConstants.JACCARD_BATCH_WARNING_CELLS + and n_neighbors > PackageConstants.JACCARD_BATCH_WARNING_NEIGHBORS + ): + logger.warning( + f"Computing {kernel_method} kernel for {n_query_cells:,} cells with {n_neighbors} neighbors. " + f"Consider using batch processing (n_batches parameter) to reduce memory usage." + ) + + # Compute kernel matrix with optional batching using utility function + kernel_matrix = compute_jaccard_kernel_matrix(xx, yy, xy, yx, kernel_method, n_neighbors, n_batches) elif kernel_method in PackageConstants.CONNECTIVITY_BASED_KERNELS: # Validate self-mapping-only kernels @@ -423,8 +434,9 @@ def __repr__(self): # Kernel matrix info if self.kernel_matrix is not None: # Calculate sparsity percentage - total_elements = self.kernel_matrix.shape[0] * self.kernel_matrix.shape[1] - sparsity = self.kernel_matrix.nnz / total_elements + kernel_matrix = cast(csr_matrix, self.kernel_matrix) # for type checker + total_elements = kernel_matrix.shape[0] * kernel_matrix.shape[1] + sparsity = kernel_matrix.nnz / total_elements kernel_info = [ f"kernel='{self.kernel_method}'", diff --git a/src/cellmapper/utils.py b/src/cellmapper/utils.py index fbfdee4..b5a8e7f 100644 --- a/src/cellmapper/utils.py +++ b/src/cellmapper/utils.py @@ -1,10 +1,12 @@ """Utility functions for the CellMapper package.""" +import gc + import anndata as ad import numpy as np import pandas as pd from anndata import AnnData -from scipy.sparse import csr_matrix, issparse +from scipy.sparse import csr_matrix, issparse, vstack from scipy.sparse.linalg import LinearOperator, svds from sklearn.utils.extmath import randomized_svd @@ -367,3 +369,161 @@ def rmatvec(v): vt = vt[idx, :] return u, s, vt + + +def apply_jaccard_transformation(kernel_matrix: csr_matrix, kernel_method: str, n_neighbors: int) -> None: + """ + Apply Jaccard or HNOCA transformation to a kernel matrix in-place. + + Parameters + ---------- + kernel_matrix + Sparse matrix to transform in-place + kernel_method + Method to use: "jaccard" or "hnoca" + n_neighbors + Number of neighbors used for normalization + + Notes + ----- + This function modifies the kernel_matrix in-place to save memory. + For Jaccard: kernel = intersection / (4*k - intersection) + For HNOCA: kernel = (intersection / (2*k - intersection))^2 + """ + if kernel_method == "jaccard": + kernel_matrix.data /= 4 * n_neighbors - kernel_matrix.data + elif kernel_method == "hnoca": + kernel_matrix.data /= 2 * n_neighbors - kernel_matrix.data + kernel_matrix.data = kernel_matrix.data**2 + else: + raise ValueError(f"Unknown kernel method: {kernel_method}. Expected 'jaccard' or 'hnoca'.") + + +def compute_jaccard_kernel_matrix( + xx: csr_matrix, + yy: csr_matrix, + xy: csr_matrix, + yx: csr_matrix, + kernel_method: str, + n_neighbors: int, + n_batches: int | None = None, +) -> csr_matrix: + """ + Compute Jaccard or HNOCA kernel matrix with optional batch processing. + + Parameters + ---------- + xx, yy, xy, yx + Adjacency matrices from neighbor computations + kernel_method + Kernel method to use: "jaccard" or "hnoca" + n_neighbors + Number of neighbors for normalization + n_batches + Number of batches for memory-efficient computation. If None, use standard computation. + + Returns + ------- + csr_matrix + Computed kernel matrix + + Notes + ----- + This function implements both standard and batched computation modes. + Batched mode reduces memory usage for large datasets by processing + the computation in chunks. + """ + if n_batches is None: + # Standard computation + kernel_matrix = (yx @ xx.T) + (yy @ xy.T) + apply_jaccard_transformation(kernel_matrix, kernel_method, n_neighbors) + return kernel_matrix + else: + # Batched computation + return _compute_jaccard_kernel_batched(xx, yy, xy, yx, kernel_method, n_neighbors, n_batches) + + +def _compute_jaccard_kernel_batched( + xx: csr_matrix, + yy: csr_matrix, + xy: csr_matrix, + yx: csr_matrix, + kernel_method: str, + n_neighbors: int, + n_batches: int, +) -> csr_matrix: + """ + Compute Jaccard or HNOCA kernel matrix using batch processing. + + Parameters + ---------- + xx, yy, xy, yx + Adjacency matrices from neighbor computations + kernel_method + Kernel method to use: "jaccard" or "hnoca" + n_neighbors + Number of neighbors for normalization + n_batches + Number of batches to split the computation into + + Returns + ------- + csr_matrix + Computed kernel matrix + """ + # Calculate batch size based on query dataset (yx rows) + n_query = yx.shape[0] + batch_size = int(np.ceil(n_query / n_batches)) + + logger.info( + "Computing %s kernel with %s batches (~%s query cells per batch)", kernel_method, n_batches, f"{batch_size:,}" + ) + + # Pre-allocate list to store batch results + batch_results = [] + + # Pre-compute xx.T once to avoid repeated transposition + xx_T = xx.T + + for batch_idx in range(n_batches): + start_idx = batch_idx * batch_size + end_idx = min((batch_idx + 1) * batch_size, n_query) + + logger.info("Processing batch %s/%s: cells %s-%s", batch_idx + 1, n_batches, f"{start_idx:,}", f"{end_idx:,}") + + # Extract batch slices - minimize memory footprint + yx_batch = yx[start_idx:end_idx] + yy_batch = yy[start_idx:end_idx, :] + + # Compute first term: yx_batch @ xx.T + term1 = yx_batch @ xx_T + + # Compute second term: yy_batch @ xy.T + term2 = yy_batch @ xy.T + + # Combine terms + batch_kernel = term1 + term2 + + # Apply Jaccard/HNOCA transformation + apply_jaccard_transformation(batch_kernel, kernel_method, n_neighbors) + + # Store result in list + batch_results.append(batch_kernel) + + # Cleanup batch variables immediately to save memory + del yx_batch, yy_batch, term1, term2, batch_kernel + gc.collect() + + # Combine batch results using vstack + logger.info("Combining batch results...") + kernel_matrix = vstack(batch_results, format="csr") + + # Ensure we return a csr_matrix (not csr_array) + if not isinstance(kernel_matrix, csr_matrix): + kernel_matrix = csr_matrix(kernel_matrix) + + # Final cleanup + del batch_results, xx_T + gc.collect() + + return kernel_matrix diff --git a/tests/model/test_query_to_reference_mapping.py b/tests/model/test_query_to_reference_mapping.py index 6dc82a9..c4b5fd9 100644 --- a/tests/model/test_query_to_reference_mapping.py +++ b/tests/model/test_query_to_reference_mapping.py @@ -439,3 +439,60 @@ def test_map_method_with_subset_categories(self, query_reference_adata): assert "leiden_pred" in cmap.query.obs predicted_categories = set(cmap.query.obs["leiden_pred"].dropna().unique()) assert predicted_categories.issubset(set(subset_cats)) + + +class TestBatchProcessingCrossMapping: + """Test batch processing functionality for cross-mapping mode.""" + + def test_jaccard_batch_vs_standard_identical_results(self, query_reference_adata): + """Test that batched and standard Jaccard computation give identical results in cross-mapping mode.""" + query, reference = query_reference_adata + + # Test standard computation + cm_standard = CellMapper(query, reference) + cm_standard.compute_neighbors(n_neighbors=15, use_rep="X_pca", only_yx=False) + cm_standard.compute_mapping_matrix(kernel_method="jaccard", n_batches=None) + + standard_kernel = cm_standard.knn.kernel_matrix.copy() + + # Test batched computation + cm_batch = CellMapper(query, reference) + cm_batch.compute_neighbors(n_neighbors=15, use_rep="X_pca", only_yx=False) + cm_batch.compute_mapping_matrix(kernel_method="jaccard", n_batches=3) + + batch_kernel = cm_batch.knn.kernel_matrix + + # Verify matrices are identical + assert standard_kernel.shape == batch_kernel.shape, "Kernel matrix shapes should match" + assert standard_kernel.nnz == batch_kernel.nnz, "Number of non-zero elements should match" + assert (standard_kernel - batch_kernel).nnz == 0, "Kernel matrices should be identical" + + # Verify mapping results are identical + cm_standard.map_obs(key="leiden") + cm_batch.map_obs(key="leiden") + + # Check that predictions are identical + assert cm_standard.query.obs["leiden_pred"].equals(cm_batch.query.obs["leiden_pred"]), ( + "Label predictions should be identical between standard and batch computation" + ) + + @pytest.mark.parametrize( + "kernel_method,n_batches", [("jaccard", None), ("jaccard", 2), ("hnoca", None), ("hnoca", 3)] + ) + def test_jaccard_hnoca_batch_parametrized(self, query_reference_adata, kernel_method, n_batches): + """Test both Jaccard and HNOCA kernels with different batch configurations.""" + query, reference = query_reference_adata + + cm = CellMapper(query, reference) + cm.compute_neighbors(n_neighbors=12, use_rep="X_pca", only_yx=False) + cm.compute_mapping_matrix(kernel_method=kernel_method, n_batches=n_batches) + + # Verify kernel matrix properties + expected_shape = (query.n_obs, reference.n_obs) + assert cm.knn.kernel_matrix is not None, "Kernel matrix should be computed" + assert cm.knn.kernel_matrix.shape == expected_shape, f"Shape should match {expected_shape}" + assert cm.knn.kernel_matrix.nnz > 0, "Kernel matrix should have non-zero elements" + + # Verify mapping works + cm.map_obs(key="leiden") + assert "leiden_pred" in cm.query.obs, "Label predictions should be generated" diff --git a/tests/model/test_self_mapping.py b/tests/model/test_self_mapping.py index a6e6297..08e0aa5 100644 --- a/tests/model/test_self_mapping.py +++ b/tests/model/test_self_mapping.py @@ -502,3 +502,59 @@ def test_iterative_vs_partial_spectral_similarity(self, adata_pbmc3k): # Assert high similarity assert cell_corrs.mean() > 0.99, f"Cell correlations not high enough: {cell_corrs.mean():.6f}" assert gene_corrs.mean() > 0.95, f"Gene correlations not high enough: {gene_corrs.mean():.6f}" + + +class TestBatchProcessingSelfMapping: + """Test batch processing functionality for self-mapping mode.""" + + def test_jaccard_batch_vs_standard_identical_results(self, adata_pbmc3k): + """Test that batched and standard Jaccard computation give identical results in self-mapping mode.""" + # Use subset for faster testing + adata_subset = adata_pbmc3k[:1000].copy() + + # Test standard computation + cm_standard = CellMapper(adata_subset) + cm_standard.compute_neighbors(n_neighbors=15, use_rep="X_pca", only_yx=False) + cm_standard.compute_mapping_matrix(kernel_method="jaccard", n_batches=None) + + standard_kernel = cm_standard.knn.kernel_matrix.copy() + + # Test batched computation + cm_batch = CellMapper(adata_subset) + cm_batch.compute_neighbors(n_neighbors=15, use_rep="X_pca", only_yx=False) + cm_batch.compute_mapping_matrix(kernel_method="jaccard", n_batches=4) + + batch_kernel = cm_batch.knn.kernel_matrix + + # Verify matrices are identical + assert standard_kernel.shape == batch_kernel.shape, "Kernel matrix shapes should match" + assert standard_kernel.nnz == batch_kernel.nnz, "Number of non-zero elements should match" + assert (standard_kernel - batch_kernel).nnz == 0, "Kernel matrices should be identical" + + # Verify mapping results are identical + cm_standard.map_obs(key="leiden") + cm_batch.map_obs(key="leiden") + + # Check that predictions are identical + assert cm_standard.query.obs["leiden_pred"].equals(cm_batch.query.obs["leiden_pred"]), ( + "Label predictions should be identical between standard and batch computation" + ) + + @pytest.mark.parametrize("n_batches", [None, 2, 4]) + def test_hnoca_kernel_batch_variants(self, adata_pbmc3k, n_batches): + """Test HNOCA kernel computation with different batch configurations.""" + # Use subset for faster testing + adata_subset = adata_pbmc3k[:800].copy() + + cm = CellMapper(adata_subset) + cm.compute_neighbors(n_neighbors=10, use_rep="X_pca", only_yx=False) + cm.compute_mapping_matrix(kernel_method="hnoca", n_batches=n_batches) + + # Verify kernel matrix properties + assert cm.knn.kernel_matrix is not None, "Kernel matrix should be computed" + assert cm.knn.kernel_matrix.shape == (adata_subset.n_obs, adata_subset.n_obs), "Shape should match dataset size" + assert cm.knn.kernel_matrix.nnz > 0, "Kernel matrix should have non-zero elements" + + # Verify mapping works + cm.map_obs(key="leiden") + assert "leiden_pred" in cm.query.obs, "Label predictions should be generated" From 028ac17643d1968e9434032547a8ff659e3ba05b Mon Sep 17 00:00:00 2001 From: Marius Lange Date: Tue, 29 Jul 2025 13:22:20 +0200 Subject: [PATCH 2/6] Collect duplicated docstrings --- src/cellmapper/_docs.py | 6 ++++++ src/cellmapper/model/cellmapper.py | 4 +--- src/cellmapper/model/kernel.py | 4 +--- src/cellmapper/utils.py | 9 +++++---- 4 files changed, 13 insertions(+), 10 deletions(-) diff --git a/src/cellmapper/_docs.py b/src/cellmapper/_docs.py index 296bf58..a6381c2 100644 --- a/src/cellmapper/_docs.py +++ b/src/cellmapper/_docs.py @@ -98,6 +98,11 @@ will be mapped, and others will be ignored. For numerical data, this parameter is ignored with a warning. Can be a single category string or a list of category strings.""" +_n_batches = """\ +n_batches + Number of batches to use for Jaccard-based mapping matrix computation. If None (default), + compute in a single batch. Use batch processing to reduce memory usage for large datasets.""" + d = DocstringProcessor( t=_t, @@ -114,4 +119,5 @@ use_rep=_use_rep, knn_dist_metric=_knn_dist_metric, subset_categories=_subset_categories, + n_batches=_n_batches, ) diff --git a/src/cellmapper/model/cellmapper.py b/src/cellmapper/model/cellmapper.py index a3833ed..f620625 100644 --- a/src/cellmapper/model/cellmapper.py +++ b/src/cellmapper/model/cellmapper.py @@ -258,9 +258,7 @@ def compute_mapping_matrix( Eigendecomposition method for spectral approach: - "partial": Uses sparse eigendecomposition, faster (default) - "complete": Uses complete eigendecomposition, exact for testing - n_batches - Number of batches to split Jaccard/HNOCA computation into for memory efficiency. - Only applies to Jaccard and HNOCA methods. If None, computation is done in a single batch. + %(n_batches)s Returns ------- diff --git a/src/cellmapper/model/kernel.py b/src/cellmapper/model/kernel.py index 6a77988..2a749f5 100644 --- a/src/cellmapper/model/kernel.py +++ b/src/cellmapper/model/kernel.py @@ -233,9 +233,7 @@ def compute_kernel_matrix( - "max": Take element-wise maximum between matrix and transpose (preserves strongest connections) - "mean": Take element-wise average between matrix and transpose (smooths connections) %(self_edges)s - n_batches - Number of batches to use for Jaccard/HNOCA computation. If None (default), - compute in a single batch. Use batch processing to reduce memory usage for large datasets. + %(n_batches)s **kwargs Additional keyword arguments for kernel computation. diff --git a/src/cellmapper/utils.py b/src/cellmapper/utils.py index b5a8e7f..22b5fd9 100644 --- a/src/cellmapper/utils.py +++ b/src/cellmapper/utils.py @@ -10,6 +10,7 @@ from scipy.sparse.linalg import LinearOperator, svds from sklearn.utils.extmath import randomized_svd +from cellmapper._docs import d from cellmapper.constants import PackageConstants from cellmapper.logging import logger @@ -399,6 +400,7 @@ def apply_jaccard_transformation(kernel_matrix: csr_matrix, kernel_method: str, raise ValueError(f"Unknown kernel method: {kernel_method}. Expected 'jaccard' or 'hnoca'.") +@d.dedent def compute_jaccard_kernel_matrix( xx: csr_matrix, yy: csr_matrix, @@ -419,8 +421,7 @@ def compute_jaccard_kernel_matrix( Kernel method to use: "jaccard" or "hnoca" n_neighbors Number of neighbors for normalization - n_batches - Number of batches for memory-efficient computation. If None, use standard computation. + %(n_batches)s Returns ------- @@ -443,6 +444,7 @@ def compute_jaccard_kernel_matrix( return _compute_jaccard_kernel_batched(xx, yy, xy, yx, kernel_method, n_neighbors, n_batches) +@d.dedent def _compute_jaccard_kernel_batched( xx: csr_matrix, yy: csr_matrix, @@ -463,8 +465,7 @@ def _compute_jaccard_kernel_batched( Kernel method to use: "jaccard" or "hnoca" n_neighbors Number of neighbors for normalization - n_batches - Number of batches to split the computation into + %(n_batches)s Returns ------- From 2235fde97c957956eb853bcbf3aa3ca03c99c251 Mon Sep 17 00:00:00 2001 From: Marius Lange Date: Tue, 29 Jul 2025 16:14:53 +0200 Subject: [PATCH 3/6] Improve logging --- src/cellmapper/model/kernel.py | 8 +++++--- src/cellmapper/utils.py | 2 +- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/cellmapper/model/kernel.py b/src/cellmapper/model/kernel.py index 2a749f5..e416ac5 100644 --- a/src/cellmapper/model/kernel.py +++ b/src/cellmapper/model/kernel.py @@ -266,14 +266,16 @@ def compute_kernel_matrix( # Check if batch processing might be beneficial and warn user assert yx is not None, "yx adjacency matrix must be available" - n_query_cells = yx.shape[0] + n_query_cells = self.yrep.shape[0] + n_reference_cells = self.xrep.shape[0] if ( n_batches is None - and n_query_cells > PackageConstants.JACCARD_BATCH_WARNING_CELLS + and (n_query_cells > PackageConstants.JACCARD_BATCH_WARNING_CELLS) + or (n_reference_cells > PackageConstants.JACCARD_BATCH_WARNING_CELLS) and n_neighbors > PackageConstants.JACCARD_BATCH_WARNING_NEIGHBORS ): logger.warning( - f"Computing {kernel_method} kernel for {n_query_cells:,} cells with {n_neighbors} neighbors. " + f"Computing {kernel_method} kernel for {n_query_cells:,} x {n_reference_cells:,} cells with {n_neighbors} neighbors. " f"Consider using batch processing (n_batches parameter) to reduce memory usage." ) diff --git a/src/cellmapper/utils.py b/src/cellmapper/utils.py index 22b5fd9..cdec31b 100644 --- a/src/cellmapper/utils.py +++ b/src/cellmapper/utils.py @@ -490,7 +490,7 @@ def _compute_jaccard_kernel_batched( start_idx = batch_idx * batch_size end_idx = min((batch_idx + 1) * batch_size, n_query) - logger.info("Processing batch %s/%s: cells %s-%s", batch_idx + 1, n_batches, f"{start_idx:,}", f"{end_idx:,}") + logger.debug("Processing batch %s/%s: cells %s-%s", batch_idx + 1, n_batches, f"{start_idx:,}", f"{end_idx:,}") # Extract batch slices - minimize memory footprint yx_batch = yx[start_idx:end_idx] From 985ea6fc1315b2346b50058706c827d48781bbd8 Mon Sep 17 00:00:00 2001 From: Marius Lange Date: Tue, 29 Jul 2025 16:20:26 +0200 Subject: [PATCH 4/6] Be more explicit abotu matrix types --- src/cellmapper/model/mapping_operator.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/cellmapper/model/mapping_operator.py b/src/cellmapper/model/mapping_operator.py index f1b5be4..ab3d89b 100644 --- a/src/cellmapper/model/mapping_operator.py +++ b/src/cellmapper/model/mapping_operator.py @@ -94,7 +94,7 @@ def __init__( is_self_mapping = kernel_obj._is_self_mapping kernel_matrix = actual_matrix - else: + elif isinstance(kernel_matrix, csr_matrix | coo_matrix | csc_matrix | np.ndarray): # This is a raw matrix actual_matrix = kernel_matrix @@ -103,6 +103,8 @@ def __init__( n_rows, n_cols = actual_matrix.shape is_self_mapping = n_rows == n_cols logger.info("Inferred is_self_mapping=%s from matrix shape %s", is_self_mapping, actual_matrix.shape) + else: + raise ValueError(f"Unknown kernel_matrix type: {type(kernel_matrix)}") self.is_self_mapping = is_self_mapping self.eigen_solver = eigen_solver From 800b54d064d4b03aaa8dc1dde964fd8899758db6 Mon Sep 17 00:00:00 2001 From: Marius Lange Date: Tue, 29 Jul 2025 16:23:53 +0200 Subject: [PATCH 5/6] Fix the warning --- src/cellmapper/model/kernel.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/cellmapper/model/kernel.py b/src/cellmapper/model/kernel.py index e416ac5..87f8f0f 100644 --- a/src/cellmapper/model/kernel.py +++ b/src/cellmapper/model/kernel.py @@ -270,8 +270,10 @@ def compute_kernel_matrix( n_reference_cells = self.xrep.shape[0] if ( n_batches is None - and (n_query_cells > PackageConstants.JACCARD_BATCH_WARNING_CELLS) - or (n_reference_cells > PackageConstants.JACCARD_BATCH_WARNING_CELLS) + and ( + (n_query_cells > PackageConstants.JACCARD_BATCH_WARNING_CELLS) + or (n_reference_cells > PackageConstants.JACCARD_BATCH_WARNING_CELLS) + ) and n_neighbors > PackageConstants.JACCARD_BATCH_WARNING_NEIGHBORS ): logger.warning( From 3eb057e920ca186b8bd79d2dca476e42c104e9c3 Mon Sep 17 00:00:00 2001 From: Marius Lange Date: Tue, 29 Jul 2025 16:34:00 +0200 Subject: [PATCH 6/6] Add an extra check for type validation --- src/cellmapper/model/mapping_operator.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/cellmapper/model/mapping_operator.py b/src/cellmapper/model/mapping_operator.py index ab3d89b..227cb83 100644 --- a/src/cellmapper/model/mapping_operator.py +++ b/src/cellmapper/model/mapping_operator.py @@ -80,7 +80,9 @@ def __init__( """ # Extract matrix and metadata from Kernel object if provided - if isinstance(kernel_matrix, Kernel): + # We check the type name as a string to avoid issues with module reloading + # where `isinstance` can fail unexpectedly. + if type(kernel_matrix).__name__ == "Kernel" or isinstance(kernel_matrix, Kernel): # This is a Kernel object kernel_obj = kernel_matrix actual_matrix = kernel_obj.kernel_matrix @@ -93,7 +95,6 @@ def __init__( if is_self_mapping is None: is_self_mapping = kernel_obj._is_self_mapping - kernel_matrix = actual_matrix elif isinstance(kernel_matrix, csr_matrix | coo_matrix | csc_matrix | np.ndarray): # This is a raw matrix actual_matrix = kernel_matrix