Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import copy
from typing import Union
from itertools import product
# pylint: disable=no-name-in-module,no-member
import pyrecest.backend

# pylint: disable=redefined-builtin,no-name-in-module,no-member
# pylint: disable=no-name-in-module,no-member
from pyrecest.backend import (
allclose,
arange,
Expand All @@ -11,12 +13,12 @@
concatenate,
cos,
diag,
empty,
exp,
hstack,
int32,
int64,
linalg,
meshgrid,
mod,
ndim,
pi,
Expand All @@ -25,8 +27,8 @@
sin,
stack,
sum,
tile,
where,
zeros,
)
from scipy.stats import multivariate_normal

Expand Down Expand Up @@ -56,49 +58,95 @@ def __init__(self, mu, C, bound_dim: Union[int, int32, int64]):
self.mu = where(arange(mu.shape[0]) < bound_dim, mod(mu, 2.0 * pi), mu)
self.C = C

def pdf(self, xs, m: Union[int, int32, int64] = 3):
xs = atleast_2d(xs)
condition = (
arange(xs.shape[1]) < self.bound_dim
) # Create a condition based on column indices
xs = where(
# Broadcast the condition to match the shape of xs
condition[None, :], # noqa: E203
mod(xs, 2.0 * pi), # Compute the modulus where the condition is True
xs, # Keep the original values where the condition is False
)

assert xs.shape[-1] == self.input_dim

# generate multiples for wrapping
multiples = array(range(-m, m + 1)) * 2.0 * pi

# create meshgrid for all combinations of multiples
mesh = array(meshgrid(*[multiples] * self.bound_dim)).reshape(
-1, self.bound_dim
)

# reshape xs for broadcasting
xs_reshaped = tile(xs[:, : self.bound_dim], (mesh.shape[0], 1)) # noqa: E203

# prepare data for wrapping (not applied to linear dimensions)
xs_wrapped = xs_reshaped + repeat(mesh, xs.shape[0], axis=0)
xs_wrapped = concatenate(
[
xs_wrapped,
tile(xs[:, self.bound_dim :], (mesh.shape[0], 1)), # noqa: E203
],
axis=1,
)
# pylint: disable=too-many-locals
def pdf(self, xs, m=3):
"""
Evaluate the PDF of the Hypercylindrical Wrapped Normal Distribution at given points.

# evaluate normal for all xs_wrapped
mvn = multivariate_normal(self.mu, self.C)
evals = array(mvn.pdf(xs_wrapped)) # For being compatible with all backends
Parameters:
xs (array-like): Input points of shape (n, d), where d = bound_dim + lin_dim.
m (int, optional): Number of summands in each direction for wrapping. Default is 3.

# sum evaluations for the wrapped dimensions
summed_evals = sum(evals.reshape(-1, (2 * m + 1) ** self.bound_dim), axis=1)
Returns:
p (ndarray): PDF values at each input point of shape (n,).
"""
assert (
pyrecest.backend.__backend_name__ == "numpy"
), "Only supported for numpy backend"

return summed_evals
xs = atleast_2d(xs) # Ensure xs is 2D
n, d = xs.shape
assert (
d == self.dim
), f"Input dimensionality {d} does not match distribution dimensionality {self.dim}."

# Initialize the PDF values array
p = zeros(n)

# Define batch size to manage memory usage
batch_size = 1000

# Generate all possible offset combinations for periodic dimensions
multiples = arange(-m, m + 1) * 2.0 * pi
offset_combinations = list(
product(multiples, repeat=self.bound_dim)
) # Total combinations: (2m+1)^bound_dim
num_offsets = len(offset_combinations)

# Pre-convert offset combinations to a NumPy array for efficient computation
offset_array = array(offset_combinations) # Shape: (num_offsets, bound_dim)

# Process input data in batches
for start in range(0, n, batch_size):
end = min(start + batch_size, n)
batch = xs[start:end] # Shape: (batch_size, d)

# Wrap periodic dimensions using modulus
batch_wrapped = batch.copy()
if self.bound_dim > 0:
batch_wrapped[:, : self.bound_dim] = mod(
batch_wrapped[:, : self.bound_dim], 2.0 * pi
) # noqa: E203

if self.bound_dim > 0:
# Correct broadcasting: batch_wrapped becomes (batch_size, 1, bound_dim)
# offset_array becomes (1, num_offsets, bound_dim)
wrapped_periodic = batch_wrapped[:, :self.bound_dim][:, None, :] + offset_array[None, :, :]
# Now wrapped_periodic has shape (batch_size, num_offsets, bound_dim)
wrapped_periodic = wrapped_periodic.reshape(-1, self.bound_dim)
else:
wrapped_periodic = empty((0, 0)) # No periodic dimensions

# Repeat linear dimensions for each offset
if self.lin_dim > 0:
linear_part = repeat(
batch_wrapped[:, self.bound_dim :], # noqa: E203
num_offsets,
axis=0,
) # Shape: (batch_size * num_offsets, lin_dim)
# Concatenate wrapped periodic and linear parts
if self.bound_dim > 0:
wrapped_points = hstack(
(wrapped_periodic, linear_part)
) # Shape: (batch_size * num_offsets, d)
else:
wrapped_points = linear_part # Shape: (batch_size * num_offsets, d)
else:
wrapped_points = (
wrapped_periodic # Shape: (batch_size * num_offsets, d)
)

mvn = multivariate_normal(mean=self.mu, cov=self.C)
# Evaluate the multivariate normal PDF at all wrapped points
pdf_vals = mvn.pdf(wrapped_points) # Shape: (batch_size * num_offsets,)

