Skip to content

assign_mixture_model() is really slow, even using JAX with CUDA support #881

@jdeu1023

Description

@jdeu1023

Please make sure these conditions are met

  • I have checked that this issue has not already been reported.
  • I have confirmed this bug exists on the latest version of pertpy.
  • (optional) I have confirmed this bug exists on the main branch.

Report

Hi all,

I am trying to do gRNA assignment with pertpy's assign_mixture_model(), but it seems extremely slow and so I would be very grateful for any advice! I am running this on an interactive GPU node on Wharton's HPC where JAX is installed with CUDA support. I am using a subset of the Gasperini data with the first 2500 gRNAs and the first 50k cells. It takes hours to run on this dataset, which seems weird to me. I get a constant stream of JAX/JIT compilation notices (when using JAX_LOG_COMPILES=1), which don't seem to slow down as it runs. I'd have expected a flurry of them to begin with, but it seems non-stop, so I'm wondering if this reveals unintended behavior. Thanks in advance for any help here!

Below are the exact steps I'm doing to recreate this.

First, I'm starting a fresh GPU session (via $ qlogin -q gpu.q), confirming that I do indeed have a GPU, and determining which version of CUDA I have.

$ nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Wed_Apr_17_19:19:55_PDT_2024
Cuda compilation tools, release 12.5, V12.5.40
Build cuda_12.5.r12.5/compiler.34177558_0

$ nvidia-smi
Wed Nov 12 13:40:59 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 555.42.02              Driver Version: 555.42.02      CUDA Version: 12.5     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA L40S                    Off |   00000000:30:00.0 Off |                    0 |
| N/A   22C    P8             20W /  350W |       4MiB /  46068MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+

Next, I'm making my conda environment and setting a few shell variables:

conda create -n pertpy-jax-issue python=3.12 -y
conda activate pertpy-jax-issue

python -m pip install -U pip setuptools wheel
pip install -U pertpy scanpy anndata h5py filelock

pip install -U "jax[cuda12]"

unset LD_LIBRARY_PATH # prevents use of system CUDA, which lacks cuSPARSE

export JAX_PLATFORM_NAME=gpu
export JAX_LOG_COMPILES=1 # give detailed logs

Finally, this is the code I run to actually do gRNA assignment, along with some checks to verify that everything seems as expected.

import time
import sys

import jax, jax.numpy as jnp
import pertpy as pt
import scanpy as sc

# This is a subset of the (high MOI) Gasperini data,
# with 2500 gRNAs and 50,000 cells
dataset_name = "grna_matrix.h5ad"

# verify python and pertpy versions
print("python version:", sys.version) # shows `3.12.12`
print("pertpy version:", pt.__version__) # shows `1.0.3`

# Checking that we see the GPU
print("JAX backend:", jax.default_backend()) # shows `JAX backend: gpu`
print("JAX devices:", jax.devices()) # shows `JAX devices: [CudaDevice(id=0)]`
# quick warmup to force GPU work, just in case this helps
x = jnp.ones((4096, 4096), dtype=jnp.float32)
t0 = time.time()
(x @ x.T).block_until_ready()
print(f"Warmup matmul completed in {time.time() - t0:.2f}s")

# Loading the data and running gRNA assignment
adata = sc.read_h5ad(dataset_name)

ga = pt.pp.GuideAssignment()
print("Running mixture-model assignment...")
t0 = time.time()
ga.assign_mixture_model(
   adata,
   max_assignments_per_cell=100, # I tried this with 40 too
   show_progress=True
)
dt = time.time() - t0
print(f"Finished in {dt:.1f}s")

Here is an example of the stream of logging output that I get. This is after letting it run for one hour, so I don't think this is just things warming up! And it is only ~ 20% done after an hour, on a dataset with just 2500 gRNAs and 50,000 cells.

