diff --git a/pymc_extras/model/marginal/marginal_model.py b/pymc_extras/model/marginal/marginal_model.py index b6ca25bf1..1ac7f1741 100644 --- a/pymc_extras/model/marginal/marginal_model.py +++ b/pymc_extras/model/marginal/marginal_model.py @@ -11,7 +11,7 @@ from pymc.distributions.discrete import Bernoulli, Categorical, DiscreteUniform from pymc.distributions.transforms import Chain from pymc.logprob.transforms import IntervalTransform -from pymc.model import Model +from pymc.model import Model, modelcontext from pymc.model.fgraph import ( ModelFreeRV, ModelValuedVar, @@ -337,8 +337,9 @@ def transform_posterior_pts(model, posterior_pts): def recover_marginals( - model: Model, idata: InferenceData, + *, + model: Model | None = None, var_names: Sequence[str] | None = None, return_samples: bool = True, extend_inferencedata: bool = True, @@ -389,6 +390,15 @@ def recover_marginals( """ + # Temporary error message for helping with migration + # Will be removed in a future release + if isinstance(idata, Model): + raise TypeError( + "The order of arguments of `recover_marginals` changed. The first input must be an idata" + ) + + model = modelcontext(model) + unmarginal_model = unmarginalize(model) # Find the names of the marginalized variables diff --git a/pyproject.toml b/pyproject.toml index 787902f94..8503796a6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -105,10 +105,13 @@ filterwarnings =[ 'ignore:os\.fork\(\) was called\.:RuntimeWarning', # Preliz needs to update for pytensor > 2.35 - 'ignore:.*`pytensor\.graph\.basic\.ancestors`.*`pytensor\.graph\.traversal\.ancestors`.*:FutureWarning:^preliz(\.|$)', + 'ignore:.*`pytensor\.graph\.basic\.ancestors`.*`pytensor\.graph\.traversal\.ancestors`.*:FutureWarning:^preliz(\.|$)', # OpenMP library warning on windows CI - 'ignore::RuntimeWarning:threadpoolctl' + 'ignore::RuntimeWarning:threadpoolctl', + + # ArviZ warning related to 1.0 release + 'ignore:\nArviZ is undergoing a major refactor.*:FutureWarning' ] [tool.coverage.report] diff --git a/tests/model/marginal/test_marginal_model.py b/tests/model/marginal/test_marginal_model.py index d9e50569a..b3cbaef93 100644 --- a/tests/model/marginal/test_marginal_model.py +++ b/tests/model/marginal/test_marginal_model.py @@ -816,7 +816,8 @@ def test_unmarginalize(): class TestRecoverMarginals: - def test_basic(self): + @pytest.mark.parametrize("explicit_model", (True, False)) + def test_basic(self, explicit_model): with Model() as m: sigma = pm.HalfNormal("sigma") p = np.array([0.5, 0.2, 0.3]) @@ -837,7 +838,12 @@ def test_basic(self): ) idata = InferenceData(posterior=dict_to_dataset(prior)) - idata = recover_marginals(marginal_m, idata, return_samples=True) + if explicit_model: + idata = recover_marginals(idata, model=marginal_m, return_samples=True) + else: + with marginal_m: + idata = recover_marginals(idata, return_samples=True) + post = idata.posterior assert "k" in post assert "lp_k" in post @@ -881,7 +887,8 @@ def test_coords(self): posterior=dict_to_dataset({k: np.expand_dims(prior[k], axis=0) for k in prior}) ) - idata = recover_marginals(marginal_m, idata, return_samples=True) + with marginal_m: + idata = recover_marginals(idata, return_samples=True) post = idata.posterior assert post.idx.dims == ("chain", "draw", "year") assert post.lp_idx.dims == ("chain", "draw", "year", "lp_idx_dim") @@ -907,7 +914,7 @@ def test_batched(self): posterior=dict_to_dataset({k: np.expand_dims(prior[k], axis=0) for k in prior}) ) - idata = recover_marginals(marginal_m, idata, return_samples=True) + idata = recover_marginals(idata, return_samples=True) post = idata.posterior assert post["y"].shape == (1, 20, 2, 3) assert post["idx"].shape == (1, 20, 3, 2) @@ -933,7 +940,7 @@ def test_nested(self): ) idata = InferenceData(posterior=dict_to_dataset(prior)) - idata = recover_marginals(marginal_m, idata, return_samples=True) + idata = recover_marginals(idata, return_samples=True) post = idata.posterior assert "idx" in post assert "lp_idx" in post diff --git a/tests/test_blackjax_smc.py b/tests/test_blackjax_smc.py index 06245d0ba..cd002f79a 100644 --- a/tests/test_blackjax_smc.py +++ b/tests/test_blackjax_smc.py @@ -75,6 +75,9 @@ def fast_model(): return m +@pytest.mark.xfail( + reason="TODO: fix AttributeError: 'TemperedSMCState' object has no attribute 'lmbda'" +) @pytest.mark.parametrize( "kernel, check_for_integration_steps, inner_kernel_params", [