# Reshape and sum the PDF values for each original point
pdf_vals = pdf_vals.reshape(
end - start, num_offsets
) # Shape: (batch_size, num_offsets)
p[start:end] = sum(pdf_vals, axis=1) # Shape: (batch_size,)

return p

def mode(self):
"""
Expand Down Expand Up @@ -148,7 +196,7 @@ def sample(self, n: int):
"""
assert n > 0, "n must be positive"
s = random.multivariate_normal(mean=self.mu, cov=self.C, size=(n,))
wrapped_values = mod(s[:, : self.bound_dim], 2.0 * pi)
wrapped_values = mod(s[:, : self.bound_dim], 2.0 * pi) # noqa: E203
unbounded_values = s[:, self.bound_dim :] # noqa: E203

# Concatenate the modified section with the unmodified section
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,19 @@

class AbstractHypercylindricalDistributionTest(unittest.TestCase):
@unittest.skipIf(
pyrecest.backend.__backend_name__ == "jax",
reason="Not supported on jax backend",
pyrecest.backend.__backend_name__ in ("pytorch", "jax"),
reason="Not supported on this backend",
)
def test_mode_numerical_gaussian_2D(self):
mu = array([5.0, 1.0])
C = array([[2.0, 1.0], [1.0, 1.0]])
g = PartiallyWrappedNormalDistribution(mu, C, 1)
npt.assert_allclose(g.mode_numerical(), mu, atol=5e-5)


@unittest.skipIf(
pyrecest.backend.__backend_name__ in ("pytorch", "jax"),
reason="Not supported on this backend",
)
def test_linear_mean_numerical(self):
hwn = PartiallyWrappedNormalDistribution(
array([1.0, 2.0]), array([[2.0, 0.3], [0.3, 1.0]]), 1
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import unittest
from math import pi

import numpy.testing as npt

# pylint: disable=no-name-in-module,no-member
import pyrecest.backend
import scipy.linalg

# pylint: disable=no-name-in-module,no-member
from pyrecest.backend import array, ones
from pyrecest.backend import array, column_stack, diag, linspace, meshgrid
from pyrecest.distributions.cart_prod.partially_wrapped_normal_distribution import (
PartiallyWrappedNormalDistribution,
)
Expand All @@ -16,8 +20,96 @@ def setUp(self) -> None:
self.C = array([[2.0, 1.0], [1.0, 1.0]])
self.dist_2d = PartiallyWrappedNormalDistribution(self.mu, self.C, 1)

def test_pdf(self):
self.assertEqual(self.dist_2d.pdf(ones((10, 2))).shape, (10,))
@unittest.skipIf(
pyrecest.backend.__backend_name__ in ("pytorch", "jax"),
reason="Not supported on this backend",
)
def test_pdf_2d(self):
expected_vals = array(
[ # From Matlab implementation
0.00719442236938856,
0.0251110014500013,
0.0531599904868136,
0.0682587789359472,
0.0531599904868136,
0.0100784602259792,
0.0351772826718058,
0.0744703080006016,
0.0956217682613369,
0.0744703080006016,
0.00119956714181477,
0.00418690072543581,
0.00886366890530323,
0.0113811761595142,
0.00886366890530323,
0.000447592726560109,
0.00156225212096022,
0.00330728776602597,
0.00424664155187776,
0.00330728776602597,
0.00719442236938856,
0.0251110014500013,
0.0531599904868136,
0.0682587789359472,
0.0531599904868136,
]
)

hwn = PartiallyWrappedNormalDistribution(
array([1.0, 2.0]), diag(array([1.0, 2.0])), 1
)
x, y = meshgrid(linspace(0.0, 2.0 * pi, 5), linspace(-1.0, 3.0, 5))
points = column_stack([x.T.ravel(), y.T.ravel()])
npt.assert_allclose(hwn.pdf(points), expected_vals, atol=1e-7)

@unittest.skipIf(
pyrecest.backend.__backend_name__ in ("pytorch", "jax"),
reason="Not supported on this backend",
)
def test_pdf_3d(self):
expected_vals = array(
[
1.385492786310657e-07,
1.473370096411339e-05,
4.130095787341451e-04,
2.310111272798270e-08,
2.456634132168293e-06,
6.886344649603137e-05,
1.385492786310657e-07,
1.473370096411339e-05,
4.130095787341451e-04,
2.650620509537817e-07,
2.818740764494994e-05,
7.901388378523353e-04,
4.419530999723758e-08,
4.699847505157542e-06,
1.317443623260505e-04,
2.650620509537817e-07,
2.818740764494994e-05,
7.901388378523353e-04,
1.385492786310657e-07,
1.473370096411339e-05,
4.130095787341451e-04,
2.310111272798270e-08,
2.456634132168293e-06,
6.886344649603137e-05,
1.385492786310657e-07,
1.473370096411339e-05,
4.130095787341451e-04,
]
)

hwn = PartiallyWrappedNormalDistribution(
array([1.0, 2.0, 7.0]), diag(array([1.0, 2.0, 3.0])), 2
)
x, y, z = meshgrid(
linspace(0.0, 2.0 * pi, 3),
linspace(0.0, 2.0 * pi, 3),
linspace(-1.0, 3.0, 3),
)
points = column_stack([x.ravel(), y.ravel(), z.ravel()])
npt.assert_allclose(hwn.pdf(points), expected_vals, atol=1e-7)


def test_hybrid_mean_2d(self):
npt.assert_allclose(self.dist_2d.hybrid_mean(), self.mu)
Expand Down
Loading