Working... ━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━  18% 5:14:50WARNING:2025-11-12 15:16:02,874:jax._src.dispatch:222: Finished XLA compilation of jit(div) in 0.038060904 sec
WARNING:2025-11-12 15:16:02,877:jax._src.interpreters.pxla:1960: Compiling jit(_where) with global shapes and types (ShapedArray(bool[2,366]), ShapedArray(float64[2,366])). Argument mapping: (UnspecifiedValue, UnspecifiedValue).
WARNING:2025-11-12 15:16:02,885:jax._src.dispatch:222: Finished jaxpr to MLIR module conversion jit(_where) in 0.007384777 sec
WARNING:2025-11-12 15:16:02,922:jax._src.dispatch:222: Finished XLA compilation of jit(_where) in 0.036580563 sec
WARNING:2025-11-12 15:16:02,926:jax._src.interpreters.pxla:1960: Compiling jit(neg) with global shapes and types (ShapedArray(float64[366]),). Argument mapping: (UnspecifiedValue,).
WARNING:2025-11-12 15:16:02,933:jax._src.dispatch:222: Finished jaxpr to MLIR module conversion jit(neg) in 0.006520033 sec
Working... ━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━  18% 5:14:50WARNING:2025-11-12 15:16:02,964:jax._src.dispatch:222: Finished XLA compilation of jit(neg) in 0.030367374 sec
WARNING:2025-11-12 15:16:02,969:jax._src.interpreters.pxla:1960: Compiling jit(reduce_sum) with global shapes and types (ShapedArray(float64[366]),). Argument mapping: (UnspecifiedValue,).
WARNING:2025-11-12 15:16:02,977:jax._src.dispatch:222: Finished jaxpr to MLIR module conversion jit(reduce_sum) in 0.007550478 sec
WARNING:2025-11-12 15:16:03,021:jax._src.dispatch:222: Finished XLA compilation of jit(reduce_sum) in 0.043447256 sec
WARNING:2025-11-12 15:16:03,024:jax._src.interpreters.pxla:1960: Compiling jit(mul) with global shapes and types (ShapedArray(float64[]), ShapedArray(float64[366])). Argument mapping: (UnspecifiedValue, UnspecifiedValue).
WARNING:2025-11-12 15:16:03,032:jax._src.dispatch:222: Finished jaxpr to MLIR module conversion jit(mul) in 0.007074356 sec
Working... ━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━  18% 5:14:50WARNING:2025-11-12 15:16:03,066:jax._src.dispatch:222: Finished XLA compilation of jit(mul) in 0.034173489 sec
WARNING:2025-11-12 15:16:03,070:jax._src.interpreters.pxla:1960: Compiling jit(mul) with global shapes and types (ShapedArray(float64[366]), ShapedArray(float64[366])). Argument mapping: (UnspecifiedValue, UnspecifiedValue).
WARNING:2025-11-12 15:16:03,077:jax._src.dispatch:222: Finished jaxpr to MLIR module conversion jit(mul) in 0.006629467 sec
WARNING:2025-11-12 15:16:03,108:jax._src.dispatch:222: Finished XLA compilation of jit(mul) in 0.031239033 sec
WARNING:2025-11-12 15:16:03,112:jax._src.interpreters.pxla:1960: Compiling jit(mul) with global shapes and types (ShapedArray(float64[366]), ShapedArray(float64[])). Argument mapping: (UnspecifiedValue, UnspecifiedValue).
WARNING:2025-11-12 15:16:03,119:jax._src.dispatch:222: Finished jaxpr to MLIR module conversion jit(mul) in 0.006843328 sec
Working... ━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━  18% 5:14:50WARNING:2025-11-12 15:16:03,152:jax._src.dispatch:222: Finished XLA compilation of jit(mul) in 0.032195091 sec
WARNING:2025-11-12 15:16:03,156:jax._src.interpreters.pxla:1960: Compiling jit(div) with global shapes and types (ShapedArray(float64[366]), ShapedArray(float64[])). Argument mapping: (UnspecifiedValue, UnspecifiedValue).
WARNING:2025-11-12 15:16:03,163:jax._src.dispatch:222: Finished jaxpr to MLIR module conversion jit(div) in 0.006853342 sec
WARNING:2025-11-12 15:16:03,201:jax._src.dispatch:222: Finished XLA compilation of jit(div) in 0.037773132 sec
WARNING:2025-11-12 15:16:03,207:jax._src.interpreters.pxla:1960: Compiling jit(mul) with global shapes and types (ShapedArray(float64[366,2]), ShapedArray(float64[366,2])). Argument mapping: (UnspecifiedValue, UnspecifiedValue).
WARNING:2025-11-12 15:16:03,216:jax._src.dispatch:222: Finished jaxpr to MLIR module conversion jit(mul) in 0.007925987 sec

