From 35bba26a1c7ef22bc80892cadf126ff80867e929 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Mon, 22 Dec 2025 14:24:06 -0500 Subject: [PATCH 01/33] added generic observation processes and unit tests --- pyrenew/observation/__init__.py | 23 ++ pyrenew/observation/base.py | 324 +++++++++++++++++ pyrenew/observation/count_observations.py | 327 +++++++++++++++++ pyrenew/observation/measurements.py | 133 +++++++ pyrenew/observation/noise.py | 344 ++++++++++++++++++ pyrenew/randomvariable/__init__.py | 8 + pyrenew/randomvariable/hierarchical.py | 336 ++++++++++++++++++ test/conftest.py | 406 ++++++++++++++++++++++ test/test_observation_counts.py | 378 ++++++++++++++++++++ test/test_observation_measurements.py | 234 +++++++++++++ test/test_randomvariable_hierarchical.py | 179 ++++++++++ 11 files changed, 2692 insertions(+) create mode 100644 pyrenew/observation/base.py create mode 100644 pyrenew/observation/count_observations.py create mode 100644 pyrenew/observation/measurements.py create mode 100644 pyrenew/observation/noise.py create mode 100644 pyrenew/randomvariable/hierarchical.py create mode 100644 test/conftest.py create mode 100644 test/test_observation_counts.py create mode 100644 test/test_observation_measurements.py create mode 100644 test/test_randomvariable_hierarchical.py diff --git a/pyrenew/observation/__init__.py b/pyrenew/observation/__init__.py index b0e04e69..46cfbb67 100644 --- a/pyrenew/observation/__init__.py +++ b/pyrenew/observation/__init__.py @@ -1,9 +1,32 @@ # numpydoc ignore=GL08 +from pyrenew.observation.base import BaseObservationProcess +from pyrenew.observation.count_observations import Counts, CountsBySite +from pyrenew.observation.measurements import Measurements from pyrenew.observation.negativebinomial import NegativeBinomialObservation +from pyrenew.observation.noise import ( + CountNoise, + HierarchicalNormalNoise, + MeasurementNoise, + NegativeBinomialNoise, + PoissonNoise, +) from pyrenew.observation.poisson import PoissonObservation __all__ = [ + # Existing (kept for backward compatibility) "NegativeBinomialObservation", "PoissonObservation", + # New base classes + "BaseObservationProcess", + # New noise models + "CountNoise", + "PoissonNoise", + "NegativeBinomialNoise", + "MeasurementNoise", + "HierarchicalNormalNoise", + # New observation processes + "Counts", + "CountsBySite", + "Measurements", ] diff --git a/pyrenew/observation/base.py b/pyrenew/observation/base.py new file mode 100644 index 00000000..d84125a1 --- /dev/null +++ b/pyrenew/observation/base.py @@ -0,0 +1,324 @@ +# numpydoc ignore=GL08 +""" +Abstract base class for observation processes. + +Provides common functionality for observation processes that use convolution +with temporal distributions to connect infections to observed data. +""" + +from __future__ import annotations + +from abc import abstractmethod + +import jax.numpy as jnp +import numpyro +from jax.typing import ArrayLike + +from pyrenew.convolve import compute_delay_ascertained_incidence +from pyrenew.metaclass import RandomVariable + + +class BaseObservationProcess(RandomVariable): + """ + Abstract base class for observation processes that use convolution + with temporal distributions. + + This class provides common functionality for connecting infections + to observed data (e.g., hospital admissions, wastewater concentrations) + through temporal convolution operations. + + Key features provided: + + - PMF validation (sum to 1, non-negative) + - Minimum observation day calculation + - Convolution wrapper with timeline alignment + - Deterministic quantity tracking + + Notes + ----- + Computing expected observations on day t requires infection history + from previous days (determined by the temporal PMF length). + The first ``len(pmf) - 1`` days have insufficient history and return NaN. + + See Also + -------- + pyrenew.convolve.compute_delay_ascertained_incidence : + Underlying convolution function + pyrenew.metaclass.RandomVariable : + Base class for all random variables + """ + + def __init__(self, temporal_pmf_rv: RandomVariable) -> None: + """ + Initialize base observation process. + + Parameters + ---------- + temporal_pmf_rv : RandomVariable + The temporal distribution PMF (e.g., delay or shedding distribution). + Must sample to a 1D array that sums to ~1.0 with non-negative values. + Subclasses may have additional parameters. + + Notes + ----- + Subclasses should call ``super().__init__(temporal_pmf_rv)`` + in their constructors and may add additional parameters. + """ + self.temporal_pmf_rv = temporal_pmf_rv + + @abstractmethod + def validate(self) -> None: + """ + Validate observation process parameters. + + Subclasses must implement this method to validate all parameters. + Typically this involves calling ``_validate_pmf()`` for the PMF + and adding any additional parameter-specific validation. + + Raises + ------ + ValueError + If any parameters fail validation. + """ + pass + + @abstractmethod + def get_required_lookback(self) -> int: + """ + Return the number of days this observation process needs to look back. + + This determines the minimum n_initialization_points required by the + latent process when this observation is included in a multi-signal model. + + Returns + ------- + int + Number of days of infection history required. + Typically the length of the delay or shedding PMF. + + Notes + ----- + This is used by model builders to automatically compute + n_initialization_points as: + ``max(gen_int_length, max(all lookbacks)) - 1`` + """ + pass + + @abstractmethod + def infection_resolution(self) -> str: + """ + Return the resolution of infections this observation uses. + + Returns one of: + + - ``"jurisdiction"``: Uses jurisdiction-level aggregated infections + - ``"site"``: Uses site-level disaggregated infections + + Returns + ------- + str + Either ``"jurisdiction"`` or ``"site"`` + + Examples + -------- + >>> # Aggregated count observations use jurisdiction-level + >>> hosp_obs.infection_resolution() # Returns "jurisdiction" + >>> + >>> # Wastewater uses site-level + >>> ww_obs.infection_resolution() # Returns "site" + + Notes + ----- + This is used by multi-signal models to route the correct infection + output to each observation process. + """ + pass + + def _validate_pmf( + self, + pmf: ArrayLike, + param_name: str, + atol: float = 1e-6, + ) -> None: + """ + Validate that an array is a valid probability mass function. + + Checks: + + - Non-empty array + - Sums to 1.0 (within tolerance) + - All non-negative values + + Parameters + ---------- + pmf : ArrayLike + The PMF array to validate + param_name : str + Name of the parameter (for error messages) + atol : float, default 1e-6 + Absolute tolerance for sum-to-one check + + Raises + ------ + ValueError + If PMF is empty, doesn't sum to 1.0 (within tolerance), + or contains negative values. + """ + if pmf.size == 0: + raise ValueError(f"{param_name} must return non-empty array") + + pmf_sum = jnp.sum(pmf) + if not jnp.isclose(pmf_sum, 1.0, atol=atol): + raise ValueError( + f"{param_name} must sum to 1.0 (±{atol}), got {float(pmf_sum):.6f}" + ) + + if jnp.any(pmf < 0): + raise ValueError(f"{param_name} must have non-negative values") + + def get_minimum_observation_day(self) -> int: + """ + Get the first day with valid (non-NaN) convolution results. + + Due to the convolution operation requiring a history window, + the first ``len(pmf) - 1`` days will have NaN values in the + output. This method returns the index of the first valid day. + + Returns + ------- + int + Day index (0-based) of first valid observation. + Equal to ``len(pmf) - 1``. + """ + pmf = self.temporal_pmf_rv() + return int(len(pmf) - 1) + + def _convolve_with_alignment( + self, + latent_incidence: ArrayLike, + pmf: ArrayLike, + p_observed: float = 1.0, + ) -> tuple[ArrayLike, int]: + """ + Convolve latent incidence with PMF while maintaining timeline alignment. + + This is a wrapper around ``compute_delay_ascertained_incidence`` that + always uses ``pad=True`` to ensure day t in the output corresponds to + day t in the input. The first ``len(pmf) - 1`` days will be NaN. + + Parameters + ---------- + latent_incidence : ArrayLike + Latent incidence time series (infections, prevalence, etc.). + Shape: (n_days,) + pmf : ArrayLike + Delay or shedding PMF. Shape: (n_pmf,) + p_observed : float, default 1.0 + Observation probability multiplier. Scales the convolution result. + + Returns + ------- + tuple[ArrayLike, int] + - convolved_array : ArrayLike + Convolved time series with same length as input. + First ``len(pmf) - 1`` days are NaN. + Shape: (n_days,) + - offset : int + Always 0 when pad=True (maintained for API compatibility) + + Notes + ----- + For t < len(pmf)-1, there is insufficient history, so output[t] = NaN. + + See Also + -------- + pyrenew.convolve.compute_delay_ascertained_incidence : + Underlying function + """ + return compute_delay_ascertained_incidence( + latent_incidence=latent_incidence, + delay_incidence_to_observation_pmf=pmf, + p_observed_given_incident=p_observed, + pad=True, # Maintains timeline alignment + ) + + def _deterministic(self, name: str, value: ArrayLike) -> None: + """ + Track a deterministic quantity in the numpyro execution trace. + + This is a convenience wrapper around ``numpyro.deterministic`` for + tracking intermediate quantities (e.g., latent admissions, expected + concentrations) that are useful for diagnostics and model checking. + These quantities are stored in MCMC samples and can be used for + model diagnostics and posterior predictive checks. + + Parameters + ---------- + name : str + Name for the tracked quantity. Will appear in MCMC samples. + value : ArrayLike + Value to track. Can be any shape. + """ + numpyro.deterministic(name, value) + + @abstractmethod + def _expected_signal( + self, + infections: ArrayLike, + ) -> ArrayLike: + """ + Transform infections to expected observation values. + + This is the core transformation that each observation process must + implement. It converts infections (from the infection process) + to expected values for the observation model. + + Parameters + ---------- + infections : ArrayLike + Infections from the infection process. + Shape: (n_days,) for jurisdiction-level observations + Shape: (n_days, n_sites) for site-level observations + + Returns + ------- + ArrayLike + Expected observation values (counts, log-concentrations, etc.). + Same shape as input, with first len(pmf)-1 days as NaN. + + Notes + ----- + The transformation is observation-specific: + + - Count observations: ascertainment x delay convolution -> expected counts + - Wastewater: shedding convolution -> genome scaling -> dilution -> log + + See Also + -------- + sample : Uses this method then applies noise model + """ + pass + + @abstractmethod + def sample(self, **kwargs) -> ArrayLike: + """ + Sample from the observation process. + + Subclasses must implement this method to define the specific + observation model. Typically calls ``_expected_signal`` first, + then applies the noise model. + + Parameters + ---------- + **kwargs + Subclass-specific parameters. At minimum, should include: + + - infections from the infection process + - Observed data (or None for prior predictive sampling) + + Returns + ------- + ArrayLike + Observed or sampled values from the observation process. + """ + pass diff --git a/pyrenew/observation/count_observations.py b/pyrenew/observation/count_observations.py new file mode 100644 index 00000000..437c37e6 --- /dev/null +++ b/pyrenew/observation/count_observations.py @@ -0,0 +1,327 @@ +# numpydoc ignore=GL08 +""" +Count observations with composable noise models. + +Ascertainment x delay convolution with pluggable noise (Poisson, Negative Binomial, etc.). +""" + +from __future__ import annotations + +import jax +import jax.numpy as jnp +from jax.typing import ArrayLike + +from pyrenew.metaclass import RandomVariable +from pyrenew.observation.base import BaseObservationProcess +from pyrenew.observation.noise import CountNoise + + +class _CountBase(BaseObservationProcess): + """ + Internal base for count observation processes. + + Implements ascertainment x delay convolution with pluggable noise model. + """ + + def __init__( + self, + ascertainment_rate_rv: RandomVariable, + delay_distribution_rv: RandomVariable, + noise: CountNoise, + ) -> None: + """ + Initialize count observation base. + + Parameters + ---------- + ascertainment_rate_rv : RandomVariable + Ascertainment rate in [0, 1] (e.g., IHR, IER). + delay_distribution_rv : RandomVariable + Delay distribution PMF (must sum to ~1.0). + noise : CountNoise + Noise model for count observations (Poisson, NegBin, etc.). + """ + super().__init__(temporal_pmf_rv=delay_distribution_rv) + self.ascertainment_rate_rv = ascertainment_rate_rv + self.noise = noise + + def validate(self) -> None: + """ + Validate observation parameters. + + Raises + ------ + ValueError + If delay PMF invalid, ascertainment rate outside [0,1], + or noise params invalid. + """ + delay_pmf = self.temporal_pmf_rv() + self._validate_pmf(delay_pmf, "delay_distribution_rv") + + ascertainment_rate = self.ascertainment_rate_rv() + if jnp.any(ascertainment_rate < 0) or jnp.any(ascertainment_rate > 1): + raise ValueError( + "ascertainment_rate_rv must be in [0, 1], " + "got value(s) outside this range" + ) + + self.noise.validate() + + def get_required_lookback(self) -> int: + """ + Return delay PMF length. + + Returns + ------- + int + Length of delay distribution PMF. + """ + return len(self.temporal_pmf_rv()) + + def infection_resolution(self) -> str: + """ + Return required infection resolution. + + Returns + ------- + str + "jurisdiction" for aggregated, "site" for disaggregated. + """ + raise NotImplementedError("Subclasses must implement infection_resolution()") + + def _expected_signal( + self, + infections: ArrayLike, + ) -> ArrayLike: + """ + Compute expected counts via ascertainment x delay convolution. + + Parameters + ---------- + infections : ArrayLike + Infections from the infection process. + Shape: (n_days,) for jurisdiction-level + Shape: (n_days, n_sites) for site-level + + Returns + ------- + ArrayLike + Expected counts with timeline alignment. + Same shape as input. + First len(delay_pmf)-1 days are NaN. + """ + delay_pmf = self.temporal_pmf_rv() + ascertainment_rate = self.ascertainment_rate_rv() + + is_1d = infections.ndim == 1 + if is_1d: + infections = infections[:, jnp.newaxis] + + def convolve_col(col): # numpydoc ignore=GL08 + return self._convolve_with_alignment(col, delay_pmf, ascertainment_rate)[0] + + expected_counts = jax.vmap(convolve_col, in_axes=1, out_axes=1)(infections) + + return expected_counts[:, 0] if is_1d else expected_counts + + +class Counts(_CountBase): + """ + Aggregated count observation for jurisdiction-level data. + + Maps jurisdiction-level infections to aggregated counts through + ascertainment x delay convolution with composable noise model. + + Parameters + ---------- + ascertainment_rate_rv : RandomVariable + Ascertainment rate in [0, 1] (e.g., IHR, IER). + delay_distribution_rv : RandomVariable + Delay distribution PMF (must sum to ~1.0). + noise : CountNoise + Noise model (PoissonNoise, NegativeBinomialNoise, etc.). + + Notes + ----- + Output preserves input timeline. First len(delay_pmf)-1 days return + -1 or ~0 (depending on noise model) due to NaN padding. + + Examples + -------- + >>> from pyrenew.deterministic import DeterministicVariable, DeterministicPMF + >>> from pyrenew.observation import Counts, NegativeBinomialNoise + >>> import jax.numpy as jnp + >>> import numpyro + >>> + >>> delay_pmf = jnp.array([0.2, 0.5, 0.3]) + >>> counts_obs = Counts( + ... ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + ... delay_distribution_rv=DeterministicPMF("delay", delay_pmf), + ... noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), + ... ) + >>> + >>> with numpyro.handlers.seed(rng_seed=42): + ... infections = jnp.ones(30) * 1000 + ... sampled_counts = counts_obs.sample(infections=infections, counts=None) + """ + + def infection_resolution(self) -> str: + """ + Return "jurisdiction" for aggregated observations. + + Returns + ------- + str + The string "jurisdiction". + """ + return "jurisdiction" + + def sample( + self, + infections: ArrayLike, + counts: ArrayLike | None = None, + times: ArrayLike | None = None, + ) -> ArrayLike: + """ + Sample aggregated counts with dense or sparse observations. + + Validation is performed before JAX tracing at runtime, + prior to calling this method. + + Parameters + ---------- + infections : ArrayLike + Jurisdiction-level infections from the infection process. + Shape: (n_days,) + counts : ArrayLike | None + Observed counts. Dense: (n_days,), Sparse: (n_obs,), None: prior. + times : ArrayLike | None + Day indices for sparse observations. None for dense observations. + + Returns + ------- + ArrayLike + Observed or sampled counts. + Dense: (n_days,), Sparse: (n_obs,) + """ + expected_counts = self._expected_signal(infections) + self._deterministic("expected_counts", expected_counts) + expected_counts_safe = jnp.nan_to_num(expected_counts, nan=0.0) + + # Only use sparse indexing when conditioning on observations + if times is not None and counts is not None: + expected_obs = expected_counts_safe[times] + else: + expected_obs = expected_counts_safe + + return self.noise.sample( + name="counts", + expected=expected_obs, + obs=counts, + ) + + +class CountsBySite(_CountBase): + """ + Disaggregated count observation for site-specific data. + + Maps site-level infections to site-specific counts through + ascertainment x delay convolution with composable noise model. + + Parameters + ---------- + ascertainment_rate_rv : RandomVariable + Ascertainment rate in [0, 1]. + delay_distribution_rv : RandomVariable + Delay distribution PMF (must sum to ~1.0). + noise : CountNoise + Noise model (PoissonNoise, NegativeBinomialNoise, etc.). + + Notes + ----- + Output preserves input timeline. First len(delay_pmf)-1 days are NaN. + + Examples + -------- + >>> from pyrenew.deterministic import DeterministicVariable, DeterministicPMF + >>> from pyrenew.observation import CountsBySite, PoissonNoise + >>> import jax.numpy as jnp + >>> import numpyro + >>> + >>> delay_pmf = jnp.array([0.3, 0.4, 0.3]) + >>> counts_obs = CountsBySite( + ... ascertainment_rate_rv=DeterministicVariable("ihr", 0.02), + ... delay_distribution_rv=DeterministicPMF("delay", delay_pmf), + ... noise=PoissonNoise(), + ... ) + >>> + >>> with numpyro.handlers.seed(rng_seed=42): + ... infections = jnp.ones((30, 3)) * 500 # 30 days, 3 sites + ... times = jnp.array([10, 15, 10, 15]) + ... subpop_indices = jnp.array([0, 0, 1, 1]) + ... sampled = counts_obs.sample( + ... infections=infections, + ... subpop_indices=subpop_indices, + ... times=times, + ... counts=None, + ... ) + """ + + def infection_resolution(self) -> str: + """ + Return "site" for disaggregated observations. + + Returns + ------- + str + The string "site". + """ + return "site" + + def sample( + self, + infections: ArrayLike, + subpop_indices: ArrayLike, + times: ArrayLike, + counts: ArrayLike | None = None, + ) -> ArrayLike: + """ + Sample disaggregated counts with flexible indexing. + + Validation is performed before JAX tracing at runtime, + prior to calling this method. + + Parameters + ---------- + infections : ArrayLike + Site-level infections from the infection process. + Shape: (n_days, n_sites) + subpop_indices : ArrayLike + Subpopulation index for each observation (0-indexed). + Shape: (n_obs,) + times : ArrayLike + Day index for each observation (0-indexed). + Shape: (n_obs,) + counts : ArrayLike | None + Observed counts (n_obs,), or None for prior sampling. + + Returns + ------- + ArrayLike + Observed or sampled counts. + Shape: (n_obs,) + """ + # Compute expected counts for all sites + expected_counts_all = self._expected_signal(infections) + + self._deterministic("expected_counts_by_site", expected_counts_all) + + # Replace NaN padding with 0 for distribution creation + expected_counts_safe = jnp.nan_to_num(expected_counts_all, nan=0.0) + expected_obs = expected_counts_safe[times, subpop_indices] + + return self.noise.sample( + name="counts_by_site", + expected=expected_obs, + obs=counts, + ) diff --git a/pyrenew/observation/measurements.py b/pyrenew/observation/measurements.py new file mode 100644 index 00000000..65a95fa9 --- /dev/null +++ b/pyrenew/observation/measurements.py @@ -0,0 +1,133 @@ +# numpydoc ignore=GL08 +""" +Continuous measurement observation processes. + +Abstract base for any population-level continuous measurements (wastewater, +air quality, serology, etc.) with signal-specific processing. +""" + +from jax.typing import ArrayLike + +from pyrenew.metaclass import RandomVariable +from pyrenew.observation.base import BaseObservationProcess +from pyrenew.observation.noise import MeasurementNoise + + +class Measurements(BaseObservationProcess): + """ + Abstract base for continuous measurement observations. + + Subclasses implement signal-specific transformations from infections + to expected measurement values, then add measurement noise. + + Parameters + ---------- + temporal_pmf_rv : RandomVariable + Temporal distribution PMF (e.g., shedding kinetics for wastewater). + noise : MeasurementNoise + Noise model for continuous measurements + (e.g., HierarchicalNormalNoise). + + Notes + ----- + Subclasses must implement ``_expected_signal()`` according to their + specific signal processing (e.g., wastewater shedding kinetics, + dilution factors, etc.). + + See Also + -------- + pyrenew.observation.noise.HierarchicalNormalNoise : + Suitable noise model for site-level measurements + pyrenew.observation.base.BaseObservationProcess : + Parent class with common observation utilities + """ + + def __init__( + self, + temporal_pmf_rv: RandomVariable, + noise: MeasurementNoise, + ) -> None: + """ + Initialize measurement observation base. + + Parameters + ---------- + temporal_pmf_rv : RandomVariable + Temporal distribution PMF (e.g., shedding kinetics). + noise : MeasurementNoise + Noise model (e.g., HierarchicalNormalNoise with site effects). + """ + super().__init__(temporal_pmf_rv=temporal_pmf_rv) + self.noise = noise + + def infection_resolution(self) -> str: + """ + Return "site" for measurement observations. + + Measurement observations require site-level infections + because each measurement site corresponds to a specific catchment. + + Returns + ------- + str + ``"site"`` + """ + return "site" + + def sample( + self, + infections: ArrayLike, + subpop_indices: ArrayLike, + site_indices: ArrayLike, + times: ArrayLike, + concentrations: ArrayLike | None, + n_sites: int, + ) -> ArrayLike: + """ + Sample measurements from observed sites. + + This method does not perform runtime validation of index values + (times, subpop_indices, site_indices). Validate observation data + before sampling. + + Transforms infections to expected values via signal-specific processing + (``_expected_signal``), then applies noise model. + + Parameters + ---------- + infections : ArrayLike + Infections from the infection process. + Shape: (n_days, n_sites) + subpop_indices : ArrayLike + Subpopulation index for each observation (0-indexed). + Shape: (n_obs,) + site_indices : ArrayLike + Site index for each observation (0-indexed). + Shape: (n_obs,) + times : ArrayLike + Day index for each observation (0-indexed). + Shape: (n_obs,) + concentrations : ArrayLike | None + Observed measurements (n_obs,), or None for prior sampling. + n_sites : int + Total number of measurement sites. + + Returns + ------- + ArrayLike + Observed or sampled measurements. + Shape: (n_obs,) + """ + expected_values = self._expected_signal(infections) + + self._deterministic("expected_log_conc", expected_values) + + expected_obs = expected_values[times, subpop_indices] + + return self.noise.sample( + name="concentrations", + expected=expected_obs, + obs=concentrations, + site_indices=site_indices, + n_sites=n_sites, + ) diff --git a/pyrenew/observation/noise.py b/pyrenew/observation/noise.py new file mode 100644 index 00000000..7264b726 --- /dev/null +++ b/pyrenew/observation/noise.py @@ -0,0 +1,344 @@ +# numpydoc ignore=GL08 +""" +Noise models for observation processes. + +Provides composable noise strategies for count and measurement observations, +separating the noise distribution from the observation structure. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod + +import jax.numpy as jnp +import numpyro +import numpyro.distributions as dist +from jax.typing import ArrayLike + +from pyrenew.metaclass import RandomVariable + +_EPSILON = 1e-10 + + +class CountNoise(ABC): + """ + Abstract base for count observation noise models. + + Defines how discrete count observations are distributed around expected values. + """ + + @abstractmethod + def sample( + self, + name: str, + expected: ArrayLike, + obs: ArrayLike | None = None, + ) -> ArrayLike: + """ + Sample count observations given expected counts. + + Parameters + ---------- + name : str + Numpyro sample site name. + expected : ArrayLike + Expected count values (non-negative). + obs : ArrayLike | None + Observed counts for conditioning, or None for prior sampling. + + Returns + ------- + ArrayLike + Sampled or conditioned counts, same shape as expected. + """ + pass + + @abstractmethod + def validate(self) -> None: + """ + Validate noise model parameters. + + Raises + ------ + ValueError + If parameters are invalid. + """ + pass + + +class PoissonNoise(CountNoise): + """ + Poisson noise for equidispersed counts (variance = mean). + """ + + def __init__(self) -> None: + """Initialize Poisson noise (no parameters).""" + pass + + def validate(self) -> None: + """Validate Poisson noise (always valid).""" + pass + + def sample( + self, + name: str, + expected: ArrayLike, + obs: ArrayLike | None = None, + ) -> ArrayLike: + """ + Sample from Poisson distribution. + + Parameters + ---------- + name : str + Numpyro sample site name. + expected : ArrayLike + Expected count values. + obs : ArrayLike | None + Observed counts for conditioning. + + Returns + ------- + ArrayLike + Poisson-distributed counts. + """ + return numpyro.sample( + name, + dist.Poisson(rate=expected + _EPSILON), + obs=obs, + ) + + +class NegativeBinomialNoise(CountNoise): + """ + Negative Binomial noise for overdispersed counts (variance > mean). + + Uses NB2 parameterization. Higher concentration reduces overdispersion. + + Parameters + ---------- + concentration_rv : RandomVariable + Concentration parameter (must be > 0). + Higher values reduce overdispersion. + + Notes + ----- + The NB2 parameterization has variance = mean + mean^2 / concentration. + As concentration -> infinity, this approaches Poisson. + """ + + def __init__(self, concentration_rv: RandomVariable) -> None: + """ + Initialize Negative Binomial noise. + + Parameters + ---------- + concentration_rv : RandomVariable + Concentration parameter (must be > 0). + Higher values reduce overdispersion. + """ + self.concentration_rv = concentration_rv + + def validate(self) -> None: + """ + Validate concentration is positive. + + Raises + ------ + ValueError + If concentration <= 0. + """ + concentration = self.concentration_rv() + if jnp.any(concentration <= 0): + raise ValueError( + f"NegativeBinomialNoise: concentration must be positive, " + f"got {float(concentration)}" + ) + + def sample( + self, + name: str, + expected: ArrayLike, + obs: ArrayLike | None = None, + ) -> ArrayLike: + """ + Sample from Negative Binomial distribution. + + Parameters + ---------- + name : str + Numpyro sample site name. + expected : ArrayLike + Expected count values. + obs : ArrayLike | None + Observed counts for conditioning. + + Returns + ------- + ArrayLike + Negative Binomial-distributed counts. + """ + concentration = self.concentration_rv() + return numpyro.sample( + name, + dist.NegativeBinomial2( + mean=expected + _EPSILON, + concentration=concentration, + ), + obs=obs, + ) + + +class MeasurementNoise(ABC): + """ + Abstract base for continuous measurement noise models. + + Defines how continuous observations are distributed around expected values. + """ + + @abstractmethod + def sample( + self, + name: str, + expected: ArrayLike, + obs: ArrayLike | None = None, + **kwargs, + ) -> ArrayLike: + """ + Sample continuous observations given expected values. + + Parameters + ---------- + name : str + Numpyro sample site name. + expected : ArrayLike + Expected measurement values. + obs : ArrayLike | None + Observed measurements for conditioning, or None for prior sampling. + **kwargs + Additional context (e.g., site indices). + + Returns + ------- + ArrayLike + Sampled or conditioned measurements, same shape as expected. + """ + pass + + @abstractmethod + def validate(self) -> None: + """ + Validate noise model parameters. + + Raises + ------ + ValueError + If parameters are invalid. + """ + pass + + +class HierarchicalNormalNoise(MeasurementNoise): + """ + Normal noise with hierarchical site-level effects. + + Observation model: ``obs ~ Normal(expected + site_mode, site_sd)`` + where site_mode and site_sd are hierarchically modeled. + + Parameters + ---------- + site_mode_prior_rv : RandomVariable + Hierarchical prior for site-level modes (log-scale biases). + Must support ``sample(n_groups=...)`` interface. + site_sd_prior_rv : RandomVariable + Hierarchical prior for site-level SDs (must be > 0). + Must support ``sample(n_groups=...)`` interface. + + Notes + ----- + Expects data already on log scale for wastewater applications. + + See Also + -------- + pyrenew.randomvariable.HierarchicalNormalPrior : + Suitable prior for site_mode_prior_rv + pyrenew.randomvariable.GammaGroupSdPrior : + Suitable prior for site_sd_prior_rv + """ + + def __init__( + self, + site_mode_prior_rv: RandomVariable, + site_sd_prior_rv: RandomVariable, + ) -> None: + """ + Initialize hierarchical Normal noise. + + Parameters + ---------- + site_mode_prior_rv : RandomVariable + Hierarchical prior for site-level modes (log-scale biases). + Must support ``sample(n_groups=...)`` interface. + site_sd_prior_rv : RandomVariable + Hierarchical prior for site-level SDs (must be > 0). + Must support ``sample(n_groups=...)`` interface. + """ + self.site_mode_prior_rv = site_mode_prior_rv + self.site_sd_prior_rv = site_sd_prior_rv + + def validate(self) -> None: + """ + Validate noise parameters. + + Notes + ----- + Full validation requires n_groups, which is only available during sample(). + """ + pass + + def sample( + self, + name: str, + expected: ArrayLike, + obs: ArrayLike | None = None, + *, + site_indices: ArrayLike, + n_sites: int, + ) -> ArrayLike: + """ + Sample from Normal distribution with site-level hierarchical effects. + + Parameters + ---------- + name : str + Numpyro sample site name. + expected : ArrayLike + Expected log-scale measurement values. + Shape: (n_obs,) + obs : ArrayLike | None + Observed log-scale measurements for conditioning. + Shape: (n_obs,) + site_indices : ArrayLike + Site index for each observation (0-indexed). + Shape: (n_obs,) + n_sites : int + Total number of sites. + + Returns + ------- + ArrayLike + Normal distributed measurements with hierarchical site effects. + Shape: (n_obs,) + + Raises + ------ + ValueError + If site_sd samples non-positive values. + """ + site_mode = self.site_mode_prior_rv.sample(n_groups=n_sites) + site_sd = self.site_sd_prior_rv.sample(n_groups=n_sites) + + loc = expected + site_mode[site_indices] + scale = site_sd[site_indices] + + return numpyro.sample(name, dist.Normal(loc=loc, scale=scale), obs=obs) diff --git a/pyrenew/randomvariable/__init__.py b/pyrenew/randomvariable/__init__.py index c599d101..75594c58 100644 --- a/pyrenew/randomvariable/__init__.py +++ b/pyrenew/randomvariable/__init__.py @@ -5,6 +5,11 @@ DynamicDistributionalVariable, StaticDistributionalVariable, ) +from pyrenew.randomvariable.hierarchical import ( + GammaGroupSdPrior, + HierarchicalNormalPrior, + StudentTGroupModePrior, +) from pyrenew.randomvariable.transformedvariable import TransformedVariable __all__ = [ @@ -12,4 +17,7 @@ "StaticDistributionalVariable", "DynamicDistributionalVariable", "TransformedVariable", + "HierarchicalNormalPrior", + "GammaGroupSdPrior", + "StudentTGroupModePrior", ] diff --git a/pyrenew/randomvariable/hierarchical.py b/pyrenew/randomvariable/hierarchical.py new file mode 100644 index 00000000..e97126d4 --- /dev/null +++ b/pyrenew/randomvariable/hierarchical.py @@ -0,0 +1,336 @@ +# numpydoc ignore=GL08 +""" +Hierarchical prior distributions for group-level random effects. + +These classes provide random variables that sample from hierarchical +distributions with a `sample(n_groups=...)` interface, enabling +dynamic group sizes at sample time with proper numpyro plate contexts. +""" + +import jax.numpy as jnp +import numpyro +import numpyro.distributions as dist + +from pyrenew.metaclass import RandomVariable + + +class HierarchicalNormalPrior(RandomVariable): + """ + Zero-centered Normal prior for group-level effects. + + Samples n_groups values from Normal(0, sd) within a numpyro plate context. + + Parameters + ---------- + name : str + Unique name for the sampled parameter in numpyro. + sd_rv : RandomVariable + RandomVariable returning the standard deviation. + + Notes + ----- + This class is designed for hierarchical models where group effects + are assumed to be drawn from a common distribution centered at zero. + The number of groups is specified at sample time, allowing dynamic + group sizes. + + Examples + -------- + >>> from pyrenew.deterministic import DeterministicVariable + >>> from pyrenew.randomvariable import HierarchicalNormalPrior + >>> import numpyro + >>> + >>> sd_rv = DeterministicVariable("sd", 0.5) + >>> prior = HierarchicalNormalPrior("site_effects", sd_rv) + >>> + >>> with numpyro.handlers.seed(rng_seed=42): + ... effects = prior.sample(n_groups=5) + >>> effects.shape + (5,) + """ + + def __init__( + self, + name: str, + sd_rv: RandomVariable, + ) -> None: + """ + Default constructor for HierarchicalNormalPrior. + + Parameters + ---------- + name : str + Unique name for the sampled parameter in numpyro. + sd_rv : RandomVariable + RandomVariable returning the standard deviation. + + Returns + ------- + None + """ + if not isinstance(sd_rv, RandomVariable): + raise TypeError( + f"sd_rv must be a RandomVariable, got {type(sd_rv).__name__}. " + "Use DeterministicVariable(name, value) to wrap a fixed value." + ) + + self.name = name + self.sd_rv = sd_rv + + def validate(self): + """Validate the random variable (no-op for this class).""" + pass + + def sample(self, n_groups: int, **kwargs): + """ + Sample group-level effects. + + Parameters + ---------- + n_groups : int + Number of groups. + **kwargs + Additional keyword arguments (unused). + + Returns + ------- + ArrayLike + Array of shape (n_groups,) containing sampled effects. + """ + sd = self.sd_rv() + + with numpyro.plate(f"n_{self.name}", n_groups): + effects = numpyro.sample( + self.name, + dist.Normal(0.0, sd), + ) + return effects + + +class GammaGroupSdPrior(RandomVariable): + """ + Gamma prior for group-level standard deviations, bounded away from zero. + + Samples n_groups positive values from Gamma(concentration, rate) + sd_min + within a numpyro plate context. + + Parameters + ---------- + name : str + Unique name for the sampled parameter in numpyro. + sd_mean_rv : RandomVariable + RandomVariable returning the mean of the Gamma distribution. + sd_concentration_rv : RandomVariable + RandomVariable returning the concentration (shape) parameter of Gamma. + sd_min : float, default=0.05 + Minimum SD value (lower bound). + + Notes + ----- + This class parameterizes Gamma by mean and concentration rather than + shape and rate, which is often more interpretable. The rate is computed as + concentration / mean. + + The sd_min floor prevents numerical issues when standard deviations + approach zero. + + Examples + -------- + >>> from pyrenew.deterministic import DeterministicVariable + >>> from pyrenew.randomvariable import GammaGroupSdPrior + >>> import numpyro + >>> + >>> mean_rv = DeterministicVariable("sd_mean", 0.3) + >>> conc_rv = DeterministicVariable("sd_conc", 4.0) + >>> prior = GammaGroupSdPrior("site_sd", mean_rv, conc_rv, sd_min=0.05) + >>> + >>> with numpyro.handlers.seed(rng_seed=42): + ... sds = prior.sample(n_groups=5) + >>> all(sds >= 0.05) + True + """ + + def __init__( + self, + name: str, + sd_mean_rv: RandomVariable, + sd_concentration_rv: RandomVariable, + sd_min: float = 0.05, + ) -> None: + """ + Default constructor for GammaGroupSdPrior. + + Parameters + ---------- + name : str + Unique name for the sampled parameter in numpyro. + sd_mean_rv : RandomVariable + RandomVariable returning the mean of the Gamma distribution. + sd_concentration_rv : RandomVariable + RandomVariable returning the concentration (shape) parameter. + sd_min : float, default=0.05 + Minimum SD value (lower bound). + + Returns + ------- + None + """ + if not isinstance(sd_mean_rv, RandomVariable): + raise TypeError( + f"sd_mean_rv must be a RandomVariable, got {type(sd_mean_rv).__name__}. " + "Use DeterministicVariable(name, value) to wrap a fixed value." + ) + if not isinstance(sd_concentration_rv, RandomVariable): + raise TypeError( + f"sd_concentration_rv must be a RandomVariable, got {type(sd_concentration_rv).__name__}. " + "Use DeterministicVariable(name, value) to wrap a fixed value." + ) + if sd_min < 0: + raise ValueError(f"sd_min must be non-negative, got {sd_min}") + + self.name = name + self.sd_mean_rv = sd_mean_rv + self.sd_concentration_rv = sd_concentration_rv + self.sd_min = sd_min + + def validate(self): + """Validate the random variable (no-op for this class).""" + pass + + def sample(self, n_groups: int, **kwargs): + """ + Sample group-level standard deviations. + + Parameters + ---------- + n_groups : int + Number of groups. + **kwargs + Additional keyword arguments (unused). + + Returns + ------- + ArrayLike + Array of shape (n_groups,) with values >= sd_min. + """ + sd_mean = self.sd_mean_rv() + concentration = self.sd_concentration_rv() + rate = concentration / sd_mean + + with numpyro.plate(f"n_{self.name}", n_groups): + raw_sd = numpyro.sample( + f"{self.name}_raw", + dist.Gamma(concentration, rate), + ) + + group_sd = numpyro.deterministic( + self.name, + jnp.maximum(raw_sd, self.sd_min), + ) + return group_sd + + +class StudentTGroupModePrior(RandomVariable): + """ + Zero-centered Student-t prior for group-level modes (robust alternative to Normal). + + Samples n_groups values from StudentT(df, 0, sd) within a numpyro plate context. + This is useful when group effects may have heavier tails than a Normal distribution. + + Parameters + ---------- + name : str + Unique name for the sampled parameter in numpyro. + sd_rv : RandomVariable + RandomVariable returning the scale parameter. + df_rv : RandomVariable + RandomVariable returning the degrees of freedom. + + Notes + ----- + The Student-t distribution approaches the Normal distribution as df -> infinity. + Lower df values give heavier tails, making the prior more robust to outliers. + Common choices include df=3 (heavy tails) or df=7 (moderate tails). + + Examples + -------- + >>> from pyrenew.deterministic import DeterministicVariable + >>> from pyrenew.randomvariable import StudentTGroupModePrior + >>> import numpyro + >>> + >>> sd_rv = DeterministicVariable("scale", 0.5) + >>> df_rv = DeterministicVariable("df", 4.0) + >>> prior = StudentTGroupModePrior("site_modes", sd_rv, df_rv) + >>> + >>> with numpyro.handlers.seed(rng_seed=42): + ... modes = prior.sample(n_groups=5) + >>> modes.shape + (5,) + """ + + def __init__( + self, + name: str, + sd_rv: RandomVariable, + df_rv: RandomVariable, + ) -> None: + """ + Default constructor for StudentTGroupModePrior. + + Parameters + ---------- + name : str + Unique name for the sampled parameter in numpyro. + sd_rv : RandomVariable + RandomVariable returning the scale parameter. + df_rv : RandomVariable + RandomVariable returning the degrees of freedom. + + Returns + ------- + None + """ + if not isinstance(sd_rv, RandomVariable): + raise TypeError( + f"sd_rv must be a RandomVariable, got {type(sd_rv).__name__}. " + "Use DeterministicVariable(name, value) to wrap a fixed value." + ) + if not isinstance(df_rv, RandomVariable): + raise TypeError( + f"df_rv must be a RandomVariable, got {type(df_rv).__name__}. " + "Use DeterministicVariable(name, value) to wrap a fixed value." + ) + + self.name = name + self.sd_rv = sd_rv + self.df_rv = df_rv + + def validate(self): + """Validate the random variable (no-op for this class).""" + pass + + def sample(self, n_groups: int, **kwargs): + """ + Sample group-level modes. + + Parameters + ---------- + n_groups : int + Number of groups. + **kwargs + Additional keyword arguments (unused). + + Returns + ------- + ArrayLike + Array of shape (n_groups,) containing sampled modes. + """ + sd = self.sd_rv() + df = self.df_rv() + + with numpyro.plate(f"n_{self.name}", n_groups): + effects = numpyro.sample( + self.name, + dist.StudentT(df=df, loc=0.0, scale=sd), + ) + return effects diff --git a/test/conftest.py b/test/conftest.py new file mode 100644 index 00000000..73854684 --- /dev/null +++ b/test/conftest.py @@ -0,0 +1,406 @@ +""" +Shared pytest fixtures for PyRenew tests. + +This module provides reusable fixtures for creating observation processes, +test data, and common configurations used across multiple test files. +""" + +import jax.numpy as jnp +import pytest + +from pyrenew.deterministic import DeterministicPMF, DeterministicVariable +from pyrenew.observation import Counts, NegativeBinomialNoise +from pyrenew.randomvariable import GammaGroupSdPrior, HierarchicalNormalPrior + +# ============================================================================= +# PMF Fixtures +# ============================================================================= + + +@pytest.fixture +def simple_delay_pmf(): + """ + Simple 1-day delay PMF (no delay). + + Returns + ------- + jnp.ndarray + A single-element PMF array representing no delay. + """ + return jnp.array([1.0]) + + +@pytest.fixture +def short_delay_pmf(): + """ + Short 2-day delay PMF. + + Returns + ------- + jnp.ndarray + A 2-element PMF array. + """ + return jnp.array([0.5, 0.5]) + + +@pytest.fixture +def medium_delay_pmf(): + """ + Medium 4-day delay PMF. + + Returns + ------- + jnp.ndarray + A 4-element PMF array. + """ + return jnp.array([0.1, 0.3, 0.4, 0.2]) + + +@pytest.fixture +def realistic_delay_pmf(): + """ + Realistic 10-day delay PMF (shifted gamma-like). + + Returns + ------- + jnp.ndarray + A 10-element PMF array with gamma-like shape. + """ + return jnp.array([0.01, 0.05, 0.10, 0.15, 0.20, 0.20, 0.15, 0.08, 0.04, 0.02]) + + +@pytest.fixture +def long_delay_pmf(): + """ + Long 10-day delay PMF for edge case testing. + + Returns + ------- + jnp.ndarray + A 10-element PMF array. + """ + return jnp.array([0.05, 0.1, 0.15, 0.2, 0.2, 0.15, 0.1, 0.03, 0.01, 0.01]) + + +@pytest.fixture +def simple_shedding_pmf(): + """ + Simple 1-day shedding PMF (no delay). + + Returns + ------- + jnp.ndarray + A single-element PMF array representing no shedding delay. + """ + return jnp.array([1.0]) + + +@pytest.fixture +def short_shedding_pmf(): + """ + Short 3-day shedding PMF. + + Returns + ------- + jnp.ndarray + A 3-element PMF array. + """ + return jnp.array([0.3, 0.4, 0.3]) + + +@pytest.fixture +def medium_shedding_pmf(): + """ + Medium 5-day shedding PMF. + + Returns + ------- + jnp.ndarray + A 5-element PMF array. + """ + return jnp.array([0.1, 0.3, 0.3, 0.2, 0.1]) + + +# ============================================================================= +# Hierarchical Prior Fixtures +# ============================================================================= + + +@pytest.fixture +def site_mode_prior(): + """ + Standard hierarchical normal prior for site modes. + + Returns + ------- + HierarchicalNormalPrior + A hierarchical normal prior with standard deviation 0.5. + """ + return HierarchicalNormalPrior( + name="ww_site_mode", sd_rv=DeterministicVariable("mode_sd", 0.5) + ) + + +@pytest.fixture +def site_mode_prior_tight(): + """ + Tight hierarchical normal prior for deterministic-like behavior. + + Returns + ------- + HierarchicalNormalPrior + A hierarchical normal prior with small standard deviation 0.01. + """ + return HierarchicalNormalPrior( + name="ww_site_mode", sd_rv=DeterministicVariable("mode_sd_tight", 0.01) + ) + + +@pytest.fixture +def site_sd_prior(): + """ + Standard gamma prior for site standard deviations. + + Returns + ------- + GammaGroupSdPrior + A gamma prior for group standard deviations. + """ + return GammaGroupSdPrior( + name="ww_site_sd", + sd_mean_rv=DeterministicVariable("sd_mean", 0.3), + sd_concentration_rv=DeterministicVariable("sd_concentration", 4.0), + sd_min=0.10, + ) + + +@pytest.fixture +def site_sd_prior_tight(): + """ + Tight gamma prior for deterministic-like behavior. + + Returns + ------- + GammaGroupSdPrior + A gamma prior with small mean for tight behavior. + """ + return GammaGroupSdPrior( + name="ww_site_sd", + sd_mean_rv=DeterministicVariable("sd_mean_tight", 0.01), + sd_concentration_rv=DeterministicVariable("sd_concentration_tight", 4.0), + sd_min=0.005, + ) + + +# ============================================================================= +# Counts Process Fixtures +# ============================================================================= + + +@pytest.fixture +def counts_process(simple_delay_pmf): + """ + Standard Counts observation process with simple delay. + + Returns + ------- + Counts + A Counts observation process with no delay. + """ + return Counts( + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), + ) + + +@pytest.fixture +def counts_process_medium_delay(medium_delay_pmf): + """ + Counts observation process with medium delay. + + Returns + ------- + Counts + A Counts observation process with 4-day delay. + """ + return Counts( + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", medium_delay_pmf), + noise=NegativeBinomialNoise(DeterministicVariable("conc", 50.0)), + ) + + +@pytest.fixture +def counts_process_realistic(realistic_delay_pmf): + """ + Counts observation process with realistic delay and ascertainment. + + Returns + ------- + Counts + A Counts observation process with realistic parameters. + """ + return Counts( + ascertainment_rate_rv=DeterministicVariable("ihr", 0.005), + delay_distribution_rv=DeterministicPMF("delay", realistic_delay_pmf), + noise=NegativeBinomialNoise(DeterministicVariable("conc", 100.0)), + ) + + +class CountsProcessFactory: + """Factory for creating Counts processes with custom parameters.""" + + @staticmethod + def create( + delay_pmf=None, + ascertainment_rate=0.01, + concentration=10.0, + ): + """ + Create a Counts process with specified parameters. + + Returns + ------- + Counts + A Counts observation process with the specified parameters. + """ + if delay_pmf is None: + delay_pmf = jnp.array([1.0]) + return Counts( + ascertainment_rate_rv=DeterministicVariable("ihr", ascertainment_rate), + delay_distribution_rv=DeterministicPMF("delay", delay_pmf), + noise=NegativeBinomialNoise(DeterministicVariable("conc", concentration)), + ) + + +@pytest.fixture +def counts_factory(): + """ + Factory fixture for creating custom Counts processes. + + Returns + ------- + CountsProcessFactory + A factory for creating Counts processes. + """ + return CountsProcessFactory() + + +# ============================================================================= +# Infection Fixtures +# ============================================================================= + + +@pytest.fixture +def constant_infections(): + """ + Constant infections array (30 days, 100 infections/day). + + Returns + ------- + jnp.ndarray + A 1D array of shape (30,) with constant value 100. + """ + return jnp.ones(30) * 100 + + +@pytest.fixture +def constant_infections_2d(): + """ + Constant infections array for 2 subpopulations. + + Returns + ------- + jnp.ndarray + A 2D array of shape (30, 2) with constant value 100. + """ + return jnp.ones((30, 2)) * 100 + + +def make_infections(n_days, n_subpops=None, value=100.0): + """ + Create infection arrays for testing. + + Parameters + ---------- + n_days : int + Number of days + n_subpops : int, optional + Number of subpopulations (None for 1D array) + value : float + Constant infection value + + Returns + ------- + jnp.ndarray + Infections array + """ + if n_subpops is None: + return jnp.ones(n_days) * value + return jnp.ones((n_days, n_subpops)) * value + + +def make_spike_infections(n_days, spike_day, spike_value=1000.0, n_subpops=None): + """ + Create spike infection arrays for testing. + + Parameters + ---------- + n_days : int + Number of days + spike_day : int + Day of the spike + spike_value : float + Value at spike + n_subpops : int, optional + Number of subpopulations + + Returns + ------- + jnp.ndarray + Infections array with spike + """ + if n_subpops is None: + infections = jnp.zeros(n_days) + return infections.at[spike_day].set(spike_value) + infections = jnp.zeros((n_days, n_subpops)) + return infections.at[spike_day, :].set(spike_value) + + +def create_mock_infections( + n_days: int, + peak_day: int = 10, + peak_value: float = 1000.0, + shape: str = "spike", +) -> jnp.ndarray: + """ + Create mock infection time series for testing. + + Parameters + ---------- + n_days : int + Number of days + peak_day : int + Day of peak infections + peak_value : float + Peak infection value + shape : str + Shape of the curve: "spike", "constant", or "decay" + + Returns + ------- + jnp.ndarray + Array of infections of shape (n_days,) + """ + if shape == "spike": + infections = jnp.zeros(n_days) + infections = infections.at[peak_day].set(peak_value) + elif shape == "constant": + infections = jnp.ones(n_days) * peak_value + elif shape == "decay": + infections = peak_value * jnp.exp(-jnp.arange(n_days) / 20.0) + else: + raise ValueError(f"Unknown shape: {shape}") + + return infections diff --git a/test/test_observation_counts.py b/test/test_observation_counts.py new file mode 100644 index 00000000..cd87d403 --- /dev/null +++ b/test/test_observation_counts.py @@ -0,0 +1,378 @@ +""" +Unit tests for Counts (aggregated count observations). + +These tests validate the count observation process implementation. +""" + +import jax.numpy as jnp +import numpyro +import numpyro.distributions as dist +import pytest + +from pyrenew.deterministic import DeterministicPMF, DeterministicVariable +from pyrenew.observation import ( + Counts, + CountsBySite, + NegativeBinomialNoise, + PoissonNoise, +) +from pyrenew.randomvariable import DistributionalVariable + + +def create_mock_infections( + n_days: int, + peak_day: int = 10, + peak_value: float = 1000.0, + shape: str = "spike", +) -> jnp.ndarray: + """ + Create mock infection time series for testing. + + Parameters + ---------- + n_days : int + Number of days + peak_day : int + Day of peak infections + peak_value : float + Peak infection value + shape : str + Shape of the curve: "spike", "constant", or "decay" + + Returns + ------- + jnp.ndarray + Array of infections of shape (n_days,) + """ + if shape == "spike": + infections = jnp.zeros(n_days) + infections = infections.at[peak_day].set(peak_value) + elif shape == "constant": + infections = jnp.ones(n_days) * peak_value + elif shape == "decay": + infections = peak_value * jnp.exp(-jnp.arange(n_days) / 20.0) + else: + raise ValueError(f"Unknown shape: {shape}") + + return infections + + +class TestCountsBasics: + """Test basic functionality of aggregated count observation process.""" + + def test_sample_returns_correct_shape(self, counts_process): + """Test that sample returns correct shape.""" + infections = jnp.ones(30) * 100 + + with numpyro.handlers.seed(rng_seed=42): + counts = counts_process.sample( + infections=infections, + counts=None, + ) + + assert counts.shape[0] > 0 + assert counts.ndim == 1 + + def test_delay_convolution(self, counts_factory, short_delay_pmf): + """Test that delay is properly applied.""" + process = counts_factory.create(delay_pmf=short_delay_pmf) + + infections = jnp.zeros(30) + infections = infections.at[10].set(1000) + + with numpyro.handlers.seed(rng_seed=42): + counts = process.sample( + infections=infections, + counts=None, + ) + + # Timeline alignment: output length equals input length + assert counts.shape[0] == len(infections) + # First len(delay_pmf)-1 days are NaN (appear as -1 after NegativeBinomial sampling) + assert jnp.all(counts[1:] >= 0) + assert jnp.sum(counts[counts >= 0]) > 0 + + def test_ascertainment_scaling(self, counts_factory, simple_delay_pmf): + """Test that ascertainment rate properly scales counts.""" + infections = jnp.ones(20) * 100 + + results = [] + for rate_value in [0.01, 0.02, 0.05]: + process = counts_factory.create( + delay_pmf=simple_delay_pmf, + ascertainment_rate=rate_value, + ) + + with numpyro.handlers.seed(rng_seed=42): + counts = process.sample( + infections=infections, + counts=None, + ) + results.append(jnp.mean(counts)) + + # Higher ascertainment rate should lead to more counts + assert results[1] > results[0] + assert results[2] > results[1] + + def test_negative_binomial_observation(self, counts_factory, simple_delay_pmf): + """Test that negative binomial observation is used.""" + process = counts_factory.create( + delay_pmf=simple_delay_pmf, + concentration=5.0, + ) + + infections = jnp.ones(20) * 100 + + samples = [] + for seed in range(5): + with numpyro.handlers.seed(rng_seed=seed): + counts = process.sample( + infections=infections, + counts=None, + ) + samples.append(jnp.sum(counts)) + + # Should have some variability due to negative binomial sampling + assert jnp.std(jnp.array(samples)) > 0 + + +class TestCountsWithPriors: + """Test aggregated count observation with uncertain parameters.""" + + def test_with_stochastic_ascertainment(self, short_shedding_pmf): + """Test with uncertain ascertainment rate parameter.""" + delay = DeterministicPMF("delay", jnp.array([0.2, 0.5, 0.3])) + ascertainment = DistributionalVariable("ihr", dist.Beta(2, 100)) + concentration = DeterministicVariable("conc", 10.0) + + process = Counts( + ascertainment_rate_rv=ascertainment, + delay_distribution_rv=delay, + noise=NegativeBinomialNoise(concentration), + ) + + infections = jnp.ones(20) * 100 + + with numpyro.handlers.seed(rng_seed=42): + counts = process.sample( + infections=infections, + counts=None, + ) + + assert counts.shape[0] > 0 + # Skip NaN padding + valid_counts = counts[2:] + assert jnp.all(valid_counts >= 0) + + def test_with_stochastic_concentration(self, simple_delay_pmf): + """Test with uncertain concentration parameter.""" + delay = DeterministicPMF("delay", simple_delay_pmf) + ascertainment = DeterministicVariable("ihr", 0.01) + concentration = DistributionalVariable("conc", dist.HalfNormal(10.0)) + + process = Counts( + ascertainment_rate_rv=ascertainment, + delay_distribution_rv=delay, + noise=NegativeBinomialNoise(concentration), + ) + + infections = jnp.ones(20) * 100 + + with numpyro.handlers.seed(rng_seed=42): + counts = process.sample( + infections=infections, + counts=None, + ) + + assert counts.shape[0] > 0 + assert jnp.all(counts >= 0) + + +class TestCountsEdgeCases: + """Test edge cases and error handling.""" + + def test_zero_infections(self, counts_process): + """Test with zero infections.""" + infections = jnp.zeros(20) + + with numpyro.handlers.seed(rng_seed=42): + counts = counts_process.sample( + infections=infections, + counts=None, + ) + + assert counts.shape[0] > 0 + + def test_small_infections(self, counts_process): + """Test with small infection values.""" + infections = jnp.ones(20) * 10 + + with numpyro.handlers.seed(rng_seed=42): + counts = counts_process.sample( + infections=infections, + counts=None, + ) + + assert counts.shape[0] > 0 + assert jnp.all(counts >= 0) + + def test_long_delay_distribution(self, counts_factory, long_delay_pmf): + """Test with longer delay distribution.""" + process = counts_factory.create(delay_pmf=long_delay_pmf) + + infections = create_mock_infections(40, peak_day=20, shape="spike") + + with numpyro.handlers.seed(rng_seed=42): + counts = process.sample( + infections=infections, + counts=None, + ) + + # Timeline alignment maintained + assert counts.shape[0] == infections.shape[0] + # Skip NaN padding: 10-day delay -> first 9 days are NaN + valid_counts = counts[9:] + assert jnp.sum(valid_counts) > 0 + + +class TestCountsSparseObservations: + """Test sparse observation support.""" + + def test_sparse_observations(self, counts_process): + """Test with sparse (irregular) observations.""" + n_days = 30 + infections = jnp.ones(n_days) * 100 + + # Sparse observations: only days 5, 10, 15, 20 + times = jnp.array([5, 10, 15, 20]) + counts_data = jnp.array([10, 12, 8, 15]) + + with numpyro.handlers.seed(rng_seed=42): + counts = counts_process.sample( + infections=infections, + counts=counts_data, + times=times, + ) + + assert counts.shape == times.shape + assert jnp.allclose(counts, counts_data) + + def test_sparse_vs_dense_sampling(self, counts_process): + """Test that sparse sampling gives different output shape than dense.""" + n_days = 30 + infections = jnp.ones(n_days) * 100 + + # Dense: prior sampling (counts=None, no times) + with numpyro.handlers.seed(rng_seed=42): + dense_counts = counts_process.sample( + infections=infections, + counts=None, + ) + + # Sparse with observed data: only some days + times = jnp.array([5, 10, 15, 20]) + sparse_obs_data = jnp.array([10, 12, 8, 15]) + with numpyro.handlers.seed(rng_seed=42): + sparse_counts = counts_process.sample( + infections=infections, + counts=sparse_obs_data, + times=times, + ) + + # Dense prior produces full length output + assert dense_counts.shape == (n_days,) + + # Sparse observations produce output matching times shape + assert sparse_counts.shape == times.shape + assert jnp.allclose(sparse_counts, sparse_obs_data) + + def test_prior_sampling_ignores_times(self, counts_process): + """Test that times parameter is ignored when counts=None (prior sampling).""" + n_days = 30 + infections = jnp.ones(n_days) * 100 + times = jnp.array([5, 10, 15, 20]) + + # When counts=None, times is ignored - output is dense + with numpyro.handlers.seed(rng_seed=42): + prior_with_times = counts_process.sample( + infections=infections, + counts=None, + times=times, + ) + + with numpyro.handlers.seed(rng_seed=42): + prior_without_times = counts_process.sample( + infections=infections, + counts=None, + ) + + # Both should produce dense output of shape (n_days,) + assert prior_with_times.shape == (n_days,) + assert prior_without_times.shape == (n_days,) + # With same seed, outputs should be identical + assert jnp.allclose(prior_with_times, prior_without_times) + + +class TestCountsBySite: + """Test CountsBySite for disaggregated observations.""" + + def test_sample_returns_correct_shape(self): + """Test that CountsBySite sample returns correct shape.""" + delay_pmf = jnp.array([0.3, 0.4, 0.3]) + process = CountsBySite( + ascertainment_rate_rv=DeterministicVariable("ihr", 0.02), + delay_distribution_rv=DeterministicPMF("delay", delay_pmf), + noise=PoissonNoise(), + ) + + infections = jnp.ones((30, 3)) * 500 # 30 days, 3 sites + times = jnp.array([10, 15, 10, 15]) + subpop_indices = jnp.array([0, 0, 1, 1]) + + with numpyro.handlers.seed(rng_seed=42): + counts = process.sample( + infections=infections, + subpop_indices=subpop_indices, + times=times, + counts=None, + ) + + assert counts.shape == times.shape + + def test_infection_resolution(self): + """Test that CountsBySite returns 'site' resolution.""" + delay_pmf = jnp.array([1.0]) + process = CountsBySite( + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", delay_pmf), + noise=PoissonNoise(), + ) + + assert process.infection_resolution() == "site" + + +class TestPoissonNoise: + """Test PoissonNoise model.""" + + def test_poisson_counts(self, simple_delay_pmf): + """Test Counts with Poisson noise.""" + process = Counts( + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=PoissonNoise(), + ) + + infections = jnp.ones(20) * 1000 + + with numpyro.handlers.seed(rng_seed=42): + counts = process.sample( + infections=infections, + counts=None, + ) + + assert counts.shape[0] == 20 + assert jnp.all(counts >= 0) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/test/test_observation_measurements.py b/test/test_observation_measurements.py new file mode 100644 index 00000000..2ca11a22 --- /dev/null +++ b/test/test_observation_measurements.py @@ -0,0 +1,234 @@ +""" +Unit tests for Measurements (continuous measurement observations). + +These tests validate the measurement observation process base class implementation. +""" + +import jax.numpy as jnp +import numpyro +import pytest + +from pyrenew.deterministic import DeterministicPMF, DeterministicVariable +from pyrenew.observation import HierarchicalNormalNoise, Measurements +from pyrenew.observation.base import BaseObservationProcess +from pyrenew.randomvariable import GammaGroupSdPrior, HierarchicalNormalPrior + + +class ConcreteMeasurements(Measurements): + """Concrete implementation of Measurements for testing.""" + + def __init__(self, temporal_pmf_rv, noise, log10_scale=9.0): + """Initialize the concrete measurements for testing.""" + super().__init__(temporal_pmf_rv=temporal_pmf_rv, noise=noise) + self.log10_scale = log10_scale + + def validate(self) -> None: + """Validate parameters.""" + pmf = self.temporal_pmf_rv() + self._validate_pmf(pmf, "temporal_pmf_rv") + + def get_required_lookback(self) -> int: + """ + Return temporal PMF length. + + Returns + ------- + int + Length of the temporal PMF. + """ + return len(self.temporal_pmf_rv()) + + def _expected_signal(self, infections): + """ + Simple expected signal: log(convolution * scale). + + Returns + ------- + jnp.ndarray + Log-transformed expected signal. + """ + pmf = self.temporal_pmf_rv() + + # Handle 2D infections (n_days, n_sites) + if infections.ndim == 1: + infections = infections[:, jnp.newaxis] + + def convolve_col(col): # numpydoc ignore=GL08 + return self._convolve_with_alignment(col, pmf, 1.0)[0] + + import jax + + expected = jax.vmap(convolve_col, in_axes=1, out_axes=1)(infections) + + # Apply log10 scaling (simplified from wastewater model) + log_expected = jnp.log(expected + 1e-10) + self.log10_scale * jnp.log(10) + + return log_expected + + +class TestMeasurementsBase: + """Test Measurements abstract base class.""" + + def test_is_base_observation_process(self): + """Test that Measurements inherits from BaseObservationProcess.""" + assert issubclass(Measurements, BaseObservationProcess) + + def test_infection_resolution_is_site(self): + """Test that Measurements returns 'site' resolution.""" + shedding_pmf = jnp.array([0.3, 0.4, 0.3]) + site_mode_prior = HierarchicalNormalPrior( + name="mode", sd_rv=DeterministicVariable("sd", 0.5) + ) + site_sd_prior = GammaGroupSdPrior( + name="sd", + sd_mean_rv=DeterministicVariable("mean", 0.3), + sd_concentration_rv=DeterministicVariable("conc", 4.0), + ) + noise = HierarchicalNormalNoise(site_mode_prior, site_sd_prior) + + process = ConcreteMeasurements( + temporal_pmf_rv=DeterministicPMF("shedding", shedding_pmf), + noise=noise, + ) + + assert process.infection_resolution() == "site" + + +class TestHierarchicalNormalNoise: + """Test HierarchicalNormalNoise model.""" + + def test_sample_shape(self): + """Test that HierarchicalNormalNoise produces correct shape.""" + site_mode_prior = HierarchicalNormalPrior( + name="mode", sd_rv=DeterministicVariable("mode_sd", 0.5) + ) + site_sd_prior = GammaGroupSdPrior( + name="sd", + sd_mean_rv=DeterministicVariable("sd_mean", 0.3), + sd_concentration_rv=DeterministicVariable("sd_conc", 4.0), + ) + noise = HierarchicalNormalNoise(site_mode_prior, site_sd_prior) + + expected = jnp.array([1.0, 2.0, 3.0, 4.0]) + site_indices = jnp.array([0, 0, 1, 1]) + + with numpyro.handlers.seed(rng_seed=42): + samples = noise.sample( + name="test", + expected=expected, + obs=None, + site_indices=site_indices, + n_sites=2, + ) + + assert samples.shape == expected.shape + + def test_sample_with_observations(self): + """Test that HierarchicalNormalNoise conditions on observations.""" + site_mode_prior = HierarchicalNormalPrior( + name="mode", sd_rv=DeterministicVariable("mode_sd", 0.5) + ) + site_sd_prior = GammaGroupSdPrior( + name="sd", + sd_mean_rv=DeterministicVariable("sd_mean", 0.3), + sd_concentration_rv=DeterministicVariable("sd_conc", 4.0), + ) + noise = HierarchicalNormalNoise(site_mode_prior, site_sd_prior) + + expected = jnp.array([1.0, 2.0, 3.0, 4.0]) + obs = jnp.array([1.1, 2.1, 3.1, 4.1]) + site_indices = jnp.array([0, 0, 1, 1]) + + with numpyro.handlers.seed(rng_seed=42): + samples = noise.sample( + name="test", + expected=expected, + obs=obs, + site_indices=site_indices, + n_sites=2, + ) + + # When obs is provided, samples should equal obs + assert jnp.allclose(samples, obs) + + +class TestConcreteMeasurements: + """Test concrete Measurements implementation.""" + + def test_sample_shape(self): + """Test that sample returns correct shape.""" + shedding_pmf = jnp.array([0.3, 0.4, 0.3]) + site_mode_prior = HierarchicalNormalPrior( + name="mode", sd_rv=DeterministicVariable("mode_sd", 0.5) + ) + site_sd_prior = GammaGroupSdPrior( + name="sd", + sd_mean_rv=DeterministicVariable("sd_mean", 0.3), + sd_concentration_rv=DeterministicVariable("sd_conc", 4.0), + ) + noise = HierarchicalNormalNoise(site_mode_prior, site_sd_prior) + + process = ConcreteMeasurements( + temporal_pmf_rv=DeterministicPMF("shedding", shedding_pmf), + noise=noise, + ) + + # 30 days, 2 sites + infections = jnp.ones((30, 2)) * 1000 + subpop_indices = jnp.array([0, 0, 1, 1]) + site_indices = jnp.array([0, 0, 1, 1]) + times = jnp.array([10, 15, 10, 15]) + + with numpyro.handlers.seed(rng_seed=42): + samples = process.sample( + infections=infections, + subpop_indices=subpop_indices, + site_indices=site_indices, + times=times, + concentrations=None, + n_sites=2, + ) + + assert samples.shape == times.shape + + def test_expected_signal_stored(self): + """Test that expected_log_conc is stored as deterministic.""" + shedding_pmf = jnp.array([0.5, 0.5]) + site_mode_prior = HierarchicalNormalPrior( + name="mode", sd_rv=DeterministicVariable("mode_sd", 0.01) + ) + site_sd_prior = GammaGroupSdPrior( + name="sd", + sd_mean_rv=DeterministicVariable("sd_mean", 0.01), + sd_concentration_rv=DeterministicVariable("sd_conc", 4.0), + sd_min=0.001, + ) + noise = HierarchicalNormalNoise(site_mode_prior, site_sd_prior) + + process = ConcreteMeasurements( + temporal_pmf_rv=DeterministicPMF("shedding", shedding_pmf), + noise=noise, + ) + + infections = jnp.ones((20, 2)) * 1000 + subpop_indices = jnp.array([0, 1]) + site_indices = jnp.array([0, 1]) + times = jnp.array([10, 10]) + + with numpyro.handlers.seed(rng_seed=42): + trace = numpyro.handlers.trace( + lambda: process.sample( + infections=infections, + subpop_indices=subpop_indices, + site_indices=site_indices, + times=times, + concentrations=None, + n_sites=2, + ) + ).get_trace() + + assert "expected_log_conc" in trace + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/test/test_randomvariable_hierarchical.py b/test/test_randomvariable_hierarchical.py new file mode 100644 index 00000000..30384800 --- /dev/null +++ b/test/test_randomvariable_hierarchical.py @@ -0,0 +1,179 @@ +"""Unit tests for hierarchical prior distributions.""" + +import jax.numpy as jnp +import numpyro +import numpyro.distributions as dist +import pytest + +from pyrenew.deterministic import DeterministicVariable +from pyrenew.randomvariable import ( + DistributionalVariable, + GammaGroupSdPrior, + HierarchicalNormalPrior, + StudentTGroupModePrior, +) + + +class TestHierarchicalNormalPrior: + """Test HierarchicalNormalPrior.""" + + def test_sample_shape(self): + """Test that sample returns correct shape.""" + prior = HierarchicalNormalPrior( + "effect", sd_rv=DeterministicVariable("sd", 1.0) + ) + + with numpyro.handlers.seed(rng_seed=42): + samples = prior.sample(n_groups=5) + + assert samples.shape == (5,) + + def test_smaller_sd_produces_tighter_distribution(self): + """Test that smaller sd produces samples closer to zero.""" + prior_tight = HierarchicalNormalPrior( + "a", sd_rv=DeterministicVariable("sd_tight", 0.1) + ) + prior_wide = HierarchicalNormalPrior( + "b", sd_rv=DeterministicVariable("sd_wide", 10.0) + ) + + n_samples = 1000 + with numpyro.handlers.seed(rng_seed=42): + samples_tight = prior_tight.sample(n_groups=n_samples) + with numpyro.handlers.seed(rng_seed=43): + samples_wide = prior_wide.sample(n_groups=n_samples) + + # Tight prior should have smaller standard deviation + assert jnp.std(samples_tight) < jnp.std(samples_wide) + + def test_rejects_non_random_variable_sd(self): + """Test that non-RandomVariable sd_rv is rejected.""" + with pytest.raises(TypeError, match="sd_rv must be a RandomVariable"): + HierarchicalNormalPrior("effect", sd_rv=1.0) + + def test_accepts_distributional_variable_for_sd(self): + """Test that DistributionalVariable can be used for sd_rv.""" + sd_rv = DistributionalVariable("sd", dist.HalfNormal(1.0)) + prior = HierarchicalNormalPrior("effect", sd_rv=sd_rv) + + with numpyro.handlers.seed(rng_seed=42): + samples = prior.sample(n_groups=5) + + assert samples.shape == (5,) + + +class TestGammaGroupSdPrior: + """Test GammaGroupSdPrior.""" + + def test_sample_shape(self): + """Test that sample returns correct shape.""" + prior = GammaGroupSdPrior( + "sd", + sd_mean_rv=DeterministicVariable("sd_mean", 0.5), + sd_concentration_rv=DeterministicVariable("sd_conc", 4.0), + ) + + with numpyro.handlers.seed(rng_seed=42): + samples = prior.sample(n_groups=5) + + assert samples.shape == (5,) + + def test_respects_sd_min(self): + """Test that sd_min is enforced as lower bound.""" + prior = GammaGroupSdPrior( + "sd", + sd_mean_rv=DeterministicVariable("sd_mean", 0.1), + sd_concentration_rv=DeterministicVariable("sd_conc", 4.0), + sd_min=0.5, + ) + + with numpyro.handlers.seed(rng_seed=42): + samples = prior.sample(n_groups=100) + + assert jnp.all(samples >= 0.5) + + def test_rejects_non_random_variable_params(self): + """Test that non-RandomVariable parameters are rejected.""" + with pytest.raises(TypeError, match="sd_mean_rv must be a RandomVariable"): + GammaGroupSdPrior( + "sd", + sd_mean_rv=0.5, + sd_concentration_rv=DeterministicVariable("sd_conc", 4.0), + ) + + with pytest.raises( + TypeError, match="sd_concentration_rv must be a RandomVariable" + ): + GammaGroupSdPrior( + "sd", + sd_mean_rv=DeterministicVariable("sd_mean", 0.5), + sd_concentration_rv=4.0, + ) + + def test_rejects_negative_sd_min(self): + """Test that negative sd_min is rejected.""" + with pytest.raises(ValueError, match="sd_min must be non-negative"): + GammaGroupSdPrior( + "sd", + sd_mean_rv=DeterministicVariable("sd_mean", 0.5), + sd_concentration_rv=DeterministicVariable("sd_conc", 4.0), + sd_min=-0.1, + ) + + +class TestStudentTGroupModePrior: + """Test StudentTGroupModePrior.""" + + def test_sample_shape(self): + """Test that sample returns correct shape.""" + prior = StudentTGroupModePrior( + "mode", + sd_rv=DeterministicVariable("sd", 1.0), + df_rv=DeterministicVariable("df", 4.0), + ) + + with numpyro.handlers.seed(rng_seed=42): + samples = prior.sample(n_groups=5) + + assert samples.shape == (5,) + + def test_heavier_tails_than_normal(self): + """Test Student-t produces more extreme values than Normal.""" + # df=2 gives very heavy tails + student_prior = StudentTGroupModePrior( + "s", + sd_rv=DeterministicVariable("sd_s", 1.0), + df_rv=DeterministicVariable("df", 2.0), + ) + normal_prior = HierarchicalNormalPrior( + "n", sd_rv=DeterministicVariable("sd_n", 1.0) + ) + + n_samples = 5000 + with numpyro.handlers.seed(rng_seed=42): + student_samples = student_prior.sample(n_groups=n_samples) + with numpyro.handlers.seed(rng_seed=42): + normal_samples = normal_prior.sample(n_groups=n_samples) + + # Student-t should have more extreme values (higher max absolute value) + assert jnp.max(jnp.abs(student_samples)) > jnp.max(jnp.abs(normal_samples)) + + def test_rejects_non_random_variable_params(self): + """Test that non-RandomVariable parameters are rejected.""" + with pytest.raises(TypeError, match="sd_rv must be a RandomVariable"): + StudentTGroupModePrior( + "mode", + sd_rv=1.0, + df_rv=DeterministicVariable("df", 4.0), + ) + + with pytest.raises(TypeError, match="df_rv must be a RandomVariable"): + StudentTGroupModePrior( + "mode", + sd_rv=DeterministicVariable("sd", 1.0), + df_rv=4.0, + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) From 89d7aab5fcca52ecb52432e8413aace4df5d7f11 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Tue, 23 Dec 2025 10:56:50 -0500 Subject: [PATCH 02/33] adding tutorials --- docs/tutorials/.pages | 2 + .../observation_processes_counts.qmd | 425 +++++++++++++ .../observation_processes_measurements.qmd | 569 ++++++++++++++++++ pyrenew/observation/__init__.py | 19 + pyrenew/observation/base.py | 8 + pyrenew/observation/noise.py | 12 + 6 files changed, 1035 insertions(+) create mode 100644 docs/tutorials/observation_processes_counts.qmd create mode 100644 docs/tutorials/observation_processes_measurements.qmd diff --git a/docs/tutorials/.pages b/docs/tutorials/.pages index 875214fc..4f1079b3 100644 --- a/docs/tutorials/.pages +++ b/docs/tutorials/.pages @@ -3,5 +3,7 @@ nav: - basic_renewal_model.md - custom_randomvariables.md - hospital_admissions_model.md + - observation_processes_counts.md + - observation_processes_measurements.md - day_of_the_week.md - periodic_effects.md diff --git a/docs/tutorials/observation_processes_counts.qmd b/docs/tutorials/observation_processes_counts.qmd new file mode 100644 index 00000000..83279331 --- /dev/null +++ b/docs/tutorials/observation_processes_counts.qmd @@ -0,0 +1,425 @@ +--- +title: "Observation processes for count data" +format: gfm +engine: jupyter +--- + +This tutorial demonstrates how to use the `Counts` observation process to model count data such as hospital admissions, emergency department visits, or deaths. + +```{python} +# | label: setup +# | output: false +import jax.numpy as jnp +import numpy as np +import numpyro +import matplotlib.pyplot as plt + +from pyrenew.observation import Counts, NegativeBinomialNoise, PoissonNoise +from pyrenew.deterministic import DeterministicVariable, DeterministicPMF +from pyrenew import datasets +``` + +## Overview + +Count observation processes model the lag between infections and an observed outcome such as hospital admissions, emergency department visits, confirmed cases, or deaths. +Observed data can be aggregated at the jurisdiction level or available as disaggregated counts from reporting entities, which are modeled by classes `Counts` and `CountsBySite`, respectively. + +Count observation processes transform infections into expected observed counts by applying an ascertainment rate and convolving with a delay distribution. + +The expected observations on day $t$ are: + +$$\lambda_t = \alpha \cdot \sum_{d=0}^{D} I_{t-d} \cdot p_d$$ + +where: + +- $I_{t-d}$ is the number of incident (new) infections on day $t-d$ (i.e., $d$ days before day $t$) +- $\alpha$ is the ascertainment rate (e.g., infection-hospitalization ratio) +- $p_d$ is the delay distribution from infection to observation, conditional on an infection leading to an observation +- $D$ is the maximum delay + +Discrete observations are generated by sampling from a noise distribution—either Poisson or negative binomial—to model reporting variability. +Poisson assumes variance equals the mean; negative binomial accommodates the overdispersion common in surveillance data. + +**Note on terminology:** In real-world inference, infections are *latent* (unobserved) and must be estimated from observed data like hospital admissions. In this tutorial, we simulate the observation process by specifying infections directly and showing how they produce hospital admissions through convolution and sampling. + +## Hospital admissions example + +For hospital admissions data, we construct a `Counts` observation process. +The delay is the key mechanism: infections from $d$ days ago ($I_{t-d}$) contribute to today's hospital admissions ($\lambda_t$) weighted by the probability ($p_d$) that an infection leads to hospitalization after exactly $d$ days. The convolution sums these contributions across all past days. + +The process generates hospital admissions by sampling from a negative binomial distribution: +$$Y_t \sim \text{NegativeBinomial}(\mu = \lambda_t, \text{concentration} = \phi)$$ + +The concentration parameter $\phi$ (sometimes called $k$ or the dispersion parameter) controls overdispersion: as $\phi \to \infty$, the distribution approaches Poisson; smaller values allow greater overdispersion. + +We use the negative binomial distribution because real-world hospital admission counts exhibit overdispersion—the variance exceeds the mean. +The Poisson distribution assumes variance equals the mean, which is too restrictive. The negative binomial adds an overdispersion term: +$$\text{Var}[Y_t] = \mu + \frac{\mu^2}{\phi}$$ + +In this example, we use fixed parameter values for illustration; in practice, these parameters would be estimated from data using weakly informative priors. + +## Infection-to-hospitalization delay distribution + +The delay distribution specifies the probability that an infected person is hospitalized $d$ days after infection, conditional on the infection leading to a hospitalization. +For example, if `hosp_delay_pmf[5] = 0.2`, then 20% of infections that result in hospitalization will appear as hospital admissions 5 days after infection. + +We load a delay distribution from PyRenew's datasets: + +```{python} +# | label: load-delay +inf_hosp_int = datasets.load_infection_admission_interval() +hosp_delay_pmf = jnp.array(inf_hosp_int["probability_mass"].to_numpy()) + +delay_rv = DeterministicPMF("inf_to_hosp_delay", hosp_delay_pmf) + +# Summary statistics +days = np.arange(len(hosp_delay_pmf)) +mean_delay = float(np.sum(days * hosp_delay_pmf)) +mode_delay = int(np.argmax(hosp_delay_pmf)) +print(f"Mode delay: {mode_delay} days, Mean delay: {mean_delay:.1f} days") +``` + +```{python} +# | label: fig-delay-distribution +# | fig-cap: Infection-to-hospitalization delay distribution +fig, ax = plt.subplots(figsize=(8, 4)) +ax.bar(days, hosp_delay_pmf, color="steelblue", alpha=0.7, edgecolor="black") +ax.axvline( + mode_delay, + color="purple", + linestyle="-", + linewidth=2, + label=f"Mode: {mode_delay}", +) +ax.axvline( + mean_delay, + color="red", + linestyle="--", + linewidth=2, + label=f"Mean: {mean_delay:.1f}", +) +ax.set_xlabel("Days from infection to hospitalization") +ax.set_ylabel("Probability") +ax.set_title("Infection-to-Hospitalization Delay Distribution") +ax.legend() +plt.tight_layout() +plt.show() +``` + +## Creating a Counts observation process + +A `Counts` object takes the following arguments: + +- **`ascertainment_rate_rv`**: the probability an infection results in an observation (e.g., IHR) +- **`delay_distribution_rv`**: delay distribution from infection to observation (PMF) +- **`noise`**: noise model (`PoissonNoise()` or `NegativeBinomialNoise(concentration_rv)`) + +```{python} +# | label: create-counts-process +# Infection-hospitalization ratio (1% of infections lead to hospitalization) +ihr_rv = DeterministicVariable("ihr", 0.01) + +# Overdispersion parameter for negative binomial +concentration_rv = DeterministicVariable("concentration", 10.0) + +# Create the observation process +hosp_process = Counts( + ascertainment_rate_rv=ihr_rv, + delay_distribution_rv=delay_rv, + noise=NegativeBinomialNoise(concentration_rv), +) + +print(f"Required lookback: {hosp_process.get_required_lookback()} days") +``` + +### Timeline alignment and lookback period + +The observation process convolves infections with a delay distribution, maintaining alignment between input and output: day $t$ in the output corresponds to day $t$ in the input. + +Hospital admissions depend on infections from prior days (the length of our delay distribution minus one). The method `get_required_lookback()` returns this length; the first valid observation day is at index `lookback - 1`. Earlier days are marked invalid. + +```{python} +# | label: helper-function +def first_valid_observation_day(obs_process) -> int: + """Return the first day index with complete infection history for convolution.""" + return obs_process.get_required_lookback() - 1 +``` + +## Simulating hospital admissions from infections + +To demonstrate how the observation process works, we simulate admissions from a spike of infections on a single day. + +```{python} +# | label: simulate-spike +n_days = 100 +lookback = hosp_process.get_required_lookback() + +# First valid observation day +day_one = lookback - 1 + +# Create infections with a spike +infection_spike_day = day_one + 10 +infections = jnp.zeros(n_days) +infections = infections.at[infection_spike_day].set(2000) + +# Sample hospital admissions +with numpyro.handlers.seed(rng_seed=42): + hosp_admissions = hosp_process.sample( + infections=infections, + counts=None, # Sample from prior (no observed data) + ) +``` + +```{python} +# | label: fig-spike-infections +# | fig-cap: Input infections with a single-day spike +# Plot relative to first valid observation day +spike_day = infection_spike_day - day_one +n_plot_days = n_days - day_one + +fig, ax = plt.subplots(figsize=(8, 4)) +ax.plot( + np.arange(n_plot_days), + np.array(infections[day_one:]), + "-o", + color="darkblue", +) +ax.axvline(spike_day, color="darkred", linestyle="--", alpha=0.7) +ax.annotate( + f"Infection spike\n(day {spike_day})", + xy=(spike_day, 1800), + xytext=(spike_day + 5, 1800), + fontsize=10, + color="darkred", +) +ax.set_xlabel("Day") +ax.set_ylabel("Daily Infections") +ax.set_title("Infections (Input)") +plt.tight_layout() +plt.show() +``` + +Because all infections occur on a single day, we can see how they spread into hospital admissions over subsequent days according to the delay distribution: + +```{python} +# | label: fig-spike-admissions +# | fig-cap: Hospital admissions from a single-day infection spike +fig, ax = plt.subplots(figsize=(8, 4)) +ax.plot( + np.arange(n_plot_days), + np.array(hosp_admissions[day_one:]), + "-o", + color="purple", +) +ax.axvline( + spike_day, + color="darkred", + linestyle="--", + alpha=0.5, + label="Infection spike", +) +ax.axvline( + spike_day + mode_delay, + color="purple", + linestyle="--", + alpha=0.5, + label="Peak admissions", +) +ax.set_xlabel("Day") +ax.set_ylabel("Hospital Admissions") +ax.set_title("Hospital Admissions (Output)") +ax.legend() +plt.tight_layout() +plt.show() +``` + +The admissions peak occurs `{python} mode_delay` days after the infection spike, matching the mode of the delay distribution. + +## Observation noise + +The negative binomial distribution adds stochastic variation. Sampling multiple times from the same infections shows the range of possible observations: + +```{python} +# | label: sample-realizations +n_samples = 50 +samples = [] + +for seed in range(n_samples): + with numpyro.handlers.seed(rng_seed=seed): + sample = hosp_process.sample(infections=infections, counts=None) + samples.append(np.array(sample[day_one:])) + +samples = np.array(samples) +sample_mean = samples.mean(axis=0) +``` + +```{python} +# | label: fig-sampled-admissions +# | fig-cap: Multiple realizations showing observation noise +fig, ax = plt.subplots(figsize=(8, 4)) + +# Plot all samples in light orange +for i in range(1, n_samples): + ax.plot( + np.arange(n_plot_days), + samples[i], + color="orange", + alpha=0.15, + linewidth=0.5, + ) + +# Highlight one sample +ax.plot( + np.arange(n_plot_days), + samples[0], + color="steelblue", + linewidth=1, + label="One realization", +) + +# Sample mean +ax.plot( + np.arange(n_plot_days), + sample_mean, + color="darkred", + linewidth=1.2, + label="Sample mean", +) + +ax.axvline(spike_day, color="darkblue", linestyle="--", alpha=0.5) +ax.set_xlabel("Day") +ax.set_ylabel("Hospital Admissions") +ax.set_title(f"Observation Noise: {n_samples} Samples from Same Infections") +ax.legend() +plt.tight_layout() +plt.show() +``` + +```{python} +# | label: timeline-stats +# Print timeline statistics +print("Timeline Analysis:") +print( + f" Infection spike on day {spike_day}: {infections[infection_spike_day]:.0f} people" +) +print(f" Mode delay from infection to hospitalization: {mode_delay} days") +print( + f" Expected hospitalization peak: day {spike_day + mode_delay} (= {spike_day} + {mode_delay})" +) +``` + +## Effect of the ascertainment rate + +The ascertainment rate directly scales expected hospital admissions. We compare IHR values of 0.5% and 2.5%: + +```{python} +# | label: compare-ihr +ihr_values = [0.005, 0.025] +infections_decay = 3000 * jnp.exp(-jnp.arange(n_days) / 20.0) + +results = {} +for ihr_val in ihr_values: + process = Counts( + ascertainment_rate_rv=DeterministicVariable("ihr", ihr_val), + delay_distribution_rv=delay_rv, + noise=NegativeBinomialNoise(concentration_rv), + ) + # Average over multiple samples to show the effect on the mean + samples = [] + for seed in range(20): + with numpyro.handlers.seed(rng_seed=seed): + sample = process.sample(infections=infections_decay, counts=None) + samples.append(np.array(sample[day_one:])) + results[ihr_val] = np.mean(samples, axis=0) +``` + +```{python} +# | label: fig-ihr-comparison +# | fig-cap: Effect of different infection-hospitalization rates +fig, ax = plt.subplots(figsize=(8, 4)) +colors = ["steelblue", "darkred"] + +for (ihr_val, mean_sample), color in zip(results.items(), colors): + ax.plot( + np.arange(n_plot_days), + mean_sample, + color=color, + linewidth=1.5, + label=f"IHR = {ihr_val:.1%}", + ) + +ax.set_xlabel("Day") +ax.set_ylabel("Hospital Admissions (mean of samples)") +ax.set_title("Effect of IHR on Hospital Admissions") +ax.legend() +plt.tight_layout() +plt.show() +``` + +## Negative binomial concentration parameter + +The concentration parameter $\phi$ controls overdispersion: + +- Higher $\phi$ → less overdispersion (approaches Poisson) +- Lower $\phi$ → more overdispersion (noisier data) + +```{python} +# | label: fig-concentration-comparison +# | fig-cap: Effect of concentration parameter on variability +infections_constant = 2000 * jnp.ones(n_days) +concentration_values = [1.0, 10.0, 100.0] +n_replicates = 10 + +fig, axes = plt.subplots(1, 3, figsize=(12, 4), sharey=True) + +for ax, conc_val in zip(axes, concentration_values): + process = Counts( + ascertainment_rate_rv=ihr_rv, + delay_distribution_rv=delay_rv, + noise=NegativeBinomialNoise(DeterministicVariable("conc", conc_val)), + ) + + for seed in range(n_replicates): + with numpyro.handlers.seed(rng_seed=seed): + sample = process.sample( + infections=infections_constant, counts=None + ) + ax.plot( + np.arange(n_plot_days), + np.array(sample[day_one:]), + color="steelblue", + alpha=0.5, + linewidth=0.8, + ) + + ax.set_xlabel("Day") + ax.set_title(f"φ = {int(conc_val)}") + +axes[0].set_ylabel("Hospital Admissions") +fig.suptitle("Effect of Concentration Parameter on Variability") +plt.tight_layout() +plt.show() +``` + +## Swapping noise models + +To use Poisson noise instead of negative binomial, change the noise model: + +```{python} +# | label: poisson-noise +hosp_process_poisson = Counts( + ascertainment_rate_rv=ihr_rv, + delay_distribution_rv=delay_rv, + noise=PoissonNoise(), +) + +with numpyro.handlers.seed(rng_seed=42): + poisson_sample = hosp_process_poisson.sample( + infections=infections, + counts=None, + ) + +print( + f"Sampled {len(poisson_sample)} days of hospital admissions with Poisson noise" +) +``` diff --git a/docs/tutorials/observation_processes_measurements.qmd b/docs/tutorials/observation_processes_measurements.qmd new file mode 100644 index 00000000..96138d0a --- /dev/null +++ b/docs/tutorials/observation_processes_measurements.qmd @@ -0,0 +1,569 @@ +--- +title: "Observation processes for continuous measurements" +format: gfm +engine: jupyter +--- + +This tutorial demonstrates how to use the `Measurements` observation process to model continuous measurement data such as wastewater viral concentrations. + +```{python} +# | label: setup +# | output: false +import jax +import jax.numpy as jnp +import numpy as np +import numpyro +import matplotlib.pyplot as plt + +from pyrenew.observation import Measurements, HierarchicalNormalNoise +from pyrenew.randomvariable import HierarchicalNormalPrior, GammaGroupSdPrior +from pyrenew.deterministic import DeterministicVariable, DeterministicPMF +``` + +## Overview + +Measurement observation processes model continuous signals derived from infections, such as viral RNA concentrations in wastewater. Unlike count observations (hospital admissions, deaths), measurements are continuous values that span orders of magnitude. + +The expected measurement on day $t$ is: + +$$\lambda_t = \frac{G}{V} \cdot \sum_{d=0}^{D} I_{t-d} \cdot p_d$$ + +where: + +- $I_{t-d}$ is the number of incident (new) infections on day $t-d$ +- $G$ is a scaling factor (e.g., genome copies shed per infection) +- $V$ is a normalization factor (e.g., wastewater volume per person per day) +- $p_d$ is the temporal distribution (e.g., viral shedding kinetics) +- $D$ is the maximum duration + +Observed log-concentrations are generated by sampling from a normal distribution: + +$$\log(C_t) \sim \text{Normal}(\mu = \log(\lambda_t), \sigma)$$ + +The log-normal distribution is appropriate for concentration data because concentrations are strictly positive and often span several orders of magnitude. + +### Comparison with count observations + +The core convolution $\sum_{d=0}^{D} I_{t-d} \cdot p_d$ is the same as for count observations. The key differences are: + +| Aspect | Counts | Measurements | +|--------|--------|--------------| +| Scaling factor | Ascertainment rate $\alpha \in [0,1]$ | Domain-specific (e.g., $G/V$) | +| Temporal PMF | Delay distribution | Shedding/decay kinetics | +| Output space | Expected counts (linear) | Log-concentrations | +| Noise model | Poisson or Negative Binomial | Normal on log scale | +| Site structure | Optional (`CountsBySite`) | Inherent (hierarchical effects) | + +**Key features of measurement data:** + +- **Multiple sites**: Each jurisdiction has several measurement sites (e.g., wastewater treatment plants). +- **Irregular sampling**: Sites measure on different schedules. +- **Site-level variability**: Lab protocols and sampling methods vary across sites. +- **Temporal lag**: Signal peaks several days after infection. + +**Note on terminology:** In real-world inference, infections are *latent* (unobserved) and must be estimated from observed data. In this tutorial, we simulate the observation process by specifying infections directly and showing how they produce measurements through convolution and sampling. + +## Subclassing Measurements for wastewater + +The `Measurements` class is abstract—you must subclass it and implement `_expected_signal()` for your specific signal type. Here we create a `Wastewater` class for viral concentration measurements: + +```{python} +# | label: wastewater-class +from jax.typing import ArrayLike +from pyrenew.metaclass import RandomVariable +from pyrenew.observation.noise import MeasurementNoise + + +class Wastewater(Measurements): + """ + Wastewater viral concentration observation process. + + Transforms site-level infections into expected log-concentrations + via shedding kinetics convolution and genome/volume scaling. + """ + + def __init__( + self, + shedding_kinetics_rv: RandomVariable, + log10_genome_per_infection_rv: RandomVariable, + ml_per_person_per_day: float, + noise: MeasurementNoise, + ) -> None: + """ + Initialize wastewater observation process. + + Parameters + ---------- + shedding_kinetics_rv : RandomVariable + Viral shedding PMF (fraction shed each day post-infection). + log10_genome_per_infection_rv : RandomVariable + Log10 genome copies shed per infection. + ml_per_person_per_day : float + Wastewater volume per person per day (mL). + noise : MeasurementNoise + Noise model (e.g., HierarchicalNormalNoise). + """ + super().__init__(temporal_pmf_rv=shedding_kinetics_rv, noise=noise) + self.log10_genome_per_infection_rv = log10_genome_per_infection_rv + self.ml_per_person_per_day = ml_per_person_per_day + + def validate(self) -> None: + """Validate parameters.""" + shedding_pmf = self.temporal_pmf_rv() + self._validate_pmf(shedding_pmf, "shedding_kinetics_rv") + self.noise.validate() + + def get_required_lookback(self) -> int: + """Return shedding PMF length.""" + return len(self.temporal_pmf_rv()) + + def _expected_signal(self, infections: ArrayLike) -> ArrayLike: + """ + Compute expected log-concentration from infections. + + Applies shedding kinetics convolution, then scales by + genome copies and volume to get concentration. + """ + shedding_pmf = self.temporal_pmf_rv() + log10_genome = self.log10_genome_per_infection_rv() + + # Convolve each site's infections with shedding kinetics + def convolve_site(site_infections): + convolved, _ = self._convolve_with_alignment( + site_infections, shedding_pmf, p_observed=1.0 + ) + return convolved + + # Apply to all sites (infections shape: n_days x n_sites) + shedding_signal = jax.vmap(convolve_site, in_axes=1, out_axes=1)( + infections + ) + + # Convert to concentration: genomes per mL + genome_copies = 10**log10_genome + concentration = ( + shedding_signal * genome_copies / self.ml_per_person_per_day + ) + + # Return log-concentration (what we model) + return jnp.log(concentration) +``` + +## Viral shedding kinetics + +The shedding PMF describes what fraction of total viral shedding occurs on each day after infection: + +```{python} +# | label: shedding-pmf +# Peak shedding ~3 days after infection, continues for ~10 days +shedding_pmf = jnp.array( + [0.0, 0.05, 0.15, 0.25, 0.20, 0.15, 0.10, 0.05, 0.03, 0.02] +) +print(f"PMF sums to: {shedding_pmf.sum():.2f}") + +shedding_rv = DeterministicPMF("viral_shedding", shedding_pmf) + +# Summary statistics +days = np.arange(len(shedding_pmf)) +mean_shedding_day = float(np.sum(days * shedding_pmf)) +mode_shedding_day = int(np.argmax(shedding_pmf)) +print(f"Mode: {mode_shedding_day} days, Mean: {mean_shedding_day:.1f} days") +``` + +```{python} +# | label: fig-shedding +# | fig-cap: Viral shedding kinetics distribution +fig, ax = plt.subplots(figsize=(8, 4)) +ax.bar(days, shedding_pmf, color="steelblue", alpha=0.7, edgecolor="black") +ax.axvline( + mode_shedding_day, + color="purple", + linestyle="-", + linewidth=2, + label=f"Mode: {mode_shedding_day}", +) +ax.axvline( + mean_shedding_day, + color="red", + linestyle="--", + linewidth=2, + label=f"Mean: {mean_shedding_day:.1f}", +) +ax.set_xlabel("Days after infection") +ax.set_ylabel("Fraction of total shedding") +ax.set_title("Viral Shedding Kinetics") +ax.legend() +plt.tight_layout() +plt.show() +``` + +## Genome copies and wastewater volume + +```{python} +# | label: scaling-params +# Log10 genome copies shed per infection (typical: 8-10) +log10_genome_rv = DeterministicVariable("log10_genome", 9.0) + +# Wastewater volume per person per day (mL) +ml_per_person_per_day = 1000.0 +``` + +## Noise model with site-level effects + +Different sites have systematic differences in measured concentrations due to lab protocols, sampling methods, and flow estimation. We model this with hierarchical site effects: + +```{python} +# | label: noise-model +# Site-level mode: systematic differences between sites +site_mode_prior = HierarchicalNormalPrior( + name="ww_site_mode", + sd_rv=DeterministicVariable("mode_sd", 0.5), +) + +# Site-level SD: measurement variability within each site +site_sd_prior = GammaGroupSdPrior( + name="ww_site_sd", + sd_mean_rv=DeterministicVariable("sd_mean", 0.3), + sd_concentration_rv=DeterministicVariable("sd_concentration", 4.0), + sd_min=0.10, +) + +# Create the noise model +ww_noise = HierarchicalNormalNoise( + site_mode_prior_rv=site_mode_prior, + site_sd_prior_rv=site_sd_prior, +) +``` + +## Creating the observation process + +```{python} +# | label: create-process +ww_process = Wastewater( + shedding_kinetics_rv=shedding_rv, + log10_genome_per_infection_rv=log10_genome_rv, + ml_per_person_per_day=ml_per_person_per_day, + noise=ww_noise, +) + +print(f"Required lookback: {ww_process.get_required_lookback()} days") +``` + +### Timeline alignment and lookback period + +The observation process convolves infections with a shedding distribution, maintaining alignment between input and output: day $t$ in the output corresponds to day $t$ in the input. + +Wastewater concentrations depend on infections from prior days (the length of our shedding distribution minus one). The method `get_required_lookback()` returns this length; the first valid observation day is at index `lookback - 1`. Earlier days are marked invalid with NaN. + +```{python} +# | label: helper-function +def first_valid_observation_day(obs_process) -> int: + """Return the first day index with complete infection history for convolution.""" + return obs_process.get_required_lookback() - 1 +``` + +## Simulating wastewater observations from infections + +To demonstrate how the observation process works, we simulate concentrations from a spike of infections on a single day. + +```{python} +# | label: simulate-spike +n_days = 50 + +# First valid observation day +day_one = first_valid_observation_day(ww_process) + +# Create infections with a spike (shape: n_days x n_sites) +infection_spike_day = day_one + 10 +infections = jnp.zeros((n_days, 1)) # 1 subpopulation +infections = infections.at[infection_spike_day, 0].set(2000.0) + +# For plotting +spike_day = infection_spike_day - day_one +n_plot_days = n_days - day_one + +# Observation times and indices +observation_days = jnp.arange(day_one, 40, dtype=jnp.int32) +n_obs = len(observation_days) + +with numpyro.handlers.seed(rng_seed=42): + ww_obs = ww_process.sample( + infections=infections, + subpop_indices=jnp.zeros(n_obs, dtype=jnp.int32), + site_indices=jnp.zeros(n_obs, dtype=jnp.int32), + times=observation_days, + concentrations=None, # Sample from prior + n_sites=1, + ) +``` + +```{python} +# | label: fig-spike-infections +# | fig-cap: Input infections with a single-day spike +fig, ax = plt.subplots(figsize=(8, 4)) +ax.plot( + np.arange(n_plot_days), + np.array(infections[day_one:, 0]), + "-o", + color="darkblue", +) +ax.axvline(spike_day, color="darkred", linestyle="--", alpha=0.7) +ax.annotate( + f"Infection spike\n(day {spike_day})", + xy=(spike_day, 1800), + xytext=(spike_day + 3, 1800), + fontsize=10, + color="darkred", +) +ax.set_xlabel("Day") +ax.set_ylabel("Daily Infections") +ax.set_title("Infections (Input)") +plt.tight_layout() +plt.show() +``` + +Because all infections occur on a single day, we can see how they spread into wastewater concentrations over subsequent days according to the shedding kinetics. + +## Observation noise + +The log-normal noise model adds stochastic variation. Sampling multiple times from the same infections shows the range of possible observations: + +```{python} +# | label: sample-realizations +n_samples = 50 +samples_by_day = {int(d - day_one): [] for d in observation_days} + +for seed in range(n_samples): + with numpyro.handlers.seed(rng_seed=seed): + ww_sample = ww_process.sample( + infections=infections, + subpop_indices=jnp.zeros(n_obs, dtype=jnp.int32), + site_indices=jnp.zeros(n_obs, dtype=jnp.int32), + times=observation_days, + concentrations=None, + n_sites=1, + ) + for day_idx, conc in zip(observation_days, ww_sample): + samples_by_day[int(day_idx) - day_one].append(float(conc)) + +# Convert to arrays for plotting +plot_days = sorted(samples_by_day.keys()) +all_samples = np.array( + [[samples_by_day[d][i] for d in plot_days] for i in range(n_samples)] +) +sample_mean = all_samples.mean(axis=0) +``` + +```{python} +# | label: fig-sampled-concentrations +# | fig-cap: Multiple realizations showing observation noise +fig, ax = plt.subplots(figsize=(8, 4)) + +# Plot all samples +for i in range(1, n_samples): + ax.plot( + plot_days, all_samples[i], color="orange", alpha=0.15, linewidth=0.5 + ) + +# Highlight one sample +ax.plot( + plot_days, + all_samples[0], + color="steelblue", + linewidth=1, + label="One realization", +) + +# Sample mean +ax.plot( + plot_days, sample_mean, color="darkred", linewidth=1.2, label="Sample mean" +) + +ax.axvline(spike_day, color="darkblue", linestyle="--", alpha=0.5) +ax.set_xlabel("Day") +ax.set_ylabel("Log Viral Concentration") +ax.set_title(f"Observation Noise: {n_samples} Samples from Same Infections") +ax.legend() +plt.tight_layout() +plt.show() +``` + +```{python} +# | label: timeline-stats +print("Timeline Analysis:") +print( + f" Infection spike on day {spike_day}: {infections[infection_spike_day, 0]:.0f} people" +) +print( + f" Mode delay from infection to concentration peak: {mode_shedding_day} days" +) +print( + f" Expected concentration peak: day {spike_day + mode_shedding_day} (= {spike_day} + {mode_shedding_day})" +) +``` + +## Site-level variability + +Different sites measuring the same underlying infections will show systematic differences: + +```{python} +# | label: multi-site +n_sites = 4 +infections_constant = jnp.ones((40, 1)) * 1500.0 + +# Each site samples at multiple time points +times_per_site = 10 +observation_times = jnp.tile(jnp.arange(10, 30, 2, dtype=jnp.int32), n_sites) +site_ids = jnp.repeat(jnp.arange(n_sites, dtype=jnp.int32), times_per_site) +subpop_ids = jnp.zeros(n_sites * times_per_site, dtype=jnp.int32) + +with numpyro.handlers.seed(rng_seed=42): + ww_multi_site = ww_process.sample( + infections=infections_constant, + subpop_indices=subpop_ids, + site_indices=site_ids, + times=observation_times, + concentrations=None, + n_sites=n_sites, + ) +``` + +```{python} +# | label: fig-multi-site +# | fig-cap: Multiple sites observing the same infections +fig, ax = plt.subplots(figsize=(8, 4)) + +colors = ["steelblue", "coral", "green", "purple"] +for site in range(n_sites): + mask = np.array(site_ids) == site + ax.plot( + np.array(observation_times)[mask], + np.array(ww_multi_site)[mask], + "-o", + color=colors[site], + label=f"Site {site}", + ) + +ax.set_xlabel("Day") +ax.set_ylabel("Log Viral Concentration") +ax.set_title("Multiple Sites Observing Same Infections") +ax.legend() +plt.tight_layout() +plt.show() +``` + +Each site has a different baseline concentration (site mode) and different measurement variability (site SD). The hierarchical model learns these site-specific effects during inference. + +## Effect of genome shedding parameter + +Higher genome shedding means higher observed concentrations: + +```{python} +# | label: compare-genome +genome_values = [8.5, 9.0, 9.5, 10.0] +infections_decay = 2000.0 * jnp.exp(-jnp.arange(40) / 15.0) +infections_decay = infections_decay.reshape(-1, 1) + +observation_days_cmp = jnp.arange(10, 30, dtype=jnp.int32) +n_obs_cmp = len(observation_days_cmp) + +results = {} +for genome_val in genome_values: + process_temp = Wastewater( + shedding_kinetics_rv=shedding_rv, + log10_genome_per_infection_rv=DeterministicVariable( + "log10_genome", genome_val + ), + ml_per_person_per_day=ml_per_person_per_day, + noise=ww_noise, + ) + + with numpyro.handlers.seed(rng_seed=42): + ww_temp = process_temp.sample( + infections=infections_decay, + subpop_indices=jnp.zeros(n_obs_cmp, dtype=jnp.int32), + site_indices=jnp.zeros(n_obs_cmp, dtype=jnp.int32), + times=observation_days_cmp, + concentrations=None, + n_sites=1, + ) + results[genome_val] = np.array(ww_temp) +``` + +```{python} +# | label: fig-genome-effect +# | fig-cap: Effect of genome shedding parameter on concentrations +fig, ax = plt.subplots(figsize=(8, 4)) + +colors = ["steelblue", "coral", "green", "purple"] +for (genome_val, conc), color in zip(results.items(), colors): + ax.plot( + np.array(observation_days_cmp), + conc, + "-o", + color=color, + label=f"log10 = {genome_val}", + markersize=4, + ) + +ax.set_xlabel("Day") +ax.set_ylabel("Log Viral Concentration") +ax.set_title("Effect of Genome Shedding Parameter") +ax.legend() +plt.tight_layout() +plt.show() +``` + +## Multiple subpopulations + +In hierarchical models, each measurement site serves a distinct subpopulation. Different subpopulations can have different infection levels: + +```{python} +# | label: multi-subpop +# Two subpopulations with different infection patterns +n_days_mp = 40 +infections_subpop1 = 1000.0 * jnp.exp( + -jnp.arange(n_days_mp) / 20.0 +) # Slow decay +infections_subpop2 = 2000.0 * jnp.exp( + -jnp.arange(n_days_mp) / 10.0 +) # Fast decay +infections_multi = jnp.stack([infections_subpop1, infections_subpop2], axis=1) + +# Two sites, each observing a different subpopulation +obs_days_mp = jnp.tile(jnp.arange(10, 30, 2, dtype=jnp.int32), 2) +subpop_ids_mp = jnp.array([0] * 10 + [1] * 10, dtype=jnp.int32) +site_ids_mp = jnp.array([0] * 10 + [1] * 10, dtype=jnp.int32) + +with numpyro.handlers.seed(rng_seed=42): + ww_multi_subpop = ww_process.sample( + infections=infections_multi, + subpop_indices=subpop_ids_mp, + site_indices=site_ids_mp, + times=obs_days_mp, + concentrations=None, + n_sites=2, + ) +``` + +```{python} +# | label: fig-multi-subpop +# | fig-cap: Sites observing different subpopulations with different infection dynamics +fig, ax = plt.subplots(figsize=(8, 4)) + +for subpop in [0, 1]: + mask = np.array(subpop_ids_mp) == subpop + ax.plot( + np.array(obs_days_mp)[mask], + np.array(ww_multi_subpop)[mask], + "-o", + label=f"Subpop {subpop}", + ) + +ax.set_xlabel("Day") +ax.set_ylabel("Log Viral Concentration") +ax.set_title("Multiple Subpopulations with Different Infection Dynamics") +ax.legend() +plt.tight_layout() +plt.show() +``` diff --git a/pyrenew/observation/__init__.py b/pyrenew/observation/__init__.py index 46cfbb67..dec3fadf 100644 --- a/pyrenew/observation/__init__.py +++ b/pyrenew/observation/__init__.py @@ -1,4 +1,23 @@ # numpydoc ignore=GL08 +""" +Observation processes for connecting infections to observed data. + +Architecture +------------ +``BaseObservationProcess`` is the abstract base. Concrete subclasses: + +- ``Counts``: Jurisdiction-level counts (admissions, deaths) +- ``CountsBySite``: Site-specific disaggregated counts +- ``Measurements``: Continuous site-level signals (e.g., wastewater) + +All observation processes implement: + +- ``_expected_signal(infections)``: transforms infections to expected values +- ``sample()``: calls ``_expected_signal()`` then applies noise model + +Noise models (``CountNoise``, ``MeasurementNoise``) are composable—pass them +to observation constructors to control the output distribution. +""" from pyrenew.observation.base import BaseObservationProcess from pyrenew.observation.count_observations import Counts, CountsBySite diff --git a/pyrenew/observation/base.py b/pyrenew/observation/base.py index d84125a1..5ee0c7e5 100644 --- a/pyrenew/observation/base.py +++ b/pyrenew/observation/base.py @@ -34,6 +34,14 @@ class BaseObservationProcess(RandomVariable): - Convolution wrapper with timeline alignment - Deterministic quantity tracking + Subclasses must implement: + + - ``validate()``: Validate parameters (call ``_validate_pmf()`` for PMFs) + - ``get_required_lookback()``: Return PMF length for initialization + - ``infection_resolution()``: Return ``"jurisdiction"`` or ``"site"`` + - ``_expected_signal()``: Transform infections to expected values + - ``sample()``: Apply noise model to expected signal + Notes ----- Computing expected observations on day t requires infection history diff --git a/pyrenew/observation/noise.py b/pyrenew/observation/noise.py index 7264b726..10f1e337 100644 --- a/pyrenew/observation/noise.py +++ b/pyrenew/observation/noise.py @@ -4,6 +4,18 @@ Provides composable noise strategies for count and measurement observations, separating the noise distribution from the observation structure. + +Count Noise +----------- +- ``PoissonNoise``: Equidispersed counts (variance = mean). No parameters. +- ``NegativeBinomialNoise``: Overdispersed counts (variance > mean). + Takes ``concentration_rv`` (higher = less overdispersion). + +Measurement Noise +----------------- +- ``HierarchicalNormalNoise``: Normal noise with hierarchical site effects. + Takes ``site_mode_prior_rv`` and ``site_sd_prior_rv`` for site-level + bias and variability. """ from __future__ import annotations From a2e4630431e85a0f4b67fa183f26284d1370351d Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Tue, 23 Dec 2025 12:07:52 -0500 Subject: [PATCH 03/33] improve test coverage --- test/test_observation_counts.py | 163 +++++++++++++++++++++++ test/test_observation_measurements.py | 14 ++ test/test_observation_poisson.py | 7 + test/test_randomvariable_hierarchical.py | 25 ++++ 4 files changed, 209 insertions(+) diff --git a/test/test_observation_counts.py b/test/test_observation_counts.py index cd87d403..58f0b33e 100644 --- a/test/test_observation_counts.py +++ b/test/test_observation_counts.py @@ -16,6 +16,7 @@ NegativeBinomialNoise, PoissonNoise, ) +from pyrenew.observation.count_observations import _CountBase from pyrenew.randomvariable import DistributionalVariable @@ -374,5 +375,167 @@ def test_poisson_counts(self, simple_delay_pmf): assert jnp.all(counts >= 0) +class TestCountBaseInternalMethods: + """Test internal _CountBase methods for coverage.""" + + def test_count_base_infection_resolution_raises(self, simple_delay_pmf): + """Test that _CountBase.infection_resolution() raises NotImplementedError.""" + + # Create a subclass that doesn't override infection_resolution + class IncompleteCountProcess(_CountBase): + """Incomplete count process for testing.""" + + def sample(self, **kwargs): + """Sample method stub.""" + pass + + process = IncompleteCountProcess( + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), + ) + with pytest.raises( + NotImplementedError, match="Subclasses must implement infection_resolution" + ): + process.infection_resolution() + + +class TestValidationMethods: + """Test validation methods for coverage.""" + + def test_validate_calls_all_validations(self, simple_delay_pmf): + """Test that validate() calls all necessary validations.""" + process = Counts( + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), + ) + # Should not raise + process.validate() + + def test_validate_invalid_ascertainment_rate_negative(self, simple_delay_pmf): + """Test that validate raises for negative ascertainment rate.""" + process = Counts( + ascertainment_rate_rv=DeterministicVariable("ihr", -0.1), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), + ) + with pytest.raises(ValueError, match="ascertainment_rate_rv must be in"): + process.validate() + + def test_validate_invalid_ascertainment_rate_greater_than_one( + self, simple_delay_pmf + ): + """Test that validate raises for ascertainment rate > 1.""" + process = Counts( + ascertainment_rate_rv=DeterministicVariable("ihr", 1.5), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), + ) + with pytest.raises(ValueError, match="ascertainment_rate_rv must be in"): + process.validate() + + def test_get_required_lookback(self, simple_delay_pmf, long_delay_pmf): + """Test get_required_lookback returns PMF length.""" + process_short = Counts( + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), + ) + assert process_short.get_required_lookback() == 1 + + process_long = Counts( + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", long_delay_pmf), + noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), + ) + assert process_long.get_required_lookback() == 10 + + def test_infection_resolution_counts(self, simple_delay_pmf): + """Test that Counts returns 'jurisdiction' resolution.""" + process = Counts( + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), + ) + assert process.infection_resolution() == "jurisdiction" + + +class TestNoiseValidation: + """Test noise model validation methods.""" + + def test_poisson_noise_validate(self): + """Test PoissonNoise validate method.""" + noise = PoissonNoise() + # Should not raise - Poisson has no parameters to validate + noise.validate() + + def test_negative_binomial_noise_validate_success(self): + """Test NegativeBinomialNoise validate with valid concentration.""" + noise = NegativeBinomialNoise(DeterministicVariable("conc", 10.0)) + # Should not raise + noise.validate() + + def test_negative_binomial_noise_validate_zero_concentration(self): + """Test NegativeBinomialNoise validate with zero concentration.""" + noise = NegativeBinomialNoise(DeterministicVariable("conc", 0.0)) + with pytest.raises(ValueError, match="concentration must be positive"): + noise.validate() + + def test_negative_binomial_noise_validate_negative_concentration(self): + """Test NegativeBinomialNoise validate with negative concentration.""" + noise = NegativeBinomialNoise(DeterministicVariable("conc", -1.0)) + with pytest.raises(ValueError, match="concentration must be positive"): + noise.validate() + + +class TestBaseObservationProcessValidation: + """Test base observation process PMF validation.""" + + def test_validate_pmf_empty_array(self, simple_delay_pmf): + """Test that _validate_pmf raises for empty array.""" + process = Counts( + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), + ) + empty_pmf = jnp.array([]) + with pytest.raises(ValueError, match="must return non-empty array"): + process._validate_pmf(empty_pmf, "test_pmf") + + def test_validate_pmf_sum_not_one(self, simple_delay_pmf): + """Test that _validate_pmf raises for PMF not summing to 1.""" + process = Counts( + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), + ) + bad_pmf = jnp.array([0.3, 0.3, 0.3]) # sums to 0.9 + with pytest.raises(ValueError, match="must sum to 1.0"): + process._validate_pmf(bad_pmf, "test_pmf") + + def test_validate_pmf_negative_values(self, simple_delay_pmf): + """Test that _validate_pmf raises for negative values.""" + process = Counts( + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), + ) + bad_pmf = jnp.array([1.5, -0.5]) # sums to 1.0 but has negative + with pytest.raises(ValueError, match="must have non-negative values"): + process._validate_pmf(bad_pmf, "test_pmf") + + def test_get_minimum_observation_day(self): + """Test get_minimum_observation_day returns correct value.""" + delay_pmf = jnp.array([0.2, 0.5, 0.3]) # length 3 + process = Counts( + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", delay_pmf), + noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), + ) + # First valid day should be len(pmf) - 1 = 2 + assert process.get_minimum_observation_day() == 2 + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/test/test_observation_measurements.py b/test/test_observation_measurements.py index 2ca11a22..5e69fd3d 100644 --- a/test/test_observation_measurements.py +++ b/test/test_observation_measurements.py @@ -97,6 +97,20 @@ def test_infection_resolution_is_site(self): class TestHierarchicalNormalNoise: """Test HierarchicalNormalNoise model.""" + def test_validate(self): + """Test HierarchicalNormalNoise validate method.""" + site_mode_prior = HierarchicalNormalPrior( + name="mode", sd_rv=DeterministicVariable("mode_sd", 0.5) + ) + site_sd_prior = GammaGroupSdPrior( + name="sd", + sd_mean_rv=DeterministicVariable("sd_mean", 0.3), + sd_concentration_rv=DeterministicVariable("sd_conc", 4.0), + ) + noise = HierarchicalNormalNoise(site_mode_prior, site_sd_prior) + # Should not raise - validation is deferred to sample time + noise.validate() + def test_sample_shape(self): """Test that HierarchicalNormalNoise produces correct shape.""" site_mode_prior = HierarchicalNormalPrior( diff --git a/test/test_observation_poisson.py b/test/test_observation_poisson.py index b9d975be..8b9c0716 100644 --- a/test/test_observation_poisson.py +++ b/test/test_observation_poisson.py @@ -20,3 +20,10 @@ def test_poisson_obs(): sim_pois = pois(mu=rates) testing.assert_array_equal(sim_pois, jnp.ceil(sim_pois)) + + +def test_poisson_validate(): + """ + Check that PoissonObservation.validate() runs without error. + """ + PoissonObservation.validate() diff --git a/test/test_randomvariable_hierarchical.py b/test/test_randomvariable_hierarchical.py index 30384800..89a58c3a 100644 --- a/test/test_randomvariable_hierarchical.py +++ b/test/test_randomvariable_hierarchical.py @@ -46,6 +46,13 @@ def test_smaller_sd_produces_tighter_distribution(self): # Tight prior should have smaller standard deviation assert jnp.std(samples_tight) < jnp.std(samples_wide) + def test_validate(self): + """Test that validate() runs without error.""" + prior = HierarchicalNormalPrior( + "effect", sd_rv=DeterministicVariable("sd", 1.0) + ) + prior.validate() # Should not raise + def test_rejects_non_random_variable_sd(self): """Test that non-RandomVariable sd_rv is rejected.""" with pytest.raises(TypeError, match="sd_rv must be a RandomVariable"): @@ -120,6 +127,15 @@ def test_rejects_negative_sd_min(self): sd_min=-0.1, ) + def test_validate(self): + """Test that validate() runs without error.""" + prior = GammaGroupSdPrior( + "sd", + sd_mean_rv=DeterministicVariable("sd_mean", 0.5), + sd_concentration_rv=DeterministicVariable("sd_conc", 4.0), + ) + prior.validate() # Should not raise + class TestStudentTGroupModePrior: """Test StudentTGroupModePrior.""" @@ -174,6 +190,15 @@ def test_rejects_non_random_variable_params(self): df_rv=4.0, ) + def test_validate(self): + """Test that validate() runs without error.""" + prior = StudentTGroupModePrior( + "mode", + sd_rv=DeterministicVariable("sd", 1.0), + df_rv=DeterministicVariable("df", 4.0), + ) + prior.validate() # Should not raise + if __name__ == "__main__": pytest.main([__file__, "-v"]) From 24096bc824dd03ab52e9e24cacda1a7782954900 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Tue, 23 Dec 2025 12:27:17 -0500 Subject: [PATCH 04/33] improve unit test coverage --- pyrenew/observation/base.py | 10 +++++----- pyrenew/observation/noise.py | 8 ++++---- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/pyrenew/observation/base.py b/pyrenew/observation/base.py index 5ee0c7e5..8468566a 100644 --- a/pyrenew/observation/base.py +++ b/pyrenew/observation/base.py @@ -88,7 +88,7 @@ def validate(self) -> None: ValueError If any parameters fail validation. """ - pass + pass # pragma: no cover @abstractmethod def get_required_lookback(self) -> int: @@ -110,7 +110,7 @@ def get_required_lookback(self) -> int: n_initialization_points as: ``max(gen_int_length, max(all lookbacks)) - 1`` """ - pass + pass # pragma: no cover @abstractmethod def infection_resolution(self) -> str: @@ -140,7 +140,7 @@ def infection_resolution(self) -> str: This is used by multi-signal models to route the correct infection output to each observation process. """ - pass + pass # pragma: no cover def _validate_pmf( self, @@ -305,7 +305,7 @@ def _expected_signal( -------- sample : Uses this method then applies noise model """ - pass + pass # pragma: no cover @abstractmethod def sample(self, **kwargs) -> ArrayLike: @@ -329,4 +329,4 @@ def sample(self, **kwargs) -> ArrayLike: ArrayLike Observed or sampled values from the observation process. """ - pass + pass # pragma: no cover diff --git a/pyrenew/observation/noise.py b/pyrenew/observation/noise.py index 10f1e337..4065061e 100644 --- a/pyrenew/observation/noise.py +++ b/pyrenew/observation/noise.py @@ -63,7 +63,7 @@ def sample( ArrayLike Sampled or conditioned counts, same shape as expected. """ - pass + pass # pragma: no cover @abstractmethod def validate(self) -> None: @@ -75,7 +75,7 @@ def validate(self) -> None: ValueError If parameters are invalid. """ - pass + pass # pragma: no cover class PoissonNoise(CountNoise): @@ -235,7 +235,7 @@ def sample( ArrayLike Sampled or conditioned measurements, same shape as expected. """ - pass + pass # pragma: no cover @abstractmethod def validate(self) -> None: @@ -247,7 +247,7 @@ def validate(self) -> None: ValueError If parameters are invalid. """ - pass + pass # pragma: no cover class HierarchicalNormalNoise(MeasurementNoise): From 671d9d0075234fee8640e323ea4f46314090fa12 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Fri, 26 Dec 2025 16:05:05 -0500 Subject: [PATCH 05/33] consistent names: 'subpop' (not site), 'sensor' for site/lab measurements, 'aggregate' instead of 'jurisdiction' --- .../observation_processes_counts.qmd | 28 ++--- .../observation_processes_measurements.qmd | 108 ++++++++-------- pyrenew/observation/__init__.py | 23 ++-- pyrenew/observation/base.py | 28 +++-- pyrenew/observation/count_observations.py | 93 ++++++++------ pyrenew/observation/measurements.py | 57 +++++---- pyrenew/observation/noise.py | 82 +++++++----- test/conftest.py | 20 +-- test/test_observation_counts.py | 118 +++++++++--------- test/test_observation_measurements.py | 77 ++++++------ 10 files changed, 345 insertions(+), 289 deletions(-) diff --git a/docs/tutorials/observation_processes_counts.qmd b/docs/tutorials/observation_processes_counts.qmd index 83279331..b3b415c7 100644 --- a/docs/tutorials/observation_processes_counts.qmd +++ b/docs/tutorials/observation_processes_counts.qmd @@ -22,7 +22,7 @@ from pyrenew import datasets ## Overview Count observation processes model the lag between infections and an observed outcome such as hospital admissions, emergency department visits, confirmed cases, or deaths. -Observed data can be aggregated at the jurisdiction level or available as disaggregated counts from reporting entities, which are modeled by classes `Counts` and `CountsBySite`, respectively. +Observed data can be aggregated or available as subpopulation-level counts, which are modeled by classes `Counts` and `CountsBySubpop`, respectively. Count observation processes transform infections into expected observed counts by applying an ascertainment rate and convolving with a delay distribution. @@ -129,20 +129,20 @@ hosp_process = Counts( noise=NegativeBinomialNoise(concentration_rv), ) -print(f"Required lookback: {hosp_process.get_required_lookback()} days") +print(f"Required lookback: {hosp_process.lookback_days()} days") ``` ### Timeline alignment and lookback period The observation process convolves infections with a delay distribution, maintaining alignment between input and output: day $t$ in the output corresponds to day $t$ in the input. -Hospital admissions depend on infections from prior days (the length of our delay distribution minus one). The method `get_required_lookback()` returns this length; the first valid observation day is at index `lookback - 1`. Earlier days are marked invalid. +Hospital admissions depend on infections from prior days (the length of our delay distribution minus one). The method `lookback_days()` returns this length; the first valid observation day is at index `lookback - 1`. Earlier days are marked invalid. ```{python} # | label: helper-function def first_valid_observation_day(obs_process) -> int: """Return the first day index with complete infection history for convolution.""" - return obs_process.get_required_lookback() - 1 + return obs_process.lookback_days() - 1 ``` ## Simulating hospital admissions from infections @@ -152,7 +152,7 @@ To demonstrate how the observation process works, we simulate admissions from a ```{python} # | label: simulate-spike n_days = 100 -lookback = hosp_process.get_required_lookback() +lookback = hosp_process.lookback_days() # First valid observation day day_one = lookback - 1 @@ -207,7 +207,7 @@ Because all infections occur on a single day, we can see how they spread into ho fig, ax = plt.subplots(figsize=(8, 4)) ax.plot( np.arange(n_plot_days), - np.array(hosp_admissions[day_one:]), + np.array(hosp_admissions.observed[day_one:]), "-o", color="purple", ) @@ -246,8 +246,8 @@ samples = [] for seed in range(n_samples): with numpyro.handlers.seed(rng_seed=seed): - sample = hosp_process.sample(infections=infections, counts=None) - samples.append(np.array(sample[day_one:])) + result = hosp_process.sample(infections=infections, counts=None) + samples.append(np.array(result.observed[day_one:])) samples = np.array(samples) sample_mean = samples.mean(axis=0) @@ -328,8 +328,8 @@ for ihr_val in ihr_values: samples = [] for seed in range(20): with numpyro.handlers.seed(rng_seed=seed): - sample = process.sample(infections=infections_decay, counts=None) - samples.append(np.array(sample[day_one:])) + result = process.sample(infections=infections_decay, counts=None) + samples.append(np.array(result.observed[day_one:])) results[ihr_val] = np.mean(samples, axis=0) ``` @@ -381,12 +381,12 @@ for ax, conc_val in zip(axes, concentration_values): for seed in range(n_replicates): with numpyro.handlers.seed(rng_seed=seed): - sample = process.sample( + result = process.sample( infections=infections_constant, counts=None ) ax.plot( np.arange(n_plot_days), - np.array(sample[day_one:]), + np.array(result.observed[day_one:]), color="steelblue", alpha=0.5, linewidth=0.8, @@ -414,12 +414,12 @@ hosp_process_poisson = Counts( ) with numpyro.handlers.seed(rng_seed=42): - poisson_sample = hosp_process_poisson.sample( + poisson_result = hosp_process_poisson.sample( infections=infections, counts=None, ) print( - f"Sampled {len(poisson_sample)} days of hospital admissions with Poisson noise" + f"Sampled {len(poisson_result.observed)} days of hospital admissions with Poisson noise" ) ``` diff --git a/docs/tutorials/observation_processes_measurements.qmd b/docs/tutorials/observation_processes_measurements.qmd index 96138d0a..613e1176 100644 --- a/docs/tutorials/observation_processes_measurements.qmd +++ b/docs/tutorials/observation_processes_measurements.qmd @@ -52,7 +52,7 @@ The core convolution $\sum_{d=0}^{D} I_{t-d} \cdot p_d$ is the same as for count | Temporal PMF | Delay distribution | Shedding/decay kinetics | | Output space | Expected counts (linear) | Log-concentrations | | Noise model | Poisson or Negative Binomial | Normal on log scale | -| Site structure | Optional (`CountsBySite`) | Inherent (hierarchical effects) | +| Subpop structure | Optional (`CountsBySubpop`) | Inherent (hierarchical effects) | **Key features of measurement data:** @@ -113,7 +113,7 @@ class Wastewater(Measurements): self._validate_pmf(shedding_pmf, "shedding_kinetics_rv") self.noise.validate() - def get_required_lookback(self) -> int: + def lookback_days(self) -> int: """Return shedding PMF length.""" return len(self.temporal_pmf_rv()) @@ -134,7 +134,7 @@ class Wastewater(Measurements): ) return convolved - # Apply to all sites (infections shape: n_days x n_sites) + # Apply to all subpops (infections shape: n_days x n_subpops) shedding_signal = jax.vmap(convolve_site, in_axes=1, out_axes=1)( infections ) @@ -208,21 +208,21 @@ log10_genome_rv = DeterministicVariable("log10_genome", 9.0) ml_per_person_per_day = 1000.0 ``` -## Noise model with site-level effects +## Noise model with sensor-level effects -Different sites have systematic differences in measured concentrations due to lab protocols, sampling methods, and flow estimation. We model this with hierarchical site effects: +A measurement depends on both the thing being measured (e.g., wastewater from a treatment plant) and the calibration of the lab/instruments used to obtain that measurement. We call this combination a "sensor"—the WWTP/lab pair that determines the measurement characteristics. Different sensors have systematic biases and variabilities that we model with hierarchical effects. ```{python} # | label: noise-model -# Site-level mode: systematic differences between sites -site_mode_prior = HierarchicalNormalPrior( - name="ww_site_mode", +# Sensor-level mode: systematic differences between sensors +sensor_mode_prior = HierarchicalNormalPrior( + name="ww_sensor_mode", sd_rv=DeterministicVariable("mode_sd", 0.5), ) -# Site-level SD: measurement variability within each site -site_sd_prior = GammaGroupSdPrior( - name="ww_site_sd", +# Sensor-level SD: measurement variability within each sensor +sensor_sd_prior = GammaGroupSdPrior( + name="ww_sensor_sd", sd_mean_rv=DeterministicVariable("sd_mean", 0.3), sd_concentration_rv=DeterministicVariable("sd_concentration", 4.0), sd_min=0.10, @@ -230,8 +230,8 @@ site_sd_prior = GammaGroupSdPrior( # Create the noise model ww_noise = HierarchicalNormalNoise( - site_mode_prior_rv=site_mode_prior, - site_sd_prior_rv=site_sd_prior, + sensor_mode_prior_rv=sensor_mode_prior, + sensor_sd_prior_rv=sensor_sd_prior, ) ``` @@ -246,20 +246,20 @@ ww_process = Wastewater( noise=ww_noise, ) -print(f"Required lookback: {ww_process.get_required_lookback()} days") +print(f"Required lookback: {ww_process.lookback_days()} days") ``` ### Timeline alignment and lookback period The observation process convolves infections with a shedding distribution, maintaining alignment between input and output: day $t$ in the output corresponds to day $t$ in the input. -Wastewater concentrations depend on infections from prior days (the length of our shedding distribution minus one). The method `get_required_lookback()` returns this length; the first valid observation day is at index `lookback - 1`. Earlier days are marked invalid with NaN. +Wastewater concentrations depend on infections from prior days (the length of our shedding distribution minus one). The method `lookback_days()` returns this length; the first valid observation day is at index `lookback - 1`. Earlier days are marked invalid with NaN. ```{python} # | label: helper-function def first_valid_observation_day(obs_process) -> int: """Return the first day index with complete infection history for convolution.""" - return obs_process.get_required_lookback() - 1 + return obs_process.lookback_days() - 1 ``` ## Simulating wastewater observations from infections @@ -273,7 +273,7 @@ n_days = 50 # First valid observation day day_one = first_valid_observation_day(ww_process) -# Create infections with a spike (shape: n_days x n_sites) +# Create infections with a spike (shape: n_days x n_subpops) infection_spike_day = day_one + 10 infections = jnp.zeros((n_days, 1)) # 1 subpopulation infections = infections.at[infection_spike_day, 0].set(2000.0) @@ -290,10 +290,10 @@ with numpyro.handlers.seed(rng_seed=42): ww_obs = ww_process.sample( infections=infections, subpop_indices=jnp.zeros(n_obs, dtype=jnp.int32), - site_indices=jnp.zeros(n_obs, dtype=jnp.int32), + sensor_indices=jnp.zeros(n_obs, dtype=jnp.int32), times=observation_days, concentrations=None, # Sample from prior - n_sites=1, + n_sensors=1, ) ``` @@ -335,15 +335,15 @@ samples_by_day = {int(d - day_one): [] for d in observation_days} for seed in range(n_samples): with numpyro.handlers.seed(rng_seed=seed): - ww_sample = ww_process.sample( + ww_result = ww_process.sample( infections=infections, subpop_indices=jnp.zeros(n_obs, dtype=jnp.int32), - site_indices=jnp.zeros(n_obs, dtype=jnp.int32), + sensor_indices=jnp.zeros(n_obs, dtype=jnp.int32), times=observation_days, concentrations=None, - n_sites=1, + n_sensors=1, ) - for day_idx, conc in zip(observation_days, ww_sample): + for day_idx, conc in zip(observation_days, ww_result.observed): samples_by_day[int(day_idx) - day_one].append(float(conc)) # Convert to arrays for plotting @@ -402,57 +402,57 @@ print( ) ``` -## Site-level variability +## Sensor-level variability -Different sites measuring the same underlying infections will show systematic differences: +Different sensors measuring the same underlying infections will show systematic differences: ```{python} -# | label: multi-site -n_sites = 4 +# | label: multi-sensor +num_sensors = 4 infections_constant = jnp.ones((40, 1)) * 1500.0 -# Each site samples at multiple time points -times_per_site = 10 -observation_times = jnp.tile(jnp.arange(10, 30, 2, dtype=jnp.int32), n_sites) -site_ids = jnp.repeat(jnp.arange(n_sites, dtype=jnp.int32), times_per_site) -subpop_ids = jnp.zeros(n_sites * times_per_site, dtype=jnp.int32) +# Each sensor samples at multiple time points +times_per_sensor = 10 +observation_times = jnp.tile(jnp.arange(10, 30, 2, dtype=jnp.int32), num_sensors) +sensor_ids = jnp.repeat(jnp.arange(num_sensors, dtype=jnp.int32), times_per_sensor) +subpop_ids = jnp.zeros(num_sensors * times_per_sensor, dtype=jnp.int32) with numpyro.handlers.seed(rng_seed=42): - ww_multi_site = ww_process.sample( + ww_multi_sensor = ww_process.sample( infections=infections_constant, subpop_indices=subpop_ids, - site_indices=site_ids, + sensor_indices=sensor_ids, times=observation_times, concentrations=None, - n_sites=n_sites, + n_sensors=num_sensors, ) ``` ```{python} -# | label: fig-multi-site -# | fig-cap: Multiple sites observing the same infections +# | label: fig-multi-sensor +# | fig-cap: Multiple sensors observing the same infections fig, ax = plt.subplots(figsize=(8, 4)) colors = ["steelblue", "coral", "green", "purple"] -for site in range(n_sites): - mask = np.array(site_ids) == site +for sensor in range(num_sensors): + mask = np.array(sensor_ids) == sensor ax.plot( np.array(observation_times)[mask], - np.array(ww_multi_site)[mask], + np.array(ww_multi_sensor.observed)[mask], "-o", - color=colors[site], - label=f"Site {site}", + color=colors[sensor], + label=f"Sensor {sensor}", ) ax.set_xlabel("Day") ax.set_ylabel("Log Viral Concentration") -ax.set_title("Multiple Sites Observing Same Infections") +ax.set_title("Multiple Sensors Observing Same Infections") ax.legend() plt.tight_layout() plt.show() ``` -Each site has a different baseline concentration (site mode) and different measurement variability (site SD). The hierarchical model learns these site-specific effects during inference. +Each sensor has a different baseline concentration (sensor mode) and different measurement variability (sensor SD). The hierarchical model learns these sensor-specific effects during inference. ## Effect of genome shedding parameter @@ -482,12 +482,12 @@ for genome_val in genome_values: ww_temp = process_temp.sample( infections=infections_decay, subpop_indices=jnp.zeros(n_obs_cmp, dtype=jnp.int32), - site_indices=jnp.zeros(n_obs_cmp, dtype=jnp.int32), + sensor_indices=jnp.zeros(n_obs_cmp, dtype=jnp.int32), times=observation_days_cmp, concentrations=None, - n_sites=1, + n_sensors=1, ) - results[genome_val] = np.array(ww_temp) + results[genome_val] = np.array(ww_temp.observed) ``` ```{python} @@ -516,7 +516,7 @@ plt.show() ## Multiple subpopulations -In hierarchical models, each measurement site serves a distinct subpopulation. Different subpopulations can have different infection levels: +In hierarchical models, each sensor serves a distinct subpopulation (catchment area). Different subpopulations can have different infection levels: ```{python} # | label: multi-subpop @@ -530,32 +530,32 @@ infections_subpop2 = 2000.0 * jnp.exp( ) # Fast decay infections_multi = jnp.stack([infections_subpop1, infections_subpop2], axis=1) -# Two sites, each observing a different subpopulation +# Two sensors, each observing a different subpopulation obs_days_mp = jnp.tile(jnp.arange(10, 30, 2, dtype=jnp.int32), 2) subpop_ids_mp = jnp.array([0] * 10 + [1] * 10, dtype=jnp.int32) -site_ids_mp = jnp.array([0] * 10 + [1] * 10, dtype=jnp.int32) +sensor_ids_mp = jnp.array([0] * 10 + [1] * 10, dtype=jnp.int32) with numpyro.handlers.seed(rng_seed=42): ww_multi_subpop = ww_process.sample( infections=infections_multi, subpop_indices=subpop_ids_mp, - site_indices=site_ids_mp, + sensor_indices=sensor_ids_mp, times=obs_days_mp, concentrations=None, - n_sites=2, + n_sensors=2, ) ``` ```{python} # | label: fig-multi-subpop -# | fig-cap: Sites observing different subpopulations with different infection dynamics +# | fig-cap: Sensors observing different subpopulations with different infection dynamics fig, ax = plt.subplots(figsize=(8, 4)) for subpop in [0, 1]: mask = np.array(subpop_ids_mp) == subpop ax.plot( np.array(obs_days_mp)[mask], - np.array(ww_multi_subpop)[mask], + np.array(ww_multi_subpop.observed)[mask], "-o", label=f"Subpop {subpop}", ) diff --git a/pyrenew/observation/__init__.py b/pyrenew/observation/__init__.py index dec3fadf..8a0cdeab 100644 --- a/pyrenew/observation/__init__.py +++ b/pyrenew/observation/__init__.py @@ -6,21 +6,22 @@ ------------ ``BaseObservationProcess`` is the abstract base. Concrete subclasses: -- ``Counts``: Jurisdiction-level counts (admissions, deaths) -- ``CountsBySite``: Site-specific disaggregated counts -- ``Measurements``: Continuous site-level signals (e.g., wastewater) +- ``Counts``: Aggregate counts (admissions, deaths) +- ``CountsBySubpop``: Subpopulation-level counts +- ``Measurements``: Continuous subpopulation-level signals (e.g., wastewater) All observation processes implement: -- ``_expected_signal(infections)``: transforms infections to expected values -- ``sample()``: calls ``_expected_signal()`` then applies noise model +- ``sample()``: Sample observations given infections +- ``infection_resolution()``: returns ``"aggregate"`` or ``"subpop"`` +- ``lookback_days()``: returns required infection history length Noise models (``CountNoise``, ``MeasurementNoise``) are composable—pass them to observation constructors to control the output distribution. """ from pyrenew.observation.base import BaseObservationProcess -from pyrenew.observation.count_observations import Counts, CountsBySite +from pyrenew.observation.count_observations import Counts, CountsBySubpop from pyrenew.observation.measurements import Measurements from pyrenew.observation.negativebinomial import NegativeBinomialObservation from pyrenew.observation.noise import ( @@ -31,21 +32,23 @@ PoissonNoise, ) from pyrenew.observation.poisson import PoissonObservation +from pyrenew.observation.types import ObservationSample __all__ = [ # Existing (kept for backward compatibility) "NegativeBinomialObservation", "PoissonObservation", - # New base classes + # Base classes and types "BaseObservationProcess", - # New noise models + "ObservationSample", + # Noise models "CountNoise", "PoissonNoise", "NegativeBinomialNoise", "MeasurementNoise", "HierarchicalNormalNoise", - # New observation processes + # Observation processes "Counts", - "CountsBySite", + "CountsBySubpop", "Measurements", ] diff --git a/pyrenew/observation/base.py b/pyrenew/observation/base.py index 8468566a..635d386f 100644 --- a/pyrenew/observation/base.py +++ b/pyrenew/observation/base.py @@ -37,8 +37,8 @@ class BaseObservationProcess(RandomVariable): Subclasses must implement: - ``validate()``: Validate parameters (call ``_validate_pmf()`` for PMFs) - - ``get_required_lookback()``: Return PMF length for initialization - - ``infection_resolution()``: Return ``"jurisdiction"`` or ``"site"`` + - ``lookback_days()``: Return PMF length for initialization + - ``infection_resolution()``: Return ``"aggregate"`` or ``"subpop"`` - ``_expected_signal()``: Transform infections to expected values - ``sample()``: Apply noise model to expected signal @@ -91,7 +91,7 @@ def validate(self) -> None: pass # pragma: no cover @abstractmethod - def get_required_lookback(self) -> int: + def lookback_days(self) -> int: """ Return the number of days this observation process needs to look back. @@ -115,25 +115,27 @@ def get_required_lookback(self) -> int: @abstractmethod def infection_resolution(self) -> str: """ - Return the resolution of infections this observation uses. + Return whether this observation uses aggregate or subpop infections. Returns one of: - - ``"jurisdiction"``: Uses jurisdiction-level aggregated infections - - ``"site"``: Uses site-level disaggregated infections + - ``"aggregate"``: Uses a single aggregated infection trajectory. + Shape: ``(n_days,)`` + - ``"subpop"``: Uses subpopulation-level infection trajectories. + Shape: ``(n_days, n_subpops)``, indexed via ``subpop_indices``. Returns ------- str - Either ``"jurisdiction"`` or ``"site"`` + Either ``"aggregate"`` or ``"subpop"`` Examples -------- - >>> # Aggregated count observations use jurisdiction-level - >>> hosp_obs.infection_resolution() # Returns "jurisdiction" + >>> # Aggregated count observations + >>> hosp_obs.infection_resolution() # Returns "aggregate" >>> - >>> # Wastewater uses site-level - >>> ww_obs.infection_resolution() # Returns "site" + >>> # Subpopulation-level observations (wastewater, subpop-specific counts) + >>> ww_obs.infection_resolution() # Returns "subpop" Notes ----- @@ -285,8 +287,8 @@ def _expected_signal( ---------- infections : ArrayLike Infections from the infection process. - Shape: (n_days,) for jurisdiction-level observations - Shape: (n_days, n_sites) for site-level observations + Shape: (n_days,) for aggregate observations + Shape: (n_days, n_subpops) for subpop-level observations Returns ------- diff --git a/pyrenew/observation/count_observations.py b/pyrenew/observation/count_observations.py index 437c37e6..7a883932 100644 --- a/pyrenew/observation/count_observations.py +++ b/pyrenew/observation/count_observations.py @@ -14,6 +14,7 @@ from pyrenew.metaclass import RandomVariable from pyrenew.observation.base import BaseObservationProcess from pyrenew.observation.noise import CountNoise +from pyrenew.observation.types import ObservationSample class _CountBase(BaseObservationProcess): @@ -67,7 +68,7 @@ def validate(self) -> None: self.noise.validate() - def get_required_lookback(self) -> int: + def lookback_days(self) -> int: """ Return delay PMF length. @@ -85,7 +86,7 @@ def infection_resolution(self) -> str: Returns ------- str - "jurisdiction" for aggregated, "site" for disaggregated. + "aggregate" or "subpop". """ raise NotImplementedError("Subclasses must implement infection_resolution()") @@ -100,8 +101,8 @@ def _expected_signal( ---------- infections : ArrayLike Infections from the infection process. - Shape: (n_days,) for jurisdiction-level - Shape: (n_days, n_sites) for site-level + Shape: (n_days,) for aggregate + Shape: (n_days, n_subpops) for subpop-level Returns ------- @@ -127,10 +128,10 @@ def convolve_col(col): # numpydoc ignore=GL08 class Counts(_CountBase): """ - Aggregated count observation for jurisdiction-level data. + Aggregated count observation. - Maps jurisdiction-level infections to aggregated counts through - ascertainment x delay convolution with composable noise model. + Maps aggregate infections to counts through ascertainment x delay + convolution with composable noise model. Parameters ---------- @@ -167,21 +168,29 @@ class Counts(_CountBase): def infection_resolution(self) -> str: """ - Return "jurisdiction" for aggregated observations. + Return "aggregate" for aggregated observations. Returns ------- str - The string "jurisdiction". + The string "aggregate". """ - return "jurisdiction" + return "aggregate" + + def __repr__(self) -> str: + """Return string representation.""" + return ( + f"Counts(ascertainment_rate_rv={self.ascertainment_rate_rv!r}, " + f"delay_distribution_rv={self.temporal_pmf_rv!r}, " + f"noise={self.noise!r})" + ) def sample( self, infections: ArrayLike, counts: ArrayLike | None = None, times: ArrayLike | None = None, - ) -> ArrayLike: + ) -> ObservationSample: """ Sample aggregated counts with dense or sparse observations. @@ -191,7 +200,7 @@ def sample( Parameters ---------- infections : ArrayLike - Jurisdiction-level infections from the infection process. + Aggregate infections from the infection process. Shape: (n_days,) counts : ArrayLike | None Observed counts. Dense: (n_days,), Sparse: (n_obs,), None: prior. @@ -200,9 +209,9 @@ def sample( Returns ------- - ArrayLike - Observed or sampled counts. - Dense: (n_days,), Sparse: (n_obs,) + ObservationSample + Named tuple with `observed` (sampled/conditioned counts) and + `expected` (expected counts before noise). """ expected_counts = self._expected_signal(infections) self._deterministic("expected_counts", expected_counts) @@ -214,18 +223,20 @@ def sample( else: expected_obs = expected_counts_safe - return self.noise.sample( + observed = self.noise.sample( name="counts", expected=expected_obs, obs=counts, ) + return ObservationSample(observed=observed, expected=expected_counts) + -class CountsBySite(_CountBase): +class CountsBySubpop(_CountBase): """ - Disaggregated count observation for site-specific data. + Subpopulation-level count observation. - Maps site-level infections to site-specific counts through + Maps subpopulation-level infections to counts through ascertainment x delay convolution with composable noise model. Parameters @@ -244,19 +255,19 @@ class CountsBySite(_CountBase): Examples -------- >>> from pyrenew.deterministic import DeterministicVariable, DeterministicPMF - >>> from pyrenew.observation import CountsBySite, PoissonNoise + >>> from pyrenew.observation import CountsBySubpop, PoissonNoise >>> import jax.numpy as jnp >>> import numpyro >>> >>> delay_pmf = jnp.array([0.3, 0.4, 0.3]) - >>> counts_obs = CountsBySite( + >>> counts_obs = CountsBySubpop( ... ascertainment_rate_rv=DeterministicVariable("ihr", 0.02), ... delay_distribution_rv=DeterministicPMF("delay", delay_pmf), ... noise=PoissonNoise(), ... ) >>> >>> with numpyro.handlers.seed(rng_seed=42): - ... infections = jnp.ones((30, 3)) * 500 # 30 days, 3 sites + ... infections = jnp.ones((30, 3)) * 500 # 30 days, 3 subpops ... times = jnp.array([10, 15, 10, 15]) ... subpop_indices = jnp.array([0, 0, 1, 1]) ... sampled = counts_obs.sample( @@ -267,16 +278,24 @@ class CountsBySite(_CountBase): ... ) """ + def __repr__(self) -> str: + """Return string representation.""" + return ( + f"CountsBySubpop(ascertainment_rate_rv={self.ascertainment_rate_rv!r}, " + f"delay_distribution_rv={self.temporal_pmf_rv!r}, " + f"noise={self.noise!r})" + ) + def infection_resolution(self) -> str: """ - Return "site" for disaggregated observations. + Return "subpop" for subpopulation-level observations. Returns ------- str - The string "site". + The string "subpop". """ - return "site" + return "subpop" def sample( self, @@ -284,9 +303,9 @@ def sample( subpop_indices: ArrayLike, times: ArrayLike, counts: ArrayLike | None = None, - ) -> ArrayLike: + ) -> ObservationSample: """ - Sample disaggregated counts with flexible indexing. + Sample subpopulation-level counts with flexible indexing. Validation is performed before JAX tracing at runtime, prior to calling this method. @@ -294,8 +313,8 @@ def sample( Parameters ---------- infections : ArrayLike - Site-level infections from the infection process. - Shape: (n_days, n_sites) + Subpopulation-level infections from the infection process. + Shape: (n_days, n_subpops) subpop_indices : ArrayLike Subpopulation index for each observation (0-indexed). Shape: (n_obs,) @@ -307,21 +326,23 @@ def sample( Returns ------- - ArrayLike - Observed or sampled counts. - Shape: (n_obs,) + ObservationSample + Named tuple with `observed` (sampled/conditioned counts) and + `expected` (expected counts before noise, shape: n_days x n_subpops). """ - # Compute expected counts for all sites + # Compute expected counts for all subpops expected_counts_all = self._expected_signal(infections) - self._deterministic("expected_counts_by_site", expected_counts_all) + self._deterministic("expected_counts_by_subpop", expected_counts_all) # Replace NaN padding with 0 for distribution creation expected_counts_safe = jnp.nan_to_num(expected_counts_all, nan=0.0) expected_obs = expected_counts_safe[times, subpop_indices] - return self.noise.sample( - name="counts_by_site", + observed = self.noise.sample( + name="counts_by_subpop", expected=expected_obs, obs=counts, ) + + return ObservationSample(observed=observed, expected=expected_counts_all) diff --git a/pyrenew/observation/measurements.py b/pyrenew/observation/measurements.py index 65a95fa9..33e38980 100644 --- a/pyrenew/observation/measurements.py +++ b/pyrenew/observation/measurements.py @@ -11,6 +11,7 @@ from pyrenew.metaclass import RandomVariable from pyrenew.observation.base import BaseObservationProcess from pyrenew.observation.noise import MeasurementNoise +from pyrenew.observation.types import ObservationSample class Measurements(BaseObservationProcess): @@ -37,7 +38,7 @@ class Measurements(BaseObservationProcess): See Also -------- pyrenew.observation.noise.HierarchicalNormalNoise : - Suitable noise model for site-level measurements + Suitable noise model for sensor-level measurements pyrenew.observation.base.BaseObservationProcess : Parent class with common observation utilities """ @@ -55,39 +56,47 @@ def __init__( temporal_pmf_rv : RandomVariable Temporal distribution PMF (e.g., shedding kinetics). noise : MeasurementNoise - Noise model (e.g., HierarchicalNormalNoise with site effects). + Noise model (e.g., HierarchicalNormalNoise with sensor effects). """ super().__init__(temporal_pmf_rv=temporal_pmf_rv) self.noise = noise + def __repr__(self) -> str: + """Return string representation.""" + return ( + f"{self.__class__.__name__}(" + f"temporal_pmf_rv={self.temporal_pmf_rv!r}, " + f"noise={self.noise!r})" + ) + def infection_resolution(self) -> str: """ - Return "site" for measurement observations. + Return "subpop" for measurement observations. - Measurement observations require site-level infections - because each measurement site corresponds to a specific catchment. + Measurement observations require subpopulation-level infections + because each measurement corresponds to a specific catchment area. Returns ------- str - ``"site"`` + ``"subpop"`` """ - return "site" + return "subpop" def sample( self, infections: ArrayLike, subpop_indices: ArrayLike, - site_indices: ArrayLike, + sensor_indices: ArrayLike, times: ArrayLike, concentrations: ArrayLike | None, - n_sites: int, - ) -> ArrayLike: + n_sensors: int, + ) -> ObservationSample: """ - Sample measurements from observed sites. + Sample measurements from observed sensors. This method does not perform runtime validation of index values - (times, subpop_indices, site_indices). Validate observation data + (times, subpop_indices, sensor_indices). Validate observation data before sampling. Transforms infections to expected values via signal-specific processing @@ -97,26 +106,26 @@ def sample( ---------- infections : ArrayLike Infections from the infection process. - Shape: (n_days, n_sites) + Shape: (n_days, n_subpops) subpop_indices : ArrayLike Subpopulation index for each observation (0-indexed). Shape: (n_obs,) - site_indices : ArrayLike - Site index for each observation (0-indexed). + sensor_indices : ArrayLike + Sensor index for each observation (0-indexed). Shape: (n_obs,) times : ArrayLike Day index for each observation (0-indexed). Shape: (n_obs,) concentrations : ArrayLike | None Observed measurements (n_obs,), or None for prior sampling. - n_sites : int - Total number of measurement sites. + n_sensors : int + Total number of measurement sensors. Returns ------- - ArrayLike - Observed or sampled measurements. - Shape: (n_obs,) + ObservationSample + Named tuple with `observed` (sampled/conditioned measurements) and + `expected` (expected values before noise, shape: n_days x n_subpops). """ expected_values = self._expected_signal(infections) @@ -124,10 +133,12 @@ def sample( expected_obs = expected_values[times, subpop_indices] - return self.noise.sample( + observed = self.noise.sample( name="concentrations", expected=expected_obs, obs=concentrations, - site_indices=site_indices, - n_sites=n_sites, + sensor_indices=sensor_indices, + n_sensors=n_sensors, ) + + return ObservationSample(observed=observed, expected=expected_values) diff --git a/pyrenew/observation/noise.py b/pyrenew/observation/noise.py index 4065061e..647b1781 100644 --- a/pyrenew/observation/noise.py +++ b/pyrenew/observation/noise.py @@ -13,8 +13,8 @@ Measurement Noise ----------------- -- ``HierarchicalNormalNoise``: Normal noise with hierarchical site effects. - Takes ``site_mode_prior_rv`` and ``site_sd_prior_rv`` for site-level +- ``HierarchicalNormalNoise``: Normal noise with hierarchical sensor effects. + Takes ``sensor_mode_prior_rv`` and ``sensor_sd_prior_rv`` for sensor-level bias and variability. """ @@ -87,6 +87,10 @@ def __init__(self) -> None: """Initialize Poisson noise (no parameters).""" pass + def __repr__(self) -> str: + """Return string representation.""" + return "PoissonNoise()" + def validate(self) -> None: """Validate Poisson noise (always valid).""" pass @@ -151,6 +155,10 @@ def __init__(self, concentration_rv: RandomVariable) -> None: """ self.concentration_rv = concentration_rv + def __repr__(self) -> str: + """Return string representation.""" + return f"NegativeBinomialNoise(concentration_rv={self.concentration_rv!r})" + def validate(self) -> None: """ Validate concentration is positive. @@ -228,7 +236,7 @@ def sample( obs : ArrayLike | None Observed measurements for conditioning, or None for prior sampling. **kwargs - Additional context (e.g., site indices). + Additional context (e.g., sensor indices). Returns ------- @@ -252,18 +260,18 @@ def validate(self) -> None: class HierarchicalNormalNoise(MeasurementNoise): """ - Normal noise with hierarchical site-level effects. + Normal noise with hierarchical sensor-level effects. - Observation model: ``obs ~ Normal(expected + site_mode, site_sd)`` - where site_mode and site_sd are hierarchically modeled. + Observation model: ``obs ~ Normal(expected + sensor_mode, sensor_sd)`` + where sensor_mode and sensor_sd are hierarchically modeled. Parameters ---------- - site_mode_prior_rv : RandomVariable - Hierarchical prior for site-level modes (log-scale biases). + sensor_mode_prior_rv : RandomVariable + Hierarchical prior for sensor-level modes (log-scale biases). Must support ``sample(n_groups=...)`` interface. - site_sd_prior_rv : RandomVariable - Hierarchical prior for site-level SDs (must be > 0). + sensor_sd_prior_rv : RandomVariable + Hierarchical prior for sensor-level SDs (must be > 0). Must support ``sample(n_groups=...)`` interface. Notes @@ -273,30 +281,38 @@ class HierarchicalNormalNoise(MeasurementNoise): See Also -------- pyrenew.randomvariable.HierarchicalNormalPrior : - Suitable prior for site_mode_prior_rv + Suitable prior for sensor_mode_prior_rv pyrenew.randomvariable.GammaGroupSdPrior : - Suitable prior for site_sd_prior_rv + Suitable prior for sensor_sd_prior_rv """ def __init__( self, - site_mode_prior_rv: RandomVariable, - site_sd_prior_rv: RandomVariable, + sensor_mode_prior_rv: RandomVariable, + sensor_sd_prior_rv: RandomVariable, ) -> None: """ Initialize hierarchical Normal noise. Parameters ---------- - site_mode_prior_rv : RandomVariable - Hierarchical prior for site-level modes (log-scale biases). + sensor_mode_prior_rv : RandomVariable + Hierarchical prior for sensor-level modes (log-scale biases). Must support ``sample(n_groups=...)`` interface. - site_sd_prior_rv : RandomVariable - Hierarchical prior for site-level SDs (must be > 0). + sensor_sd_prior_rv : RandomVariable + Hierarchical prior for sensor-level SDs (must be > 0). Must support ``sample(n_groups=...)`` interface. """ - self.site_mode_prior_rv = site_mode_prior_rv - self.site_sd_prior_rv = site_sd_prior_rv + self.sensor_mode_prior_rv = sensor_mode_prior_rv + self.sensor_sd_prior_rv = sensor_sd_prior_rv + + def __repr__(self) -> str: + """Return string representation.""" + return ( + f"HierarchicalNormalNoise(" + f"sensor_mode_prior_rv={self.sensor_mode_prior_rv!r}, " + f"sensor_sd_prior_rv={self.sensor_sd_prior_rv!r})" + ) def validate(self) -> None: """ @@ -314,11 +330,11 @@ def sample( expected: ArrayLike, obs: ArrayLike | None = None, *, - site_indices: ArrayLike, - n_sites: int, + sensor_indices: ArrayLike, + n_sensors: int, ) -> ArrayLike: """ - Sample from Normal distribution with site-level hierarchical effects. + Sample from Normal distribution with sensor-level hierarchical effects. Parameters ---------- @@ -330,27 +346,27 @@ def sample( obs : ArrayLike | None Observed log-scale measurements for conditioning. Shape: (n_obs,) - site_indices : ArrayLike - Site index for each observation (0-indexed). + sensor_indices : ArrayLike + Sensor index for each observation (0-indexed). Shape: (n_obs,) - n_sites : int - Total number of sites. + n_sensors : int + Total number of sensors. Returns ------- ArrayLike - Normal distributed measurements with hierarchical site effects. + Normal distributed measurements with hierarchical sensor effects. Shape: (n_obs,) Raises ------ ValueError - If site_sd samples non-positive values. + If sensor_sd samples non-positive values. """ - site_mode = self.site_mode_prior_rv.sample(n_groups=n_sites) - site_sd = self.site_sd_prior_rv.sample(n_groups=n_sites) + sensor_mode = self.sensor_mode_prior_rv.sample(n_groups=n_sensors) + sensor_sd = self.sensor_sd_prior_rv.sample(n_groups=n_sensors) - loc = expected + site_mode[site_indices] - scale = site_sd[site_indices] + loc = expected + sensor_mode[sensor_indices] + scale = sensor_sd[sensor_indices] return numpyro.sample(name, dist.Normal(loc=loc, scale=scale), obs=obs) diff --git a/test/conftest.py b/test/conftest.py index 73854684..6f55681a 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -127,9 +127,9 @@ def medium_shedding_pmf(): @pytest.fixture -def site_mode_prior(): +def sensor_mode_prior(): """ - Standard hierarchical normal prior for site modes. + Standard hierarchical normal prior for sensor modes. Returns ------- @@ -137,12 +137,12 @@ def site_mode_prior(): A hierarchical normal prior with standard deviation 0.5. """ return HierarchicalNormalPrior( - name="ww_site_mode", sd_rv=DeterministicVariable("mode_sd", 0.5) + name="ww_sensor_mode", sd_rv=DeterministicVariable("mode_sd", 0.5) ) @pytest.fixture -def site_mode_prior_tight(): +def sensor_mode_prior_tight(): """ Tight hierarchical normal prior for deterministic-like behavior. @@ -152,14 +152,14 @@ def site_mode_prior_tight(): A hierarchical normal prior with small standard deviation 0.01. """ return HierarchicalNormalPrior( - name="ww_site_mode", sd_rv=DeterministicVariable("mode_sd_tight", 0.01) + name="ww_sensor_mode", sd_rv=DeterministicVariable("mode_sd_tight", 0.01) ) @pytest.fixture -def site_sd_prior(): +def sensor_sd_prior(): """ - Standard gamma prior for site standard deviations. + Standard gamma prior for sensor standard deviations. Returns ------- @@ -167,7 +167,7 @@ def site_sd_prior(): A gamma prior for group standard deviations. """ return GammaGroupSdPrior( - name="ww_site_sd", + name="ww_sensor_sd", sd_mean_rv=DeterministicVariable("sd_mean", 0.3), sd_concentration_rv=DeterministicVariable("sd_concentration", 4.0), sd_min=0.10, @@ -175,7 +175,7 @@ def site_sd_prior(): @pytest.fixture -def site_sd_prior_tight(): +def sensor_sd_prior_tight(): """ Tight gamma prior for deterministic-like behavior. @@ -185,7 +185,7 @@ def site_sd_prior_tight(): A gamma prior with small mean for tight behavior. """ return GammaGroupSdPrior( - name="ww_site_sd", + name="ww_sensor_sd", sd_mean_rv=DeterministicVariable("sd_mean_tight", 0.01), sd_concentration_rv=DeterministicVariable("sd_concentration_tight", 4.0), sd_min=0.005, diff --git a/test/test_observation_counts.py b/test/test_observation_counts.py index 58f0b33e..ddb4d684 100644 --- a/test/test_observation_counts.py +++ b/test/test_observation_counts.py @@ -12,7 +12,7 @@ from pyrenew.deterministic import DeterministicPMF, DeterministicVariable from pyrenew.observation import ( Counts, - CountsBySite, + CountsBySubpop, NegativeBinomialNoise, PoissonNoise, ) @@ -66,13 +66,14 @@ def test_sample_returns_correct_shape(self, counts_process): infections = jnp.ones(30) * 100 with numpyro.handlers.seed(rng_seed=42): - counts = counts_process.sample( + result = counts_process.sample( infections=infections, counts=None, ) - assert counts.shape[0] > 0 - assert counts.ndim == 1 + assert result.observed.shape[0] > 0 + assert result.observed.ndim == 1 + assert result.expected.shape == infections.shape def test_delay_convolution(self, counts_factory, short_delay_pmf): """Test that delay is properly applied.""" @@ -82,16 +83,16 @@ def test_delay_convolution(self, counts_factory, short_delay_pmf): infections = infections.at[10].set(1000) with numpyro.handlers.seed(rng_seed=42): - counts = process.sample( + result = process.sample( infections=infections, counts=None, ) # Timeline alignment: output length equals input length - assert counts.shape[0] == len(infections) + assert result.observed.shape[0] == len(infections) # First len(delay_pmf)-1 days are NaN (appear as -1 after NegativeBinomial sampling) - assert jnp.all(counts[1:] >= 0) - assert jnp.sum(counts[counts >= 0]) > 0 + assert jnp.all(result.observed[1:] >= 0) + assert jnp.sum(result.observed[result.observed >= 0]) > 0 def test_ascertainment_scaling(self, counts_factory, simple_delay_pmf): """Test that ascertainment rate properly scales counts.""" @@ -105,11 +106,11 @@ def test_ascertainment_scaling(self, counts_factory, simple_delay_pmf): ) with numpyro.handlers.seed(rng_seed=42): - counts = process.sample( + result = process.sample( infections=infections, counts=None, ) - results.append(jnp.mean(counts)) + results.append(jnp.mean(result.observed)) # Higher ascertainment rate should lead to more counts assert results[1] > results[0] @@ -127,11 +128,11 @@ def test_negative_binomial_observation(self, counts_factory, simple_delay_pmf): samples = [] for seed in range(5): with numpyro.handlers.seed(rng_seed=seed): - counts = process.sample( + result = process.sample( infections=infections, counts=None, ) - samples.append(jnp.sum(counts)) + samples.append(jnp.sum(result.observed)) # Should have some variability due to negative binomial sampling assert jnp.std(jnp.array(samples)) > 0 @@ -155,14 +156,14 @@ def test_with_stochastic_ascertainment(self, short_shedding_pmf): infections = jnp.ones(20) * 100 with numpyro.handlers.seed(rng_seed=42): - counts = process.sample( + result = process.sample( infections=infections, counts=None, ) - assert counts.shape[0] > 0 + assert result.observed.shape[0] > 0 # Skip NaN padding - valid_counts = counts[2:] + valid_counts = result.observed[2:] assert jnp.all(valid_counts >= 0) def test_with_stochastic_concentration(self, simple_delay_pmf): @@ -180,13 +181,13 @@ def test_with_stochastic_concentration(self, simple_delay_pmf): infections = jnp.ones(20) * 100 with numpyro.handlers.seed(rng_seed=42): - counts = process.sample( + result = process.sample( infections=infections, counts=None, ) - assert counts.shape[0] > 0 - assert jnp.all(counts >= 0) + assert result.observed.shape[0] > 0 + assert jnp.all(result.observed >= 0) class TestCountsEdgeCases: @@ -197,25 +198,25 @@ def test_zero_infections(self, counts_process): infections = jnp.zeros(20) with numpyro.handlers.seed(rng_seed=42): - counts = counts_process.sample( + result = counts_process.sample( infections=infections, counts=None, ) - assert counts.shape[0] > 0 + assert result.observed.shape[0] > 0 def test_small_infections(self, counts_process): """Test with small infection values.""" infections = jnp.ones(20) * 10 with numpyro.handlers.seed(rng_seed=42): - counts = counts_process.sample( + result = counts_process.sample( infections=infections, counts=None, ) - assert counts.shape[0] > 0 - assert jnp.all(counts >= 0) + assert result.observed.shape[0] > 0 + assert jnp.all(result.observed >= 0) def test_long_delay_distribution(self, counts_factory, long_delay_pmf): """Test with longer delay distribution.""" @@ -224,15 +225,15 @@ def test_long_delay_distribution(self, counts_factory, long_delay_pmf): infections = create_mock_infections(40, peak_day=20, shape="spike") with numpyro.handlers.seed(rng_seed=42): - counts = process.sample( + result = process.sample( infections=infections, counts=None, ) # Timeline alignment maintained - assert counts.shape[0] == infections.shape[0] + assert result.observed.shape[0] == infections.shape[0] # Skip NaN padding: 10-day delay -> first 9 days are NaN - valid_counts = counts[9:] + valid_counts = result.observed[9:] assert jnp.sum(valid_counts) > 0 @@ -249,14 +250,14 @@ def test_sparse_observations(self, counts_process): counts_data = jnp.array([10, 12, 8, 15]) with numpyro.handlers.seed(rng_seed=42): - counts = counts_process.sample( + result = counts_process.sample( infections=infections, counts=counts_data, times=times, ) - assert counts.shape == times.shape - assert jnp.allclose(counts, counts_data) + assert result.observed.shape == times.shape + assert jnp.allclose(result.observed, counts_data) def test_sparse_vs_dense_sampling(self, counts_process): """Test that sparse sampling gives different output shape than dense.""" @@ -265,7 +266,7 @@ def test_sparse_vs_dense_sampling(self, counts_process): # Dense: prior sampling (counts=None, no times) with numpyro.handlers.seed(rng_seed=42): - dense_counts = counts_process.sample( + dense_result = counts_process.sample( infections=infections, counts=None, ) @@ -274,18 +275,18 @@ def test_sparse_vs_dense_sampling(self, counts_process): times = jnp.array([5, 10, 15, 20]) sparse_obs_data = jnp.array([10, 12, 8, 15]) with numpyro.handlers.seed(rng_seed=42): - sparse_counts = counts_process.sample( + sparse_result = counts_process.sample( infections=infections, counts=sparse_obs_data, times=times, ) # Dense prior produces full length output - assert dense_counts.shape == (n_days,) + assert dense_result.observed.shape == (n_days,) # Sparse observations produce output matching times shape - assert sparse_counts.shape == times.shape - assert jnp.allclose(sparse_counts, sparse_obs_data) + assert sparse_result.observed.shape == times.shape + assert jnp.allclose(sparse_result.observed, sparse_obs_data) def test_prior_sampling_ignores_times(self, counts_process): """Test that times parameter is ignored when counts=None (prior sampling).""" @@ -295,61 +296,62 @@ def test_prior_sampling_ignores_times(self, counts_process): # When counts=None, times is ignored - output is dense with numpyro.handlers.seed(rng_seed=42): - prior_with_times = counts_process.sample( + result_with_times = counts_process.sample( infections=infections, counts=None, times=times, ) with numpyro.handlers.seed(rng_seed=42): - prior_without_times = counts_process.sample( + result_without_times = counts_process.sample( infections=infections, counts=None, ) # Both should produce dense output of shape (n_days,) - assert prior_with_times.shape == (n_days,) - assert prior_without_times.shape == (n_days,) + assert result_with_times.observed.shape == (n_days,) + assert result_without_times.observed.shape == (n_days,) # With same seed, outputs should be identical - assert jnp.allclose(prior_with_times, prior_without_times) + assert jnp.allclose(result_with_times.observed, result_without_times.observed) -class TestCountsBySite: - """Test CountsBySite for disaggregated observations.""" +class TestCountsBySubpop: + """Test CountsBySubpop for subpopulation-level observations.""" def test_sample_returns_correct_shape(self): - """Test that CountsBySite sample returns correct shape.""" + """Test that CountsBySubpop sample returns correct shape.""" delay_pmf = jnp.array([0.3, 0.4, 0.3]) - process = CountsBySite( + process = CountsBySubpop( ascertainment_rate_rv=DeterministicVariable("ihr", 0.02), delay_distribution_rv=DeterministicPMF("delay", delay_pmf), noise=PoissonNoise(), ) - infections = jnp.ones((30, 3)) * 500 # 30 days, 3 sites + infections = jnp.ones((30, 3)) * 500 # 30 days, 3 subpops times = jnp.array([10, 15, 10, 15]) subpop_indices = jnp.array([0, 0, 1, 1]) with numpyro.handlers.seed(rng_seed=42): - counts = process.sample( + result = process.sample( infections=infections, subpop_indices=subpop_indices, times=times, counts=None, ) - assert counts.shape == times.shape + assert result.observed.shape == times.shape + assert result.expected.shape == infections.shape def test_infection_resolution(self): - """Test that CountsBySite returns 'site' resolution.""" + """Test that CountsBySubpop returns 'subpop' resolution.""" delay_pmf = jnp.array([1.0]) - process = CountsBySite( + process = CountsBySubpop( ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), delay_distribution_rv=DeterministicPMF("delay", delay_pmf), noise=PoissonNoise(), ) - assert process.infection_resolution() == "site" + assert process.infection_resolution() == "subpop" class TestPoissonNoise: @@ -366,13 +368,13 @@ def test_poisson_counts(self, simple_delay_pmf): infections = jnp.ones(20) * 1000 with numpyro.handlers.seed(rng_seed=42): - counts = process.sample( + result = process.sample( infections=infections, counts=None, ) - assert counts.shape[0] == 20 - assert jnp.all(counts >= 0) + assert result.observed.shape[0] == 20 + assert jnp.all(result.observed >= 0) class TestCountBaseInternalMethods: @@ -435,30 +437,30 @@ def test_validate_invalid_ascertainment_rate_greater_than_one( with pytest.raises(ValueError, match="ascertainment_rate_rv must be in"): process.validate() - def test_get_required_lookback(self, simple_delay_pmf, long_delay_pmf): - """Test get_required_lookback returns PMF length.""" + def test_lookback_days(self, simple_delay_pmf, long_delay_pmf): + """Test lookback_days returns PMF length.""" process_short = Counts( ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), ) - assert process_short.get_required_lookback() == 1 + assert process_short.lookback_days() == 1 process_long = Counts( ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), delay_distribution_rv=DeterministicPMF("delay", long_delay_pmf), noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), ) - assert process_long.get_required_lookback() == 10 + assert process_long.lookback_days() == 10 def test_infection_resolution_counts(self, simple_delay_pmf): - """Test that Counts returns 'jurisdiction' resolution.""" + """Test that Counts returns 'aggregate' resolution.""" process = Counts( ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), ) - assert process.infection_resolution() == "jurisdiction" + assert process.infection_resolution() == "aggregate" class TestNoiseValidation: diff --git a/test/test_observation_measurements.py b/test/test_observation_measurements.py index 5e69fd3d..e5c014a2 100644 --- a/test/test_observation_measurements.py +++ b/test/test_observation_measurements.py @@ -27,7 +27,7 @@ def validate(self) -> None: pmf = self.temporal_pmf_rv() self._validate_pmf(pmf, "temporal_pmf_rv") - def get_required_lookback(self) -> int: + def lookback_days(self) -> int: """ Return temporal PMF length. @@ -49,7 +49,7 @@ def _expected_signal(self, infections): """ pmf = self.temporal_pmf_rv() - # Handle 2D infections (n_days, n_sites) + # Handle 2D infections (n_days, n_subpops) if infections.ndim == 1: infections = infections[:, jnp.newaxis] @@ -73,25 +73,25 @@ def test_is_base_observation_process(self): """Test that Measurements inherits from BaseObservationProcess.""" assert issubclass(Measurements, BaseObservationProcess) - def test_infection_resolution_is_site(self): - """Test that Measurements returns 'site' resolution.""" + def test_infection_resolution_is_subpop(self): + """Test that Measurements returns 'subpop' resolution.""" shedding_pmf = jnp.array([0.3, 0.4, 0.3]) - site_mode_prior = HierarchicalNormalPrior( + sensor_mode_prior = HierarchicalNormalPrior( name="mode", sd_rv=DeterministicVariable("sd", 0.5) ) - site_sd_prior = GammaGroupSdPrior( + sensor_sd_prior = GammaGroupSdPrior( name="sd", sd_mean_rv=DeterministicVariable("mean", 0.3), sd_concentration_rv=DeterministicVariable("conc", 4.0), ) - noise = HierarchicalNormalNoise(site_mode_prior, site_sd_prior) + noise = HierarchicalNormalNoise(sensor_mode_prior, sensor_sd_prior) process = ConcreteMeasurements( temporal_pmf_rv=DeterministicPMF("shedding", shedding_pmf), noise=noise, ) - assert process.infection_resolution() == "site" + assert process.infection_resolution() == "subpop" class TestHierarchicalNormalNoise: @@ -99,67 +99,67 @@ class TestHierarchicalNormalNoise: def test_validate(self): """Test HierarchicalNormalNoise validate method.""" - site_mode_prior = HierarchicalNormalPrior( + sensor_mode_prior = HierarchicalNormalPrior( name="mode", sd_rv=DeterministicVariable("mode_sd", 0.5) ) - site_sd_prior = GammaGroupSdPrior( + sensor_sd_prior = GammaGroupSdPrior( name="sd", sd_mean_rv=DeterministicVariable("sd_mean", 0.3), sd_concentration_rv=DeterministicVariable("sd_conc", 4.0), ) - noise = HierarchicalNormalNoise(site_mode_prior, site_sd_prior) + noise = HierarchicalNormalNoise(sensor_mode_prior, sensor_sd_prior) # Should not raise - validation is deferred to sample time noise.validate() def test_sample_shape(self): """Test that HierarchicalNormalNoise produces correct shape.""" - site_mode_prior = HierarchicalNormalPrior( + sensor_mode_prior = HierarchicalNormalPrior( name="mode", sd_rv=DeterministicVariable("mode_sd", 0.5) ) - site_sd_prior = GammaGroupSdPrior( + sensor_sd_prior = GammaGroupSdPrior( name="sd", sd_mean_rv=DeterministicVariable("sd_mean", 0.3), sd_concentration_rv=DeterministicVariable("sd_conc", 4.0), ) - noise = HierarchicalNormalNoise(site_mode_prior, site_sd_prior) + noise = HierarchicalNormalNoise(sensor_mode_prior, sensor_sd_prior) expected = jnp.array([1.0, 2.0, 3.0, 4.0]) - site_indices = jnp.array([0, 0, 1, 1]) + sensor_indices = jnp.array([0, 0, 1, 1]) with numpyro.handlers.seed(rng_seed=42): samples = noise.sample( name="test", expected=expected, obs=None, - site_indices=site_indices, - n_sites=2, + sensor_indices=sensor_indices, + n_sensors=2, ) assert samples.shape == expected.shape def test_sample_with_observations(self): """Test that HierarchicalNormalNoise conditions on observations.""" - site_mode_prior = HierarchicalNormalPrior( + sensor_mode_prior = HierarchicalNormalPrior( name="mode", sd_rv=DeterministicVariable("mode_sd", 0.5) ) - site_sd_prior = GammaGroupSdPrior( + sensor_sd_prior = GammaGroupSdPrior( name="sd", sd_mean_rv=DeterministicVariable("sd_mean", 0.3), sd_concentration_rv=DeterministicVariable("sd_conc", 4.0), ) - noise = HierarchicalNormalNoise(site_mode_prior, site_sd_prior) + noise = HierarchicalNormalNoise(sensor_mode_prior, sensor_sd_prior) expected = jnp.array([1.0, 2.0, 3.0, 4.0]) obs = jnp.array([1.1, 2.1, 3.1, 4.1]) - site_indices = jnp.array([0, 0, 1, 1]) + sensor_indices = jnp.array([0, 0, 1, 1]) with numpyro.handlers.seed(rng_seed=42): samples = noise.sample( name="test", expected=expected, obs=obs, - site_indices=site_indices, - n_sites=2, + sensor_indices=sensor_indices, + n_sensors=2, ) # When obs is provided, samples should equal obs @@ -172,52 +172,53 @@ class TestConcreteMeasurements: def test_sample_shape(self): """Test that sample returns correct shape.""" shedding_pmf = jnp.array([0.3, 0.4, 0.3]) - site_mode_prior = HierarchicalNormalPrior( + sensor_mode_prior = HierarchicalNormalPrior( name="mode", sd_rv=DeterministicVariable("mode_sd", 0.5) ) - site_sd_prior = GammaGroupSdPrior( + sensor_sd_prior = GammaGroupSdPrior( name="sd", sd_mean_rv=DeterministicVariable("sd_mean", 0.3), sd_concentration_rv=DeterministicVariable("sd_conc", 4.0), ) - noise = HierarchicalNormalNoise(site_mode_prior, site_sd_prior) + noise = HierarchicalNormalNoise(sensor_mode_prior, sensor_sd_prior) process = ConcreteMeasurements( temporal_pmf_rv=DeterministicPMF("shedding", shedding_pmf), noise=noise, ) - # 30 days, 2 sites + # 30 days, 2 subpops infections = jnp.ones((30, 2)) * 1000 subpop_indices = jnp.array([0, 0, 1, 1]) - site_indices = jnp.array([0, 0, 1, 1]) + sensor_indices = jnp.array([0, 0, 1, 1]) times = jnp.array([10, 15, 10, 15]) with numpyro.handlers.seed(rng_seed=42): - samples = process.sample( + result = process.sample( infections=infections, subpop_indices=subpop_indices, - site_indices=site_indices, + sensor_indices=sensor_indices, times=times, concentrations=None, - n_sites=2, + n_sensors=2, ) - assert samples.shape == times.shape + assert result.observed.shape == times.shape + assert result.expected.shape == infections.shape def test_expected_signal_stored(self): """Test that expected_log_conc is stored as deterministic.""" shedding_pmf = jnp.array([0.5, 0.5]) - site_mode_prior = HierarchicalNormalPrior( + sensor_mode_prior = HierarchicalNormalPrior( name="mode", sd_rv=DeterministicVariable("mode_sd", 0.01) ) - site_sd_prior = GammaGroupSdPrior( + sensor_sd_prior = GammaGroupSdPrior( name="sd", sd_mean_rv=DeterministicVariable("sd_mean", 0.01), sd_concentration_rv=DeterministicVariable("sd_conc", 4.0), sd_min=0.001, ) - noise = HierarchicalNormalNoise(site_mode_prior, site_sd_prior) + noise = HierarchicalNormalNoise(sensor_mode_prior, sensor_sd_prior) process = ConcreteMeasurements( temporal_pmf_rv=DeterministicPMF("shedding", shedding_pmf), @@ -226,7 +227,7 @@ def test_expected_signal_stored(self): infections = jnp.ones((20, 2)) * 1000 subpop_indices = jnp.array([0, 1]) - site_indices = jnp.array([0, 1]) + sensor_indices = jnp.array([0, 1]) times = jnp.array([10, 10]) with numpyro.handlers.seed(rng_seed=42): @@ -234,10 +235,10 @@ def test_expected_signal_stored(self): lambda: process.sample( infections=infections, subpop_indices=subpop_indices, - site_indices=site_indices, + sensor_indices=sensor_indices, times=times, concentrations=None, - n_sites=2, + n_sensors=2, ) ).get_trace() From 57d2fbaacd229768d7989f1a31448de16401f34d Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Fri, 26 Dec 2025 16:05:17 -0500 Subject: [PATCH 06/33] consistent names: 'subpop' (not site), 'sensor' for site/lab measurements, 'aggregate' instead of 'jurisdiction' --- docs/tutorials/observation_processes_measurements.qmd | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/docs/tutorials/observation_processes_measurements.qmd b/docs/tutorials/observation_processes_measurements.qmd index 613e1176..21dbff21 100644 --- a/docs/tutorials/observation_processes_measurements.qmd +++ b/docs/tutorials/observation_processes_measurements.qmd @@ -413,8 +413,12 @@ infections_constant = jnp.ones((40, 1)) * 1500.0 # Each sensor samples at multiple time points times_per_sensor = 10 -observation_times = jnp.tile(jnp.arange(10, 30, 2, dtype=jnp.int32), num_sensors) -sensor_ids = jnp.repeat(jnp.arange(num_sensors, dtype=jnp.int32), times_per_sensor) +observation_times = jnp.tile( + jnp.arange(10, 30, 2, dtype=jnp.int32), num_sensors +) +sensor_ids = jnp.repeat( + jnp.arange(num_sensors, dtype=jnp.int32), times_per_sensor +) subpop_ids = jnp.zeros(num_sensors * times_per_sensor, dtype=jnp.int32) with numpyro.handlers.seed(rng_seed=42): From 7efb52413894cdbd861fd1c8c83680f5ce3a2628 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 26 Dec 2025 21:05:27 +0000 Subject: [PATCH 07/33] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pyrenew/observation/__init__.py | 3 ++- pyrenew/observation/count_observations.py | 2 +- pyrenew/observation/measurements.py | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/pyrenew/observation/__init__.py b/pyrenew/observation/__init__.py index 8a0cdeab..7c8972d8 100644 --- a/pyrenew/observation/__init__.py +++ b/pyrenew/observation/__init__.py @@ -20,6 +20,8 @@ to observation constructors to control the output distribution. """ +from pyrenew.observation.types import ObservationSample + from pyrenew.observation.base import BaseObservationProcess from pyrenew.observation.count_observations import Counts, CountsBySubpop from pyrenew.observation.measurements import Measurements @@ -32,7 +34,6 @@ PoissonNoise, ) from pyrenew.observation.poisson import PoissonObservation -from pyrenew.observation.types import ObservationSample __all__ = [ # Existing (kept for backward compatibility) diff --git a/pyrenew/observation/count_observations.py b/pyrenew/observation/count_observations.py index 7a883932..135ccca3 100644 --- a/pyrenew/observation/count_observations.py +++ b/pyrenew/observation/count_observations.py @@ -10,11 +10,11 @@ import jax import jax.numpy as jnp from jax.typing import ArrayLike +from pyrenew.observation.types import ObservationSample from pyrenew.metaclass import RandomVariable from pyrenew.observation.base import BaseObservationProcess from pyrenew.observation.noise import CountNoise -from pyrenew.observation.types import ObservationSample class _CountBase(BaseObservationProcess): diff --git a/pyrenew/observation/measurements.py b/pyrenew/observation/measurements.py index 33e38980..6826dff8 100644 --- a/pyrenew/observation/measurements.py +++ b/pyrenew/observation/measurements.py @@ -7,11 +7,11 @@ """ from jax.typing import ArrayLike +from pyrenew.observation.types import ObservationSample from pyrenew.metaclass import RandomVariable from pyrenew.observation.base import BaseObservationProcess from pyrenew.observation.noise import MeasurementNoise -from pyrenew.observation.types import ObservationSample class Measurements(BaseObservationProcess): From 8a7947f383637810a6dd8d503dd562d858eefca8 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Mon, 29 Dec 2025 10:13:46 -0500 Subject: [PATCH 08/33] add observation types --- pyrenew/observation/types.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 pyrenew/observation/types.py diff --git a/pyrenew/observation/types.py b/pyrenew/observation/types.py new file mode 100644 index 00000000..b494e2e6 --- /dev/null +++ b/pyrenew/observation/types.py @@ -0,0 +1,28 @@ +# numpydoc ignore=GL08 +""" +Return types for observation processes. + +Named tuples providing structured access to observation process outputs. +""" + +from typing import NamedTuple + +from jax.typing import ArrayLike + + +class ObservationSample(NamedTuple): + """ + Return type for observation process sample() methods. + + Attributes + ---------- + observed : ArrayLike + Sampled or conditioned observations. Shape depends on the + observation process and indexing. + expected : ArrayLike + Expected values before noise is applied. Useful for + diagnostics and posterior predictive checks. + """ + + observed: ArrayLike + expected: ArrayLike From 6dff8cdf0de75f6f7e18df5c797f44153b299e07 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 29 Dec 2025 15:14:34 +0000 Subject: [PATCH 09/33] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pyrenew/observation/__init__.py | 3 +-- pyrenew/observation/count_observations.py | 2 +- pyrenew/observation/measurements.py | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/pyrenew/observation/__init__.py b/pyrenew/observation/__init__.py index 7c8972d8..8a0cdeab 100644 --- a/pyrenew/observation/__init__.py +++ b/pyrenew/observation/__init__.py @@ -20,8 +20,6 @@ to observation constructors to control the output distribution. """ -from pyrenew.observation.types import ObservationSample - from pyrenew.observation.base import BaseObservationProcess from pyrenew.observation.count_observations import Counts, CountsBySubpop from pyrenew.observation.measurements import Measurements @@ -34,6 +32,7 @@ PoissonNoise, ) from pyrenew.observation.poisson import PoissonObservation +from pyrenew.observation.types import ObservationSample __all__ = [ # Existing (kept for backward compatibility) diff --git a/pyrenew/observation/count_observations.py b/pyrenew/observation/count_observations.py index 135ccca3..7a883932 100644 --- a/pyrenew/observation/count_observations.py +++ b/pyrenew/observation/count_observations.py @@ -10,11 +10,11 @@ import jax import jax.numpy as jnp from jax.typing import ArrayLike -from pyrenew.observation.types import ObservationSample from pyrenew.metaclass import RandomVariable from pyrenew.observation.base import BaseObservationProcess from pyrenew.observation.noise import CountNoise +from pyrenew.observation.types import ObservationSample class _CountBase(BaseObservationProcess): diff --git a/pyrenew/observation/measurements.py b/pyrenew/observation/measurements.py index 6826dff8..33e38980 100644 --- a/pyrenew/observation/measurements.py +++ b/pyrenew/observation/measurements.py @@ -7,11 +7,11 @@ """ from jax.typing import ArrayLike -from pyrenew.observation.types import ObservationSample from pyrenew.metaclass import RandomVariable from pyrenew.observation.base import BaseObservationProcess from pyrenew.observation.noise import MeasurementNoise +from pyrenew.observation.types import ObservationSample class Measurements(BaseObservationProcess): From 90e1ad04a52a554f774e736afc4458bf22146e17 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Fri, 2 Jan 2026 15:33:57 -0500 Subject: [PATCH 10/33] Update docs/tutorials/observation_processes_counts.qmd Co-authored-by: Dylan H. Morris --- docs/tutorials/observation_processes_counts.qmd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/tutorials/observation_processes_counts.qmd b/docs/tutorials/observation_processes_counts.qmd index b3b415c7..0bda5fb8 100644 --- a/docs/tutorials/observation_processes_counts.qmd +++ b/docs/tutorials/observation_processes_counts.qmd @@ -33,7 +33,7 @@ $$\lambda_t = \alpha \cdot \sum_{d=0}^{D} I_{t-d} \cdot p_d$$ where: - $I_{t-d}$ is the number of incident (new) infections on day $t-d$ (i.e., $d$ days before day $t$) -- $\alpha$ is the ascertainment rate (e.g., infection-hospitalization ratio) +- $\alpha$ is the rate of ascertained counts per infection (e.g., infection-to-hospital admission rate). This can model a mix of biological effects (e.g. some percentage of infections lead to hospital admissions, but not all) and reporting effects (e.g. some percentage of admissions that occur are reported, but not all). - $p_d$ is the delay distribution from infection to observation, conditional on an infection leading to an observation - $D$ is the maximum delay From c752be5ed03440b23807ed4091f0d38cc3e20ffb Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Fri, 2 Jan 2026 15:34:24 -0500 Subject: [PATCH 11/33] Update docs/tutorials/observation_processes_counts.qmd Co-authored-by: Dylan H. Morris --- docs/tutorials/observation_processes_counts.qmd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/tutorials/observation_processes_counts.qmd b/docs/tutorials/observation_processes_counts.qmd index 0bda5fb8..2c574e83 100644 --- a/docs/tutorials/observation_processes_counts.qmd +++ b/docs/tutorials/observation_processes_counts.qmd @@ -24,7 +24,7 @@ from pyrenew import datasets Count observation processes model the lag between infections and an observed outcome such as hospital admissions, emergency department visits, confirmed cases, or deaths. Observed data can be aggregated or available as subpopulation-level counts, which are modeled by classes `Counts` and `CountsBySubpop`, respectively. -Count observation processes transform infections into expected observed counts by applying an ascertainment rate and convolving with a delay distribution. +Count observation processes transform infections into predicted counts by applying an event probability and/or ascertainment rate and convolving with a delay distribution. The expected observations on day $t$ are: From f136bd742c6390f2b7e640c68ec56b5218db70ac Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Fri, 2 Jan 2026 15:34:37 -0500 Subject: [PATCH 12/33] Update docs/tutorials/observation_processes_counts.qmd Co-authored-by: Dylan H. Morris --- docs/tutorials/observation_processes_counts.qmd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/tutorials/observation_processes_counts.qmd b/docs/tutorials/observation_processes_counts.qmd index 2c574e83..b8da8147 100644 --- a/docs/tutorials/observation_processes_counts.qmd +++ b/docs/tutorials/observation_processes_counts.qmd @@ -26,7 +26,7 @@ Observed data can be aggregated or available as subpopulation-level counts, whic Count observation processes transform infections into predicted counts by applying an event probability and/or ascertainment rate and convolving with a delay distribution. -The expected observations on day $t$ are: +The predicted observations on day $t$ are: $$\lambda_t = \alpha \cdot \sum_{d=0}^{D} I_{t-d} \cdot p_d$$ From f93d081cc3b1073cd350ec9ed088918b3e401b35 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Fri, 2 Jan 2026 15:34:56 -0500 Subject: [PATCH 13/33] Update docs/tutorials/observation_processes_counts.qmd Co-authored-by: Dylan H. Morris --- docs/tutorials/observation_processes_counts.qmd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/tutorials/observation_processes_counts.qmd b/docs/tutorials/observation_processes_counts.qmd index b8da8147..b263d8ab 100644 --- a/docs/tutorials/observation_processes_counts.qmd +++ b/docs/tutorials/observation_processes_counts.qmd @@ -37,7 +37,7 @@ where: - $p_d$ is the delay distribution from infection to observation, conditional on an infection leading to an observation - $D$ is the maximum delay -Discrete observations are generated by sampling from a noise distribution—either Poisson or negative binomial—to model reporting variability. +Discrete observations are generated by sampling from a noise distribution—e.g. Poisson or negative binomial—to model reporting variability. Poisson assumes variance equals the mean; negative binomial accommodates the overdispersion common in surveillance data. **Note on terminology:** In real-world inference, infections are *latent* (unobserved) and must be estimated from observed data like hospital admissions. In this tutorial, we simulate the observation process by specifying infections directly and showing how they produce hospital admissions through convolution and sampling. From 648b2c76b9dbe9f9365d3fe410ca0b03cb3f32ca Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Fri, 2 Jan 2026 15:35:21 -0500 Subject: [PATCH 14/33] Update docs/tutorials/observation_processes_counts.qmd Co-authored-by: Dylan H. Morris --- docs/tutorials/observation_processes_counts.qmd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/tutorials/observation_processes_counts.qmd b/docs/tutorials/observation_processes_counts.qmd index b263d8ab..426a9c53 100644 --- a/docs/tutorials/observation_processes_counts.qmd +++ b/docs/tutorials/observation_processes_counts.qmd @@ -40,7 +40,7 @@ where: Discrete observations are generated by sampling from a noise distribution—e.g. Poisson or negative binomial—to model reporting variability. Poisson assumes variance equals the mean; negative binomial accommodates the overdispersion common in surveillance data. -**Note on terminology:** In real-world inference, infections are *latent* (unobserved) and must be estimated from observed data like hospital admissions. In this tutorial, we simulate the observation process by specifying infections directly and showing how they produce hospital admissions through convolution and sampling. +**Note on terminology:** In real-world inference, incident infections are typically a *latent* (unobserved) quantity and must be estimated from observed data like hospital admissions. In this tutorial, we simulate the observation process by specifying infections directly and showing how they produce hospital admissions through convolution and sampling. ## Hospital admissions example From 27f057789f5481c58726875cc860a44df6966ed6 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Tue, 6 Jan 2026 22:07:27 -0500 Subject: [PATCH 15/33] Update docs/tutorials/observation_processes_measurements.qmd Co-authored-by: Dylan H. Morris --- docs/tutorials/observation_processes_measurements.qmd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/tutorials/observation_processes_measurements.qmd b/docs/tutorials/observation_processes_measurements.qmd index 21dbff21..37d8f6b2 100644 --- a/docs/tutorials/observation_processes_measurements.qmd +++ b/docs/tutorials/observation_processes_measurements.qmd @@ -22,7 +22,7 @@ from pyrenew.deterministic import DeterministicVariable, DeterministicPMF ## Overview -Measurement observation processes model continuous signals derived from infections, such as viral RNA concentrations in wastewater. Unlike count observations (hospital admissions, deaths), measurements are continuous values that span orders of magnitude. +Measurement observation processes model continuous signals derived from infections, such as viral RNA concentrations in wastewater. Unlike count observations (hospital admissions, deaths), measurements are continuous values and in some cases may be negative. The expected measurement on day $t$ is: From d9749c7299617dfb1da58daabd34fa71d76fe14c Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Tue, 6 Jan 2026 22:07:44 -0500 Subject: [PATCH 16/33] Update docs/tutorials/observation_processes_measurements.qmd Co-authored-by: Dylan H. Morris --- docs/tutorials/observation_processes_measurements.qmd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/tutorials/observation_processes_measurements.qmd b/docs/tutorials/observation_processes_measurements.qmd index 37d8f6b2..61d794b2 100644 --- a/docs/tutorials/observation_processes_measurements.qmd +++ b/docs/tutorials/observation_processes_measurements.qmd @@ -31,7 +31,7 @@ $$\lambda_t = \frac{G}{V} \cdot \sum_{d=0}^{D} I_{t-d} \cdot p_d$$ where: - $I_{t-d}$ is the number of incident (new) infections on day $t-d$ -- $G$ is a scaling factor (e.g., genome copies shed per infection) +- $G$ is a scaling factor (here, genome copies shed per infection) - $V$ is a normalization factor (e.g., wastewater volume per person per day) - $p_d$ is the temporal distribution (e.g., viral shedding kinetics) - $D$ is the maximum duration From e2a955b9b75effa3d4ba4ee7c0e7220581e2b59a Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Tue, 6 Jan 2026 22:07:58 -0500 Subject: [PATCH 17/33] Update docs/tutorials/observation_processes_measurements.qmd Co-authored-by: Dylan H. Morris --- docs/tutorials/observation_processes_measurements.qmd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/tutorials/observation_processes_measurements.qmd b/docs/tutorials/observation_processes_measurements.qmd index 61d794b2..ab9e1948 100644 --- a/docs/tutorials/observation_processes_measurements.qmd +++ b/docs/tutorials/observation_processes_measurements.qmd @@ -24,7 +24,7 @@ from pyrenew.deterministic import DeterministicVariable, DeterministicPMF Measurement observation processes model continuous signals derived from infections, such as viral RNA concentrations in wastewater. Unlike count observations (hospital admissions, deaths), measurements are continuous values and in some cases may be negative. -The expected measurement on day $t$ is: +As an example, we'll consider the wastewater viral genome measurement from the [Pyrenew-HEW](https://github.com/cdcgov/pyrenew-hew) family of models. The predicted measurement on day $t$ is: $$\lambda_t = \frac{G}{V} \cdot \sum_{d=0}^{D} I_{t-d} \cdot p_d$$ From d30b39113d0512325be447557c307ce5046aba2d Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Tue, 6 Jan 2026 22:08:14 -0500 Subject: [PATCH 18/33] Update docs/tutorials/observation_processes_measurements.qmd Co-authored-by: Dylan H. Morris --- docs/tutorials/observation_processes_measurements.qmd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/tutorials/observation_processes_measurements.qmd b/docs/tutorials/observation_processes_measurements.qmd index ab9e1948..408e2ef9 100644 --- a/docs/tutorials/observation_processes_measurements.qmd +++ b/docs/tutorials/observation_processes_measurements.qmd @@ -32,7 +32,7 @@ where: - $I_{t-d}$ is the number of incident (new) infections on day $t-d$ - $G$ is a scaling factor (here, genome copies shed per infection) -- $V$ is a normalization factor (e.g., wastewater volume per person per day) +- $V$ is another scaling factor (here, wastewater volume per person per day) - $p_d$ is the temporal distribution (e.g., viral shedding kinetics) - $D$ is the maximum duration From a655f993b93cbd31a6f54dbc83913b84525902eb Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Tue, 6 Jan 2026 22:08:33 -0500 Subject: [PATCH 19/33] Update docs/tutorials/observation_processes_measurements.qmd Co-authored-by: Dylan H. Morris --- docs/tutorials/observation_processes_measurements.qmd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/tutorials/observation_processes_measurements.qmd b/docs/tutorials/observation_processes_measurements.qmd index 408e2ef9..eab28e52 100644 --- a/docs/tutorials/observation_processes_measurements.qmd +++ b/docs/tutorials/observation_processes_measurements.qmd @@ -33,7 +33,7 @@ where: - $I_{t-d}$ is the number of incident (new) infections on day $t-d$ - $G$ is a scaling factor (here, genome copies shed per infection) - $V$ is another scaling factor (here, wastewater volume per person per day) -- $p_d$ is the temporal distribution (e.g., viral shedding kinetics) +- $p_d$ is a delay distribution (here, normalized viral genome shedding by time-since-infection) - $D$ is the maximum duration Observed log-concentrations are generated by sampling from a normal distribution: From e37a5724f9fbb15f50b42ecc6096e618acecc6c4 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Tue, 6 Jan 2026 22:08:45 -0500 Subject: [PATCH 20/33] Update docs/tutorials/observation_processes_measurements.qmd Co-authored-by: Dylan H. Morris --- docs/tutorials/observation_processes_measurements.qmd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/tutorials/observation_processes_measurements.qmd b/docs/tutorials/observation_processes_measurements.qmd index eab28e52..81ddc8af 100644 --- a/docs/tutorials/observation_processes_measurements.qmd +++ b/docs/tutorials/observation_processes_measurements.qmd @@ -34,7 +34,7 @@ where: - $G$ is a scaling factor (here, genome copies shed per infection) - $V$ is another scaling factor (here, wastewater volume per person per day) - $p_d$ is a delay distribution (here, normalized viral genome shedding by time-since-infection) -- $D$ is the maximum duration +- $D$ is the maximum delay considered. Observed log-concentrations are generated by sampling from a normal distribution: From 3457f62d1bef6259055e52436cfcb27613e02bab Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Tue, 6 Jan 2026 22:09:02 -0500 Subject: [PATCH 21/33] Update pyrenew/observation/noise.py Co-authored-by: Dylan H. Morris --- pyrenew/observation/noise.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyrenew/observation/noise.py b/pyrenew/observation/noise.py index 647b1781..1ca68129 100644 --- a/pyrenew/observation/noise.py +++ b/pyrenew/observation/noise.py @@ -9,7 +9,7 @@ ----------- - ``PoissonNoise``: Equidispersed counts (variance = mean). No parameters. - ``NegativeBinomialNoise``: Overdispersed counts (variance > mean). - Takes ``concentration_rv`` (higher = less overdispersion). + Takes ``concentration_rv`` (higher concentration = less overdispersed, more Poisson-like). Measurement Noise ----------------- From 25e4dd8d1e6e7df837fbca80033f2b0f9dc1d948 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Tue, 6 Jan 2026 22:09:22 -0500 Subject: [PATCH 22/33] Update pyrenew/observation/noise.py Co-authored-by: Dylan H. Morris --- pyrenew/observation/noise.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyrenew/observation/noise.py b/pyrenew/observation/noise.py index 1ca68129..5e5eb248 100644 --- a/pyrenew/observation/noise.py +++ b/pyrenew/observation/noise.py @@ -8,7 +8,7 @@ Count Noise ----------- - ``PoissonNoise``: Equidispersed counts (variance = mean). No parameters. -- ``NegativeBinomialNoise``: Overdispersed counts (variance > mean). +- ``NegativeBinomialNoise``: Overdispersed counts relative to Poisson (variance > mean). Takes ``concentration_rv`` (higher concentration = less overdispersed, more Poisson-like). Measurement Noise From 4bd8f2e9c0ff556c2ce4a2078ceffe6a3da84344 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Wed, 7 Jan 2026 08:39:58 -0500 Subject: [PATCH 23/33] change expected to predicted --- .../observation_processes_measurements.qmd | 8 +-- pyrenew/observation/base.py | 20 ++++---- pyrenew/observation/count_observations.py | 42 ++++++++-------- pyrenew/observation/measurements.py | 20 ++++---- pyrenew/observation/noise.py | 50 +++++++++---------- pyrenew/observation/types.py | 6 +-- test/test_observation_counts.py | 4 +- test/test_observation_measurements.py | 30 +++++------ 8 files changed, 90 insertions(+), 90 deletions(-) diff --git a/docs/tutorials/observation_processes_measurements.qmd b/docs/tutorials/observation_processes_measurements.qmd index 81ddc8af..321ab226 100644 --- a/docs/tutorials/observation_processes_measurements.qmd +++ b/docs/tutorials/observation_processes_measurements.qmd @@ -65,7 +65,7 @@ The core convolution $\sum_{d=0}^{D} I_{t-d} \cdot p_d$ is the same as for count ## Subclassing Measurements for wastewater -The `Measurements` class is abstract—you must subclass it and implement `_expected_signal()` for your specific signal type. Here we create a `Wastewater` class for viral concentration measurements: +The `Measurements` class is abstract—you must subclass it and implement `_predicted_obs()` for your specific signal type. Here we create a `Wastewater` class for viral concentration measurements: ```{python} # | label: wastewater-class @@ -78,7 +78,7 @@ class Wastewater(Measurements): """ Wastewater viral concentration observation process. - Transforms site-level infections into expected log-concentrations + Transforms site-level infections into predicted log-concentrations via shedding kinetics convolution and genome/volume scaling. """ @@ -117,9 +117,9 @@ class Wastewater(Measurements): """Return shedding PMF length.""" return len(self.temporal_pmf_rv()) - def _expected_signal(self, infections: ArrayLike) -> ArrayLike: + def _predicted_obs(self, infections: ArrayLike) -> ArrayLike: """ - Compute expected log-concentration from infections. + Compute predicted log-concentration from infections. Applies shedding kinetics convolution, then scales by genome copies and volume to get concentration. diff --git a/pyrenew/observation/base.py b/pyrenew/observation/base.py index 635d386f..3ee71b83 100644 --- a/pyrenew/observation/base.py +++ b/pyrenew/observation/base.py @@ -39,12 +39,12 @@ class BaseObservationProcess(RandomVariable): - ``validate()``: Validate parameters (call ``_validate_pmf()`` for PMFs) - ``lookback_days()``: Return PMF length for initialization - ``infection_resolution()``: Return ``"aggregate"`` or ``"subpop"`` - - ``_expected_signal()``: Transform infections to expected values - - ``sample()``: Apply noise model to expected signal + - ``_predicted_obs()``: Transform infections to predicted values + - ``sample()``: Apply noise model to predicted observations Notes ----- - Computing expected observations on day t requires infection history + Computing predicted observations on day t requires infection history from previous days (determined by the temporal PMF length). The first ``len(pmf) - 1`` days have insufficient history and return NaN. @@ -257,7 +257,7 @@ def _deterministic(self, name: str, value: ArrayLike) -> None: Track a deterministic quantity in the numpyro execution trace. This is a convenience wrapper around ``numpyro.deterministic`` for - tracking intermediate quantities (e.g., latent admissions, expected + tracking intermediate quantities (e.g., latent admissions, predicted concentrations) that are useful for diagnostics and model checking. These quantities are stored in MCMC samples and can be used for model diagnostics and posterior predictive checks. @@ -272,16 +272,16 @@ def _deterministic(self, name: str, value: ArrayLike) -> None: numpyro.deterministic(name, value) @abstractmethod - def _expected_signal( + def _predicted_obs( self, infections: ArrayLike, ) -> ArrayLike: """ - Transform infections to expected observation values. + Transform infections to predicted observation values. This is the core transformation that each observation process must implement. It converts infections (from the infection process) - to expected values for the observation model. + to predicted values for the observation model. Parameters ---------- @@ -293,14 +293,14 @@ def _expected_signal( Returns ------- ArrayLike - Expected observation values (counts, log-concentrations, etc.). + Predicted observation values (counts, log-concentrations, etc.). Same shape as input, with first len(pmf)-1 days as NaN. Notes ----- The transformation is observation-specific: - - Count observations: ascertainment x delay convolution -> expected counts + - Count observations: ascertainment x delay convolution -> predicted counts - Wastewater: shedding convolution -> genome scaling -> dilution -> log See Also @@ -315,7 +315,7 @@ def sample(self, **kwargs) -> ArrayLike: Sample from the observation process. Subclasses must implement this method to define the specific - observation model. Typically calls ``_expected_signal`` first, + observation model. Typically calls ``_predicted_obs`` first, then applies the noise model. Parameters diff --git a/pyrenew/observation/count_observations.py b/pyrenew/observation/count_observations.py index 7a883932..6d89d0f5 100644 --- a/pyrenew/observation/count_observations.py +++ b/pyrenew/observation/count_observations.py @@ -90,12 +90,12 @@ def infection_resolution(self) -> str: """ raise NotImplementedError("Subclasses must implement infection_resolution()") - def _expected_signal( + def _predicted_obs( self, infections: ArrayLike, ) -> ArrayLike: """ - Compute expected counts via ascertainment x delay convolution. + Compute predicted counts via ascertainment x delay convolution. Parameters ---------- @@ -107,7 +107,7 @@ def _expected_signal( Returns ------- ArrayLike - Expected counts with timeline alignment. + Predicted counts with timeline alignment. Same shape as input. First len(delay_pmf)-1 days are NaN. """ @@ -121,9 +121,9 @@ def _expected_signal( def convolve_col(col): # numpydoc ignore=GL08 return self._convolve_with_alignment(col, delay_pmf, ascertainment_rate)[0] - expected_counts = jax.vmap(convolve_col, in_axes=1, out_axes=1)(infections) + predicted_counts = jax.vmap(convolve_col, in_axes=1, out_axes=1)(infections) - return expected_counts[:, 0] if is_1d else expected_counts + return predicted_counts[:, 0] if is_1d else predicted_counts class Counts(_CountBase): @@ -211,25 +211,25 @@ def sample( ------- ObservationSample Named tuple with `observed` (sampled/conditioned counts) and - `expected` (expected counts before noise). + `predicted` (predicted counts before noise). """ - expected_counts = self._expected_signal(infections) - self._deterministic("expected_counts", expected_counts) - expected_counts_safe = jnp.nan_to_num(expected_counts, nan=0.0) + predicted_counts = self._predicted_obs(infections) + self._deterministic("predicted_counts", predicted_counts) + predicted_counts_safe = jnp.nan_to_num(predicted_counts, nan=0.0) # Only use sparse indexing when conditioning on observations if times is not None and counts is not None: - expected_obs = expected_counts_safe[times] + predicted_obs = predicted_counts_safe[times] else: - expected_obs = expected_counts_safe + predicted_obs = predicted_counts_safe observed = self.noise.sample( name="counts", - expected=expected_obs, + predicted=predicted_obs, obs=counts, ) - return ObservationSample(observed=observed, expected=expected_counts) + return ObservationSample(observed=observed, predicted=predicted_counts) class CountsBySubpop(_CountBase): @@ -328,21 +328,21 @@ def sample( ------- ObservationSample Named tuple with `observed` (sampled/conditioned counts) and - `expected` (expected counts before noise, shape: n_days x n_subpops). + `predicted` (predicted counts before noise, shape: n_days x n_subpops). """ - # Compute expected counts for all subpops - expected_counts_all = self._expected_signal(infections) + # Compute predicted counts for all subpops + predicted_counts_all = self._predicted_obs(infections) - self._deterministic("expected_counts_by_subpop", expected_counts_all) + self._deterministic("predicted_counts_by_subpop", predicted_counts_all) # Replace NaN padding with 0 for distribution creation - expected_counts_safe = jnp.nan_to_num(expected_counts_all, nan=0.0) - expected_obs = expected_counts_safe[times, subpop_indices] + predicted_counts_safe = jnp.nan_to_num(predicted_counts_all, nan=0.0) + predicted_obs = predicted_counts_safe[times, subpop_indices] observed = self.noise.sample( name="counts_by_subpop", - expected=expected_obs, + predicted=predicted_obs, obs=counts, ) - return ObservationSample(observed=observed, expected=expected_counts_all) + return ObservationSample(observed=observed, predicted=predicted_counts_all) diff --git a/pyrenew/observation/measurements.py b/pyrenew/observation/measurements.py index 33e38980..d9c17d0a 100644 --- a/pyrenew/observation/measurements.py +++ b/pyrenew/observation/measurements.py @@ -19,7 +19,7 @@ class Measurements(BaseObservationProcess): Abstract base for continuous measurement observations. Subclasses implement signal-specific transformations from infections - to expected measurement values, then add measurement noise. + to predicted measurement values, then add measurement noise. Parameters ---------- @@ -31,7 +31,7 @@ class Measurements(BaseObservationProcess): Notes ----- - Subclasses must implement ``_expected_signal()`` according to their + Subclasses must implement ``_predicted_obs()`` according to their specific signal processing (e.g., wastewater shedding kinetics, dilution factors, etc.). @@ -99,8 +99,8 @@ def sample( (times, subpop_indices, sensor_indices). Validate observation data before sampling. - Transforms infections to expected values via signal-specific processing - (``_expected_signal``), then applies noise model. + Transforms infections to predicted values via signal-specific processing + (``_predicted_obs``), then applies noise model. Parameters ---------- @@ -125,20 +125,20 @@ def sample( ------- ObservationSample Named tuple with `observed` (sampled/conditioned measurements) and - `expected` (expected values before noise, shape: n_days x n_subpops). + `predicted` (predicted values before noise, shape: n_days x n_subpops). """ - expected_values = self._expected_signal(infections) + predicted_values = self._predicted_obs(infections) - self._deterministic("expected_log_conc", expected_values) + self._deterministic("predicted_log_conc", predicted_values) - expected_obs = expected_values[times, subpop_indices] + predicted_obs = predicted_values[times, subpop_indices] observed = self.noise.sample( name="concentrations", - expected=expected_obs, + predicted=predicted_obs, obs=concentrations, sensor_indices=sensor_indices, n_sensors=n_sensors, ) - return ObservationSample(observed=observed, expected=expected_values) + return ObservationSample(observed=observed, predicted=predicted_values) diff --git a/pyrenew/observation/noise.py b/pyrenew/observation/noise.py index 5e5eb248..1de84f42 100644 --- a/pyrenew/observation/noise.py +++ b/pyrenew/observation/noise.py @@ -36,32 +36,32 @@ class CountNoise(ABC): """ Abstract base for count observation noise models. - Defines how discrete count observations are distributed around expected values. + Defines how discrete count observations are distributed around predicted values. """ @abstractmethod def sample( self, name: str, - expected: ArrayLike, + predicted: ArrayLike, obs: ArrayLike | None = None, ) -> ArrayLike: """ - Sample count observations given expected counts. + Sample count observations given predicted counts. Parameters ---------- name : str Numpyro sample site name. - expected : ArrayLike - Expected count values (non-negative). + predicted : ArrayLike + Predicted count values (non-negative). obs : ArrayLike | None Observed counts for conditioning, or None for prior sampling. Returns ------- ArrayLike - Sampled or conditioned counts, same shape as expected. + Sampled or conditioned counts, same shape as predicted. """ pass # pragma: no cover @@ -98,7 +98,7 @@ def validate(self) -> None: def sample( self, name: str, - expected: ArrayLike, + predicted: ArrayLike, obs: ArrayLike | None = None, ) -> ArrayLike: """ @@ -108,8 +108,8 @@ def sample( ---------- name : str Numpyro sample site name. - expected : ArrayLike - Expected count values. + predicted : ArrayLike + Predicted count values. obs : ArrayLike | None Observed counts for conditioning. @@ -120,7 +120,7 @@ def sample( """ return numpyro.sample( name, - dist.Poisson(rate=expected + _EPSILON), + dist.Poisson(rate=predicted + _EPSILON), obs=obs, ) @@ -178,7 +178,7 @@ def validate(self) -> None: def sample( self, name: str, - expected: ArrayLike, + predicted: ArrayLike, obs: ArrayLike | None = None, ) -> ArrayLike: """ @@ -188,8 +188,8 @@ def sample( ---------- name : str Numpyro sample site name. - expected : ArrayLike - Expected count values. + predicted : ArrayLike + Predicted count values. obs : ArrayLike | None Observed counts for conditioning. @@ -202,7 +202,7 @@ def sample( return numpyro.sample( name, dist.NegativeBinomial2( - mean=expected + _EPSILON, + mean=predicted + _EPSILON, concentration=concentration, ), obs=obs, @@ -213,26 +213,26 @@ class MeasurementNoise(ABC): """ Abstract base for continuous measurement noise models. - Defines how continuous observations are distributed around expected values. + Defines how continuous observations are distributed around predicted values. """ @abstractmethod def sample( self, name: str, - expected: ArrayLike, + predicted: ArrayLike, obs: ArrayLike | None = None, **kwargs, ) -> ArrayLike: """ - Sample continuous observations given expected values. + Sample continuous observations given predicted values. Parameters ---------- name : str Numpyro sample site name. - expected : ArrayLike - Expected measurement values. + predicted : ArrayLike + Predicted measurement values. obs : ArrayLike | None Observed measurements for conditioning, or None for prior sampling. **kwargs @@ -241,7 +241,7 @@ def sample( Returns ------- ArrayLike - Sampled or conditioned measurements, same shape as expected. + Sampled or conditioned measurements, same shape as predicted. """ pass # pragma: no cover @@ -262,7 +262,7 @@ class HierarchicalNormalNoise(MeasurementNoise): """ Normal noise with hierarchical sensor-level effects. - Observation model: ``obs ~ Normal(expected + sensor_mode, sensor_sd)`` + Observation model: ``obs ~ Normal(predicted + sensor_mode, sensor_sd)`` where sensor_mode and sensor_sd are hierarchically modeled. Parameters @@ -327,7 +327,7 @@ def validate(self) -> None: def sample( self, name: str, - expected: ArrayLike, + predicted: ArrayLike, obs: ArrayLike | None = None, *, sensor_indices: ArrayLike, @@ -340,8 +340,8 @@ def sample( ---------- name : str Numpyro sample site name. - expected : ArrayLike - Expected log-scale measurement values. + predicted : ArrayLike + Predicted log-scale measurement values. Shape: (n_obs,) obs : ArrayLike | None Observed log-scale measurements for conditioning. @@ -366,7 +366,7 @@ def sample( sensor_mode = self.sensor_mode_prior_rv.sample(n_groups=n_sensors) sensor_sd = self.sensor_sd_prior_rv.sample(n_groups=n_sensors) - loc = expected + sensor_mode[sensor_indices] + loc = predicted + sensor_mode[sensor_indices] scale = sensor_sd[sensor_indices] return numpyro.sample(name, dist.Normal(loc=loc, scale=scale), obs=obs) diff --git a/pyrenew/observation/types.py b/pyrenew/observation/types.py index b494e2e6..f8a1163c 100644 --- a/pyrenew/observation/types.py +++ b/pyrenew/observation/types.py @@ -19,10 +19,10 @@ class ObservationSample(NamedTuple): observed : ArrayLike Sampled or conditioned observations. Shape depends on the observation process and indexing. - expected : ArrayLike - Expected values before noise is applied. Useful for + predicted : ArrayLike + Predicted values before noise is applied. Useful for diagnostics and posterior predictive checks. """ observed: ArrayLike - expected: ArrayLike + predicted: ArrayLike diff --git a/test/test_observation_counts.py b/test/test_observation_counts.py index ddb4d684..57bdb37a 100644 --- a/test/test_observation_counts.py +++ b/test/test_observation_counts.py @@ -73,7 +73,7 @@ def test_sample_returns_correct_shape(self, counts_process): assert result.observed.shape[0] > 0 assert result.observed.ndim == 1 - assert result.expected.shape == infections.shape + assert result.predicted.shape == infections.shape def test_delay_convolution(self, counts_factory, short_delay_pmf): """Test that delay is properly applied.""" @@ -340,7 +340,7 @@ def test_sample_returns_correct_shape(self): ) assert result.observed.shape == times.shape - assert result.expected.shape == infections.shape + assert result.predicted.shape == infections.shape def test_infection_resolution(self): """Test that CountsBySubpop returns 'subpop' resolution.""" diff --git a/test/test_observation_measurements.py b/test/test_observation_measurements.py index e5c014a2..a3af44a1 100644 --- a/test/test_observation_measurements.py +++ b/test/test_observation_measurements.py @@ -38,14 +38,14 @@ def lookback_days(self) -> int: """ return len(self.temporal_pmf_rv()) - def _expected_signal(self, infections): + def _predicted_obs(self, infections): """ - Simple expected signal: log(convolution * scale). + Simple predicted signal: log(convolution * scale). Returns ------- jnp.ndarray - Log-transformed expected signal. + Log-transformed predicted signal. """ pmf = self.temporal_pmf_rv() @@ -58,12 +58,12 @@ def convolve_col(col): # numpydoc ignore=GL08 import jax - expected = jax.vmap(convolve_col, in_axes=1, out_axes=1)(infections) + predicted = jax.vmap(convolve_col, in_axes=1, out_axes=1)(infections) # Apply log10 scaling (simplified from wastewater model) - log_expected = jnp.log(expected + 1e-10) + self.log10_scale * jnp.log(10) + log_predicted = jnp.log(predicted + 1e-10) + self.log10_scale * jnp.log(10) - return log_expected + return log_predicted class TestMeasurementsBase: @@ -123,19 +123,19 @@ def test_sample_shape(self): ) noise = HierarchicalNormalNoise(sensor_mode_prior, sensor_sd_prior) - expected = jnp.array([1.0, 2.0, 3.0, 4.0]) + predicted = jnp.array([1.0, 2.0, 3.0, 4.0]) sensor_indices = jnp.array([0, 0, 1, 1]) with numpyro.handlers.seed(rng_seed=42): samples = noise.sample( name="test", - expected=expected, + predicted=predicted, obs=None, sensor_indices=sensor_indices, n_sensors=2, ) - assert samples.shape == expected.shape + assert samples.shape == predicted.shape def test_sample_with_observations(self): """Test that HierarchicalNormalNoise conditions on observations.""" @@ -149,14 +149,14 @@ def test_sample_with_observations(self): ) noise = HierarchicalNormalNoise(sensor_mode_prior, sensor_sd_prior) - expected = jnp.array([1.0, 2.0, 3.0, 4.0]) + predicted = jnp.array([1.0, 2.0, 3.0, 4.0]) obs = jnp.array([1.1, 2.1, 3.1, 4.1]) sensor_indices = jnp.array([0, 0, 1, 1]) with numpyro.handlers.seed(rng_seed=42): samples = noise.sample( name="test", - expected=expected, + predicted=predicted, obs=obs, sensor_indices=sensor_indices, n_sensors=2, @@ -204,10 +204,10 @@ def test_sample_shape(self): ) assert result.observed.shape == times.shape - assert result.expected.shape == infections.shape + assert result.predicted.shape == infections.shape - def test_expected_signal_stored(self): - """Test that expected_log_conc is stored as deterministic.""" + def test_predicted_obs_stored(self): + """Test that predicted_log_conc is stored as deterministic.""" shedding_pmf = jnp.array([0.5, 0.5]) sensor_mode_prior = HierarchicalNormalPrior( name="mode", sd_rv=DeterministicVariable("mode_sd", 0.01) @@ -242,7 +242,7 @@ def test_expected_signal_stored(self): ) ).get_trace() - assert "expected_log_conc" in trace + assert "predicted_log_conc" in trace if __name__ == "__main__": From 4e429bccbf76b8f1403ed80318ce311d23fc67bd Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Wed, 7 Jan 2026 10:17:29 -0500 Subject: [PATCH 24/33] standardize on param name obs for sample method --- .../observation_processes_counts.qmd | 12 +++--- .../observation_processes_measurements.qmd | 10 ++--- pyrenew/observation/count_observations.py | 18 ++++----- pyrenew/observation/measurements.py | 6 +-- test/test_observation_counts.py | 38 +++++++++---------- test/test_observation_measurements.py | 4 +- 6 files changed, 43 insertions(+), 45 deletions(-) diff --git a/docs/tutorials/observation_processes_counts.qmd b/docs/tutorials/observation_processes_counts.qmd index 426a9c53..ad50e074 100644 --- a/docs/tutorials/observation_processes_counts.qmd +++ b/docs/tutorials/observation_processes_counts.qmd @@ -166,7 +166,7 @@ infections = infections.at[infection_spike_day].set(2000) with numpyro.handlers.seed(rng_seed=42): hosp_admissions = hosp_process.sample( infections=infections, - counts=None, # Sample from prior (no observed data) + obs=None, # Sample from prior (no observed data) ) ``` @@ -246,7 +246,7 @@ samples = [] for seed in range(n_samples): with numpyro.handlers.seed(rng_seed=seed): - result = hosp_process.sample(infections=infections, counts=None) + result = hosp_process.sample(infections=infections, obs=None) samples.append(np.array(result.observed[day_one:])) samples = np.array(samples) @@ -328,7 +328,7 @@ for ihr_val in ihr_values: samples = [] for seed in range(20): with numpyro.handlers.seed(rng_seed=seed): - result = process.sample(infections=infections_decay, counts=None) + result = process.sample(infections=infections_decay, obs=None) samples.append(np.array(result.observed[day_one:])) results[ihr_val] = np.mean(samples, axis=0) ``` @@ -381,9 +381,7 @@ for ax, conc_val in zip(axes, concentration_values): for seed in range(n_replicates): with numpyro.handlers.seed(rng_seed=seed): - result = process.sample( - infections=infections_constant, counts=None - ) + result = process.sample(infections=infections_constant, obs=None) ax.plot( np.arange(n_plot_days), np.array(result.observed[day_one:]), @@ -416,7 +414,7 @@ hosp_process_poisson = Counts( with numpyro.handlers.seed(rng_seed=42): poisson_result = hosp_process_poisson.sample( infections=infections, - counts=None, + obs=None, ) print( diff --git a/docs/tutorials/observation_processes_measurements.qmd b/docs/tutorials/observation_processes_measurements.qmd index 321ab226..7d0b458c 100644 --- a/docs/tutorials/observation_processes_measurements.qmd +++ b/docs/tutorials/observation_processes_measurements.qmd @@ -292,7 +292,7 @@ with numpyro.handlers.seed(rng_seed=42): subpop_indices=jnp.zeros(n_obs, dtype=jnp.int32), sensor_indices=jnp.zeros(n_obs, dtype=jnp.int32), times=observation_days, - concentrations=None, # Sample from prior + obs=None, n_sensors=1, ) ``` @@ -340,7 +340,7 @@ for seed in range(n_samples): subpop_indices=jnp.zeros(n_obs, dtype=jnp.int32), sensor_indices=jnp.zeros(n_obs, dtype=jnp.int32), times=observation_days, - concentrations=None, + obs=None, n_sensors=1, ) for day_idx, conc in zip(observation_days, ww_result.observed): @@ -427,7 +427,7 @@ with numpyro.handlers.seed(rng_seed=42): subpop_indices=subpop_ids, sensor_indices=sensor_ids, times=observation_times, - concentrations=None, + obs=None, n_sensors=num_sensors, ) ``` @@ -488,7 +488,7 @@ for genome_val in genome_values: subpop_indices=jnp.zeros(n_obs_cmp, dtype=jnp.int32), sensor_indices=jnp.zeros(n_obs_cmp, dtype=jnp.int32), times=observation_days_cmp, - concentrations=None, + obs=None, n_sensors=1, ) results[genome_val] = np.array(ww_temp.observed) @@ -545,7 +545,7 @@ with numpyro.handlers.seed(rng_seed=42): subpop_indices=subpop_ids_mp, sensor_indices=sensor_ids_mp, times=obs_days_mp, - concentrations=None, + obs=None, n_sensors=2, ) ``` diff --git a/pyrenew/observation/count_observations.py b/pyrenew/observation/count_observations.py index 6d89d0f5..48a8d574 100644 --- a/pyrenew/observation/count_observations.py +++ b/pyrenew/observation/count_observations.py @@ -163,7 +163,7 @@ class Counts(_CountBase): >>> >>> with numpyro.handlers.seed(rng_seed=42): ... infections = jnp.ones(30) * 1000 - ... sampled_counts = counts_obs.sample(infections=infections, counts=None) + ... sampled_counts = counts_obs.sample(infections=infections, obs=None) """ def infection_resolution(self) -> str: @@ -188,7 +188,7 @@ def __repr__(self) -> str: def sample( self, infections: ArrayLike, - counts: ArrayLike | None = None, + obs: ArrayLike | None = None, times: ArrayLike | None = None, ) -> ObservationSample: """ @@ -202,7 +202,7 @@ def sample( infections : ArrayLike Aggregate infections from the infection process. Shape: (n_days,) - counts : ArrayLike | None + obs : ArrayLike | None Observed counts. Dense: (n_days,), Sparse: (n_obs,), None: prior. times : ArrayLike | None Day indices for sparse observations. None for dense observations. @@ -218,7 +218,7 @@ def sample( predicted_counts_safe = jnp.nan_to_num(predicted_counts, nan=0.0) # Only use sparse indexing when conditioning on observations - if times is not None and counts is not None: + if times is not None and obs is not None: predicted_obs = predicted_counts_safe[times] else: predicted_obs = predicted_counts_safe @@ -226,7 +226,7 @@ def sample( observed = self.noise.sample( name="counts", predicted=predicted_obs, - obs=counts, + obs=obs, ) return ObservationSample(observed=observed, predicted=predicted_counts) @@ -274,7 +274,7 @@ class CountsBySubpop(_CountBase): ... infections=infections, ... subpop_indices=subpop_indices, ... times=times, - ... counts=None, + ... obs=None, ... ) """ @@ -302,7 +302,7 @@ def sample( infections: ArrayLike, subpop_indices: ArrayLike, times: ArrayLike, - counts: ArrayLike | None = None, + obs: ArrayLike | None = None, ) -> ObservationSample: """ Sample subpopulation-level counts with flexible indexing. @@ -321,7 +321,7 @@ def sample( times : ArrayLike Day index for each observation (0-indexed). Shape: (n_obs,) - counts : ArrayLike | None + obs : ArrayLike | None Observed counts (n_obs,), or None for prior sampling. Returns @@ -342,7 +342,7 @@ def sample( observed = self.noise.sample( name="counts_by_subpop", predicted=predicted_obs, - obs=counts, + obs=obs, ) return ObservationSample(observed=observed, predicted=predicted_counts_all) diff --git a/pyrenew/observation/measurements.py b/pyrenew/observation/measurements.py index d9c17d0a..82043e65 100644 --- a/pyrenew/observation/measurements.py +++ b/pyrenew/observation/measurements.py @@ -89,7 +89,7 @@ def sample( subpop_indices: ArrayLike, sensor_indices: ArrayLike, times: ArrayLike, - concentrations: ArrayLike | None, + obs: ArrayLike | None, n_sensors: int, ) -> ObservationSample: """ @@ -116,7 +116,7 @@ def sample( times : ArrayLike Day index for each observation (0-indexed). Shape: (n_obs,) - concentrations : ArrayLike | None + obs : ArrayLike | None Observed measurements (n_obs,), or None for prior sampling. n_sensors : int Total number of measurement sensors. @@ -136,7 +136,7 @@ def sample( observed = self.noise.sample( name="concentrations", predicted=predicted_obs, - obs=concentrations, + obs=obs, sensor_indices=sensor_indices, n_sensors=n_sensors, ) diff --git a/test/test_observation_counts.py b/test/test_observation_counts.py index 57bdb37a..bcfbd54f 100644 --- a/test/test_observation_counts.py +++ b/test/test_observation_counts.py @@ -68,7 +68,7 @@ def test_sample_returns_correct_shape(self, counts_process): with numpyro.handlers.seed(rng_seed=42): result = counts_process.sample( infections=infections, - counts=None, + obs=None, ) assert result.observed.shape[0] > 0 @@ -85,7 +85,7 @@ def test_delay_convolution(self, counts_factory, short_delay_pmf): with numpyro.handlers.seed(rng_seed=42): result = process.sample( infections=infections, - counts=None, + obs=None, ) # Timeline alignment: output length equals input length @@ -108,7 +108,7 @@ def test_ascertainment_scaling(self, counts_factory, simple_delay_pmf): with numpyro.handlers.seed(rng_seed=42): result = process.sample( infections=infections, - counts=None, + obs=None, ) results.append(jnp.mean(result.observed)) @@ -130,7 +130,7 @@ def test_negative_binomial_observation(self, counts_factory, simple_delay_pmf): with numpyro.handlers.seed(rng_seed=seed): result = process.sample( infections=infections, - counts=None, + obs=None, ) samples.append(jnp.sum(result.observed)) @@ -158,7 +158,7 @@ def test_with_stochastic_ascertainment(self, short_shedding_pmf): with numpyro.handlers.seed(rng_seed=42): result = process.sample( infections=infections, - counts=None, + obs=None, ) assert result.observed.shape[0] > 0 @@ -183,7 +183,7 @@ def test_with_stochastic_concentration(self, simple_delay_pmf): with numpyro.handlers.seed(rng_seed=42): result = process.sample( infections=infections, - counts=None, + obs=None, ) assert result.observed.shape[0] > 0 @@ -200,7 +200,7 @@ def test_zero_infections(self, counts_process): with numpyro.handlers.seed(rng_seed=42): result = counts_process.sample( infections=infections, - counts=None, + obs=None, ) assert result.observed.shape[0] > 0 @@ -212,7 +212,7 @@ def test_small_infections(self, counts_process): with numpyro.handlers.seed(rng_seed=42): result = counts_process.sample( infections=infections, - counts=None, + obs=None, ) assert result.observed.shape[0] > 0 @@ -227,7 +227,7 @@ def test_long_delay_distribution(self, counts_factory, long_delay_pmf): with numpyro.handlers.seed(rng_seed=42): result = process.sample( infections=infections, - counts=None, + obs=None, ) # Timeline alignment maintained @@ -252,7 +252,7 @@ def test_sparse_observations(self, counts_process): with numpyro.handlers.seed(rng_seed=42): result = counts_process.sample( infections=infections, - counts=counts_data, + obs=counts_data, times=times, ) @@ -264,11 +264,11 @@ def test_sparse_vs_dense_sampling(self, counts_process): n_days = 30 infections = jnp.ones(n_days) * 100 - # Dense: prior sampling (counts=None, no times) + # Dense: prior sampling (obs=None, no times) with numpyro.handlers.seed(rng_seed=42): dense_result = counts_process.sample( infections=infections, - counts=None, + obs=None, ) # Sparse with observed data: only some days @@ -277,7 +277,7 @@ def test_sparse_vs_dense_sampling(self, counts_process): with numpyro.handlers.seed(rng_seed=42): sparse_result = counts_process.sample( infections=infections, - counts=sparse_obs_data, + obs=sparse_obs_data, times=times, ) @@ -289,23 +289,23 @@ def test_sparse_vs_dense_sampling(self, counts_process): assert jnp.allclose(sparse_result.observed, sparse_obs_data) def test_prior_sampling_ignores_times(self, counts_process): - """Test that times parameter is ignored when counts=None (prior sampling).""" + """Test that times parameter is ignored when obs=None (prior sampling).""" n_days = 30 infections = jnp.ones(n_days) * 100 times = jnp.array([5, 10, 15, 20]) - # When counts=None, times is ignored - output is dense + # When obs=None, times is ignored - output is dense with numpyro.handlers.seed(rng_seed=42): result_with_times = counts_process.sample( infections=infections, - counts=None, + obs=None, times=times, ) with numpyro.handlers.seed(rng_seed=42): result_without_times = counts_process.sample( infections=infections, - counts=None, + obs=None, ) # Both should produce dense output of shape (n_days,) @@ -336,7 +336,7 @@ def test_sample_returns_correct_shape(self): infections=infections, subpop_indices=subpop_indices, times=times, - counts=None, + obs=None, ) assert result.observed.shape == times.shape @@ -370,7 +370,7 @@ def test_poisson_counts(self, simple_delay_pmf): with numpyro.handlers.seed(rng_seed=42): result = process.sample( infections=infections, - counts=None, + obs=None, ) assert result.observed.shape[0] == 20 diff --git a/test/test_observation_measurements.py b/test/test_observation_measurements.py index a3af44a1..ae5b70c2 100644 --- a/test/test_observation_measurements.py +++ b/test/test_observation_measurements.py @@ -199,7 +199,7 @@ def test_sample_shape(self): subpop_indices=subpop_indices, sensor_indices=sensor_indices, times=times, - concentrations=None, + obs=None, n_sensors=2, ) @@ -237,7 +237,7 @@ def test_predicted_obs_stored(self): subpop_indices=subpop_indices, sensor_indices=sensor_indices, times=times, - concentrations=None, + obs=None, n_sensors=2, ) ).get_trace() From d05249a510f27f8e677ef41ca32c90fca8d63f8d Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Wed, 7 Jan 2026 10:26:03 -0500 Subject: [PATCH 25/33] standardize sample - all methods must have arg obs --- pyrenew/observation/base.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/pyrenew/observation/base.py b/pyrenew/observation/base.py index 3ee71b83..f415671a 100644 --- a/pyrenew/observation/base.py +++ b/pyrenew/observation/base.py @@ -310,7 +310,11 @@ def _predicted_obs( pass # pragma: no cover @abstractmethod - def sample(self, **kwargs) -> ArrayLike: + def sample( + self, + obs: ArrayLike | None = None, + **kwargs, + ) -> ArrayLike: """ Sample from the observation process. @@ -320,11 +324,10 @@ def sample(self, **kwargs) -> ArrayLike: Parameters ---------- + obs : ArrayLike | None + Observed data for conditioning, or None for prior predictive sampling. **kwargs - Subclass-specific parameters. At minimum, should include: - - - infections from the infection process - - Observed data (or None for prior predictive sampling) + Subclass-specific parameters (e.g., infections from the infection process). Returns ------- From 613a8f1748567975291a238c7d79c0a534a0e611 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Wed, 7 Jan 2026 11:34:20 -0500 Subject: [PATCH 26/33] in order to use np TruncatedDistribution, change use of Gamma to TruncatedNormal --- .../observation_processes_measurements.qmd | 18 +++- pyrenew/observation/noise.py | 2 +- pyrenew/randomvariable/__init__.py | 4 +- pyrenew/randomvariable/hierarchical.py | 89 ++++++++++--------- test/conftest.py | 26 +++--- test/test_observation_measurements.py | 38 ++++---- test/test_randomvariable_hierarchical.py | 48 +++++----- 7 files changed, 118 insertions(+), 107 deletions(-) diff --git a/docs/tutorials/observation_processes_measurements.qmd b/docs/tutorials/observation_processes_measurements.qmd index 7d0b458c..9b32567e 100644 --- a/docs/tutorials/observation_processes_measurements.qmd +++ b/docs/tutorials/observation_processes_measurements.qmd @@ -16,7 +16,10 @@ import numpyro import matplotlib.pyplot as plt from pyrenew.observation import Measurements, HierarchicalNormalNoise -from pyrenew.randomvariable import HierarchicalNormalPrior, GammaGroupSdPrior +from pyrenew.randomvariable import ( + HierarchicalNormalPrior, + TruncatedNormalGroupSdPrior, +) from pyrenew.deterministic import DeterministicVariable, DeterministicPMF ``` @@ -212,6 +215,13 @@ ml_per_person_per_day = 1000.0 A measurement depends on both the thing being measured (e.g., wastewater from a treatment plant) and the calibration of the lab/instruments used to obtain that measurement. We call this combination a "sensor"—the WWTP/lab pair that determines the measurement characteristics. Different sensors have systematic biases and variabilities that we model with hierarchical effects. +For the sensor-level standard deviations, we use a truncated normal prior. This is a common choice for SD parameters because: + +- It naturally constrains values to be positive +- The `loc` parameter represents our prior expectation for the SD +- The `scale` parameter controls how uncertain we are about the SD +- The `sd_min` floor prevents numerical issues when SDs approach zero + ```{python} # | label: noise-model # Sensor-level mode: systematic differences between sensors @@ -221,10 +231,10 @@ sensor_mode_prior = HierarchicalNormalPrior( ) # Sensor-level SD: measurement variability within each sensor -sensor_sd_prior = GammaGroupSdPrior( +sensor_sd_prior = TruncatedNormalGroupSdPrior( name="ww_sensor_sd", - sd_mean_rv=DeterministicVariable("sd_mean", 0.3), - sd_concentration_rv=DeterministicVariable("sd_concentration", 4.0), + loc_rv=DeterministicVariable("sd_loc", 0.3), + scale_rv=DeterministicVariable("sd_scale", 0.15), sd_min=0.10, ) diff --git a/pyrenew/observation/noise.py b/pyrenew/observation/noise.py index 1de84f42..47d2d4ad 100644 --- a/pyrenew/observation/noise.py +++ b/pyrenew/observation/noise.py @@ -282,7 +282,7 @@ class HierarchicalNormalNoise(MeasurementNoise): -------- pyrenew.randomvariable.HierarchicalNormalPrior : Suitable prior for sensor_mode_prior_rv - pyrenew.randomvariable.GammaGroupSdPrior : + pyrenew.randomvariable.TruncatedNormalGroupSdPrior : Suitable prior for sensor_sd_prior_rv """ diff --git a/pyrenew/randomvariable/__init__.py b/pyrenew/randomvariable/__init__.py index 75594c58..09812b9c 100644 --- a/pyrenew/randomvariable/__init__.py +++ b/pyrenew/randomvariable/__init__.py @@ -6,9 +6,9 @@ StaticDistributionalVariable, ) from pyrenew.randomvariable.hierarchical import ( - GammaGroupSdPrior, HierarchicalNormalPrior, StudentTGroupModePrior, + TruncatedNormalGroupSdPrior, ) from pyrenew.randomvariable.transformedvariable import TransformedVariable @@ -18,6 +18,6 @@ "DynamicDistributionalVariable", "TransformedVariable", "HierarchicalNormalPrior", - "GammaGroupSdPrior", + "TruncatedNormalGroupSdPrior", "StudentTGroupModePrior", ] diff --git a/pyrenew/randomvariable/hierarchical.py b/pyrenew/randomvariable/hierarchical.py index e97126d4..45528d87 100644 --- a/pyrenew/randomvariable/hierarchical.py +++ b/pyrenew/randomvariable/hierarchical.py @@ -7,7 +7,6 @@ dynamic group sizes at sample time with proper numpyro plate contexts. """ -import jax.numpy as jnp import numpyro import numpyro.distributions as dist @@ -107,42 +106,52 @@ def sample(self, n_groups: int, **kwargs): return effects -class GammaGroupSdPrior(RandomVariable): +class TruncatedNormalGroupSdPrior(RandomVariable): """ - Gamma prior for group-level standard deviations, bounded away from zero. + Truncated Normal prior for group-level standard deviations. - Samples n_groups positive values from Gamma(concentration, rate) + sd_min - within a numpyro plate context. + Samples n_groups positive values from TruncatedNormal(loc, scale, low=sd_min) + within a numpyro plate context. Uses NumPyro's native truncated distribution + support for proper likelihood computation. Parameters ---------- name : str Unique name for the sampled parameter in numpyro. - sd_mean_rv : RandomVariable - RandomVariable returning the mean of the Gamma distribution. - sd_concentration_rv : RandomVariable - RandomVariable returning the concentration (shape) parameter of Gamma. + loc_rv : RandomVariable + RandomVariable returning the location (mean) of the underlying Normal. + scale_rv : RandomVariable + RandomVariable returning the scale (std) of the underlying Normal. sd_min : float, default=0.05 - Minimum SD value (lower bound). + Minimum SD value (left truncation point). Notes ----- - This class parameterizes Gamma by mean and concentration rather than - shape and rate, which is often more interpretable. The rate is computed as - concentration / mean. + Truncated normal is a common choice for standard deviation priors because: - The sd_min floor prevents numerical issues when standard deviations - approach zero. + - It naturally constrains values to be positive (via left truncation) + - The location parameter represents the expected SD value + - The scale parameter controls uncertainty about the SD + - It integrates well with NumPyro's truncation support for proper likelihood + + Parameter guidance: + + - ``loc``: Set to your prior expectation for the SD (e.g., 0.3 for + moderate sensor variability on log scale) + - ``scale``: Controls uncertainty; smaller values give tighter priors + around ``loc`` + - ``sd_min``: Prevents numerical issues when SDs approach zero; typical + values are 0.01-0.1 depending on the scale of the data Examples -------- >>> from pyrenew.deterministic import DeterministicVariable - >>> from pyrenew.randomvariable import GammaGroupSdPrior + >>> from pyrenew.randomvariable import TruncatedNormalGroupSdPrior >>> import numpyro >>> - >>> mean_rv = DeterministicVariable("sd_mean", 0.3) - >>> conc_rv = DeterministicVariable("sd_conc", 4.0) - >>> prior = GammaGroupSdPrior("site_sd", mean_rv, conc_rv, sd_min=0.05) + >>> loc_rv = DeterministicVariable("sd_loc", 0.3) + >>> scale_rv = DeterministicVariable("sd_scale", 0.15) + >>> prior = TruncatedNormalGroupSdPrior("site_sd", loc_rv, scale_rv, sd_min=0.05) >>> >>> with numpyro.handlers.seed(rng_seed=42): ... sds = prior.sample(n_groups=5) @@ -153,44 +162,44 @@ class GammaGroupSdPrior(RandomVariable): def __init__( self, name: str, - sd_mean_rv: RandomVariable, - sd_concentration_rv: RandomVariable, + loc_rv: RandomVariable, + scale_rv: RandomVariable, sd_min: float = 0.05, ) -> None: """ - Default constructor for GammaGroupSdPrior. + Default constructor for TruncatedNormalGroupSdPrior. Parameters ---------- name : str Unique name for the sampled parameter in numpyro. - sd_mean_rv : RandomVariable - RandomVariable returning the mean of the Gamma distribution. - sd_concentration_rv : RandomVariable - RandomVariable returning the concentration (shape) parameter. + loc_rv : RandomVariable + RandomVariable returning the location (mean) of the underlying Normal. + scale_rv : RandomVariable + RandomVariable returning the scale (std) of the underlying Normal. sd_min : float, default=0.05 - Minimum SD value (lower bound). + Minimum SD value (left truncation point). Returns ------- None """ - if not isinstance(sd_mean_rv, RandomVariable): + if not isinstance(loc_rv, RandomVariable): raise TypeError( - f"sd_mean_rv must be a RandomVariable, got {type(sd_mean_rv).__name__}. " + f"loc_rv must be a RandomVariable, got {type(loc_rv).__name__}. " "Use DeterministicVariable(name, value) to wrap a fixed value." ) - if not isinstance(sd_concentration_rv, RandomVariable): + if not isinstance(scale_rv, RandomVariable): raise TypeError( - f"sd_concentration_rv must be a RandomVariable, got {type(sd_concentration_rv).__name__}. " + f"scale_rv must be a RandomVariable, got {type(scale_rv).__name__}. " "Use DeterministicVariable(name, value) to wrap a fixed value." ) if sd_min < 0: raise ValueError(f"sd_min must be non-negative, got {sd_min}") self.name = name - self.sd_mean_rv = sd_mean_rv - self.sd_concentration_rv = sd_concentration_rv + self.loc_rv = loc_rv + self.scale_rv = scale_rv self.sd_min = sd_min def validate(self): @@ -213,19 +222,13 @@ def sample(self, n_groups: int, **kwargs): ArrayLike Array of shape (n_groups,) with values >= sd_min. """ - sd_mean = self.sd_mean_rv() - concentration = self.sd_concentration_rv() - rate = concentration / sd_mean + loc = self.loc_rv() + scale = self.scale_rv() with numpyro.plate(f"n_{self.name}", n_groups): - raw_sd = numpyro.sample( - f"{self.name}_raw", - dist.Gamma(concentration, rate), - ) - - group_sd = numpyro.deterministic( + group_sd = numpyro.sample( self.name, - jnp.maximum(raw_sd, self.sd_min), + dist.TruncatedNormal(loc=loc, scale=scale, low=self.sd_min), ) return group_sd diff --git a/test/conftest.py b/test/conftest.py index 6f55681a..d159106a 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -10,7 +10,7 @@ from pyrenew.deterministic import DeterministicPMF, DeterministicVariable from pyrenew.observation import Counts, NegativeBinomialNoise -from pyrenew.randomvariable import GammaGroupSdPrior, HierarchicalNormalPrior +from pyrenew.randomvariable import HierarchicalNormalPrior, TruncatedNormalGroupSdPrior # ============================================================================= # PMF Fixtures @@ -159,17 +159,17 @@ def sensor_mode_prior_tight(): @pytest.fixture def sensor_sd_prior(): """ - Standard gamma prior for sensor standard deviations. + Standard truncated normal prior for sensor standard deviations. Returns ------- - GammaGroupSdPrior - A gamma prior for group standard deviations. + TruncatedNormalGroupSdPrior + A truncated normal prior for group standard deviations. """ - return GammaGroupSdPrior( + return TruncatedNormalGroupSdPrior( name="ww_sensor_sd", - sd_mean_rv=DeterministicVariable("sd_mean", 0.3), - sd_concentration_rv=DeterministicVariable("sd_concentration", 4.0), + loc_rv=DeterministicVariable("sd_loc", 0.3), + scale_rv=DeterministicVariable("sd_scale", 0.15), sd_min=0.10, ) @@ -177,17 +177,17 @@ def sensor_sd_prior(): @pytest.fixture def sensor_sd_prior_tight(): """ - Tight gamma prior for deterministic-like behavior. + Tight truncated normal prior for deterministic-like behavior. Returns ------- - GammaGroupSdPrior - A gamma prior with small mean for tight behavior. + TruncatedNormalGroupSdPrior + A truncated normal prior with small scale for tight behavior. """ - return GammaGroupSdPrior( + return TruncatedNormalGroupSdPrior( name="ww_sensor_sd", - sd_mean_rv=DeterministicVariable("sd_mean_tight", 0.01), - sd_concentration_rv=DeterministicVariable("sd_concentration_tight", 4.0), + loc_rv=DeterministicVariable("sd_loc_tight", 0.01), + scale_rv=DeterministicVariable("sd_scale_tight", 0.005), sd_min=0.005, ) diff --git a/test/test_observation_measurements.py b/test/test_observation_measurements.py index ae5b70c2..4287404d 100644 --- a/test/test_observation_measurements.py +++ b/test/test_observation_measurements.py @@ -11,7 +11,7 @@ from pyrenew.deterministic import DeterministicPMF, DeterministicVariable from pyrenew.observation import HierarchicalNormalNoise, Measurements from pyrenew.observation.base import BaseObservationProcess -from pyrenew.randomvariable import GammaGroupSdPrior, HierarchicalNormalPrior +from pyrenew.randomvariable import HierarchicalNormalPrior, TruncatedNormalGroupSdPrior class ConcreteMeasurements(Measurements): @@ -79,10 +79,10 @@ def test_infection_resolution_is_subpop(self): sensor_mode_prior = HierarchicalNormalPrior( name="mode", sd_rv=DeterministicVariable("sd", 0.5) ) - sensor_sd_prior = GammaGroupSdPrior( + sensor_sd_prior = TruncatedNormalGroupSdPrior( name="sd", - sd_mean_rv=DeterministicVariable("mean", 0.3), - sd_concentration_rv=DeterministicVariable("conc", 4.0), + loc_rv=DeterministicVariable("loc", 0.3), + scale_rv=DeterministicVariable("scale", 0.15), ) noise = HierarchicalNormalNoise(sensor_mode_prior, sensor_sd_prior) @@ -102,10 +102,10 @@ def test_validate(self): sensor_mode_prior = HierarchicalNormalPrior( name="mode", sd_rv=DeterministicVariable("mode_sd", 0.5) ) - sensor_sd_prior = GammaGroupSdPrior( + sensor_sd_prior = TruncatedNormalGroupSdPrior( name="sd", - sd_mean_rv=DeterministicVariable("sd_mean", 0.3), - sd_concentration_rv=DeterministicVariable("sd_conc", 4.0), + loc_rv=DeterministicVariable("sd_loc", 0.3), + scale_rv=DeterministicVariable("sd_scale", 0.15), ) noise = HierarchicalNormalNoise(sensor_mode_prior, sensor_sd_prior) # Should not raise - validation is deferred to sample time @@ -116,10 +116,10 @@ def test_sample_shape(self): sensor_mode_prior = HierarchicalNormalPrior( name="mode", sd_rv=DeterministicVariable("mode_sd", 0.5) ) - sensor_sd_prior = GammaGroupSdPrior( + sensor_sd_prior = TruncatedNormalGroupSdPrior( name="sd", - sd_mean_rv=DeterministicVariable("sd_mean", 0.3), - sd_concentration_rv=DeterministicVariable("sd_conc", 4.0), + loc_rv=DeterministicVariable("sd_loc", 0.3), + scale_rv=DeterministicVariable("sd_scale", 0.15), ) noise = HierarchicalNormalNoise(sensor_mode_prior, sensor_sd_prior) @@ -142,10 +142,10 @@ def test_sample_with_observations(self): sensor_mode_prior = HierarchicalNormalPrior( name="mode", sd_rv=DeterministicVariable("mode_sd", 0.5) ) - sensor_sd_prior = GammaGroupSdPrior( + sensor_sd_prior = TruncatedNormalGroupSdPrior( name="sd", - sd_mean_rv=DeterministicVariable("sd_mean", 0.3), - sd_concentration_rv=DeterministicVariable("sd_conc", 4.0), + loc_rv=DeterministicVariable("sd_loc", 0.3), + scale_rv=DeterministicVariable("sd_scale", 0.15), ) noise = HierarchicalNormalNoise(sensor_mode_prior, sensor_sd_prior) @@ -175,10 +175,10 @@ def test_sample_shape(self): sensor_mode_prior = HierarchicalNormalPrior( name="mode", sd_rv=DeterministicVariable("mode_sd", 0.5) ) - sensor_sd_prior = GammaGroupSdPrior( + sensor_sd_prior = TruncatedNormalGroupSdPrior( name="sd", - sd_mean_rv=DeterministicVariable("sd_mean", 0.3), - sd_concentration_rv=DeterministicVariable("sd_conc", 4.0), + loc_rv=DeterministicVariable("sd_loc", 0.3), + scale_rv=DeterministicVariable("sd_scale", 0.15), ) noise = HierarchicalNormalNoise(sensor_mode_prior, sensor_sd_prior) @@ -212,10 +212,10 @@ def test_predicted_obs_stored(self): sensor_mode_prior = HierarchicalNormalPrior( name="mode", sd_rv=DeterministicVariable("mode_sd", 0.01) ) - sensor_sd_prior = GammaGroupSdPrior( + sensor_sd_prior = TruncatedNormalGroupSdPrior( name="sd", - sd_mean_rv=DeterministicVariable("sd_mean", 0.01), - sd_concentration_rv=DeterministicVariable("sd_conc", 4.0), + loc_rv=DeterministicVariable("sd_loc", 0.01), + scale_rv=DeterministicVariable("sd_scale", 0.005), sd_min=0.001, ) noise = HierarchicalNormalNoise(sensor_mode_prior, sensor_sd_prior) diff --git a/test/test_randomvariable_hierarchical.py b/test/test_randomvariable_hierarchical.py index 89a58c3a..8d017a73 100644 --- a/test/test_randomvariable_hierarchical.py +++ b/test/test_randomvariable_hierarchical.py @@ -8,9 +8,9 @@ from pyrenew.deterministic import DeterministicVariable from pyrenew.randomvariable import ( DistributionalVariable, - GammaGroupSdPrior, HierarchicalNormalPrior, StudentTGroupModePrior, + TruncatedNormalGroupSdPrior, ) @@ -69,15 +69,15 @@ def test_accepts_distributional_variable_for_sd(self): assert samples.shape == (5,) -class TestGammaGroupSdPrior: - """Test GammaGroupSdPrior.""" +class TestTruncatedNormalGroupSdPrior: + """Test TruncatedNormalGroupSdPrior.""" def test_sample_shape(self): """Test that sample returns correct shape.""" - prior = GammaGroupSdPrior( + prior = TruncatedNormalGroupSdPrior( "sd", - sd_mean_rv=DeterministicVariable("sd_mean", 0.5), - sd_concentration_rv=DeterministicVariable("sd_conc", 4.0), + loc_rv=DeterministicVariable("sd_loc", 0.5), + scale_rv=DeterministicVariable("sd_scale", 0.15), ) with numpyro.handlers.seed(rng_seed=42): @@ -87,10 +87,10 @@ def test_sample_shape(self): def test_respects_sd_min(self): """Test that sd_min is enforced as lower bound.""" - prior = GammaGroupSdPrior( + prior = TruncatedNormalGroupSdPrior( "sd", - sd_mean_rv=DeterministicVariable("sd_mean", 0.1), - sd_concentration_rv=DeterministicVariable("sd_conc", 4.0), + loc_rv=DeterministicVariable("sd_loc", 0.1), + scale_rv=DeterministicVariable("sd_scale", 0.05), sd_min=0.5, ) @@ -101,38 +101,36 @@ def test_respects_sd_min(self): def test_rejects_non_random_variable_params(self): """Test that non-RandomVariable parameters are rejected.""" - with pytest.raises(TypeError, match="sd_mean_rv must be a RandomVariable"): - GammaGroupSdPrior( + with pytest.raises(TypeError, match="loc_rv must be a RandomVariable"): + TruncatedNormalGroupSdPrior( "sd", - sd_mean_rv=0.5, - sd_concentration_rv=DeterministicVariable("sd_conc", 4.0), + loc_rv=0.5, + scale_rv=DeterministicVariable("sd_scale", 0.15), ) - with pytest.raises( - TypeError, match="sd_concentration_rv must be a RandomVariable" - ): - GammaGroupSdPrior( + with pytest.raises(TypeError, match="scale_rv must be a RandomVariable"): + TruncatedNormalGroupSdPrior( "sd", - sd_mean_rv=DeterministicVariable("sd_mean", 0.5), - sd_concentration_rv=4.0, + loc_rv=DeterministicVariable("sd_loc", 0.5), + scale_rv=0.15, ) def test_rejects_negative_sd_min(self): """Test that negative sd_min is rejected.""" with pytest.raises(ValueError, match="sd_min must be non-negative"): - GammaGroupSdPrior( + TruncatedNormalGroupSdPrior( "sd", - sd_mean_rv=DeterministicVariable("sd_mean", 0.5), - sd_concentration_rv=DeterministicVariable("sd_conc", 4.0), + loc_rv=DeterministicVariable("sd_loc", 0.5), + scale_rv=DeterministicVariable("sd_scale", 0.15), sd_min=-0.1, ) def test_validate(self): """Test that validate() runs without error.""" - prior = GammaGroupSdPrior( + prior = TruncatedNormalGroupSdPrior( "sd", - sd_mean_rv=DeterministicVariable("sd_mean", 0.5), - sd_concentration_rv=DeterministicVariable("sd_conc", 4.0), + loc_rv=DeterministicVariable("sd_loc", 0.5), + scale_rv=DeterministicVariable("sd_scale", 0.15), ) prior.validate() # Should not raise From 910fa70848a55230861cedd9137109e6cbb44008 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Wed, 7 Jan 2026 12:38:09 -0500 Subject: [PATCH 27/33] Remove the provided helper classes for sampling priors (i.e. pyrenew/randomvariable/hierarchical.py), per reviewer comments --- .../observation_processes_measurements.qmd | 34 +- pyrenew/observation/noise.py | 61 ++-- pyrenew/randomvariable/__init__.py | 8 - pyrenew/randomvariable/hierarchical.py | 339 ------------------ test/conftest.py | 55 ++- test/test_observation_measurements.py | 78 ++-- test/test_randomvariable_hierarchical.py | 202 ----------- 7 files changed, 96 insertions(+), 681 deletions(-) delete mode 100644 pyrenew/randomvariable/hierarchical.py delete mode 100644 test/test_randomvariable_hierarchical.py diff --git a/docs/tutorials/observation_processes_measurements.qmd b/docs/tutorials/observation_processes_measurements.qmd index 9b32567e..edfb807b 100644 --- a/docs/tutorials/observation_processes_measurements.qmd +++ b/docs/tutorials/observation_processes_measurements.qmd @@ -15,11 +15,10 @@ import numpy as np import numpyro import matplotlib.pyplot as plt +import numpyro.distributions as dist + from pyrenew.observation import Measurements, HierarchicalNormalNoise -from pyrenew.randomvariable import ( - HierarchicalNormalPrior, - TruncatedNormalGroupSdPrior, -) +from pyrenew.randomvariable import DistributionalVariable from pyrenew.deterministic import DeterministicVariable, DeterministicPMF ``` @@ -215,33 +214,22 @@ ml_per_person_per_day = 1000.0 A measurement depends on both the thing being measured (e.g., wastewater from a treatment plant) and the calibration of the lab/instruments used to obtain that measurement. We call this combination a "sensor"—the WWTP/lab pair that determines the measurement characteristics. Different sensors have systematic biases and variabilities that we model with hierarchical effects. -For the sensor-level standard deviations, we use a truncated normal prior. This is a common choice for SD parameters because: - -- It naturally constrains values to be positive -- The `loc` parameter represents our prior expectation for the SD -- The `scale` parameter controls how uncertain we are about the SD -- The `sd_min` floor prevents numerical issues when SDs approach zero +The `HierarchicalNormalNoise` model samples sensor-level parameters within a plate context, so any `RandomVariable` can be used. For the sensor-level standard deviations, we use a truncated normal prior—a common choice for SD parameters because it naturally constrains values to be positive. ```{python} # | label: noise-model -# Sensor-level mode: systematic differences between sensors -sensor_mode_prior = HierarchicalNormalPrior( - name="ww_sensor_mode", - sd_rv=DeterministicVariable("mode_sd", 0.5), -) +# Sensor-level mode: systematic differences between sensors (zero-centered normal) +sensor_mode_rv = DistributionalVariable("ww_sensor_mode", dist.Normal(0, 0.5)) -# Sensor-level SD: measurement variability within each sensor -sensor_sd_prior = TruncatedNormalGroupSdPrior( - name="ww_sensor_sd", - loc_rv=DeterministicVariable("sd_loc", 0.3), - scale_rv=DeterministicVariable("sd_scale", 0.15), - sd_min=0.10, +# Sensor-level SD: measurement variability within each sensor (truncated normal) +sensor_sd_rv = DistributionalVariable( + "ww_sensor_sd", dist.TruncatedNormal(loc=0.3, scale=0.15, low=0.10) ) # Create the noise model ww_noise = HierarchicalNormalNoise( - sensor_mode_prior_rv=sensor_mode_prior, - sensor_sd_prior_rv=sensor_sd_prior, + sensor_mode_rv=sensor_mode_rv, + sensor_sd_rv=sensor_sd_rv, ) ``` diff --git a/pyrenew/observation/noise.py b/pyrenew/observation/noise.py index 47d2d4ad..9ce282e7 100644 --- a/pyrenew/observation/noise.py +++ b/pyrenew/observation/noise.py @@ -263,55 +263,63 @@ class HierarchicalNormalNoise(MeasurementNoise): Normal noise with hierarchical sensor-level effects. Observation model: ``obs ~ Normal(predicted + sensor_mode, sensor_sd)`` - where sensor_mode and sensor_sd are hierarchically modeled. + where sensor_mode and sensor_sd are sampled per-sensor within a plate. Parameters ---------- - sensor_mode_prior_rv : RandomVariable - Hierarchical prior for sensor-level modes (log-scale biases). - Must support ``sample(n_groups=...)`` interface. - sensor_sd_prior_rv : RandomVariable - Hierarchical prior for sensor-level SDs (must be > 0). - Must support ``sample(n_groups=...)`` interface. + sensor_mode_rv : RandomVariable + Prior for sensor-level modes (log-scale biases). + Sampled once per sensor within a plate context. + Example: ``DistributionalVariable("mode", dist.Normal(0, 0.5))`` + sensor_sd_rv : RandomVariable + Prior for sensor-level SDs (should be > 0). + Sampled once per sensor within a plate context. + Example: ``DistributionalVariable("sd", dist.TruncatedNormal(0.3, 0.15, low=0.05))`` Notes ----- Expects data already on log scale for wastewater applications. - See Also + The sensor-level parameters are sampled within a numpyro plate context, + so any standard RandomVariable can be used (no special interface required). + + Examples -------- - pyrenew.randomvariable.HierarchicalNormalPrior : - Suitable prior for sensor_mode_prior_rv - pyrenew.randomvariable.TruncatedNormalGroupSdPrior : - Suitable prior for sensor_sd_prior_rv + >>> from pyrenew.randomvariable import DistributionalVariable + >>> import numpyro.distributions as dist + >>> + >>> noise = HierarchicalNormalNoise( + ... sensor_mode_rv=DistributionalVariable("mode", dist.Normal(0, 0.5)), + ... sensor_sd_rv=DistributionalVariable("sd", dist.TruncatedNormal(0.3, 0.15, low=0.05)), + ... ) """ def __init__( self, - sensor_mode_prior_rv: RandomVariable, - sensor_sd_prior_rv: RandomVariable, + sensor_mode_rv: RandomVariable, + sensor_sd_rv: RandomVariable, ) -> None: """ Initialize hierarchical Normal noise. Parameters ---------- - sensor_mode_prior_rv : RandomVariable - Hierarchical prior for sensor-level modes (log-scale biases). - Must support ``sample(n_groups=...)`` interface. - sensor_sd_prior_rv : RandomVariable - Hierarchical prior for sensor-level SDs (must be > 0). - Must support ``sample(n_groups=...)`` interface. + sensor_mode_rv : RandomVariable + Prior for sensor-level modes (log-scale biases). + Sampled once per sensor within a plate context. + sensor_sd_rv : RandomVariable + Prior for sensor-level SDs (should be > 0). + Sampled once per sensor within a plate context. """ - self.sensor_mode_prior_rv = sensor_mode_prior_rv - self.sensor_sd_prior_rv = sensor_sd_prior_rv + self.sensor_mode_rv = sensor_mode_rv + self.sensor_sd_rv = sensor_sd_rv def __repr__(self) -> str: """Return string representation.""" return ( f"HierarchicalNormalNoise(" - f"sensor_mode_prior_rv={self.sensor_mode_prior_rv!r}, " - f"sensor_sd_prior_rv={self.sensor_sd_prior_rv!r})" + f"sensor_mode_rv={self.sensor_mode_rv!r}, " + f"sensor_sd_rv={self.sensor_sd_rv!r})" ) def validate(self) -> None: @@ -363,8 +371,9 @@ def sample( ValueError If sensor_sd samples non-positive values. """ - sensor_mode = self.sensor_mode_prior_rv.sample(n_groups=n_sensors) - sensor_sd = self.sensor_sd_prior_rv.sample(n_groups=n_sensors) + with numpyro.plate("sensor", n_sensors): + sensor_mode = self.sensor_mode_rv() + sensor_sd = self.sensor_sd_rv() loc = predicted + sensor_mode[sensor_indices] scale = sensor_sd[sensor_indices] diff --git a/pyrenew/randomvariable/__init__.py b/pyrenew/randomvariable/__init__.py index 09812b9c..c599d101 100644 --- a/pyrenew/randomvariable/__init__.py +++ b/pyrenew/randomvariable/__init__.py @@ -5,11 +5,6 @@ DynamicDistributionalVariable, StaticDistributionalVariable, ) -from pyrenew.randomvariable.hierarchical import ( - HierarchicalNormalPrior, - StudentTGroupModePrior, - TruncatedNormalGroupSdPrior, -) from pyrenew.randomvariable.transformedvariable import TransformedVariable __all__ = [ @@ -17,7 +12,4 @@ "StaticDistributionalVariable", "DynamicDistributionalVariable", "TransformedVariable", - "HierarchicalNormalPrior", - "TruncatedNormalGroupSdPrior", - "StudentTGroupModePrior", ] diff --git a/pyrenew/randomvariable/hierarchical.py b/pyrenew/randomvariable/hierarchical.py deleted file mode 100644 index 45528d87..00000000 --- a/pyrenew/randomvariable/hierarchical.py +++ /dev/null @@ -1,339 +0,0 @@ -# numpydoc ignore=GL08 -""" -Hierarchical prior distributions for group-level random effects. - -These classes provide random variables that sample from hierarchical -distributions with a `sample(n_groups=...)` interface, enabling -dynamic group sizes at sample time with proper numpyro plate contexts. -""" - -import numpyro -import numpyro.distributions as dist - -from pyrenew.metaclass import RandomVariable - - -class HierarchicalNormalPrior(RandomVariable): - """ - Zero-centered Normal prior for group-level effects. - - Samples n_groups values from Normal(0, sd) within a numpyro plate context. - - Parameters - ---------- - name : str - Unique name for the sampled parameter in numpyro. - sd_rv : RandomVariable - RandomVariable returning the standard deviation. - - Notes - ----- - This class is designed for hierarchical models where group effects - are assumed to be drawn from a common distribution centered at zero. - The number of groups is specified at sample time, allowing dynamic - group sizes. - - Examples - -------- - >>> from pyrenew.deterministic import DeterministicVariable - >>> from pyrenew.randomvariable import HierarchicalNormalPrior - >>> import numpyro - >>> - >>> sd_rv = DeterministicVariable("sd", 0.5) - >>> prior = HierarchicalNormalPrior("site_effects", sd_rv) - >>> - >>> with numpyro.handlers.seed(rng_seed=42): - ... effects = prior.sample(n_groups=5) - >>> effects.shape - (5,) - """ - - def __init__( - self, - name: str, - sd_rv: RandomVariable, - ) -> None: - """ - Default constructor for HierarchicalNormalPrior. - - Parameters - ---------- - name : str - Unique name for the sampled parameter in numpyro. - sd_rv : RandomVariable - RandomVariable returning the standard deviation. - - Returns - ------- - None - """ - if not isinstance(sd_rv, RandomVariable): - raise TypeError( - f"sd_rv must be a RandomVariable, got {type(sd_rv).__name__}. " - "Use DeterministicVariable(name, value) to wrap a fixed value." - ) - - self.name = name - self.sd_rv = sd_rv - - def validate(self): - """Validate the random variable (no-op for this class).""" - pass - - def sample(self, n_groups: int, **kwargs): - """ - Sample group-level effects. - - Parameters - ---------- - n_groups : int - Number of groups. - **kwargs - Additional keyword arguments (unused). - - Returns - ------- - ArrayLike - Array of shape (n_groups,) containing sampled effects. - """ - sd = self.sd_rv() - - with numpyro.plate(f"n_{self.name}", n_groups): - effects = numpyro.sample( - self.name, - dist.Normal(0.0, sd), - ) - return effects - - -class TruncatedNormalGroupSdPrior(RandomVariable): - """ - Truncated Normal prior for group-level standard deviations. - - Samples n_groups positive values from TruncatedNormal(loc, scale, low=sd_min) - within a numpyro plate context. Uses NumPyro's native truncated distribution - support for proper likelihood computation. - - Parameters - ---------- - name : str - Unique name for the sampled parameter in numpyro. - loc_rv : RandomVariable - RandomVariable returning the location (mean) of the underlying Normal. - scale_rv : RandomVariable - RandomVariable returning the scale (std) of the underlying Normal. - sd_min : float, default=0.05 - Minimum SD value (left truncation point). - - Notes - ----- - Truncated normal is a common choice for standard deviation priors because: - - - It naturally constrains values to be positive (via left truncation) - - The location parameter represents the expected SD value - - The scale parameter controls uncertainty about the SD - - It integrates well with NumPyro's truncation support for proper likelihood - - Parameter guidance: - - - ``loc``: Set to your prior expectation for the SD (e.g., 0.3 for - moderate sensor variability on log scale) - - ``scale``: Controls uncertainty; smaller values give tighter priors - around ``loc`` - - ``sd_min``: Prevents numerical issues when SDs approach zero; typical - values are 0.01-0.1 depending on the scale of the data - - Examples - -------- - >>> from pyrenew.deterministic import DeterministicVariable - >>> from pyrenew.randomvariable import TruncatedNormalGroupSdPrior - >>> import numpyro - >>> - >>> loc_rv = DeterministicVariable("sd_loc", 0.3) - >>> scale_rv = DeterministicVariable("sd_scale", 0.15) - >>> prior = TruncatedNormalGroupSdPrior("site_sd", loc_rv, scale_rv, sd_min=0.05) - >>> - >>> with numpyro.handlers.seed(rng_seed=42): - ... sds = prior.sample(n_groups=5) - >>> all(sds >= 0.05) - True - """ - - def __init__( - self, - name: str, - loc_rv: RandomVariable, - scale_rv: RandomVariable, - sd_min: float = 0.05, - ) -> None: - """ - Default constructor for TruncatedNormalGroupSdPrior. - - Parameters - ---------- - name : str - Unique name for the sampled parameter in numpyro. - loc_rv : RandomVariable - RandomVariable returning the location (mean) of the underlying Normal. - scale_rv : RandomVariable - RandomVariable returning the scale (std) of the underlying Normal. - sd_min : float, default=0.05 - Minimum SD value (left truncation point). - - Returns - ------- - None - """ - if not isinstance(loc_rv, RandomVariable): - raise TypeError( - f"loc_rv must be a RandomVariable, got {type(loc_rv).__name__}. " - "Use DeterministicVariable(name, value) to wrap a fixed value." - ) - if not isinstance(scale_rv, RandomVariable): - raise TypeError( - f"scale_rv must be a RandomVariable, got {type(scale_rv).__name__}. " - "Use DeterministicVariable(name, value) to wrap a fixed value." - ) - if sd_min < 0: - raise ValueError(f"sd_min must be non-negative, got {sd_min}") - - self.name = name - self.loc_rv = loc_rv - self.scale_rv = scale_rv - self.sd_min = sd_min - - def validate(self): - """Validate the random variable (no-op for this class).""" - pass - - def sample(self, n_groups: int, **kwargs): - """ - Sample group-level standard deviations. - - Parameters - ---------- - n_groups : int - Number of groups. - **kwargs - Additional keyword arguments (unused). - - Returns - ------- - ArrayLike - Array of shape (n_groups,) with values >= sd_min. - """ - loc = self.loc_rv() - scale = self.scale_rv() - - with numpyro.plate(f"n_{self.name}", n_groups): - group_sd = numpyro.sample( - self.name, - dist.TruncatedNormal(loc=loc, scale=scale, low=self.sd_min), - ) - return group_sd - - -class StudentTGroupModePrior(RandomVariable): - """ - Zero-centered Student-t prior for group-level modes (robust alternative to Normal). - - Samples n_groups values from StudentT(df, 0, sd) within a numpyro plate context. - This is useful when group effects may have heavier tails than a Normal distribution. - - Parameters - ---------- - name : str - Unique name for the sampled parameter in numpyro. - sd_rv : RandomVariable - RandomVariable returning the scale parameter. - df_rv : RandomVariable - RandomVariable returning the degrees of freedom. - - Notes - ----- - The Student-t distribution approaches the Normal distribution as df -> infinity. - Lower df values give heavier tails, making the prior more robust to outliers. - Common choices include df=3 (heavy tails) or df=7 (moderate tails). - - Examples - -------- - >>> from pyrenew.deterministic import DeterministicVariable - >>> from pyrenew.randomvariable import StudentTGroupModePrior - >>> import numpyro - >>> - >>> sd_rv = DeterministicVariable("scale", 0.5) - >>> df_rv = DeterministicVariable("df", 4.0) - >>> prior = StudentTGroupModePrior("site_modes", sd_rv, df_rv) - >>> - >>> with numpyro.handlers.seed(rng_seed=42): - ... modes = prior.sample(n_groups=5) - >>> modes.shape - (5,) - """ - - def __init__( - self, - name: str, - sd_rv: RandomVariable, - df_rv: RandomVariable, - ) -> None: - """ - Default constructor for StudentTGroupModePrior. - - Parameters - ---------- - name : str - Unique name for the sampled parameter in numpyro. - sd_rv : RandomVariable - RandomVariable returning the scale parameter. - df_rv : RandomVariable - RandomVariable returning the degrees of freedom. - - Returns - ------- - None - """ - if not isinstance(sd_rv, RandomVariable): - raise TypeError( - f"sd_rv must be a RandomVariable, got {type(sd_rv).__name__}. " - "Use DeterministicVariable(name, value) to wrap a fixed value." - ) - if not isinstance(df_rv, RandomVariable): - raise TypeError( - f"df_rv must be a RandomVariable, got {type(df_rv).__name__}. " - "Use DeterministicVariable(name, value) to wrap a fixed value." - ) - - self.name = name - self.sd_rv = sd_rv - self.df_rv = df_rv - - def validate(self): - """Validate the random variable (no-op for this class).""" - pass - - def sample(self, n_groups: int, **kwargs): - """ - Sample group-level modes. - - Parameters - ---------- - n_groups : int - Number of groups. - **kwargs - Additional keyword arguments (unused). - - Returns - ------- - ArrayLike - Array of shape (n_groups,) containing sampled modes. - """ - sd = self.sd_rv() - df = self.df_rv() - - with numpyro.plate(f"n_{self.name}", n_groups): - effects = numpyro.sample( - self.name, - dist.StudentT(df=df, loc=0.0, scale=sd), - ) - return effects diff --git a/test/conftest.py b/test/conftest.py index d159106a..bc95fd10 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -6,11 +6,12 @@ """ import jax.numpy as jnp +import numpyro.distributions as dist import pytest from pyrenew.deterministic import DeterministicPMF, DeterministicVariable from pyrenew.observation import Counts, NegativeBinomialNoise -from pyrenew.randomvariable import HierarchicalNormalPrior, TruncatedNormalGroupSdPrior +from pyrenew.randomvariable import DistributionalVariable # ============================================================================= # PMF Fixtures @@ -122,73 +123,63 @@ def medium_shedding_pmf(): # ============================================================================= -# Hierarchical Prior Fixtures +# Sensor Prior Fixtures # ============================================================================= @pytest.fixture -def sensor_mode_prior(): +def sensor_mode_rv(): """ - Standard hierarchical normal prior for sensor modes. + Standard normal prior for sensor modes. Returns ------- - HierarchicalNormalPrior - A hierarchical normal prior with standard deviation 0.5. + DistributionalVariable + A normal prior with standard deviation 0.5. """ - return HierarchicalNormalPrior( - name="ww_sensor_mode", sd_rv=DeterministicVariable("mode_sd", 0.5) - ) + return DistributionalVariable("ww_sensor_mode", dist.Normal(0, 0.5)) @pytest.fixture -def sensor_mode_prior_tight(): +def sensor_mode_rv_tight(): """ - Tight hierarchical normal prior for deterministic-like behavior. + Tight normal prior for deterministic-like behavior. Returns ------- - HierarchicalNormalPrior - A hierarchical normal prior with small standard deviation 0.01. + DistributionalVariable + A normal prior with small standard deviation 0.01. """ - return HierarchicalNormalPrior( - name="ww_sensor_mode", sd_rv=DeterministicVariable("mode_sd_tight", 0.01) - ) + return DistributionalVariable("ww_sensor_mode", dist.Normal(0, 0.01)) @pytest.fixture -def sensor_sd_prior(): +def sensor_sd_rv(): """ Standard truncated normal prior for sensor standard deviations. Returns ------- - TruncatedNormalGroupSdPrior - A truncated normal prior for group standard deviations. - """ - return TruncatedNormalGroupSdPrior( - name="ww_sensor_sd", - loc_rv=DeterministicVariable("sd_loc", 0.3), - scale_rv=DeterministicVariable("sd_scale", 0.15), - sd_min=0.10, + DistributionalVariable + A truncated normal prior for sensor standard deviations. + """ + return DistributionalVariable( + "ww_sensor_sd", dist.TruncatedNormal(0.3, 0.15, low=0.10) ) @pytest.fixture -def sensor_sd_prior_tight(): +def sensor_sd_rv_tight(): """ Tight truncated normal prior for deterministic-like behavior. Returns ------- - TruncatedNormalGroupSdPrior + DistributionalVariable A truncated normal prior with small scale for tight behavior. """ - return TruncatedNormalGroupSdPrior( - name="ww_sensor_sd", - loc_rv=DeterministicVariable("sd_loc_tight", 0.01), - scale_rv=DeterministicVariable("sd_scale_tight", 0.005), - sd_min=0.005, + return DistributionalVariable( + "ww_sensor_sd", dist.TruncatedNormal(0.01, 0.005, low=0.005) ) diff --git a/test/test_observation_measurements.py b/test/test_observation_measurements.py index 4287404d..32e053eb 100644 --- a/test/test_observation_measurements.py +++ b/test/test_observation_measurements.py @@ -6,12 +6,13 @@ import jax.numpy as jnp import numpyro +import numpyro.distributions as dist import pytest -from pyrenew.deterministic import DeterministicPMF, DeterministicVariable +from pyrenew.deterministic import DeterministicPMF from pyrenew.observation import HierarchicalNormalNoise, Measurements from pyrenew.observation.base import BaseObservationProcess -from pyrenew.randomvariable import HierarchicalNormalPrior, TruncatedNormalGroupSdPrior +from pyrenew.randomvariable import DistributionalVariable class ConcreteMeasurements(Measurements): @@ -76,15 +77,11 @@ def test_is_base_observation_process(self): def test_infection_resolution_is_subpop(self): """Test that Measurements returns 'subpop' resolution.""" shedding_pmf = jnp.array([0.3, 0.4, 0.3]) - sensor_mode_prior = HierarchicalNormalPrior( - name="mode", sd_rv=DeterministicVariable("sd", 0.5) + sensor_mode_rv = DistributionalVariable("mode", dist.Normal(0, 0.5)) + sensor_sd_rv = DistributionalVariable( + "sd", dist.TruncatedNormal(0.3, 0.15, low=0.05) ) - sensor_sd_prior = TruncatedNormalGroupSdPrior( - name="sd", - loc_rv=DeterministicVariable("loc", 0.3), - scale_rv=DeterministicVariable("scale", 0.15), - ) - noise = HierarchicalNormalNoise(sensor_mode_prior, sensor_sd_prior) + noise = HierarchicalNormalNoise(sensor_mode_rv, sensor_sd_rv) process = ConcreteMeasurements( temporal_pmf_rv=DeterministicPMF("shedding", shedding_pmf), @@ -99,29 +96,21 @@ class TestHierarchicalNormalNoise: def test_validate(self): """Test HierarchicalNormalNoise validate method.""" - sensor_mode_prior = HierarchicalNormalPrior( - name="mode", sd_rv=DeterministicVariable("mode_sd", 0.5) - ) - sensor_sd_prior = TruncatedNormalGroupSdPrior( - name="sd", - loc_rv=DeterministicVariable("sd_loc", 0.3), - scale_rv=DeterministicVariable("sd_scale", 0.15), + sensor_mode_rv = DistributionalVariable("mode", dist.Normal(0, 0.5)) + sensor_sd_rv = DistributionalVariable( + "sd", dist.TruncatedNormal(0.3, 0.15, low=0.05) ) - noise = HierarchicalNormalNoise(sensor_mode_prior, sensor_sd_prior) + noise = HierarchicalNormalNoise(sensor_mode_rv, sensor_sd_rv) # Should not raise - validation is deferred to sample time noise.validate() def test_sample_shape(self): """Test that HierarchicalNormalNoise produces correct shape.""" - sensor_mode_prior = HierarchicalNormalPrior( - name="mode", sd_rv=DeterministicVariable("mode_sd", 0.5) - ) - sensor_sd_prior = TruncatedNormalGroupSdPrior( - name="sd", - loc_rv=DeterministicVariable("sd_loc", 0.3), - scale_rv=DeterministicVariable("sd_scale", 0.15), + sensor_mode_rv = DistributionalVariable("mode", dist.Normal(0, 0.5)) + sensor_sd_rv = DistributionalVariable( + "sd", dist.TruncatedNormal(0.3, 0.15, low=0.05) ) - noise = HierarchicalNormalNoise(sensor_mode_prior, sensor_sd_prior) + noise = HierarchicalNormalNoise(sensor_mode_rv, sensor_sd_rv) predicted = jnp.array([1.0, 2.0, 3.0, 4.0]) sensor_indices = jnp.array([0, 0, 1, 1]) @@ -139,15 +128,11 @@ def test_sample_shape(self): def test_sample_with_observations(self): """Test that HierarchicalNormalNoise conditions on observations.""" - sensor_mode_prior = HierarchicalNormalPrior( - name="mode", sd_rv=DeterministicVariable("mode_sd", 0.5) + sensor_mode_rv = DistributionalVariable("mode", dist.Normal(0, 0.5)) + sensor_sd_rv = DistributionalVariable( + "sd", dist.TruncatedNormal(0.3, 0.15, low=0.05) ) - sensor_sd_prior = TruncatedNormalGroupSdPrior( - name="sd", - loc_rv=DeterministicVariable("sd_loc", 0.3), - scale_rv=DeterministicVariable("sd_scale", 0.15), - ) - noise = HierarchicalNormalNoise(sensor_mode_prior, sensor_sd_prior) + noise = HierarchicalNormalNoise(sensor_mode_rv, sensor_sd_rv) predicted = jnp.array([1.0, 2.0, 3.0, 4.0]) obs = jnp.array([1.1, 2.1, 3.1, 4.1]) @@ -172,15 +157,11 @@ class TestConcreteMeasurements: def test_sample_shape(self): """Test that sample returns correct shape.""" shedding_pmf = jnp.array([0.3, 0.4, 0.3]) - sensor_mode_prior = HierarchicalNormalPrior( - name="mode", sd_rv=DeterministicVariable("mode_sd", 0.5) - ) - sensor_sd_prior = TruncatedNormalGroupSdPrior( - name="sd", - loc_rv=DeterministicVariable("sd_loc", 0.3), - scale_rv=DeterministicVariable("sd_scale", 0.15), + sensor_mode_rv = DistributionalVariable("mode", dist.Normal(0, 0.5)) + sensor_sd_rv = DistributionalVariable( + "sd", dist.TruncatedNormal(0.3, 0.15, low=0.05) ) - noise = HierarchicalNormalNoise(sensor_mode_prior, sensor_sd_prior) + noise = HierarchicalNormalNoise(sensor_mode_rv, sensor_sd_rv) process = ConcreteMeasurements( temporal_pmf_rv=DeterministicPMF("shedding", shedding_pmf), @@ -209,16 +190,11 @@ def test_sample_shape(self): def test_predicted_obs_stored(self): """Test that predicted_log_conc is stored as deterministic.""" shedding_pmf = jnp.array([0.5, 0.5]) - sensor_mode_prior = HierarchicalNormalPrior( - name="mode", sd_rv=DeterministicVariable("mode_sd", 0.01) - ) - sensor_sd_prior = TruncatedNormalGroupSdPrior( - name="sd", - loc_rv=DeterministicVariable("sd_loc", 0.01), - scale_rv=DeterministicVariable("sd_scale", 0.005), - sd_min=0.001, + sensor_mode_rv = DistributionalVariable("mode", dist.Normal(0, 0.01)) + sensor_sd_rv = DistributionalVariable( + "sd", dist.TruncatedNormal(0.01, 0.005, low=0.001) ) - noise = HierarchicalNormalNoise(sensor_mode_prior, sensor_sd_prior) + noise = HierarchicalNormalNoise(sensor_mode_rv, sensor_sd_rv) process = ConcreteMeasurements( temporal_pmf_rv=DeterministicPMF("shedding", shedding_pmf), diff --git a/test/test_randomvariable_hierarchical.py b/test/test_randomvariable_hierarchical.py deleted file mode 100644 index 8d017a73..00000000 --- a/test/test_randomvariable_hierarchical.py +++ /dev/null @@ -1,202 +0,0 @@ -"""Unit tests for hierarchical prior distributions.""" - -import jax.numpy as jnp -import numpyro -import numpyro.distributions as dist -import pytest - -from pyrenew.deterministic import DeterministicVariable -from pyrenew.randomvariable import ( - DistributionalVariable, - HierarchicalNormalPrior, - StudentTGroupModePrior, - TruncatedNormalGroupSdPrior, -) - - -class TestHierarchicalNormalPrior: - """Test HierarchicalNormalPrior.""" - - def test_sample_shape(self): - """Test that sample returns correct shape.""" - prior = HierarchicalNormalPrior( - "effect", sd_rv=DeterministicVariable("sd", 1.0) - ) - - with numpyro.handlers.seed(rng_seed=42): - samples = prior.sample(n_groups=5) - - assert samples.shape == (5,) - - def test_smaller_sd_produces_tighter_distribution(self): - """Test that smaller sd produces samples closer to zero.""" - prior_tight = HierarchicalNormalPrior( - "a", sd_rv=DeterministicVariable("sd_tight", 0.1) - ) - prior_wide = HierarchicalNormalPrior( - "b", sd_rv=DeterministicVariable("sd_wide", 10.0) - ) - - n_samples = 1000 - with numpyro.handlers.seed(rng_seed=42): - samples_tight = prior_tight.sample(n_groups=n_samples) - with numpyro.handlers.seed(rng_seed=43): - samples_wide = prior_wide.sample(n_groups=n_samples) - - # Tight prior should have smaller standard deviation - assert jnp.std(samples_tight) < jnp.std(samples_wide) - - def test_validate(self): - """Test that validate() runs without error.""" - prior = HierarchicalNormalPrior( - "effect", sd_rv=DeterministicVariable("sd", 1.0) - ) - prior.validate() # Should not raise - - def test_rejects_non_random_variable_sd(self): - """Test that non-RandomVariable sd_rv is rejected.""" - with pytest.raises(TypeError, match="sd_rv must be a RandomVariable"): - HierarchicalNormalPrior("effect", sd_rv=1.0) - - def test_accepts_distributional_variable_for_sd(self): - """Test that DistributionalVariable can be used for sd_rv.""" - sd_rv = DistributionalVariable("sd", dist.HalfNormal(1.0)) - prior = HierarchicalNormalPrior("effect", sd_rv=sd_rv) - - with numpyro.handlers.seed(rng_seed=42): - samples = prior.sample(n_groups=5) - - assert samples.shape == (5,) - - -class TestTruncatedNormalGroupSdPrior: - """Test TruncatedNormalGroupSdPrior.""" - - def test_sample_shape(self): - """Test that sample returns correct shape.""" - prior = TruncatedNormalGroupSdPrior( - "sd", - loc_rv=DeterministicVariable("sd_loc", 0.5), - scale_rv=DeterministicVariable("sd_scale", 0.15), - ) - - with numpyro.handlers.seed(rng_seed=42): - samples = prior.sample(n_groups=5) - - assert samples.shape == (5,) - - def test_respects_sd_min(self): - """Test that sd_min is enforced as lower bound.""" - prior = TruncatedNormalGroupSdPrior( - "sd", - loc_rv=DeterministicVariable("sd_loc", 0.1), - scale_rv=DeterministicVariable("sd_scale", 0.05), - sd_min=0.5, - ) - - with numpyro.handlers.seed(rng_seed=42): - samples = prior.sample(n_groups=100) - - assert jnp.all(samples >= 0.5) - - def test_rejects_non_random_variable_params(self): - """Test that non-RandomVariable parameters are rejected.""" - with pytest.raises(TypeError, match="loc_rv must be a RandomVariable"): - TruncatedNormalGroupSdPrior( - "sd", - loc_rv=0.5, - scale_rv=DeterministicVariable("sd_scale", 0.15), - ) - - with pytest.raises(TypeError, match="scale_rv must be a RandomVariable"): - TruncatedNormalGroupSdPrior( - "sd", - loc_rv=DeterministicVariable("sd_loc", 0.5), - scale_rv=0.15, - ) - - def test_rejects_negative_sd_min(self): - """Test that negative sd_min is rejected.""" - with pytest.raises(ValueError, match="sd_min must be non-negative"): - TruncatedNormalGroupSdPrior( - "sd", - loc_rv=DeterministicVariable("sd_loc", 0.5), - scale_rv=DeterministicVariable("sd_scale", 0.15), - sd_min=-0.1, - ) - - def test_validate(self): - """Test that validate() runs without error.""" - prior = TruncatedNormalGroupSdPrior( - "sd", - loc_rv=DeterministicVariable("sd_loc", 0.5), - scale_rv=DeterministicVariable("sd_scale", 0.15), - ) - prior.validate() # Should not raise - - -class TestStudentTGroupModePrior: - """Test StudentTGroupModePrior.""" - - def test_sample_shape(self): - """Test that sample returns correct shape.""" - prior = StudentTGroupModePrior( - "mode", - sd_rv=DeterministicVariable("sd", 1.0), - df_rv=DeterministicVariable("df", 4.0), - ) - - with numpyro.handlers.seed(rng_seed=42): - samples = prior.sample(n_groups=5) - - assert samples.shape == (5,) - - def test_heavier_tails_than_normal(self): - """Test Student-t produces more extreme values than Normal.""" - # df=2 gives very heavy tails - student_prior = StudentTGroupModePrior( - "s", - sd_rv=DeterministicVariable("sd_s", 1.0), - df_rv=DeterministicVariable("df", 2.0), - ) - normal_prior = HierarchicalNormalPrior( - "n", sd_rv=DeterministicVariable("sd_n", 1.0) - ) - - n_samples = 5000 - with numpyro.handlers.seed(rng_seed=42): - student_samples = student_prior.sample(n_groups=n_samples) - with numpyro.handlers.seed(rng_seed=42): - normal_samples = normal_prior.sample(n_groups=n_samples) - - # Student-t should have more extreme values (higher max absolute value) - assert jnp.max(jnp.abs(student_samples)) > jnp.max(jnp.abs(normal_samples)) - - def test_rejects_non_random_variable_params(self): - """Test that non-RandomVariable parameters are rejected.""" - with pytest.raises(TypeError, match="sd_rv must be a RandomVariable"): - StudentTGroupModePrior( - "mode", - sd_rv=1.0, - df_rv=DeterministicVariable("df", 4.0), - ) - - with pytest.raises(TypeError, match="df_rv must be a RandomVariable"): - StudentTGroupModePrior( - "mode", - sd_rv=DeterministicVariable("sd", 1.0), - df_rv=4.0, - ) - - def test_validate(self): - """Test that validate() runs without error.""" - prior = StudentTGroupModePrior( - "mode", - sd_rv=DeterministicVariable("sd", 1.0), - df_rv=DeterministicVariable("df", 4.0), - ) - prior.validate() # Should not raise - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) From 70199552b3eb70641fa0ebdabc555609faef3634 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Wed, 7 Jan 2026 14:48:11 -0500 Subject: [PATCH 28/33] updated observation process measurements per reviewer comments --- .../observation_processes_measurements.qmd | 314 ++++++++++-------- 1 file changed, 179 insertions(+), 135 deletions(-) diff --git a/docs/tutorials/observation_processes_measurements.qmd b/docs/tutorials/observation_processes_measurements.qmd index edfb807b..ab5af73e 100644 --- a/docs/tutorials/observation_processes_measurements.qmd +++ b/docs/tutorials/observation_processes_measurements.qmd @@ -4,7 +4,7 @@ format: gfm engine: jupyter --- -This tutorial demonstrates how to use the `Measurements` observation process to model continuous measurement data such as wastewater viral concentrations. +This tutorial demonstrates how to use the `Measurements` observation process to model continuous measurement data. We first explain the general framework, then illustrate with a wastewater viral concentration example. ```{python} # | label: setup @@ -22,52 +22,155 @@ from pyrenew.randomvariable import DistributionalVariable from pyrenew.deterministic import DeterministicVariable, DeterministicPMF ``` +# Part 1: The Measurements Framework + ## Overview -Measurement observation processes model continuous signals derived from infections, such as viral RNA concentrations in wastewater. Unlike count observations (hospital admissions, deaths), measurements are continuous values and in some cases may be negative. +The `Measurements` class models continuous signals derived from infections. Unlike count observations (hospital admissions, deaths), measurements are continuous values that may span orders of magnitude or even be negative (e.g., log-transformed data). -As an example, we'll consider the wastewater viral genome measurement from the [Pyrenew-HEW](https://github.com/cdcgov/pyrenew-hew) family of models. The predicted measurement on day $t$ is: +**Examples of measurement data:** -$$\lambda_t = \frac{G}{V} \cdot \sum_{d=0}^{D} I_{t-d} \cdot p_d$$ +- Wastewater viral concentrations +- Air quality pathogen levels +- Serological assay results +- Environmental sensor readings -where: +### The general pattern + +All measurement observation processes follow the same pattern: -- $I_{t-d}$ is the number of incident (new) infections on day $t-d$ -- $G$ is a scaling factor (here, genome copies shed per infection) -- $V$ is another scaling factor (here, wastewater volume per person per day) -- $p_d$ is a delay distribution (here, normalized viral genome shedding by time-since-infection) -- $D$ is the maximum delay considered. +$$\text{observed} \sim \text{Noise}\bigl(\text{predicted}(\text{infections})\bigr)$$ -Observed log-concentrations are generated by sampling from a normal distribution: +where: + +1. **`_predicted_obs(infections)`**: Transforms infections into predicted measurement values (you implement this) +2. **Noise model**: Adds stochastic variation around predictions (provided by PyRenew) -$$\log(C_t) \sim \text{Normal}(\mu = \log(\lambda_t), \sigma)$$ +The `Measurements` base class provides: -The log-normal distribution is appropriate for concentration data because concentrations are strictly positive and often span several orders of magnitude. +- Convolution utilities for temporal delays +- Timeline alignment between infections and observations +- Integration with hierarchical noise models +- Support for multiple sensors and subpopulations ### Comparison with count observations -The core convolution $\sum_{d=0}^{D} I_{t-d} \cdot p_d$ is the same as for count observations. The key differences are: +The core convolution structure is shared with count observations, but key aspects differ: | Aspect | Counts | Measurements | |--------|--------|--------------| -| Scaling factor | Ascertainment rate $\alpha \in [0,1]$ | Domain-specific (e.g., $G/V$) | -| Temporal PMF | Delay distribution | Shedding/decay kinetics | -| Output space | Expected counts (linear) | Log-concentrations | -| Noise model | Poisson or Negative Binomial | Normal on log scale | -| Subpop structure | Optional (`CountsBySubpop`) | Inherent (hierarchical effects) | +| Output type | Discrete counts | Continuous values | +| Output space | Linear (expected counts) | Often log-transformed | +| Noise model | Poisson or Negative Binomial | Normal (often on log scale) | +| Scaling | Ascertainment rate $\alpha \in [0,1]$ | Domain-specific | +| Subpop structure | Optional (`CountsBySubpop`) | Inherent (sensor/site effects) | + +## The noise model + +Measurement data typically exhibits **sensor-level variability**: different instruments, labs, or sampling locations have systematic biases and different precision levels. + +`HierarchicalNormalNoise` models this with two per-sensor parameters: + +- **Sensor mode**: Systematic bias (additive shift) +- **Sensor SD**: Measurement precision (noise level) + +``` +observed ~ Normal(predicted + sensor_mode[sensor], sensor_sd[sensor]) +``` + +The noise model samples sensor-level parameters within a plate, so any `RandomVariable` can be used as a prior: + +```{python} +# | label: noise-model-general +# Sensor modes: zero-centered, allowing positive or negative bias +sensor_mode_rv = DistributionalVariable("sensor_mode", dist.Normal(0, 0.5)) + +# Sensor SDs: must be positive, truncated normal is a common choice +sensor_sd_rv = DistributionalVariable( + "sensor_sd", dist.TruncatedNormal(loc=0.3, scale=0.15, low=0.05) +) + +# Create noise model +noise = HierarchicalNormalNoise( + sensor_mode_rv=sensor_mode_rv, + sensor_sd_rv=sensor_sd_rv, +) +``` + +## The indexing system -**Key features of measurement data:** +Measurement observations use three index arrays to map observations to their context: + +| Index array | Purpose | +|-------------|---------| +| `times` | Day index for each observation | +| `subpop_indices` | Which infection trajectory (subpopulation) generated each observation | +| `sensor_indices` | Which sensor made each observation (determines noise parameters) | + +This flexible indexing supports: + +- **Irregular sampling**: Observations don't need to be daily +- **Multiple sensors per subpopulation**: Different labs analyzing the same source +- **Multiple subpopulations per sensor**: One sensor serving multiple areas (less common) + +## Subclassing Measurements + +To create a measurement process for your domain, subclass `Measurements` and implement: + +1. **`_predicted_obs(infections)`**: Transform infections to predicted values +2. **`validate()`**: Check parameter validity +3. **`lookback_days()`**: Return the temporal PMF length + +```python +class MyMeasurement(Measurements): + def __init__(self, temporal_pmf_rv, noise, my_scaling_param): + super().__init__(temporal_pmf_rv=temporal_pmf_rv, noise=noise) + self.my_scaling_param = my_scaling_param + + def _predicted_obs(self, infections): + # Your domain-specific transformation here + pmf = self.temporal_pmf_rv() + # ... convolve, scale, transform ... + return predicted_values + + def validate(self): + pmf = self.temporal_pmf_rv() + self._validate_pmf(pmf, "temporal_pmf_rv") + + def lookback_days(self): + return len(self.temporal_pmf_rv()) +``` + +--- + +# Part 2: Wastewater Example + +We now illustrate the framework with a wastewater viral concentration model, based on the [PyRenew-HEW](https://github.com/cdcgov/pyrenew-hew) family of models. + +## The wastewater signal + +Wastewater treatment plants measure viral RNA concentrations in sewage. The predicted concentration depends on: + +- **Infections**: People shed virus into wastewater +- **Shedding kinetics**: Viral shedding peaks a few days after infection +- **Scaling factors**: Genome copies per infection, wastewater volume + +The predicted log-concentration on day $t$ is: + +$$\log(\lambda_t) = \log\left(\frac{G}{V} \cdot \sum_{d=0}^{D} I_{t-d} \cdot p_d\right)$$ + +where: -- **Multiple sites**: Each jurisdiction has several measurement sites (e.g., wastewater treatment plants). -- **Irregular sampling**: Sites measure on different schedules. -- **Site-level variability**: Lab protocols and sampling methods vary across sites. -- **Temporal lag**: Signal peaks several days after infection. +- $I_{t-d}$ is infections on day $t-d$ +- $p_d$ is the shedding kinetics PMF (fraction shed on day $d$ post-infection) +- $G$ is genome copies shed per infection +- $V$ is wastewater volume per person per day -**Note on terminology:** In real-world inference, infections are *latent* (unobserved) and must be estimated from observed data. In this tutorial, we simulate the observation process by specifying infections directly and showing how they produce measurements through convolution and sampling. +Observations are log-concentrations with normal noise: -## Subclassing Measurements for wastewater +$$y_t \sim \text{Normal}(\log(\lambda_t) + \text{sensor\_mode}, \text{sensor\_sd})$$ -The `Measurements` class is abstract—you must subclass it and implement `_predicted_obs()` for your specific signal type. Here we create a `Wastewater` class for viral concentration measurements: +## Implementing the Wastewater class ```{python} # | label: wastewater-class @@ -151,7 +254,9 @@ class Wastewater(Measurements): return jnp.log(concentration) ``` -## Viral shedding kinetics +## Configuring wastewater-specific parameters + +### Viral shedding kinetics The shedding PMF describes what fraction of total viral shedding occurs on each day after infection: @@ -199,41 +304,40 @@ plt.tight_layout() plt.show() ``` -## Genome copies and wastewater volume +### Genome copies and wastewater volume ```{python} # | label: scaling-params -# Log10 genome copies shed per infection (typical: 8-10) +# Log10 genome copies shed per infection (typical range: 8-10) log10_genome_rv = DeterministicVariable("log10_genome", 9.0) # Wastewater volume per person per day (mL) ml_per_person_per_day = 1000.0 ``` -## Noise model with sensor-level effects - -A measurement depends on both the thing being measured (e.g., wastewater from a treatment plant) and the calibration of the lab/instruments used to obtain that measurement. We call this combination a "sensor"—the WWTP/lab pair that determines the measurement characteristics. Different sensors have systematic biases and variabilities that we model with hierarchical effects. +### Sensor-level noise -The `HierarchicalNormalNoise` model samples sensor-level parameters within a plate context, so any `RandomVariable` can be used. For the sensor-level standard deviations, we use a truncated normal prior—a common choice for SD parameters because it naturally constrains values to be positive. +For wastewater, a "sensor" is a WWTP/lab pair—the combination of treatment plant and laboratory that determines measurement characteristics: ```{python} -# | label: noise-model -# Sensor-level mode: systematic differences between sensors (zero-centered normal) -sensor_mode_rv = DistributionalVariable("ww_sensor_mode", dist.Normal(0, 0.5)) +# | label: ww-noise-model +# Sensor-level mode: systematic differences between WWTP/lab pairs +ww_sensor_mode_rv = DistributionalVariable( + "ww_sensor_mode", dist.Normal(0, 0.5) +) -# Sensor-level SD: measurement variability within each sensor (truncated normal) -sensor_sd_rv = DistributionalVariable( +# Sensor-level SD: measurement variability within each WWTP/lab pair +ww_sensor_sd_rv = DistributionalVariable( "ww_sensor_sd", dist.TruncatedNormal(loc=0.3, scale=0.15, low=0.10) ) -# Create the noise model ww_noise = HierarchicalNormalNoise( - sensor_mode_rv=sensor_mode_rv, - sensor_sd_rv=sensor_sd_rv, + sensor_mode_rv=ww_sensor_mode_rv, + sensor_sd_rv=ww_sensor_sd_rv, ) ``` -## Creating the observation process +## Creating the wastewater observation process ```{python} # | label: create-process @@ -247,28 +351,26 @@ ww_process = Wastewater( print(f"Required lookback: {ww_process.lookback_days()} days") ``` -### Timeline alignment and lookback period +## Demonstrations -The observation process convolves infections with a shedding distribution, maintaining alignment between input and output: day $t$ in the output corresponds to day $t$ in the input. +### Timeline alignment -Wastewater concentrations depend on infections from prior days (the length of our shedding distribution minus one). The method `lookback_days()` returns this length; the first valid observation day is at index `lookback - 1`. Earlier days are marked invalid with NaN. +The observation process maintains alignment: day $t$ in output corresponds to day $t$ in input. The first `lookback_days() - 1` days have incomplete history and are marked invalid. ```{python} # | label: helper-function def first_valid_observation_day(obs_process) -> int: - """Return the first day index with complete infection history for convolution.""" + """Return the first day index with complete infection history.""" return obs_process.lookback_days() - 1 ``` -## Simulating wastewater observations from infections +### Simulating from an infection spike -To demonstrate how the observation process works, we simulate concentrations from a spike of infections on a single day. +To see how infections spread into concentrations via shedding kinetics, we simulate from a single-day spike: ```{python} # | label: simulate-spike n_days = 50 - -# First valid observation day day_one = first_valid_observation_day(ww_process) # Create infections with a spike (shape: n_days x n_subpops) @@ -320,11 +422,9 @@ plt.tight_layout() plt.show() ``` -Because all infections occur on a single day, we can see how they spread into wastewater concentrations over subsequent days according to the shedding kinetics. - -## Observation noise +### Observation noise -The log-normal noise model adds stochastic variation. Sampling multiple times from the same infections shows the range of possible observations: +Sampling multiple times from the same infections shows the range of possible observations: ```{python} # | label: sample-realizations @@ -386,23 +486,9 @@ plt.tight_layout() plt.show() ``` -```{python} -# | label: timeline-stats -print("Timeline Analysis:") -print( - f" Infection spike on day {spike_day}: {infections[infection_spike_day, 0]:.0f} people" -) -print( - f" Mode delay from infection to concentration peak: {mode_shedding_day} days" -) -print( - f" Expected concentration peak: day {spike_day + mode_shedding_day} (= {spike_day} + {mode_shedding_day})" -) -``` - -## Sensor-level variability +### Sensor-level variability -Different sensors measuring the same underlying infections will show systematic differences: +Different sensors (WWTP/lab pairs) measuring the same underlying infections show systematic differences: ```{python} # | label: multi-sensor @@ -454,71 +540,11 @@ plt.tight_layout() plt.show() ``` -Each sensor has a different baseline concentration (sensor mode) and different measurement variability (sensor SD). The hierarchical model learns these sensor-specific effects during inference. - -## Effect of genome shedding parameter - -Higher genome shedding means higher observed concentrations: - -```{python} -# | label: compare-genome -genome_values = [8.5, 9.0, 9.5, 10.0] -infections_decay = 2000.0 * jnp.exp(-jnp.arange(40) / 15.0) -infections_decay = infections_decay.reshape(-1, 1) - -observation_days_cmp = jnp.arange(10, 30, dtype=jnp.int32) -n_obs_cmp = len(observation_days_cmp) - -results = {} -for genome_val in genome_values: - process_temp = Wastewater( - shedding_kinetics_rv=shedding_rv, - log10_genome_per_infection_rv=DeterministicVariable( - "log10_genome", genome_val - ), - ml_per_person_per_day=ml_per_person_per_day, - noise=ww_noise, - ) - - with numpyro.handlers.seed(rng_seed=42): - ww_temp = process_temp.sample( - infections=infections_decay, - subpop_indices=jnp.zeros(n_obs_cmp, dtype=jnp.int32), - sensor_indices=jnp.zeros(n_obs_cmp, dtype=jnp.int32), - times=observation_days_cmp, - obs=None, - n_sensors=1, - ) - results[genome_val] = np.array(ww_temp.observed) -``` - -```{python} -# | label: fig-genome-effect -# | fig-cap: Effect of genome shedding parameter on concentrations -fig, ax = plt.subplots(figsize=(8, 4)) +Each sensor has a different baseline (sensor mode) and measurement variability (sensor SD). During inference, these sensor-specific effects are learned from data. -colors = ["steelblue", "coral", "green", "purple"] -for (genome_val, conc), color in zip(results.items(), colors): - ax.plot( - np.array(observation_days_cmp), - conc, - "-o", - color=color, - label=f"log10 = {genome_val}", - markersize=4, - ) +### Multiple subpopulations -ax.set_xlabel("Day") -ax.set_ylabel("Log Viral Concentration") -ax.set_title("Effect of Genome Shedding Parameter") -ax.legend() -plt.tight_layout() -plt.show() -``` - -## Multiple subpopulations - -In hierarchical models, each sensor serves a distinct subpopulation (catchment area). Different subpopulations can have different infection levels: +In hierarchical models, each sensor may serve a distinct catchment area (subpopulation) with different infection dynamics: ```{python} # | label: multi-subpop @@ -550,7 +576,7 @@ with numpyro.handlers.seed(rng_seed=42): ```{python} # | label: fig-multi-subpop -# | fig-cap: Sensors observing different subpopulations with different infection dynamics +# | fig-cap: Sensors observing different subpopulations fig, ax = plt.subplots(figsize=(8, 4)) for subpop in [0, 1]: @@ -569,3 +595,21 @@ ax.legend() plt.tight_layout() plt.show() ``` + +--- + +## Summary + +The `Measurements` framework provides: + +1. **A consistent interface** for continuous observation processes +2. **Hierarchical noise models** that capture sensor-level variability +3. **Flexible indexing** for irregular, multi-sensor, multi-subpopulation data +4. **Convolution utilities** with proper timeline alignment + +To use it for your domain: + +1. Subclass `Measurements` +2. Implement `_predicted_obs()` with your signal transformation +3. Configure appropriate priors for sensor-level effects +4. Use the indexing system to map observations to their context From 2455242322a06a57021d27a1357fc2530d3c9729 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Wed, 7 Jan 2026 15:02:01 -0500 Subject: [PATCH 29/33] more unit tests --- test/test_observation_counts.py | 17 +++++++++++++++ test/test_observation_measurements.py | 31 +++++++++++++++++++++++++++ 2 files changed, 48 insertions(+) diff --git a/test/test_observation_counts.py b/test/test_observation_counts.py index bcfbd54f..33f4eade 100644 --- a/test/test_observation_counts.py +++ b/test/test_observation_counts.py @@ -463,6 +463,23 @@ def test_infection_resolution_counts(self, simple_delay_pmf): assert process.infection_resolution() == "aggregate" +class TestNoiseRepr: + """Test noise model __repr__ methods.""" + + def test_poisson_noise_repr(self): + """Test PoissonNoise __repr__ method.""" + noise = PoissonNoise() + assert repr(noise) == "PoissonNoise()" + + def test_negative_binomial_noise_repr(self): + """Test NegativeBinomialNoise __repr__ method.""" + conc_rv = DeterministicVariable("conc", 10.0) + noise = NegativeBinomialNoise(conc_rv) + repr_str = repr(noise) + assert "NegativeBinomialNoise" in repr_str + assert "concentration_rv" in repr_str + + class TestNoiseValidation: """Test noise model validation methods.""" diff --git a/test/test_observation_measurements.py b/test/test_observation_measurements.py index 32e053eb..b4ed3488 100644 --- a/test/test_observation_measurements.py +++ b/test/test_observation_measurements.py @@ -94,6 +94,18 @@ def test_infection_resolution_is_subpop(self): class TestHierarchicalNormalNoise: """Test HierarchicalNormalNoise model.""" + def test_repr(self): + """Test HierarchicalNormalNoise __repr__ method.""" + sensor_mode_rv = DistributionalVariable("mode", dist.Normal(0, 0.5)) + sensor_sd_rv = DistributionalVariable( + "sd", dist.TruncatedNormal(0.3, 0.15, low=0.05) + ) + noise = HierarchicalNormalNoise(sensor_mode_rv, sensor_sd_rv) + repr_str = repr(noise) + assert "HierarchicalNormalNoise" in repr_str + assert "sensor_mode_rv" in repr_str + assert "sensor_sd_rv" in repr_str + def test_validate(self): """Test HierarchicalNormalNoise validate method.""" sensor_mode_rv = DistributionalVariable("mode", dist.Normal(0, 0.5)) @@ -154,6 +166,25 @@ def test_sample_with_observations(self): class TestConcreteMeasurements: """Test concrete Measurements implementation.""" + def test_repr(self): + """Test Measurements __repr__ method.""" + shedding_pmf = jnp.array([0.3, 0.4, 0.3]) + sensor_mode_rv = DistributionalVariable("mode", dist.Normal(0, 0.5)) + sensor_sd_rv = DistributionalVariable( + "sd", dist.TruncatedNormal(0.3, 0.15, low=0.05) + ) + noise = HierarchicalNormalNoise(sensor_mode_rv, sensor_sd_rv) + + process = ConcreteMeasurements( + temporal_pmf_rv=DeterministicPMF("shedding", shedding_pmf), + noise=noise, + ) + + repr_str = repr(process) + assert "ConcreteMeasurements" in repr_str + assert "temporal_pmf_rv" in repr_str + assert "noise" in repr_str + def test_sample_shape(self): """Test that sample returns correct shape.""" shedding_pmf = jnp.array([0.3, 0.4, 0.3]) From 101e40168fb57b8e9237ff02cc4e9bba5d52d23a Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Thu, 8 Jan 2026 14:54:26 -0500 Subject: [PATCH 30/33] use plotnine for plots in observation_processes tutorial --- .../observation_processes_counts.qmd | 614 ++++++++++++------ mkdocs.yml | 2 +- pyproject.toml | 1 + uv.lock | 73 +++ 4 files changed, 502 insertions(+), 188 deletions(-) diff --git a/docs/tutorials/observation_processes_counts.qmd b/docs/tutorials/observation_processes_counts.qmd index ad50e074..71055d96 100644 --- a/docs/tutorials/observation_processes_counts.qmd +++ b/docs/tutorials/observation_processes_counts.qmd @@ -1,23 +1,37 @@ --- title: "Observation processes for count data" -format: gfm +format: + gfm: + fig-width: 16 + fig-height: 10 engine: jupyter --- This tutorial demonstrates how to use the `Counts` observation process to model count data such as hospital admissions, emergency department visits, or deaths. +
+Show code ```{python} # | label: setup # | output: false import jax.numpy as jnp import numpy as np import numpyro +import plotnine as p9 +import pandas as pd +from pathlib import Path +import sys +import warnings +from plotnine.exceptions import PlotnineWarning +warnings.filterwarnings("ignore", category=PlotnineWarning) + import matplotlib.pyplot as plt from pyrenew.observation import Counts, NegativeBinomialNoise, PoissonNoise from pyrenew.deterministic import DeterministicVariable, DeterministicPMF from pyrenew import datasets ``` +
## Overview @@ -48,12 +62,14 @@ For hospital admissions data, we construct a `Counts` observation process. The delay is the key mechanism: infections from $d$ days ago ($I_{t-d}$) contribute to today's hospital admissions ($\lambda_t$) weighted by the probability ($p_d$) that an infection leads to hospitalization after exactly $d$ days. The convolution sums these contributions across all past days. The process generates hospital admissions by sampling from a negative binomial distribution: + $$Y_t \sim \text{NegativeBinomial}(\mu = \lambda_t, \text{concentration} = \phi)$$ The concentration parameter $\phi$ (sometimes called $k$ or the dispersion parameter) controls overdispersion: as $\phi \to \infty$, the distribution approaches Poisson; smaller values allow greater overdispersion. We use the negative binomial distribution because real-world hospital admission counts exhibit overdispersion—the variance exceeds the mean. The Poisson distribution assumes variance equals the mean, which is too restrictive. The negative binomial adds an overdispersion term: + $$\text{Var}[Y_t] = \mu + \frac{\mu^2}{\phi}$$ In this example, we use fixed parameter values for illustration; in practice, these parameters would be estimated from data using weakly informative priors. @@ -63,47 +79,68 @@ In this example, we use fixed parameter values for illustration; in practice, th The delay distribution specifies the probability that an infected person is hospitalized $d$ days after infection, conditional on the infection leading to a hospitalization. For example, if `hosp_delay_pmf[5] = 0.2`, then 20% of infections that result in hospitalization will appear as hospital admissions 5 days after infection. -We load a delay distribution from PyRenew's datasets: +We load a delay distribution from PyRenew's datasets which peaks around day 8-9 post-infection, compute summary statistics, and plot it. ```{python} -# | label: load-delay +# | label: delay-distribution inf_hosp_int = datasets.load_infection_admission_interval() hosp_delay_pmf = jnp.array(inf_hosp_int["probability_mass"].to_numpy()) - delay_rv = DeterministicPMF("inf_to_hosp_delay", hosp_delay_pmf) # Summary statistics days = np.arange(len(hosp_delay_pmf)) mean_delay = float(np.sum(days * hosp_delay_pmf)) mode_delay = int(np.argmax(hosp_delay_pmf)) -print(f"Mode delay: {mode_delay} days, Mean delay: {mean_delay:.1f} days") +sd = float(np.sqrt(np.sum(days**2 * hosp_delay_pmf) - mean_delay**2)) +print(f"mode delay: {mode_delay}, mean delay: {mean_delay:.1f}, sd: {sd:.1f}") ``` +
+Show code ```{python} -# | label: fig-delay-distribution -# | fig-cap: Infection-to-hospitalization delay distribution -fig, ax = plt.subplots(figsize=(8, 4)) -ax.bar(days, hosp_delay_pmf, color="steelblue", alpha=0.7, edgecolor="black") -ax.axvline( - mode_delay, - color="purple", - linestyle="-", - linewidth=2, - label=f"Mode: {mode_delay}", +# | label: plot-delay-distribution +delay_df = pd.DataFrame( + {"days": days, "probability": np.array(hosp_delay_pmf)} ) -ax.axvline( - mean_delay, - color="red", - linestyle="--", - linewidth=2, - label=f"Mean: {mean_delay:.1f}", + +plot_delay = ( + p9.ggplot(delay_df, p9.aes(x="days", y="probability")) + + p9.geom_col(fill="steelblue", alpha=0.7, color="black") + + p9.geom_vline( + xintercept=mode_delay, color="purple", linetype="solid", size=1 + ) + + p9.geom_vline( + xintercept=mean_delay, color="red", linetype="dashed", size=1 + ) + + p9.labs( + x="Days from infection to hospitalization", + y="Probability", + title="Infection-to-Hospitalization Delay Distribution", + ) + + p9.theme_grey() + + p9.theme(plot_title=p9.element_text(size=14, weight="bold")) + + p9.annotate( + "text", + x=mode_delay + 8, + y=max(delay_df["probability"]) * 0.95, + label=f"Mode: {mode_delay} days", + color="purple", + size=10, + ) + + p9.annotate( + "text", + x=mean_delay + 8, + y=max(delay_df["probability"]) * 0.8, + label=f"Mean: {mean_delay:.1f} days", + color="red", + size=10, + ) ) -ax.set_xlabel("Days from infection to hospitalization") -ax.set_ylabel("Probability") -ax.set_title("Infection-to-Hospitalization Delay Distribution") -ax.legend() -plt.tight_layout() -plt.show() +``` +
+```{python} +# | label: viz-delay-distribution +plot_delay ``` ## Creating a Counts observation process @@ -114,6 +151,12 @@ A `Counts` object takes the following arguments: - **`delay_distribution_rv`**: delay distribution from infection to observation (PMF) - **`noise`**: noise model (`PoissonNoise()` or `NegativeBinomialNoise(concentration_rv)`) +For hospital admissions, the ascertainment rate is specifically called the infection-hospitalization rate (IHR). +In this example, the percentage of infections which lead to hospitalization is treated as a fixed value, +which will allow us to see how different values affect the model. +The concentration parameter for the negative binomial noise model is also fixed. +In practice, both of these parameters would be given a somewhat informative prior and then inferred. + ```{python} # | label: create-counts-process # Infection-hospitalization ratio (1% of infections lead to hospitalization) @@ -128,8 +171,6 @@ hosp_process = Counts( delay_distribution_rv=delay_rv, noise=NegativeBinomialNoise(concentration_rv), ) - -print(f"Required lookback: {hosp_process.lookback_days()} days") ``` ### Timeline alignment and lookback period @@ -140,159 +181,254 @@ Hospital admissions depend on infections from prior days (the length of our dela ```{python} # | label: helper-function +print(f"Required lookback: {hosp_process.lookback_days()} days") + def first_valid_observation_day(obs_process) -> int: """Return the first day index with complete infection history for convolution.""" return obs_process.lookback_days() - 1 ``` -## Simulating hospital admissions from infections +## Simulating observed hospital admissions given a single day's worth of infections + +To demonstrate how a `Counts` observation process works, we examine how infections occurring on a single day result in observed hospital admissions. -To demonstrate how the observation process works, we simulate admissions from a spike of infections on a single day. ```{python} # | label: simulate-spike n_days = 100 -lookback = hosp_process.lookback_days() - -# First valid observation day -day_one = lookback - 1 +day_one = first_valid_observation_day(hosp_process) # Create infections with a spike infection_spike_day = day_one + 10 infections = jnp.zeros(n_days) infections = infections.at[infection_spike_day].set(2000) - -# Sample hospital admissions -with numpyro.handlers.seed(rng_seed=42): - hosp_admissions = hosp_process.sample( - infections=infections, - obs=None, # Sample from prior (no observed data) - ) ``` +We plot the infections starting from day_one (the first valid observation day, after the lookback period). +
+Show code ```{python} -# | label: fig-spike-infections -# | fig-cap: Input infections with a single-day spike +# | label: plot-infections # Plot relative to first valid observation day -spike_day = infection_spike_day - day_one n_plot_days = n_days - day_one +rel_spike_day = infection_spike_day - day_one -fig, ax = plt.subplots(figsize=(8, 4)) -ax.plot( - np.arange(n_plot_days), - np.array(infections[day_one:]), - "-o", - color="darkblue", +infections_df = pd.DataFrame( + { + "day": np.arange(n_plot_days), + "count": np.array(infections[day_one:]), + } ) -ax.axvline(spike_day, color="darkred", linestyle="--", alpha=0.7) -ax.annotate( - f"Infection spike\n(day {spike_day})", - xy=(spike_day, 1800), - xytext=(spike_day + 5, 1800), - fontsize=10, - color="darkred", + +max_infection_count = float(jnp.max(infections[day_one:])) + +plot_infections = ( + p9.ggplot(infections_df, p9.aes(x="day", y="count")) + + p9.geom_line(color="darkblue", size=1) + + p9.geom_point(color="darkblue", size=2) + + p9.geom_vline( + xintercept=rel_spike_day, + color="darkred", + linetype="dashed", + alpha=0.5, + ) + + p9.labs(x="Day", y="Daily Infections", title="Infections (Input)") + + p9.theme_grey() + + p9.theme(plot_title=p9.element_text(size=13, weight="bold")) + + p9.annotate( + "text", + x=rel_spike_day + 2, + y=max_infection_count * 0.9, + label=f"Infection spike (day {rel_spike_day})", + color="darkred", + size=10, + ) ) -ax.set_xlabel("Day") -ax.set_ylabel("Daily Infections") -ax.set_title("Infections (Input)") -plt.tight_layout() -plt.show() ``` +
+```{python} +# | label: viz-infections +plot_infections +``` + +Because all infections occur on a single day, this allows us to see how one day's worth of infections result in hospital admissions spread over subsequent days according to the delay distribution. -Because all infections occur on a single day, we can see how they spread into hospital admissions over subsequent days according to the delay distribution: +## Predicted admissions without observation noise. + +First, we compute the predicted admissions from the convolution alone, without observation noise. This is the mean of the distribution from which samples are drawn. ```{python} -# | label: fig-spike-admissions -# | fig-cap: Hospital admissions from a single-day infection spike -fig, ax = plt.subplots(figsize=(8, 4)) -ax.plot( - np.arange(n_plot_days), - np.array(hosp_admissions.observed[day_one:]), - "-o", - color="purple", +# | label: predicted-no-noise +# Compute predicted admissions (convolution only, no observation noise) +from pyrenew.convolve import compute_delay_ascertained_incidence + +# Scale infections by IHR (ascertainment rate) +infections_scaled = infections * float(ihr_rv.sample()) +predicted_admissions, offset = compute_delay_ascertained_incidence( + p_observed_given_incident=1.0, + latent_incidence=infections_scaled, + delay_incidence_to_observation_pmf=hosp_delay_pmf, + pad=True, ) -ax.axvline( - spike_day, - color="darkred", - linestyle="--", - alpha=0.5, - label="Infection spike", +``` + +
+Show code +```{python} +# | label: plot-predicted-no-noise +# Relative peak day for plotting +peak_day = rel_spike_day + mode_delay + +# Plot predicted admissions (x-axis: day_one = first valid observation day) +predicted_df = pd.DataFrame( + { + "day": np.arange(n_plot_days), + "admissions": np.array(predicted_admissions[day_one:]), + } ) -ax.axvline( - spike_day + mode_delay, - color="purple", - linestyle="--", - alpha=0.5, - label="Peak admissions", + +max_predicted = float(predicted_df["admissions"].max()) +plot_predicted = ( + p9.ggplot(predicted_df, p9.aes(x="day", y="admissions")) + + p9.geom_line(color="purple", size=1) + + p9.geom_point(color="purple", size=1.5) + + p9.geom_vline( + xintercept=rel_spike_day, + color="darkred", + linetype="dashed", + alpha=0.5, + ) + + p9.geom_vline( + xintercept=peak_day, + color="purple", + linetype="dashed", + alpha=0.5, + ) + + p9.labs( + x="Day", + y="Hospital Admissions", + title="Predicted Hospital Admissions (Deterministic)", + ) + + p9.theme_grey() + + p9.annotate( + "text", + x=rel_spike_day, + y=max_predicted * 1.05, + label=f"Infection spike\n(day {rel_spike_day})", + color="darkred", + size=9, + ha="center", + ) + + p9.annotate( + "text", + x=peak_day, + y=max_predicted * 1.05, + label=f"Peak\n(day {peak_day})", + color="purple", + size=9, + ha="center", + ) ) -ax.set_xlabel("Day") -ax.set_ylabel("Hospital Admissions") -ax.set_title("Hospital Admissions (Output)") -ax.legend() -plt.tight_layout() -plt.show() +``` +
+```{python} +# | label: viz-predicted +plot_predicted ``` -The admissions peak occurs `{python} mode_delay` days after the infection spike, matching the mode of the delay distribution. +The predicted admissions mirror the delay distribution, shifted by the infection spike day and scaled by the IHR. -## Observation noise + +## Observation Noise (Negative Binomial) The negative binomial distribution adds stochastic variation. Sampling multiple times from the same infections shows the range of possible observations: ```{python} # | label: sample-realizations +# Sample 50 realizations of hospital admissions from the same infection spike n_samples = 50 -samples = [] +samples_list = [] for seed in range(n_samples): with numpyro.handlers.seed(rng_seed=seed): - result = hosp_process.sample(infections=infections, obs=None) - samples.append(np.array(result.observed[day_one:])) + hosp_sample = hosp_process.sample(infections=infections, obs=None) + + for i, val in enumerate(hosp_sample.observed[day_one:]): + samples_list.append( + { + "day": i, + "admissions": float(val), + "sample": seed, + "type": "sampled", + } + ) -samples = np.array(samples) -sample_mean = samples.mean(axis=0) +# Add predicted values +for i, val in enumerate(predicted_admissions[day_one:]): + samples_list.append( + { + "day": i, + "admissions": float(val), + "sample": -1, + "type": "predicted", + } + ) ``` +
+Show code ```{python} -# | label: fig-sampled-admissions -# | fig-cap: Multiple realizations showing observation noise -fig, ax = plt.subplots(figsize=(8, 4)) - -# Plot all samples in light orange -for i in range(1, n_samples): - ax.plot( - np.arange(n_plot_days), - samples[i], +# | label: plot-realizations +samples_df = pd.DataFrame(samples_list) +sampled_df = samples_df[samples_df["type"] == "sampled"] +predicted_noise_df = samples_df[samples_df["type"] == "predicted"] + +# Separate one sample to highlight +highlight_sample = 0 +other_samples_df = sampled_df[sampled_df["sample"] != highlight_sample] +highlight_df = sampled_df[sampled_df["sample"] == highlight_sample] + +plot_50_samples = ( + p9.ggplot() + + p9.geom_line( + p9.aes(x="day", y="admissions", group="sample"), + data=other_samples_df, color="orange", alpha=0.15, - linewidth=0.5, + size=0.5, ) - -# Highlight one sample -ax.plot( - np.arange(n_plot_days), - samples[0], - color="steelblue", - linewidth=1, - label="One realization", -) - -# Sample mean -ax.plot( - np.arange(n_plot_days), - sample_mean, - color="darkred", - linewidth=1.2, - label="Sample mean", + + p9.geom_line( + p9.aes(x="day", y="admissions"), + data=highlight_df, + color="steelblue", + size=1, + ) + + p9.geom_line( + p9.aes(x="day", y="admissions"), + data=predicted_noise_df, + color="darkred", + size=1.2, + ) + + p9.geom_vline( + xintercept=rel_spike_day, + color="darkblue", + linetype="dashed", + alpha=0.5, + ) + + p9.labs( + x="Day", + y="Hospital Admissions", + title=f"Observation Noise: {n_samples} Samples from Same Infections", + subtitle="Blue: one realization | Orange: other samples | Dark red: predicted", + ) + + p9.theme_grey() ) +``` +
-ax.axvline(spike_day, color="darkblue", linestyle="--", alpha=0.5) -ax.set_xlabel("Day") -ax.set_ylabel("Hospital Admissions") -ax.set_title(f"Observation Noise: {n_samples} Samples from Same Infections") -ax.legend() -plt.tight_layout() -plt.show() +```{python} +# | label: viz-realizations +plot_50_samples ``` ```{python} @@ -300,60 +436,72 @@ plt.show() # Print timeline statistics print("Timeline Analysis:") print( - f" Infection spike on day {spike_day}: {infections[infection_spike_day]:.0f} people" + f" Infection spike on day {rel_spike_day}: {infections[infection_spike_day]:.0f} people" ) print(f" Mode delay from infection to hospitalization: {mode_delay} days") print( - f" Expected hospitalization peak: day {spike_day + mode_delay} (= {spike_day} + {mode_delay})" + f" Predicted hospitalization peak: day {rel_spike_day + mode_delay} (= {rel_spike_day} + {mode_delay})" ) ``` ## Effect of the ascertainment rate -The ascertainment rate directly scales expected hospital admissions. We compare IHR values of 0.5% and 2.5%: +The ascertainment rate (here, the infection-hospitalization rate or IHR) directly scales the number of predicted hospital admissions. +We compare two contrasting IHR values: **0.5%** and **2.5%**. ```{python} # | label: compare-ihr +# Two contrasting IHR values ihr_values = [0.005, 0.025] -infections_decay = 3000 * jnp.exp(-jnp.arange(n_days) / 20.0) +peak_value = 3000 # Peak infections +infections_decay = peak_value * jnp.exp(-jnp.arange(n_days) / 20.0) -results = {} +# Compute predicted hospital admissions (no noise) for each IHR +results_list = [] for ihr_val in ihr_values: - process = Counts( - ascertainment_rate_rv=DeterministicVariable("ihr", ihr_val), - delay_distribution_rv=delay_rv, - noise=NegativeBinomialNoise(concentration_rv), + infections_scaled = infections_decay * ihr_val + predicted_hosp, _ = compute_delay_ascertained_incidence( + p_observed_given_incident=1.0, + latent_incidence=infections_scaled, + delay_incidence_to_observation_pmf=hosp_delay_pmf, + pad=True, ) - # Average over multiple samples to show the effect on the mean - samples = [] - for seed in range(20): - with numpyro.handlers.seed(rng_seed=seed): - result = process.sample(infections=infections_decay, obs=None) - samples.append(np.array(result.observed[day_one:])) - results[ihr_val] = np.mean(samples, axis=0) + + for i, admit in enumerate(predicted_hosp[day_one:]): + results_list.append( + { + "day": i, + "admissions": float(admit), + "IHR": f"IHR = {ihr_val:.1%}", + } + ) ``` + +
+Show code ```{python} -# | label: fig-ihr-comparison -# | fig-cap: Effect of different infection-hospitalization rates -fig, ax = plt.subplots(figsize=(8, 4)) -colors = ["steelblue", "darkred"] - -for (ihr_val, mean_sample), color in zip(results.items(), colors): - ax.plot( - np.arange(n_plot_days), - mean_sample, - color=color, - linewidth=1.5, - label=f"IHR = {ihr_val:.1%}", +# | label: plot-ihr-comparisons +results_df = pd.DataFrame(results_list) + +plot_ihr = ( + p9.ggplot(results_df, p9.aes(x="day", y="admissions", color="IHR")) + + p9.geom_line(size=1) + + p9.scale_color_manual(values=["steelblue", "darkred"]) + + p9.labs( + x="Day", + y="Predicted Hospital Admissions", + title="Effect of IHR on Predicted Hospital Admissions", + color="Infection-Hospitalization\nrate", ) + + p9.theme_grey() +) +``` +
-ax.set_xlabel("Day") -ax.set_ylabel("Hospital Admissions (mean of samples)") -ax.set_title("Effect of IHR on Hospital Admissions") -ax.legend() -plt.tight_layout() -plt.show() +```{python} +# | label: viz-ihr-comparisons +plot_ihr ``` ## Negative binomial concentration parameter @@ -363,40 +511,81 @@ The concentration parameter $\phi$ controls overdispersion: - Higher $\phi$ → less overdispersion (approaches Poisson) - Lower $\phi$ → more overdispersion (noisier data) +We compare three concentration values spanning two orders of magnitude: + +- **φ = 1**: high overdispersion (noisy) +- **φ = 10**: moderate overdispersion +- **φ = 100**: nearly Poisson (minimal noise) + ```{python} -# | label: fig-concentration-comparison -# | fig-cap: Effect of concentration parameter on variability -infections_constant = 2000 * jnp.ones(n_days) +# | label: concentration-comparisons +# Use constant infections +peak_value = 2000 +infections_constant = peak_value * jnp.ones(n_days) + +# Concentration values spanning two orders of magnitude concentration_values = [1.0, 10.0, 100.0] n_replicates = 10 -fig, axes = plt.subplots(1, 3, figsize=(12, 4), sharey=True) - -for ax, conc_val in zip(axes, concentration_values): - process = Counts( +# Collect results +conc_results = [] +for conc_val in concentration_values: + conc_rv_temp = DeterministicVariable("conc", conc_val) + process_temp = Counts( ascertainment_rate_rv=ihr_rv, delay_distribution_rv=delay_rv, - noise=NegativeBinomialNoise(DeterministicVariable("conc", conc_val)), + noise=NegativeBinomialNoise(conc_rv_temp), ) for seed in range(n_replicates): with numpyro.handlers.seed(rng_seed=seed): - result = process.sample(infections=infections_constant, obs=None) - ax.plot( - np.arange(n_plot_days), - np.array(result.observed[day_one:]), - color="steelblue", - alpha=0.5, - linewidth=0.8, - ) + hosp_temp = process_temp.sample( + infections=infections_constant, + obs=None, + ) + + # Use relative days + for i, admit in enumerate(hosp_temp.observed[day_one:]): + conc_results.append( + { + "day": i, + "admissions": float(admit), + "concentration": f"φ = {int(conc_val)}", + "replicate": seed, + } + ) +``` - ax.set_xlabel("Day") - ax.set_title(f"φ = {int(conc_val)}") +
+Show code +```{python} +# | label: plot-concentration-comparisons +conc_df = pd.DataFrame(conc_results) + +# Convert to ordered categorical +conc_df["concentration"] = pd.Categorical( + conc_df["concentration"], + categories=["φ = 1", "φ = 10", "φ = 100"], + ordered=True, +) -axes[0].set_ylabel("Hospital Admissions") -fig.suptitle("Effect of Concentration Parameter on Variability") -plt.tight_layout() -plt.show() +plot_concentration = ( + p9.ggplot(conc_df, p9.aes(x="day", y="admissions", group="replicate")) + + p9.geom_line(alpha=0.5, size=0.8, color="steelblue") + + p9.facet_wrap("~ concentration", ncol=3) + + p9.labs( + x="Day", + y="Hospital Admissions", + title="Effect of Concentration Parameter on Variability", + ) + + p9.theme_grey() +) +``` +
+ +```{python} +# | label: viz-concentration-comparisons +plot_concentration ``` ## Swapping noise models @@ -421,3 +610,54 @@ print( f"Sampled {len(poisson_result.observed)} days of hospital admissions with Poisson noise" ) ``` + +We can visualize the Poisson noise model using the same constant infection scenario as the concentration comparison above. Since Poisson assumes variance equals the mean, it produces less variability than the negative binomial with low concentration values. + +```{python} +# | label: poisson-realizations +# Sample multiple realizations with Poisson noise +n_replicates_poisson = 10 + +poisson_results = [] +for seed in range(n_replicates_poisson): + with numpyro.handlers.seed(rng_seed=seed): + poisson_temp = hosp_process_poisson.sample( + infections=infections_constant, + obs=None, + ) + + for i, admit in enumerate(poisson_temp.observed[day_one:]): + poisson_results.append( + { + "day": i, + "admissions": float(admit), + "replicate": seed, + } + ) +``` + +
+Show code +```{python} +# | label: plot-poisson-realizations +poisson_df = pd.DataFrame(poisson_results) + +plot_poisson = ( + p9.ggplot(poisson_df, p9.aes(x="day", y="admissions", group="replicate")) + + p9.geom_line(alpha=0.5, size=0.8, color="steelblue") + + p9.labs( + x="Day", + y="Hospital Admissions", + title="Poisson Noise Model (Variance = Mean)", + ) + + p9.theme_grey() + + p9.ylim(0, 105) +) +``` +
+ +To see the reduction in noise, it is ncessary to keep the y-axis on the same scale as in the previous plot. +```{python} +# | label: viz-poisson-realizations +plot_poisson +``` diff --git a/mkdocs.yml b/mkdocs.yml index a4da9382..1c2745eb 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -64,9 +64,9 @@ markdown_extensions: - callouts extra_javascript: - - javascripts/katex.js - https://unpkg.com/katex@0/dist/katex.min.js - https://unpkg.com/katex@0/dist/contrib/auto-render.min.js + - javascripts/katex.js extra_css: - https://unpkg.com/katex@0/dist/katex.min.css diff --git a/pyproject.toml b/pyproject.toml index 01d7800a..514e6d3d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ dev = [ "nbconvert>=7.16.6", "pytest>=8.4.2", "pytest-cov>=6.3.0", + "plotnine>=0.14.0", "pytest-mpl>=0.17.0", "scipy>=1.16.1", ] diff --git a/uv.lock b/uv.lock index f2477f56..6a0bfd17 100644 --- a/uv.lock +++ b/uv.lock @@ -949,6 +949,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7a/f0/8282d9641415e9e33df173516226b404d367a0fc55e1a60424a152913abc/mistune-3.1.4-py3-none-any.whl", hash = "sha256:93691da911e5d9d2e23bc54472892aff676df27a75274962ff9edc210364266d", size = 53481, upload-time = "2025-08-29T07:20:42.218Z" }, ] +[[package]] +name = "mizani" +version = "0.14.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, + { name = "pandas" }, + { name = "scipy" }, + { name = "tzdata", marker = "sys_platform == 'emscripten' or sys_platform == 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/6d/19/98f2bd61e5441b687e0a5d3b36041981cc032451f2d11472021b040d27fd/mizani-0.14.3.tar.gz", hash = "sha256:c2fb886b3c9e8109be5b8fd21e1130fba1f0a20230a987146240221209fc0ddd", size = 772470, upload-time = "2025-10-30T20:16:53.268Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5c/d2/4ffcaa27c8a4b4f9ad456da4821c76dfbdfada23e8210cd4d80e1eb3236a/mizani-0.14.3-py3-none-any.whl", hash = "sha256:6d2ca9b1b8366ff85668f0cc1b6095f1e702e26e66f132c4f02a949efa32a688", size = 133433, upload-time = "2025-10-30T20:16:51.218Z" }, +] + [[package]] name = "mkdocs" version = "1.6.1" @@ -1348,6 +1363,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/cc/20/ff623b09d963f88bfde16306a54e12ee5ea43e9b597108672ff3a408aad6/pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08", size = 31191, upload-time = "2023-12-10T22:30:43.14Z" }, ] +[[package]] +name = "patsy" +version = "1.0.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/be/44/ed13eccdd0519eff265f44b670d46fbb0ec813e2274932dc1c0e48520f7d/patsy-1.0.2.tar.gz", hash = "sha256:cdc995455f6233e90e22de72c37fcadb344e7586fb83f06696f54d92f8ce74c0", size = 399942, upload-time = "2025-10-20T16:17:37.535Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f1/70/ba4b949bdc0490ab78d545459acd7702b211dfccf7eb89bbc1060f52818d/patsy-1.0.2-py2.py3-none-any.whl", hash = "sha256:37bfddbc58fcf0362febb5f54f10743f8b21dd2aa73dec7e7ef59d1b02ae668a", size = 233301, upload-time = "2025-10-20T16:17:36.563Z" }, +] + [[package]] name = "pexpect" version = "4.9.0" @@ -1427,6 +1454,23 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/73/cb/ac7874b3e5d58441674fb70742e6c374b28b0c7cb988d37d991cde47166c/platformdirs-4.5.0-py3-none-any.whl", hash = "sha256:e578a81bb873cbb89a41fcc904c7ef523cc18284b7e3b3ccf06aca1403b7ebd3", size = 18651, upload-time = "2025-10-08T17:44:47.223Z" }, ] +[[package]] +name = "plotnine" +version = "0.15.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "matplotlib" }, + { name = "mizani" }, + { name = "numpy" }, + { name = "pandas" }, + { name = "scipy" }, + { name = "statsmodels" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/11/14/3adedabe6b8710caee34e4ac9f4edc48218a381594ee1980c323b8866577/plotnine-0.15.2.tar.gz", hash = "sha256:ec2e4cdf2d022eb0dab63ef4aa0017ce0d84c60bd99d55093e72637fddf757e6", size = 6787690, upload-time = "2025-12-12T10:41:37.249Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/58/27/4e6ffe2f095fbfd6285343aa6114903a4cf011564b4f1f2bb706341472df/plotnine-0.15.2-py3-none-any.whl", hash = "sha256:7dc508bc51625b9b9f945e274d8ee4463cf30b280749190a5b707e6828003fa6", size = 1332822, upload-time = "2025-12-12T10:41:34.203Z" }, +] + [[package]] name = "pluggy" version = "1.6.0" @@ -1582,6 +1626,7 @@ dev = [ { name = "mkdocstrings" }, { name = "mkdocstrings-python" }, { name = "nbconvert" }, + { name = "plotnine" }, { name = "pytest" }, { name = "pytest-cov" }, { name = "pytest-mpl" }, @@ -1609,6 +1654,7 @@ dev = [ { name = "mkdocstrings", specifier = ">=0.30.0" }, { name = "mkdocstrings-python", specifier = ">=1.18.2" }, { name = "nbconvert", specifier = ">=7.16.6" }, + { name = "plotnine", specifier = ">=0.14.0" }, { name = "pytest", specifier = ">=8.4.2" }, { name = "pytest-cov", specifier = ">=6.3.0" }, { name = "pytest-mpl", specifier = ">=0.17.0" }, @@ -1971,6 +2017,33 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f1/7b/ce1eafaf1a76852e2ec9b22edecf1daa58175c090266e9f6c64afcd81d91/stack_data-0.6.3-py3-none-any.whl", hash = "sha256:d5558e0c25a4cb0853cddad3d77da9891a08cb85dd9f9f91b9f8cd66e511e695", size = 24521, upload-time = "2023-09-30T13:58:03.53Z" }, ] +[[package]] +name = "statsmodels" +version = "0.14.6" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, + { name = "packaging" }, + { name = "pandas" }, + { name = "patsy" }, + { name = "scipy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0d/81/e8d74b34f85285f7335d30c5e3c2d7c0346997af9f3debf9a0a9a63de184/statsmodels-0.14.6.tar.gz", hash = "sha256:4d17873d3e607d398b85126cd4ed7aad89e4e9d89fc744cdab1af3189a996c2a", size = 20689085, upload-time = "2025-12-05T23:08:39.522Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/81/59/a5aad5b0cc266f5be013db8cde563ac5d2a025e7efc0c328d83b50c72992/statsmodels-0.14.6-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:47ee7af083623d2091954fa71c7549b8443168f41b7c5dce66510274c50fd73e", size = 10072009, upload-time = "2025-12-05T23:11:14.021Z" }, + { url = "https://files.pythonhosted.org/packages/53/dd/d8cfa7922fc6dc3c56fa6c59b348ea7de829a94cd73208c6f8202dd33f17/statsmodels-0.14.6-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:aa60d82e29fcd0a736e86feb63a11d2380322d77a9369a54be8b0965a3985f71", size = 9980018, upload-time = "2025-12-05T23:11:30.907Z" }, + { url = "https://files.pythonhosted.org/packages/ee/77/0ec96803eba444efd75dba32f2ef88765ae3e8f567d276805391ec2c98c6/statsmodels-0.14.6-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:89ee7d595f5939cc20bf946faedcb5137d975f03ae080f300ebb4398f16a5bd4", size = 10060269, upload-time = "2025-12-05T23:11:46.338Z" }, + { url = "https://files.pythonhosted.org/packages/10/b9/fd41f1f6af13a1a1212a06bb377b17762feaa6d656947bf666f76300fc05/statsmodels-0.14.6-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:730f3297b26749b216a06e4327fe0be59b8d05f7d594fb6caff4287b69654589", size = 10324155, upload-time = "2025-12-05T23:12:01.805Z" }, + { url = "https://files.pythonhosted.org/packages/ee/0f/a6900e220abd2c69cd0a07e3ad26c71984be6061415a60e0f17b152ecf08/statsmodels-0.14.6-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:f1c08befa85e93acc992b72a390ddb7bd876190f1360e61d10cf43833463bc9c", size = 10349765, upload-time = "2025-12-05T23:12:18.018Z" }, + { url = "https://files.pythonhosted.org/packages/98/08/b79f0c614f38e566eebbdcff90c0bcacf3c6ba7a5bbb12183c09c29ca400/statsmodels-0.14.6-cp313-cp313-win_amd64.whl", hash = "sha256:8021271a79f35b842c02a1794465a651a9d06ec2080f76ebc3b7adce77d08233", size = 9540043, upload-time = "2025-12-05T23:12:33.887Z" }, + { url = "https://files.pythonhosted.org/packages/71/de/09540e870318e0c7b58316561d417be45eff731263b4234fdd2eee3511a8/statsmodels-0.14.6-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:00781869991f8f02ad3610da6627fd26ebe262210287beb59761982a8fa88cae", size = 10069403, upload-time = "2025-12-05T23:12:48.424Z" }, + { url = "https://files.pythonhosted.org/packages/ab/f0/63c1bfda75dc53cee858006e1f46bd6d6f883853bea1b97949d0087766ca/statsmodels-0.14.6-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:73f305fbf31607b35ce919fae636ab8b80d175328ed38fdc6f354e813b86ee37", size = 9989253, upload-time = "2025-12-05T23:13:05.274Z" }, + { url = "https://files.pythonhosted.org/packages/c1/98/b0dfb4f542b2033a3341aa5f1bdd97024230a4ad3670c5b0839d54e3dcab/statsmodels-0.14.6-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e443e7077a6e2d3faeea72f5a92c9f12c63722686eb80bb40a0f04e4a7e267ad", size = 10090802, upload-time = "2025-12-05T23:13:20.653Z" }, + { url = "https://files.pythonhosted.org/packages/34/0e/2408735aca9e764643196212f9069912100151414dd617d39ffc72d77eee/statsmodels-0.14.6-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3414e40c073d725007a6603a18247ab7af3467e1af4a5e5a24e4c27bc26673b4", size = 10337587, upload-time = "2025-12-05T23:13:37.597Z" }, + { url = "https://files.pythonhosted.org/packages/0f/36/4d44f7035ab3c0b2b6a4c4ebb98dedf36246ccbc1b3e2f51ebcd7ac83abb/statsmodels-0.14.6-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:a518d3f9889ef920116f9fa56d0338069e110f823926356946dae83bc9e33e19", size = 10363350, upload-time = "2025-12-05T23:13:53.08Z" }, + { url = "https://files.pythonhosted.org/packages/26/33/f1652d0c59fa51de18492ee2345b65372550501ad061daa38f950be390b6/statsmodels-0.14.6-cp314-cp314-win_amd64.whl", hash = "sha256:151b73e29f01fe619dbce7f66d61a356e9d1fe5e906529b78807df9189c37721", size = 9588010, upload-time = "2025-12-05T23:14:07.28Z" }, +] + [[package]] name = "tinycss2" version = "1.4.0" From bd5546f608a8a308db128ff54e95e4171724ba55 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 8 Jan 2026 19:54:40 +0000 Subject: [PATCH 31/33] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/tutorials/observation_processes_counts.qmd | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/docs/tutorials/observation_processes_counts.qmd b/docs/tutorials/observation_processes_counts.qmd index 71055d96..59153e41 100644 --- a/docs/tutorials/observation_processes_counts.qmd +++ b/docs/tutorials/observation_processes_counts.qmd @@ -23,6 +23,7 @@ from pathlib import Path import sys import warnings from plotnine.exceptions import PlotnineWarning + warnings.filterwarnings("ignore", category=PlotnineWarning) import matplotlib.pyplot as plt @@ -92,7 +93,9 @@ days = np.arange(len(hosp_delay_pmf)) mean_delay = float(np.sum(days * hosp_delay_pmf)) mode_delay = int(np.argmax(hosp_delay_pmf)) sd = float(np.sqrt(np.sum(days**2 * hosp_delay_pmf) - mean_delay**2)) -print(f"mode delay: {mode_delay}, mean delay: {mean_delay:.1f}, sd: {sd:.1f}") +print( + f"mode delay: {mode_delay}, mean delay: {mean_delay:.1f}, sd: {sd:.1f}" +) ```
@@ -183,6 +186,7 @@ Hospital admissions depend on infections from prior days (the length of our dela # | label: helper-function print(f"Required lookback: {hosp_process.lookback_days()} days") + def first_valid_observation_day(obs_process) -> int: """Return the first day index with complete infection history for convolution.""" return obs_process.lookback_days() - 1 From f3a583dc5a58258362ddaee86b5e8a019fdeabf7 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Thu, 8 Jan 2026 21:16:33 -0500 Subject: [PATCH 32/33] update observation_processes tutorials to use plotnine --- .../observation_processes_counts.qmd | 68 +-- .../observation_processes_measurements.qmd | 395 +++++++++++------- 2 files changed, 266 insertions(+), 197 deletions(-) diff --git a/docs/tutorials/observation_processes_counts.qmd b/docs/tutorials/observation_processes_counts.qmd index 71055d96..6715279a 100644 --- a/docs/tutorials/observation_processes_counts.qmd +++ b/docs/tutorials/observation_processes_counts.qmd @@ -4,13 +4,17 @@ format: gfm: fig-width: 16 fig-height: 10 + html: + toc: true + embed-resources: true + self-contained-math: true + code-fold: true + code-tools: true engine: jupyter --- This tutorial demonstrates how to use the `Counts` observation process to model count data such as hospital admissions, emergency department visits, or deaths. -
-Show code ```{python} # | label: setup # | output: false @@ -23,6 +27,7 @@ from pathlib import Path import sys import warnings from plotnine.exceptions import PlotnineWarning + warnings.filterwarnings("ignore", category=PlotnineWarning) import matplotlib.pyplot as plt @@ -31,7 +36,6 @@ from pyrenew.observation import Counts, NegativeBinomialNoise, PoissonNoise from pyrenew.deterministic import DeterministicVariable, DeterministicPMF from pyrenew import datasets ``` -
## Overview @@ -92,11 +96,11 @@ days = np.arange(len(hosp_delay_pmf)) mean_delay = float(np.sum(days * hosp_delay_pmf)) mode_delay = int(np.argmax(hosp_delay_pmf)) sd = float(np.sqrt(np.sum(days**2 * hosp_delay_pmf) - mean_delay**2)) -print(f"mode delay: {mode_delay}, mean delay: {mean_delay:.1f}, sd: {sd:.1f}") +print( + f"mode delay: {mode_delay}, mean delay: {mean_delay:.1f}, sd: {sd:.1f}" +) ``` -
-Show code ```{python} # | label: plot-delay-distribution delay_df = pd.DataFrame( @@ -136,10 +140,6 @@ plot_delay = ( size=10, ) ) -``` -
-```{python} -# | label: viz-delay-distribution plot_delay ``` @@ -183,6 +183,7 @@ Hospital admissions depend on infections from prior days (the length of our dela # | label: helper-function print(f"Required lookback: {hosp_process.lookback_days()} days") + def first_valid_observation_day(obs_process) -> int: """Return the first day index with complete infection history for convolution.""" return obs_process.lookback_days() - 1 @@ -205,8 +206,6 @@ infections = infections.at[infection_spike_day].set(2000) ``` We plot the infections starting from day_one (the first valid observation day, after the lookback period). -
-Show code ```{python} # | label: plot-infections # Plot relative to first valid observation day @@ -244,10 +243,6 @@ plot_infections = ( size=10, ) ) -``` -
-```{python} -# | label: viz-infections plot_infections ``` @@ -272,8 +267,6 @@ predicted_admissions, offset = compute_delay_ascertained_incidence( ) ``` -
-Show code ```{python} # | label: plot-predicted-no-noise # Relative peak day for plotting @@ -329,10 +322,6 @@ plot_predicted = ( ha="center", ) ) -``` -
-```{python} -# | label: viz-predicted plot_predicted ``` @@ -375,8 +364,6 @@ for i, val in enumerate(predicted_admissions[day_one:]): ) ``` -
-Show code ```{python} # | label: plot-realizations samples_df = pd.DataFrame(samples_list) @@ -423,11 +410,6 @@ plot_50_samples = ( ) + p9.theme_grey() ) -``` -
- -```{python} -# | label: viz-realizations plot_50_samples ``` @@ -478,8 +460,6 @@ for ihr_val in ihr_values: ``` -
-Show code ```{python} # | label: plot-ihr-comparisons results_df = pd.DataFrame(results_list) @@ -496,11 +476,6 @@ plot_ihr = ( ) + p9.theme_grey() ) -``` -
- -```{python} -# | label: viz-ihr-comparisons plot_ihr ``` @@ -556,8 +531,6 @@ for conc_val in concentration_values: ) ``` -
-Show code ```{python} # | label: plot-concentration-comparisons conc_df = pd.DataFrame(conc_results) @@ -580,11 +553,6 @@ plot_concentration = ( ) + p9.theme_grey() ) -``` -
- -```{python} -# | label: viz-concentration-comparisons plot_concentration ``` @@ -613,6 +581,8 @@ print( We can visualize the Poisson noise model using the same constant infection scenario as the concentration comparison above. Since Poisson assumes variance equals the mean, it produces less variability than the negative binomial with low concentration values. +To see the reduction in noise, it is necessary to keep the y-axis on the same scale as in the previous plot. + ```{python} # | label: poisson-realizations # Sample multiple realizations with Poisson noise @@ -634,12 +604,6 @@ for seed in range(n_replicates_poisson): "replicate": seed, } ) -``` - -
-Show code -```{python} -# | label: plot-poisson-realizations poisson_df = pd.DataFrame(poisson_results) plot_poisson = ( @@ -653,11 +617,5 @@ plot_poisson = ( + p9.theme_grey() + p9.ylim(0, 105) ) -``` -
- -To see the reduction in noise, it is ncessary to keep the y-axis on the same scale as in the previous plot. -```{python} -# | label: viz-poisson-realizations plot_poisson ``` diff --git a/docs/tutorials/observation_processes_measurements.qmd b/docs/tutorials/observation_processes_measurements.qmd index ab5af73e..66c65aa4 100644 --- a/docs/tutorials/observation_processes_measurements.qmd +++ b/docs/tutorials/observation_processes_measurements.qmd @@ -1,6 +1,15 @@ --- title: "Observation processes for continuous measurements" -format: gfm +format: + gfm: + fig-width: 16 + fig-height: 10 + html: + toc: true + embed-resources: true + self-contained-math: true + code-fold: true + code-tools: true engine: jupyter --- @@ -14,6 +23,8 @@ import jax.numpy as jnp import numpy as np import numpyro import matplotlib.pyplot as plt +import pandas as pd +import plotnine as p9 import numpyro.distributions as dist @@ -22,9 +33,7 @@ from pyrenew.randomvariable import DistributionalVariable from pyrenew.deterministic import DeterministicVariable, DeterministicPMF ``` -# Part 1: The Measurements Framework - -## Overview +# The Measurements Framework The `Measurements` class models continuous signals derived from infections. Unlike count observations (hospital admissions, deaths), measurements are continuous values that may span orders of magnitude or even be negative (e.g., log-transformed data). @@ -35,7 +44,7 @@ The `Measurements` class models continuous signals derived from infections. Unli - Serological assay results - Environmental sensor readings -### The general pattern +## The general pattern All measurement observation processes follow the same pattern: @@ -53,7 +62,7 @@ The `Measurements` base class provides: - Integration with hierarchical noise models - Support for multiple sensors and subpopulations -### Comparison with count observations +## Comparison with count observations The core convolution structure is shared with count observations, but key aspects differ: @@ -141,15 +150,16 @@ class MyMeasurement(Measurements): return len(self.temporal_pmf_rv()) ``` ---- -# Part 2: Wastewater Example +# Wastewater Example -We now illustrate the framework with a wastewater viral concentration model, based on the [PyRenew-HEW](https://github.com/cdcgov/pyrenew-hew) family of models. +To illustrate the framework, we specify a wastewater viral concentration observation process, +based on the [PyRenew-HEW](https://github.com/cdcgov/pyrenew-hew) family of models. ## The wastewater signal -Wastewater treatment plants measure viral RNA concentrations in sewage. The predicted concentration depends on: +Wastewater treatment plants measure viral RNA concentrations in sewage. +The predicted concentration depends on: - **Infections**: People shed virus into wastewater - **Shedding kinetics**: Viral shedding peaks a few days after infection @@ -278,30 +288,45 @@ print(f"Mode: {mode_shedding_day} days, Mean: {mean_shedding_day:.1f} days") ``` ```{python} -# | label: fig-shedding -# | fig-cap: Viral shedding kinetics distribution -fig, ax = plt.subplots(figsize=(8, 4)) -ax.bar(days, shedding_pmf, color="steelblue", alpha=0.7, edgecolor="black") -ax.axvline( - mode_shedding_day, - color="purple", - linestyle="-", - linewidth=2, - label=f"Mode: {mode_shedding_day}", +# | label: plot-shedding +# Visualize the shedding distribution +shedding_df = pd.DataFrame( + {"days": days, "probability": np.array(shedding_pmf)} ) -ax.axvline( - mean_shedding_day, - color="red", - linestyle="--", - linewidth=2, - label=f"Mean: {mean_shedding_day:.1f}", + +( + p9.ggplot(shedding_df, p9.aes(x="days", y="probability")) + + p9.geom_col(fill="steelblue", alpha=0.7, color="black") + + p9.geom_vline( + xintercept=mode_shedding_day, color="purple", linetype="solid", size=1 + ) + + p9.geom_vline( + xintercept=mean_shedding_day, color="red", linetype="dashed", size=1 + ) + + p9.labs( + x="Days after infection", + y="Fraction of total shedding", + title="Viral Shedding Kinetics", + ) + + p9.theme_grey() + + p9.theme(plot_title=p9.element_text(size=14, weight="bold")) + + p9.annotate( + "text", + x=mode_shedding_day + 2, + y=max(shedding_df["probability"]) * 0.95, + label=f"Mode: {mode_shedding_day} days", + color="purple", + size=10, + ) + + p9.annotate( + "text", + x=mean_shedding_day + 2, + y=max(shedding_df["probability"]) * 0.8, + label=f"Mean: {mean_shedding_day:.1f} days", + color="red", + size=10, + ) ) -ax.set_xlabel("Days after infection") -ax.set_ylabel("Fraction of total shedding") -ax.set_title("Viral Shedding Kinetics") -ax.legend() -plt.tight_layout() -plt.show() ``` ### Genome copies and wastewater volume @@ -364,7 +389,7 @@ def first_valid_observation_day(obs_process) -> int: return obs_process.lookback_days() - 1 ``` -### Simulating from an infection spike +### Simulating from observations from a single-day infection spike To see how infections spread into concentrations via shedding kinetics, we simulate from a single-day spike: @@ -379,7 +404,7 @@ infections = jnp.zeros((n_days, 1)) # 1 subpopulation infections = infections.at[infection_spike_day, 0].set(2000.0) # For plotting -spike_day = infection_spike_day - day_one +rel_spike_day = infection_spike_day - day_one n_plot_days = n_days - day_one # Observation times and indices @@ -396,30 +421,43 @@ with numpyro.handlers.seed(rng_seed=42): n_sensors=1, ) ``` +We plot the resulting observations starting from the first valid observation day. ```{python} -# | label: fig-spike-infections -# | fig-cap: Input infections with a single-day spike -fig, ax = plt.subplots(figsize=(8, 4)) -ax.plot( - np.arange(n_plot_days), - np.array(infections[day_one:, 0]), - "-o", - color="darkblue", +# | label: plot-spike-infections +infections_df = pd.DataFrame( + { + "day": np.arange(n_plot_days), + "infections": np.array(infections[day_one:, 0]), + } ) -ax.axvline(spike_day, color="darkred", linestyle="--", alpha=0.7) -ax.annotate( - f"Infection spike\n(day {spike_day})", - xy=(spike_day, 1800), - xytext=(spike_day + 3, 1800), - fontsize=10, - color="darkred", + +max_infection_count = float(jnp.max(infections[day_one:])) + +plot_infections = ( + p9.ggplot(infections_df, p9.aes(x="day", y="infections")) + + p9.geom_line(color="darkblue", size=1) + + p9.geom_point(color="darkblue", size=2) + + p9.geom_vline( + xintercept=rel_spike_day, + color="darkred", + linetype="dashed", + alpha=0.5, + ) + + p9.labs(x="Day", y="Daily Infections", title="Infections (Input)") + + p9.theme_grey() + + p9.theme(plot_title=p9.element_text(size=13, weight="bold")) + + p9.annotate( + "text", + x=rel_spike_day, + y=max_infection_count * 1.05, + label=f"Infection spike\n(day {rel_spike_day})", + color="darkred", + size=10, + ha="center", + ) ) -ax.set_xlabel("Day") -ax.set_ylabel("Daily Infections") -ax.set_title("Infections (Input)") -plt.tight_layout() -plt.show() +plot_infections ``` ### Observation noise @@ -429,7 +467,7 @@ Sampling multiple times from the same infections shows the range of possible obs ```{python} # | label: sample-realizations n_samples = 50 -samples_by_day = {int(d - day_one): [] for d in observation_days} +ww_samples_list = [] for seed in range(n_samples): with numpyro.handlers.seed(rng_seed=seed): @@ -442,109 +480,174 @@ for seed in range(n_samples): n_sensors=1, ) for day_idx, conc in zip(observation_days, ww_result.observed): - samples_by_day[int(day_idx) - day_one].append(float(conc)) + ww_samples_list.append( + { + "day": int(day_idx) - day_one, + "log_concentration": float(conc), + "sample": seed, + } + ) -# Convert to arrays for plotting -plot_days = sorted(samples_by_day.keys()) -all_samples = np.array( - [[samples_by_day[d][i] for d in plot_days] for i in range(n_samples)] -) -sample_mean = all_samples.mean(axis=0) +ww_samples_df = pd.DataFrame(ww_samples_list) ``` ```{python} -# | label: fig-sampled-concentrations -# | fig-cap: Multiple realizations showing observation noise -fig, ax = plt.subplots(figsize=(8, 4)) - -# Plot all samples -for i in range(1, n_samples): - ax.plot( - plot_days, all_samples[i], color="orange", alpha=0.15, linewidth=0.5 - ) - -# Highlight one sample -ax.plot( - plot_days, - all_samples[0], - color="steelblue", - linewidth=1, - label="One realization", +# | label: plot-sampled-concentrations +# Compute mean across samples for each day +mean_by_day = ( + ww_samples_df.groupby("day")["log_concentration"].mean().reset_index() ) - -# Sample mean -ax.plot( - plot_days, sample_mean, color="darkred", linewidth=1.2, label="Sample mean" +mean_by_day["sample"] = -1 + +# Relative peak day for plotting (using mode, not mean, since distribution is skewed) +peak_day = rel_spike_day + mode_shedding_day + +# Separate one sample to highlight +highlight_sample = 0 +other_samples_df = ww_samples_df[ww_samples_df["sample"] != highlight_sample] +highlight_df = ww_samples_df[ww_samples_df["sample"] == highlight_sample] + +# For annotation positioning +max_conc = ww_samples_df["log_concentration"].max() + +( + p9.ggplot() + + p9.geom_line( + p9.aes(x="day", y="log_concentration", group="sample"), + data=other_samples_df, + color="orange", + alpha=0.15, + size=0.5, + ) + + p9.geom_line( + p9.aes(x="day", y="log_concentration"), + data=highlight_df, + color="steelblue", + size=1, + ) + + p9.geom_line( + p9.aes(x="day", y="log_concentration"), + data=mean_by_day, + color="darkred", + size=1.2, + ) + + p9.geom_vline( + xintercept=rel_spike_day, + color="darkblue", + linetype="dashed", + alpha=0.5, + ) + + p9.geom_vline( + xintercept=peak_day, + color="darkred", + linetype="dotted", + alpha=0.7, + ) + + p9.annotate( + "text", + x=rel_spike_day, + y=max_conc * 1.05, + label=f"Infection spike\n(day {rel_spike_day})", + color="darkblue", + size=9, + ha="center", + ) + + p9.annotate( + "text", + x=peak_day, + y=max_conc * 0.98, + label=f"Expected peak\n(day {peak_day})", + color="darkred", + size=9, + ha="center", + ) + + p9.labs( + x="Day", + y="Log Viral Concentration", + title=f"Observation Noise: {n_samples} Samples from Same Infections", + subtitle="Blue: one realization | Orange: other samples | Dark red: sample mean", + ) + + p9.theme_grey() ) - -ax.axvline(spike_day, color="darkblue", linestyle="--", alpha=0.5) -ax.set_xlabel("Day") -ax.set_ylabel("Log Viral Concentration") -ax.set_title(f"Observation Noise: {n_samples} Samples from Same Infections") -ax.legend() -plt.tight_layout() -plt.show() ``` ### Sensor-level variability -Different sensors (WWTP/lab pairs) measuring the same underlying infections show systematic differences: +The previous plot showed variability from repeatedly sampling the entire observation process (resampling sensor parameters and noise each time). In practice, we have multiple physical sensors, each with fixed but unknown characteristics. + +This plot shows four sensors observing the **same infection spike**. Each sensor has: + +- A **sensor mode** (systematic bias): shifts all observations up or down +- A **sensor SD** (measurement precision): determines noise level around predictions + +These parameters are sampled once per sensor, then held fixed across all observations from that sensor. ```{python} # | label: multi-sensor num_sensors = 4 -infections_constant = jnp.ones((40, 1)) * 1500.0 -# Each sensor samples at multiple time points -times_per_sensor = 10 -observation_times = jnp.tile( - jnp.arange(10, 30, 2, dtype=jnp.int32), num_sensors -) +# Use the same observation times and infections as the sampled-concentrations plot +sensor_obs_times = jnp.tile(observation_days, num_sensors) sensor_ids = jnp.repeat( - jnp.arange(num_sensors, dtype=jnp.int32), times_per_sensor + jnp.arange(num_sensors, dtype=jnp.int32), len(observation_days) ) -subpop_ids = jnp.zeros(num_sensors * times_per_sensor, dtype=jnp.int32) +subpop_ids = jnp.zeros(num_sensors * len(observation_days), dtype=jnp.int32) with numpyro.handlers.seed(rng_seed=42): ww_multi_sensor = ww_process.sample( - infections=infections_constant, + infections=infections, # Same spike as before subpop_indices=subpop_ids, sensor_indices=sensor_ids, - times=observation_times, + times=sensor_obs_times, obs=None, n_sensors=num_sensors, ) + +# Create DataFrame for plotting (using relative days) +multi_sensor_df = pd.DataFrame( + { + "day": np.array(sensor_obs_times) - day_one, + "log_concentration": np.array(ww_multi_sensor.observed), + "sensor": [f"Sensor {i}" for i in np.array(sensor_ids)], + } +) ``` ```{python} -# | label: fig-multi-sensor -# | fig-cap: Multiple sensors observing the same infections -fig, ax = plt.subplots(figsize=(8, 4)) - -colors = ["steelblue", "coral", "green", "purple"] -for sensor in range(num_sensors): - mask = np.array(sensor_ids) == sensor - ax.plot( - np.array(observation_times)[mask], - np.array(ww_multi_sensor.observed)[mask], - "-o", - color=colors[sensor], - label=f"Sensor {sensor}", +# | label: plot-multi-sensor +# Use same y-axis range as sampled-concentrations plot for comparison +y_min = ww_samples_df["log_concentration"].min() +y_max = ww_samples_df["log_concentration"].max() + +( + p9.ggplot( + multi_sensor_df, p9.aes(x="day", y="log_concentration", color="sensor") ) - -ax.set_xlabel("Day") -ax.set_ylabel("Log Viral Concentration") -ax.set_title("Multiple Sensors Observing Same Infections") -ax.legend() -plt.tight_layout() -plt.show() + + p9.geom_line(size=1) + + p9.geom_point(size=2) + + p9.ylim(y_min, y_max * 1.05) + + p9.labs( + x="Day", + y="Log Viral Concentration", + title="Four Sensors Observing the Same Infection Spike", + color="Sensor", + ) + + p9.theme_grey() +) ``` -Each sensor has a different baseline (sensor mode) and measurement variability (sensor SD). During inference, these sensor-specific effects are learned from data. +Compare this to the previous plot: here, each colored line represents a distinct physical sensor with its own systematic bias. The vertical spread between sensors reflects differences in sensor modes, while the noise within each line reflects each sensor's measurement precision. During inference, these sensor-specific effects are learned from data. ### Multiple subpopulations -In hierarchical models, each sensor may serve a distinct catchment area (subpopulation) with different infection dynamics: +In regional surveillance, each wastewater treatment plant serves a distinct **catchment area** (subpopulation) with its own infection dynamics. The `subpop_indices` array maps each observation to the appropriate infection trajectory. + +This example shows two subpopulations with different epidemic curves: + +- **Subpopulation 0**: Slow decay (e.g., large urban area with sustained transmission) +- **Subpopulation 1**: Fast decay (e.g., smaller community with rapid burnout) + +Each subpopulation is observed by its own sensor. The observed concentrations reflect both the underlying infection differences AND the sensor-specific measurement characteristics. ```{python} # | label: multi-subpop @@ -572,30 +675,38 @@ with numpyro.handlers.seed(rng_seed=42): obs=None, n_sensors=2, ) + +# Create DataFrame for plotting +multi_subpop_df = pd.DataFrame( + { + "day": np.array(obs_days_mp), + "log_concentration": np.array(ww_multi_subpop.observed), + "subpopulation": [f"Subpop {i}" for i in np.array(subpop_ids_mp)], + } +) ``` ```{python} -# | label: fig-multi-subpop -# | fig-cap: Sensors observing different subpopulations -fig, ax = plt.subplots(figsize=(8, 4)) - -for subpop in [0, 1]: - mask = np.array(subpop_ids_mp) == subpop - ax.plot( - np.array(obs_days_mp)[mask], - np.array(ww_multi_subpop.observed)[mask], - "-o", - label=f"Subpop {subpop}", +# | label: plot-multi-subpop +( + p9.ggplot( + multi_subpop_df, + p9.aes(x="day", y="log_concentration", color="subpopulation"), ) - -ax.set_xlabel("Day") -ax.set_ylabel("Log Viral Concentration") -ax.set_title("Multiple Subpopulations with Different Infection Dynamics") -ax.legend() -plt.tight_layout() -plt.show() + + p9.geom_line(size=1) + + p9.geom_point(size=2) + + p9.labs( + x="Day", + y="Log Viral Concentration", + title="Two Subpopulations with Different Infection Dynamics", + color="Subpopulation", + ) + + p9.theme_grey() +) ``` +The diverging trajectories reflect the different underlying infection curves. Subpopulation 1 starts higher but decays faster, while Subpopulation 0 maintains more sustained levels. In a full model, you would jointly infer the infection trajectories for each subpopulation while accounting for sensor-specific biases. + --- ## Summary From 9306edcd6bb23be8b8e04cee622fab634ac67de1 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Thu, 8 Jan 2026 21:40:08 -0500 Subject: [PATCH 33/33] minor tutorials edits --- .../observation_processes_measurements.qmd | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/docs/tutorials/observation_processes_measurements.qmd b/docs/tutorials/observation_processes_measurements.qmd index 66c65aa4..ddf89a8d 100644 --- a/docs/tutorials/observation_processes_measurements.qmd +++ b/docs/tutorials/observation_processes_measurements.qmd @@ -33,7 +33,7 @@ from pyrenew.randomvariable import DistributionalVariable from pyrenew.deterministic import DeterministicVariable, DeterministicPMF ``` -# The Measurements Framework +## The Measurements Class The `Measurements` class models continuous signals derived from infections. Unlike count observations (hospital admissions, deaths), measurements are continuous values that may span orders of magnitude or even be negative (e.g., log-transformed data). @@ -44,7 +44,7 @@ The `Measurements` class models continuous signals derived from infections. Unli - Serological assay results - Environmental sensor readings -## The general pattern +### The general pattern All measurement observation processes follow the same pattern: @@ -62,7 +62,7 @@ The `Measurements` base class provides: - Integration with hierarchical noise models - Support for multiple sensors and subpopulations -## Comparison with count observations +### Comparison with count observations The core convolution structure is shared with count observations, but key aspects differ: @@ -74,7 +74,7 @@ The core convolution structure is shared with count observations, but key aspect | Scaling | Ascertainment rate $\alpha \in [0,1]$ | Domain-specific | | Subpop structure | Optional (`CountsBySubpop`) | Inherent (sensor/site effects) | -## The noise model +### The noise model Measurement data typically exhibits **sensor-level variability**: different instruments, labs, or sampling locations have systematic biases and different precision levels. @@ -106,7 +106,7 @@ noise = HierarchicalNormalNoise( ) ``` -## The indexing system +### The indexing system Measurement observations use three index arrays to map observations to their context: @@ -122,7 +122,7 @@ This flexible indexing supports: - **Multiple sensors per subpopulation**: Different labs analyzing the same source - **Multiple subpopulations per sensor**: One sensor serving multiple areas (less common) -## Subclassing Measurements +### Subclassing Measurements To create a measurement process for your domain, subclass `Measurements` and implement: @@ -151,12 +151,12 @@ class MyMeasurement(Measurements): ``` -# Wastewater Example +## Measurement Example: Wastewater To illustrate the framework, we specify a wastewater viral concentration observation process, based on the [PyRenew-HEW](https://github.com/cdcgov/pyrenew-hew) family of models. -## The wastewater signal +**The wastewater signal** Wastewater treatment plants measure viral RNA concentrations in sewage. The predicted concentration depends on: @@ -180,7 +180,7 @@ Observations are log-concentrations with normal noise: $$y_t \sim \text{Normal}(\log(\lambda_t) + \text{sensor\_mode}, \text{sensor\_sd})$$ -## Implementing the Wastewater class +### Implementing the Wastewater class ```{python} # | label: wastewater-class @@ -264,9 +264,9 @@ class Wastewater(Measurements): return jnp.log(concentration) ``` -## Configuring wastewater-specific parameters +#### Configuring wastewater-specific parameters -### Viral shedding kinetics +**Viral shedding kinetics** The shedding PMF describes what fraction of total viral shedding occurs on each day after infection: @@ -329,7 +329,7 @@ shedding_df = pd.DataFrame( ) ``` -### Genome copies and wastewater volume +**Genome copies and wastewater volume** ```{python} # | label: scaling-params @@ -362,7 +362,7 @@ ww_noise = HierarchicalNormalNoise( ) ``` -## Creating the wastewater observation process +### Creating the wastewater observation process ```{python} # | label: create-process @@ -376,7 +376,7 @@ ww_process = Wastewater( print(f"Required lookback: {ww_process.lookback_days()} days") ``` -## Demonstrations +## Simulations ### Timeline alignment @@ -711,7 +711,7 @@ The diverging trajectories reflect the different underlying infection curves. Su ## Summary -The `Measurements` framework provides: +The `Measurements` class provides: 1. **A consistent interface** for continuous observation processes 2. **Hierarchical noise models** that capture sensor-level variability