diff --git a/.pylintrc b/.pylintrc index e8a2d0027..3e7914e54 100644 --- a/.pylintrc +++ b/.pylintrc @@ -63,7 +63,7 @@ ignore-patterns=^\.# # (useful for modules/projects where namespaces are manipulated during runtime # and thus existing member attributes cannot be deduced by static analysis). It # supports qualified module names, as well as Unix pattern matching. -ignored-modules=pyrecest.backend +ignored-modules=pyrecest.backend, jax # Python code to execute, usually for sys.path manipulation such as # pygtk.require(). diff --git a/pyrecest/_backend/__init__.py b/pyrecest/_backend/__init__.py index c64e4bc28..aa2b32187 100644 --- a/pyrecest/_backend/__init__.py +++ b/pyrecest/_backend/__init__.py @@ -246,6 +246,9 @@ def get_backend_name(): "randint", "seed", "uniform", + # For PyRecEst + "get_state", + "set_state", ], "fft": [ # For PyRecEst "rfft", diff --git a/pyrecest/_backend/jax/random.py b/pyrecest/_backend/jax/random.py index 9a06369d6..f95f35ac1 100644 --- a/pyrecest/_backend/jax/random.py +++ b/pyrecest/_backend/jax/random.py @@ -25,12 +25,13 @@ def create_random_state(seed = 0): def global_random_state(): return backend.jax_global_random_state - def set_global_random_state(state): backend.jax_global_random_state = state +get_state = global_random_state +set_state = set_global_random_state -def get_state(**kwargs): +def _get_state(**kwargs): has_state = 'state' in kwargs state = kwargs.pop('state', backend.jax_global_random_state) return state, has_state, kwargs @@ -51,7 +52,7 @@ def _rand(state, size, *args, **kwargs): def rand(size, *args, **kwargs): size = size if hasattr(size, "__iter__") else (size,) - state, has_state, kwargs = get_state(**kwargs) + state, has_state, kwargs = _get_state(**kwargs) state, res = _rand(state, size, *args, **kwargs) return set_state_return(has_state, state, res) @@ -66,7 +67,7 @@ def _randint(state, size, *args, **kwargs): def randint(size, *args, **kwargs): size = size if hasattr(size, "__iter__") else (size,) - state, has_state, kwargs = get_state(**kwargs) + state, has_state, kwargs = _get_state(**kwargs) state, res = _randint(state, size, *args, **kwargs) return set_state_return(has_state, state, res) @@ -78,7 +79,7 @@ def _normal(state, size, *args, **kwargs): def normal(size, *args, **kwargs): size = size if hasattr(size, "__iter__") else (size,) - state, has_state, kwargs = get_state(**kwargs) + state, has_state, kwargs = _get_state(**kwargs) # Check and remove 'mean' and 'cov' from kwargs mean = kwargs.pop('mean', None) @@ -102,7 +103,7 @@ def _choice(state, a, n, *args, **kwargs): def choice(a, n, *args, **kwargs): - state, has_state, kwargs = get_state(**kwargs) + state, has_state, kwargs = _get_state(**kwargs) state, res = _choice(state, a, n, *args, **kwargs) return set_state_return(has_state, state, res) @@ -114,7 +115,7 @@ def _multivariate_normal(state, size, *args, **kwargs): def multivariate_normal(size, *args, **kwargs): size = size if hasattr(size, "__iter__") else (size,) - state, has_state, kwargs = get_state(**kwargs) + state, has_state, kwargs = _get_state(**kwargs) state, res = _multivariate_normal(state, size, *args, **kwargs) return set_state_return(has_state, state, res) diff --git a/pyrecest/_backend/numpy/random.py b/pyrecest/_backend/numpy/random.py index 33f479036..2822318af 100644 --- a/pyrecest/_backend/numpy/random.py +++ b/pyrecest/_backend/numpy/random.py @@ -3,5 +3,6 @@ import numpy as _np from numpy.random import default_rng as _default_rng from numpy.random import randint, seed, multinomial +from numpy.random import set_state, get_state # For PyRecEst from .._shared_numpy.random import choice, multivariate_normal, normal, rand, uniform diff --git a/pyrecest/_backend/pytorch/random.py b/pyrecest/_backend/pytorch/random.py index ca405228e..4db7aac0a 100644 --- a/pyrecest/_backend/pytorch/random.py +++ b/pyrecest/_backend/pytorch/random.py @@ -2,6 +2,8 @@ import torch as _torch from torch import rand, randint +from torch import get_rng_state as get_state # For PyRecEst +from torch import set_rng_state as set_state # For PyRecEst from torch.distributions.multivariate_normal import ( MultivariateNormal as _MultivariateNormal, ) diff --git a/pyrecest/distributions/abstract_manifold_specific_distribution.py b/pyrecest/distributions/abstract_manifold_specific_distribution.py index 2e9f2df5d..22aeccc30 100644 --- a/pyrecest/distributions/abstract_manifold_specific_distribution.py +++ b/pyrecest/distributions/abstract_manifold_specific_distribution.py @@ -1,10 +1,11 @@ from abc import ABC, abstractmethod from collections.abc import Callable from typing import Union +import inspect import pyrecest.backend -# pylint: disable=no-name-in-module,no-member +# pylint: disable=no-name-in-module,no-member,redefined-builtin from pyrecest.backend import empty, int32, int64, log, random, squeeze @@ -64,13 +65,14 @@ def set_mode(self, _): """ raise NotImplementedError("set_mode is not implemented for this distribution") + # Need to use Union instead of | to support torch.dtype # Need to use Union instead of | to support torch.dtype def sample(self, n: Union[int, int32, int64]): """Obtain n samples from the distribution.""" return self.sample_metropolis_hastings(n) # jscpd:ignore-start - # pylint: disable=too-many-positional-arguments + # pylint: disable=too-many-positional-arguments,too-many-locals def sample_metropolis_hastings( self, n: Union[int, int32, int64], @@ -81,30 +83,48 @@ def sample_metropolis_hastings( ): # jscpd:ignore-end """Metropolis Hastings sampling algorithm.""" - assert ( - pyrecest.backend.__backend_name__ != "jax" - ), "Not supported on this backend" + if pyrecest.backend.__backend_name__ == "jax": + # Get a key from your global JAX random state *outside* of lax.scan + import jax as _jax + + key = random.get_state() + key, key_for_mh = _jax.random.split(key) + # Optionally update global state for future calls + random.set_state(key) + + if proposal is None or start_point is None: + raise NotImplementedError( + "Default proposals and starting points should be set in inheriting classes." + ) + _assert_proposal_supports_key(proposal) + + samples, _ = sample_metropolis_hastings_jax( + key=key_for_mh, + log_pdf=self.ln_pdf, + proposal=proposal, # must be (key, x) -> x_prop for JAX + start_point=start_point, + n=int(n), + burn_in=int(burn_in), + skipping=int(skipping), + ) + # You could optionally stash `key_out` somewhere if you want chain continuation. + return squeeze(samples) + + # Non-JAX backends → your old NumPy/Torch code if proposal is None or start_point is None: raise NotImplementedError( "Default proposals and starting points should be set in inheriting classes." ) total_samples = burn_in + n * skipping - s = empty( - ( - total_samples, - self.input_dim, - ), - ) + s = empty((total_samples, self.input_dim)) x = start_point i = 0 pdfx = self.pdf(x) while i < total_samples: x_new = proposal(x) - assert ( - x_new.shape == x.shape - ), "Proposal must return a vector of same shape as input" + assert x_new.shape == x.shape, "Proposal must return a vector of same shape as input" pdfx_new = self.pdf(x_new) a = pdfx_new / pdfx if a.item() > 1 or a.item() > random.rand(1): @@ -115,3 +135,112 @@ def sample_metropolis_hastings( relevant_samples = s[burn_in::skipping, :] return squeeze(relevant_samples) + +# pylint: disable=too-many-positional-arguments,too-many-locals,too-many-arguments +def sample_metropolis_hastings_jax( + key, + log_pdf, # function: x -> log p(x) + proposal, # function: (key, x) -> x_prop + start_point, + n: int, + burn_in: int = 10, + skipping: int = 5, +): + """ + Metropolis-Hastings sampler in JAX. + + key: jax.random.PRNGKey + log_pdf: callable x -> log p(x) + proposal: callable (key, x) -> x_proposed + start_point: initial state (array) + n: number of samples to return (after burn-in and thinning) + """ + import jax.numpy as _jnp + from jax import lax as _lax + from jax import random as _random + + + start_point = _jnp.asarray(start_point) + total_steps = burn_in + n * skipping + + def one_step(carry, _): + key, x, log_px = carry + key, key_prop, key_u = _random.split(key, 3) + + # Propose new state + x_prop = proposal(key_prop, x) + log_px_prop = log_pdf(x_prop) + + # log_alpha = log p(x_prop) - log p(x) + log_alpha = log_px_prop - log_px + + # Draw u ~ Uniform(0, 1) + u = _random.uniform(key_u, shape=()) + log_u = _jnp.log(u) + + # Accept if log u < min(0, log_alpha) + # (equivalent to u < exp(min(0, log_alpha))) + log_alpha_capped = _jnp.minimum(0.0, log_alpha) + accept = log_u < log_alpha_capped # scalar bool + + # Branch without Python if + x_new = _jnp.where(accept, x_prop, x) + log_px_new = _jnp.where(accept, log_px_prop, log_px) + + return (key, x_new, log_px_new), x_new + + init_carry = (key, start_point, log_pdf(start_point)) + (key_out, _, _), chain = _lax.scan( + one_step, + init_carry, + xs=None, + length=total_steps, + ) + + samples = chain[burn_in::skipping] + return samples, key_out + + +def _assert_proposal_supports_key(proposal: Callable): + """ + Check that `proposal` can be called as proposal(key, x). + + Raises a TypeError with a helpful message if this is not the case. + """ + # Unwrap jitted / partial / decorated functions if possible + func = proposal + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + + try: + sig = inspect.signature(func) + except (TypeError, ValueError): + # Can't introspect (e.g. builtins); fall back to a generic error + raise TypeError( + "For the JAX backend, `proposal` must accept (key, x) as arguments, " + "but its signature could not be inspected." + ) from None + + params = list(sig.parameters.values()) + + # Count positional(-or-keyword) parameters + num_positional = sum( + p.kind in (inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD) + for p in params + ) + has_var_positional = any( + p.kind == inspect.Parameter.VAR_POSITIONAL + for p in params + ) + + if has_var_positional or num_positional >= 2: + # Looks compatible with (key, x) + return + + raise TypeError( + "For the JAX backend, `proposal` must accept `(key, x)` as arguments.\n" + f"Got signature: {sig}\n" + "Hint: change your proposal from `def proposal(x): ...` to\n" + "`def proposal(key, x): ...` and use `jax.random` with the passed key." + ) diff --git a/pyrecest/distributions/hypersphere_subset/abstract_hyperhemispherical_distribution.py b/pyrecest/distributions/hypersphere_subset/abstract_hyperhemispherical_distribution.py index a16e685b1..d1730c68d 100644 --- a/pyrecest/distributions/hypersphere_subset/abstract_hyperhemispherical_distribution.py +++ b/pyrecest/distributions/hypersphere_subset/abstract_hyperhemispherical_distribution.py @@ -19,6 +19,11 @@ random, vstack, zeros, + sqrt, + cos, + sin, + linalg, + stack, ) from scipy.optimize import minimize @@ -60,9 +65,44 @@ def sample_metropolis_hastings( from .hyperhemispherical_uniform_distribution import ( HyperhemisphericalUniformDistribution, ) + if pyrecest.backend.__backend_name__ in ("numpy", "pytorch"): + def proposal(_): + return HyperhemisphericalUniformDistribution(self.dim).sample(1) + else: + # JAX backend: proposal(key, x) -> x_prop + import jax as _jax + import jax.numpy as _jnp + + def proposal(key, _): + """JAX independence proposal: uniform on upper hemisphere.""" + if self.dim == 2: + # Explicit S² sampling + key, key_phi = _jax.random.split(key) + key, key_sz = _jax.random.split(key) + + phi = 2.0 * _jnp.pi * _jax.random.uniform(key_phi, shape=(1,)) + sz = 2.0 * _jax.random.uniform(key_sz, shape=(1,)) - 1.0 + r = _jnp.sqrt(1.0 - sz**2) + + # Shape (1, 3) + s = _jnp.stack( + [r * _jnp.cos(phi), r * _jnp.sin(phi), sz], + axis=1, + ) + else: + # General S^d: sample N(0, I) in R^{d+1} and normalize + key, subkey = _jax.random.split(key) + samples_unnorm = _jax.random.normal(subkey, shape=(1, self.dim + 1)) + norms = _jnp.linalg.norm(samples_unnorm, axis=1, keepdims=True) + s = samples_unnorm / norms + + # Project to upper hemisphere: last coordinate >= 0 + # s shape: (1, dim+1); last coord is s[..., -1:] + sign = _jnp.where(s[..., -1:] < 0.0, -1.0, 1.0) + s = sign * s + + return s - def proposal(_): - return HyperhemisphericalUniformDistribution(self.dim).sample(1) if start_point is None: start_point = HyperhemisphericalUniformDistribution(self.dim).sample(1) diff --git a/pyrecest/distributions/hypersphere_subset/abstract_hyperspherical_distribution.py b/pyrecest/distributions/hypersphere_subset/abstract_hyperspherical_distribution.py index 378b6fff7..03a0ef200 100644 --- a/pyrecest/distributions/hypersphere_subset/abstract_hyperspherical_distribution.py +++ b/pyrecest/distributions/hypersphere_subset/abstract_hyperspherical_distribution.py @@ -21,6 +21,9 @@ sin, vstack, zeros, + sqrt, + stack, + linalg, ) from scipy.optimize import minimize @@ -74,9 +77,37 @@ def sample_metropolis_hastings( HypersphericalUniformDistribution, ) - def proposal(_): - return HypersphericalUniformDistribution(self.dim).sample(1) - + if pyrecest.backend.__backend_name__ in ("numpy", "pytorch"): + def proposal(_): + return HypersphericalUniformDistribution(self.dim).sample(1) + else: + import jax as _jax + import jax.numpy as _jnp + def proposal(key, _): + """JAX independence proposal: uniform on hypersphere.""" + if self.dim == 2: + # Explicit S² sampling + key, key_phi = _jax.random.split(key) + key, key_sz = _jax.random.split(key) + + phi = 2.0 * _jnp.pi * _jax.random.uniform(key_phi, shape=(1,)) + sz = 2.0 * _jax.random.uniform(key_sz, shape=(1,)) - 1.0 + r = _jnp.sqrt(1.0 - sz**2) + + # Shape (1, 3) + s = _jnp.stack( + [r * _jnp.cos(phi), r * _jnp.sin(phi), sz], + axis=1, + ) + else: + # General S^d: sample N(0, I) in R^{d+1} and normalize + key, subkey = _jax.random.split(key) + samples_unnorm = _jax.random.normal(subkey, shape=(1, self.dim + 1)) + norms = _jnp.linalg.norm(samples_unnorm, axis=1, keepdims=True) + s = samples_unnorm / norms + + return s + if start_point is None: start_point = HypersphericalUniformDistribution(self.dim).sample(1) # Call the sample_metropolis_hastings method of AbstractDistribution diff --git a/pyrecest/tests/distributions/test_abstract_mixture.py b/pyrecest/tests/distributions/test_abstract_mixture.py index 950e1c62a..84b7960e1 100644 --- a/pyrecest/tests/distributions/test_abstract_mixture.py +++ b/pyrecest/tests/distributions/test_abstract_mixture.py @@ -25,10 +25,6 @@ def _test_sample(self, mix, n): self.assertEqual(s.shape, (n, mix.input_dim)) return s - @unittest.skipIf( - pyrecest.backend.__backend_name__ == "jax", - reason="Not supported on this backend", - ) def test_sample_metropolis_hastings_basics_only_t2(self): vmf = ToroidalWrappedNormalDistribution(array([1.0, 0.0]), eye(2)) mix = HypertoroidalMixture( @@ -37,28 +33,28 @@ def test_sample_metropolis_hastings_basics_only_t2(self): self._test_sample(mix, 10) @unittest.skipIf( - pyrecest.backend.__backend_name__ in ("pytorch", "jax"), + pyrecest.backend.__backend_name__ in ("pytorch",), reason="Not supported on this backend", ) def test_sample_metropolis_hastings_basics_only_s2(self): vmf1 = VonMisesFisherDistribution( array([1.0, 0.0, 0.0]), 2.0 - ) # Needs to be float for scipy + ) vmf2 = VonMisesFisherDistribution( array([0.0, 1.0, 0.0]), 2.0 - ) # Needs to be float for scipy + ) mix = HypersphericalMixture([vmf1, vmf2], array([0.5, 0.5])) s = self._test_sample(mix, 10) self.assertTrue(allclose(linalg.norm(s, axis=1), ones(10), rtol=1e-10)) @unittest.skipIf( - pyrecest.backend.__backend_name__ in ("pytorch", "jax"), + pyrecest.backend.__backend_name__ in ("pytorch",), reason="Not supported on this backend", ) def test_sample_metropolis_hastings_basics_only_h2(self): vmf = VonMisesFisherDistribution( array([1.0, 0.0, 0.0]), 2.0 - ) # Needs to be float for scipy + ) mix = CustomHyperhemisphericalDistribution( lambda x: vmf.pdf(x) + vmf.pdf(-x), 2 )