Thank you so much for this interesting package! I was reading through the paper, and you calculate a shifted gene expression covariance matrix. Subsequently, you state you compute the euclidean distance (MSQR) ||sqrt(sigma_i) - sqrt(sigma_j)||^2_2 as an AOT measure, which allows for computational tractability.
I am sure this question is rather dumb and I'm missing something obvious, but I am confused how you are handling negative covariances throughout this package? Is it as simple as allowing imaginary numbers throughout the computation, or is there a computational trick to prevent Jax from freaking out about taking the square root of negative numbers?