-
Notifications
You must be signed in to change notification settings - Fork 42
Description
Please make sure these conditions are met
- I have checked that this issue has not already been reported.
- I have confirmed this bug exists on the latest version of pertpy.
- (optional) I have confirmed this bug exists on the main branch.
Report
Hi all,
I am trying to do gRNA assignment with pertpy's assign_mixture_model(), but it seems extremely slow and so I would be very grateful for any advice! I am running this on an interactive GPU node on Wharton's HPC where JAX is installed with CUDA support. I am using a subset of the Gasperini data with the first 2500 gRNAs and the first 50k cells. It takes hours to run on this dataset, which seems weird to me. I get a constant stream of JAX/JIT compilation notices (when using JAX_LOG_COMPILES=1), which don't seem to slow down as it runs. I'd have expected a flurry of them to begin with, but it seems non-stop, so I'm wondering if this reveals unintended behavior. Thanks in advance for any help here!
Below are the exact steps I'm doing to recreate this.
First, I'm starting a fresh GPU session (via $ qlogin -q gpu.q), confirming that I do indeed have a GPU, and determining which version of CUDA I have.
$ nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Wed_Apr_17_19:19:55_PDT_2024
Cuda compilation tools, release 12.5, V12.5.40
Build cuda_12.5.r12.5/compiler.34177558_0
$ nvidia-smi
Wed Nov 12 13:40:59 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 555.42.02 Driver Version: 555.42.02 CUDA Version: 12.5 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA L40S Off | 00000000:30:00.0 Off | 0 |
| N/A 22C P8 20W / 350W | 4MiB / 46068MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| No running processes found |
+-----------------------------------------------------------------------------------------+Next, I'm making my conda environment and setting a few shell variables:
conda create -n pertpy-jax-issue python=3.12 -y
conda activate pertpy-jax-issue
python -m pip install -U pip setuptools wheel
pip install -U pertpy scanpy anndata h5py filelock
pip install -U "jax[cuda12]"
unset LD_LIBRARY_PATH # prevents use of system CUDA, which lacks cuSPARSE
export JAX_PLATFORM_NAME=gpu
export JAX_LOG_COMPILES=1 # give detailed logsFinally, this is the code I run to actually do gRNA assignment, along with some checks to verify that everything seems as expected.
import time
import sys
import jax, jax.numpy as jnp
import pertpy as pt
import scanpy as sc
# This is a subset of the (high MOI) Gasperini data,
# with 2500 gRNAs and 50,000 cells
dataset_name = "grna_matrix.h5ad"
# verify python and pertpy versions
print("python version:", sys.version) # shows `3.12.12`
print("pertpy version:", pt.__version__) # shows `1.0.3`
# Checking that we see the GPU
print("JAX backend:", jax.default_backend()) # shows `JAX backend: gpu`
print("JAX devices:", jax.devices()) # shows `JAX devices: [CudaDevice(id=0)]`
# quick warmup to force GPU work, just in case this helps
x = jnp.ones((4096, 4096), dtype=jnp.float32)
t0 = time.time()
(x @ x.T).block_until_ready()
print(f"Warmup matmul completed in {time.time() - t0:.2f}s")
# Loading the data and running gRNA assignment
adata = sc.read_h5ad(dataset_name)
ga = pt.pp.GuideAssignment()
print("Running mixture-model assignment...")
t0 = time.time()
ga.assign_mixture_model(
adata,
max_assignments_per_cell=100, # I tried this with 40 too
show_progress=True
)
dt = time.time() - t0
print(f"Finished in {dt:.1f}s")Here is an example of the stream of logging output that I get. This is after letting it run for one hour, so I don't think this is just things warming up! And it is only ~ 20% done after an hour, on a dataset with just 2500 gRNAs and 50,000 cells.
Working... ━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 18% 5:14:50WARNING:2025-11-12 15:16:02,874:jax._src.dispatch:222: Finished XLA compilation of jit(div) in 0.038060904 sec
WARNING:2025-11-12 15:16:02,877:jax._src.interpreters.pxla:1960: Compiling jit(_where) with global shapes and types (ShapedArray(bool[2,366]), ShapedArray(float64[2,366])). Argument mapping: (UnspecifiedValue, UnspecifiedValue).
WARNING:2025-11-12 15:16:02,885:jax._src.dispatch:222: Finished jaxpr to MLIR module conversion jit(_where) in 0.007384777 sec
WARNING:2025-11-12 15:16:02,922:jax._src.dispatch:222: Finished XLA compilation of jit(_where) in 0.036580563 sec
WARNING:2025-11-12 15:16:02,926:jax._src.interpreters.pxla:1960: Compiling jit(neg) with global shapes and types (ShapedArray(float64[366]),). Argument mapping: (UnspecifiedValue,).
WARNING:2025-11-12 15:16:02,933:jax._src.dispatch:222: Finished jaxpr to MLIR module conversion jit(neg) in 0.006520033 sec
Working... ━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 18% 5:14:50WARNING:2025-11-12 15:16:02,964:jax._src.dispatch:222: Finished XLA compilation of jit(neg) in 0.030367374 sec
WARNING:2025-11-12 15:16:02,969:jax._src.interpreters.pxla:1960: Compiling jit(reduce_sum) with global shapes and types (ShapedArray(float64[366]),). Argument mapping: (UnspecifiedValue,).
WARNING:2025-11-12 15:16:02,977:jax._src.dispatch:222: Finished jaxpr to MLIR module conversion jit(reduce_sum) in 0.007550478 sec
WARNING:2025-11-12 15:16:03,021:jax._src.dispatch:222: Finished XLA compilation of jit(reduce_sum) in 0.043447256 sec
WARNING:2025-11-12 15:16:03,024:jax._src.interpreters.pxla:1960: Compiling jit(mul) with global shapes and types (ShapedArray(float64[]), ShapedArray(float64[366])). Argument mapping: (UnspecifiedValue, UnspecifiedValue).
WARNING:2025-11-12 15:16:03,032:jax._src.dispatch:222: Finished jaxpr to MLIR module conversion jit(mul) in 0.007074356 sec
Working... ━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 18% 5:14:50WARNING:2025-11-12 15:16:03,066:jax._src.dispatch:222: Finished XLA compilation of jit(mul) in 0.034173489 sec
WARNING:2025-11-12 15:16:03,070:jax._src.interpreters.pxla:1960: Compiling jit(mul) with global shapes and types (ShapedArray(float64[366]), ShapedArray(float64[366])). Argument mapping: (UnspecifiedValue, UnspecifiedValue).
WARNING:2025-11-12 15:16:03,077:jax._src.dispatch:222: Finished jaxpr to MLIR module conversion jit(mul) in 0.006629467 sec
WARNING:2025-11-12 15:16:03,108:jax._src.dispatch:222: Finished XLA compilation of jit(mul) in 0.031239033 sec
WARNING:2025-11-12 15:16:03,112:jax._src.interpreters.pxla:1960: Compiling jit(mul) with global shapes and types (ShapedArray(float64[366]), ShapedArray(float64[])). Argument mapping: (UnspecifiedValue, UnspecifiedValue).
WARNING:2025-11-12 15:16:03,119:jax._src.dispatch:222: Finished jaxpr to MLIR module conversion jit(mul) in 0.006843328 sec
Working... ━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 18% 5:14:50WARNING:2025-11-12 15:16:03,152:jax._src.dispatch:222: Finished XLA compilation of jit(mul) in 0.032195091 sec
WARNING:2025-11-12 15:16:03,156:jax._src.interpreters.pxla:1960: Compiling jit(div) with global shapes and types (ShapedArray(float64[366]), ShapedArray(float64[])). Argument mapping: (UnspecifiedValue, UnspecifiedValue).
WARNING:2025-11-12 15:16:03,163:jax._src.dispatch:222: Finished jaxpr to MLIR module conversion jit(div) in 0.006853342 sec
WARNING:2025-11-12 15:16:03,201:jax._src.dispatch:222: Finished XLA compilation of jit(div) in 0.037773132 sec
WARNING:2025-11-12 15:16:03,207:jax._src.interpreters.pxla:1960: Compiling jit(mul) with global shapes and types (ShapedArray(float64[366,2]), ShapedArray(float64[366,2])). Argument mapping: (UnspecifiedValue, UnspecifiedValue).
WARNING:2025-11-12 15:16:03,216:jax._src.dispatch:222: Finished jaxpr to MLIR module conversion jit(mul) in 0.007925987 secThanks again for any insights!
-Louis
Versions
| Package | Version |
| ------- | ------- |
| pertpy | 1.0.3 |
| Dependency | Version |
| ------------------- | ------------- |
| numpy | 2.3.4 |
| multipledispatch | 1.0.0 (0.6.0) |
| toolz | 1.1.0 |
| h5py | 3.15.1 |
| Pygments | 2.19.2 |
| scipy | 1.16.3 |
| llvmlite | 0.45.1 |
| idna | 3.11 |
| anndata | 0.12.6 |
| jax-cuda12-plugin | 0.8.0 |
| pyparsing | 3.2.5 |
| ply | 3.11 |
| blitzgsea | 1.3.54 |
| cycler | 0.12.1 |
| kiwisolver | 1.4.9 |
| joblib | 1.5.2 |
| natsort | 8.4.0 |
| jax-cuda12-pjrt | 0.8.0 |
| importlib_resources | 6.5.2 |
| certifi | 2025.11.12 |
| msgpack | 1.1.2 |
| numpyro | 0.19.0 |
| equinox | 0.13.2 |
| statsmodels | 0.14.5 |
| donfig | 0.8.1.post1 |
| xarray | 2025.10.1 |
| opt_einsum | 3.4.0 |
| fsspec | 2025.10.0 |
| fast-array-utils | 1.3 |
| setuptools | 80.9.0 |
| requests | 2.32.5 |
| charset-normalizer | 3.4.4 |
| pillow | 12.0.0 |
| filelock | 3.20.0 |
| python-dateutil | 2.9.0.post0 |
| zarr | 3.1.3 |
| six | 1.17.0 |
| psutil | 7.1.3 |
| jaxlib | 0.8.0 |
| legacy-api-wrap | 1.5 |
| urllib3 | 2.5.0 |
| absl-py | 2.3.1 |
| rich | 14.2.0 |
| pandas | 2.3.3 |
| flax | 0.12.0 |
| pyarrow | 22.0.0 |
| simplejson | 3.20.2 |
| jax | 0.8.0 |
| ott-jax | 0.6.0 |
| typing_extensions | 4.15.0 |
| scikit-learn | 1.7.2 |
| optax | 0.2.6 |
| pyomo | 6.9.5 |
| crc32c | 2.8 |
| lineax | 0.0.8 |
| adjustText | 1.3.0 |
| PubChemPy | 1.0.5 |
| sparsecca | 0.3.1 |
| mudata | 0.3.2 |
| pytz | 2025.2 |
| PyYAML | 6.0.3 |
| scanpy | 1.11.5 |
| packaging | 25.0 |
| wadler_lindig | 0.1.7 |
| threadpoolctl | 3.6.0 |
| matplotlib | 3.10.7 |
| seaborn | 0.13.2 |
| lamin_utils | 0.15.0 |
| numba | 0.62.1 |
| chex | 0.1.91 |
| etils | 1.13.0 |
| jaxopt | 0.8.5 |
| numcodecs | 0.16.3 |
| jaxtyping | 0.3.3 |
| ml_dtypes | 0.5.3 |
| scikit-misc | 0.5.2 |
| tqdm | 4.67.1 |
| session-info2 | 0.2.3 |
| patsy | 1.0.2 |
| mpmath | 1.3.0 |
| Component | Info |
| --------- | --------------------------------------------------------------------------------- |
| Python | 3.12.12 | packaged by Anaconda, Inc. | (main, Oct 21 2025, 20:16:04) [GCC 11.2.0] |
| OS | Linux-4.18.0-553.5.1.el8_10.x86_64-x86_64-with-glibc2.28 |
| Updated | 2025-11-12 20:27