From eb7e73bfe4e8bcaaf37759019c81b609f4c4e0a4 Mon Sep 17 00:00:00 2001 From: Florian Pfaff Date: Tue, 1 Apr 2025 20:44:37 +0200 Subject: [PATCH] Fixed pdf of partially wrapped normal distribution --- .../partially_wrapped_normal_distribution.py | 136 ++++++++++++------ ..._abstract_hypercylindrical_distribution.py | 10 +- ...t_partially_wrapped_normal_distribution.py | 98 ++++++++++++- 3 files changed, 194 insertions(+), 50 deletions(-) diff --git a/pyrecest/distributions/cart_prod/partially_wrapped_normal_distribution.py b/pyrecest/distributions/cart_prod/partially_wrapped_normal_distribution.py index d01d63ccb..84d06e468 100644 --- a/pyrecest/distributions/cart_prod/partially_wrapped_normal_distribution.py +++ b/pyrecest/distributions/cart_prod/partially_wrapped_normal_distribution.py @@ -1,8 +1,10 @@ import copy from typing import Union +from itertools import product +# pylint: disable=no-name-in-module,no-member +import pyrecest.backend # pylint: disable=redefined-builtin,no-name-in-module,no-member -# pylint: disable=no-name-in-module,no-member from pyrecest.backend import ( allclose, arange, @@ -11,12 +13,12 @@ concatenate, cos, diag, + empty, exp, hstack, int32, int64, linalg, - meshgrid, mod, ndim, pi, @@ -25,8 +27,8 @@ sin, stack, sum, - tile, where, + zeros, ) from scipy.stats import multivariate_normal @@ -56,49 +58,95 @@ def __init__(self, mu, C, bound_dim: Union[int, int32, int64]): self.mu = where(arange(mu.shape[0]) < bound_dim, mod(mu, 2.0 * pi), mu) self.C = C - def pdf(self, xs, m: Union[int, int32, int64] = 3): - xs = atleast_2d(xs) - condition = ( - arange(xs.shape[1]) < self.bound_dim - ) # Create a condition based on column indices - xs = where( - # Broadcast the condition to match the shape of xs - condition[None, :], # noqa: E203 - mod(xs, 2.0 * pi), # Compute the modulus where the condition is True - xs, # Keep the original values where the condition is False - ) - - assert xs.shape[-1] == self.input_dim - - # generate multiples for wrapping - multiples = array(range(-m, m + 1)) * 2.0 * pi - - # create meshgrid for all combinations of multiples - mesh = array(meshgrid(*[multiples] * self.bound_dim)).reshape( - -1, self.bound_dim - ) - - # reshape xs for broadcasting - xs_reshaped = tile(xs[:, : self.bound_dim], (mesh.shape[0], 1)) # noqa: E203 - - # prepare data for wrapping (not applied to linear dimensions) - xs_wrapped = xs_reshaped + repeat(mesh, xs.shape[0], axis=0) - xs_wrapped = concatenate( - [ - xs_wrapped, - tile(xs[:, self.bound_dim :], (mesh.shape[0], 1)), # noqa: E203 - ], - axis=1, - ) + # pylint: disable=too-many-locals + def pdf(self, xs, m=3): + """ + Evaluate the PDF of the Hypercylindrical Wrapped Normal Distribution at given points. - # evaluate normal for all xs_wrapped - mvn = multivariate_normal(self.mu, self.C) - evals = array(mvn.pdf(xs_wrapped)) # For being compatible with all backends + Parameters: + xs (array-like): Input points of shape (n, d), where d = bound_dim + lin_dim. + m (int, optional): Number of summands in each direction for wrapping. Default is 3. - # sum evaluations for the wrapped dimensions - summed_evals = sum(evals.reshape(-1, (2 * m + 1) ** self.bound_dim), axis=1) + Returns: + p (ndarray): PDF values at each input point of shape (n,). + """ + assert ( + pyrecest.backend.__backend_name__ == "numpy" + ), "Only supported for numpy backend" - return summed_evals + xs = atleast_2d(xs) # Ensure xs is 2D + n, d = xs.shape + assert ( + d == self.dim + ), f"Input dimensionality {d} does not match distribution dimensionality {self.dim}." + + # Initialize the PDF values array + p = zeros(n) + + # Define batch size to manage memory usage + batch_size = 1000 + + # Generate all possible offset combinations for periodic dimensions + multiples = arange(-m, m + 1) * 2.0 * pi + offset_combinations = list( + product(multiples, repeat=self.bound_dim) + ) # Total combinations: (2m+1)^bound_dim + num_offsets = len(offset_combinations) + + # Pre-convert offset combinations to a NumPy array for efficient computation + offset_array = array(offset_combinations) # Shape: (num_offsets, bound_dim) + + # Process input data in batches + for start in range(0, n, batch_size): + end = min(start + batch_size, n) + batch = xs[start:end] # Shape: (batch_size, d) + + # Wrap periodic dimensions using modulus + batch_wrapped = batch.copy() + if self.bound_dim > 0: + batch_wrapped[:, : self.bound_dim] = mod( + batch_wrapped[:, : self.bound_dim], 2.0 * pi + ) # noqa: E203 + + if self.bound_dim > 0: + # Correct broadcasting: batch_wrapped becomes (batch_size, 1, bound_dim) + # offset_array becomes (1, num_offsets, bound_dim) + wrapped_periodic = batch_wrapped[:, :self.bound_dim][:, None, :] + offset_array[None, :, :] + # Now wrapped_periodic has shape (batch_size, num_offsets, bound_dim) + wrapped_periodic = wrapped_periodic.reshape(-1, self.bound_dim) + else: + wrapped_periodic = empty((0, 0)) # No periodic dimensions + + # Repeat linear dimensions for each offset + if self.lin_dim > 0: + linear_part = repeat( + batch_wrapped[:, self.bound_dim :], # noqa: E203 + num_offsets, + axis=0, + ) # Shape: (batch_size * num_offsets, lin_dim) + # Concatenate wrapped periodic and linear parts + if self.bound_dim > 0: + wrapped_points = hstack( + (wrapped_periodic, linear_part) + ) # Shape: (batch_size * num_offsets, d) + else: + wrapped_points = linear_part # Shape: (batch_size * num_offsets, d) + else: + wrapped_points = ( + wrapped_periodic # Shape: (batch_size * num_offsets, d) + ) + + mvn = multivariate_normal(mean=self.mu, cov=self.C) + # Evaluate the multivariate normal PDF at all wrapped points + pdf_vals = mvn.pdf(wrapped_points) # Shape: (batch_size * num_offsets,) + + # Reshape and sum the PDF values for each original point + pdf_vals = pdf_vals.reshape( + end - start, num_offsets + ) # Shape: (batch_size, num_offsets) + p[start:end] = sum(pdf_vals, axis=1) # Shape: (batch_size,) + + return p def mode(self): """ @@ -148,7 +196,7 @@ def sample(self, n: int): """ assert n > 0, "n must be positive" s = random.multivariate_normal(mean=self.mu, cov=self.C, size=(n,)) - wrapped_values = mod(s[:, : self.bound_dim], 2.0 * pi) + wrapped_values = mod(s[:, : self.bound_dim], 2.0 * pi) # noqa: E203 unbounded_values = s[:, self.bound_dim :] # noqa: E203 # Concatenate the modified section with the unmodified section diff --git a/pyrecest/tests/distributions/test_abstract_hypercylindrical_distribution.py b/pyrecest/tests/distributions/test_abstract_hypercylindrical_distribution.py index fd55430e8..b90a0901f 100644 --- a/pyrecest/tests/distributions/test_abstract_hypercylindrical_distribution.py +++ b/pyrecest/tests/distributions/test_abstract_hypercylindrical_distribution.py @@ -15,15 +15,19 @@ class AbstractHypercylindricalDistributionTest(unittest.TestCase): @unittest.skipIf( - pyrecest.backend.__backend_name__ == "jax", - reason="Not supported on jax backend", + pyrecest.backend.__backend_name__ in ("pytorch", "jax"), + reason="Not supported on this backend", ) def test_mode_numerical_gaussian_2D(self): mu = array([5.0, 1.0]) C = array([[2.0, 1.0], [1.0, 1.0]]) g = PartiallyWrappedNormalDistribution(mu, C, 1) npt.assert_allclose(g.mode_numerical(), mu, atol=5e-5) - + + @unittest.skipIf( + pyrecest.backend.__backend_name__ in ("pytorch", "jax"), + reason="Not supported on this backend", + ) def test_linear_mean_numerical(self): hwn = PartiallyWrappedNormalDistribution( array([1.0, 2.0]), array([[2.0, 0.3], [0.3, 1.0]]), 1 diff --git a/pyrecest/tests/distributions/test_partially_wrapped_normal_distribution.py b/pyrecest/tests/distributions/test_partially_wrapped_normal_distribution.py index 9a281402a..e81a423b5 100644 --- a/pyrecest/tests/distributions/test_partially_wrapped_normal_distribution.py +++ b/pyrecest/tests/distributions/test_partially_wrapped_normal_distribution.py @@ -1,10 +1,14 @@ import unittest +from math import pi import numpy.testing as npt + +# pylint: disable=no-name-in-module,no-member +import pyrecest.backend import scipy.linalg # pylint: disable=no-name-in-module,no-member -from pyrecest.backend import array, ones +from pyrecest.backend import array, column_stack, diag, linspace, meshgrid from pyrecest.distributions.cart_prod.partially_wrapped_normal_distribution import ( PartiallyWrappedNormalDistribution, ) @@ -16,8 +20,96 @@ def setUp(self) -> None: self.C = array([[2.0, 1.0], [1.0, 1.0]]) self.dist_2d = PartiallyWrappedNormalDistribution(self.mu, self.C, 1) - def test_pdf(self): - self.assertEqual(self.dist_2d.pdf(ones((10, 2))).shape, (10,)) + @unittest.skipIf( + pyrecest.backend.__backend_name__ in ("pytorch", "jax"), + reason="Not supported on this backend", + ) + def test_pdf_2d(self): + expected_vals = array( + [ # From Matlab implementation + 0.00719442236938856, + 0.0251110014500013, + 0.0531599904868136, + 0.0682587789359472, + 0.0531599904868136, + 0.0100784602259792, + 0.0351772826718058, + 0.0744703080006016, + 0.0956217682613369, + 0.0744703080006016, + 0.00119956714181477, + 0.00418690072543581, + 0.00886366890530323, + 0.0113811761595142, + 0.00886366890530323, + 0.000447592726560109, + 0.00156225212096022, + 0.00330728776602597, + 0.00424664155187776, + 0.00330728776602597, + 0.00719442236938856, + 0.0251110014500013, + 0.0531599904868136, + 0.0682587789359472, + 0.0531599904868136, + ] + ) + + hwn = PartiallyWrappedNormalDistribution( + array([1.0, 2.0]), diag(array([1.0, 2.0])), 1 + ) + x, y = meshgrid(linspace(0.0, 2.0 * pi, 5), linspace(-1.0, 3.0, 5)) + points = column_stack([x.T.ravel(), y.T.ravel()]) + npt.assert_allclose(hwn.pdf(points), expected_vals, atol=1e-7) + + @unittest.skipIf( + pyrecest.backend.__backend_name__ in ("pytorch", "jax"), + reason="Not supported on this backend", + ) + def test_pdf_3d(self): + expected_vals = array( + [ + 1.385492786310657e-07, + 1.473370096411339e-05, + 4.130095787341451e-04, + 2.310111272798270e-08, + 2.456634132168293e-06, + 6.886344649603137e-05, + 1.385492786310657e-07, + 1.473370096411339e-05, + 4.130095787341451e-04, + 2.650620509537817e-07, + 2.818740764494994e-05, + 7.901388378523353e-04, + 4.419530999723758e-08, + 4.699847505157542e-06, + 1.317443623260505e-04, + 2.650620509537817e-07, + 2.818740764494994e-05, + 7.901388378523353e-04, + 1.385492786310657e-07, + 1.473370096411339e-05, + 4.130095787341451e-04, + 2.310111272798270e-08, + 2.456634132168293e-06, + 6.886344649603137e-05, + 1.385492786310657e-07, + 1.473370096411339e-05, + 4.130095787341451e-04, + ] + ) + + hwn = PartiallyWrappedNormalDistribution( + array([1.0, 2.0, 7.0]), diag(array([1.0, 2.0, 3.0])), 2 + ) + x, y, z = meshgrid( + linspace(0.0, 2.0 * pi, 3), + linspace(0.0, 2.0 * pi, 3), + linspace(-1.0, 3.0, 3), + ) + points = column_stack([x.ravel(), y.ravel(), z.ravel()]) + npt.assert_allclose(hwn.pdf(points), expected_vals, atol=1e-7) + def test_hybrid_mean_2d(self): npt.assert_allclose(self.dist_2d.hybrid_mean(), self.mu)