Thanks again for any insights!

-Louis

Versions

| Package | Version |
| ------- | ------- |
| pertpy  | 1.0.3   |

| Dependency          | Version       |
| ------------------- | ------------- |
| numpy               | 2.3.4         |
| multipledispatch    | 1.0.0 (0.6.0) |
| toolz               | 1.1.0         |
| h5py                | 3.15.1        |
| Pygments            | 2.19.2        |
| scipy               | 1.16.3        |
| llvmlite            | 0.45.1        |
| idna                | 3.11          |
| anndata             | 0.12.6        |
| jax-cuda12-plugin   | 0.8.0         |
| pyparsing           | 3.2.5         |
| ply                 | 3.11          |
| blitzgsea           | 1.3.54        |
| cycler              | 0.12.1        |
| kiwisolver          | 1.4.9         |
| joblib              | 1.5.2         |
| natsort             | 8.4.0         |
| jax-cuda12-pjrt     | 0.8.0         |
| importlib_resources | 6.5.2         |
| certifi             | 2025.11.12    |
| msgpack             | 1.1.2         |
| numpyro             | 0.19.0        |
| equinox             | 0.13.2        |
| statsmodels         | 0.14.5        |
| donfig              | 0.8.1.post1   |
| xarray              | 2025.10.1     |
| opt_einsum          | 3.4.0         |
| fsspec              | 2025.10.0     |
| fast-array-utils    | 1.3           |
| setuptools          | 80.9.0        |
| requests            | 2.32.5        |
| charset-normalizer  | 3.4.4         |
| pillow              | 12.0.0        |
| filelock            | 3.20.0        |
| python-dateutil     | 2.9.0.post0   |
| zarr                | 3.1.3         |
| six                 | 1.17.0        |
| psutil              | 7.1.3         |
| jaxlib              | 0.8.0         |
| legacy-api-wrap     | 1.5           |
| urllib3             | 2.5.0         |
| absl-py             | 2.3.1         |
| rich                | 14.2.0        |
| pandas              | 2.3.3         |
| flax                | 0.12.0        |
| pyarrow             | 22.0.0        |
| simplejson          | 3.20.2        |
| jax                 | 0.8.0         |
| ott-jax             | 0.6.0         |
| typing_extensions   | 4.15.0        |
| scikit-learn        | 1.7.2         |
| optax               | 0.2.6         |
| pyomo               | 6.9.5         |
| crc32c              | 2.8           |
| lineax              | 0.0.8         |
| adjustText          | 1.3.0         |
| PubChemPy           | 1.0.5         |
| sparsecca           | 0.3.1         |
| mudata              | 0.3.2         |
| pytz                | 2025.2        |
| PyYAML              | 6.0.3         |
| scanpy              | 1.11.5        |
| packaging           | 25.0          |
| wadler_lindig       | 0.1.7         |
| threadpoolctl       | 3.6.0         |
| matplotlib          | 3.10.7        |
| seaborn             | 0.13.2        |
| lamin_utils         | 0.15.0        |
| numba               | 0.62.1        |
| chex                | 0.1.91        |
| etils               | 1.13.0        |
| jaxopt              | 0.8.5         |
| numcodecs           | 0.16.3        |
| jaxtyping           | 0.3.3         |
| ml_dtypes           | 0.5.3         |
| scikit-misc         | 0.5.2         |
| tqdm                | 4.67.1        |
| session-info2       | 0.2.3         |
| patsy               | 1.0.2         |
| mpmath              | 1.3.0         |

| Component | Info                                                                              |
| --------- | --------------------------------------------------------------------------------- |
| Python    | 3.12.12 | packaged by Anaconda, Inc. | (main, Oct 21 2025, 20:16:04) [GCC 11.2.0] |
| OS        | Linux-4.18.0-553.5.1.el8_10.x86_64-x86_64-with-glibc2.28                          |
| Updated   | 2025-11-12 20:27

Metadata

Metadata

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions