From df29affbad95004f57780f43da5955e65e11a4f2 Mon Sep 17 00:00:00 2001 From: Stephen Nneji Date: Thu, 20 Nov 2025 12:38:02 +0000 Subject: [PATCH] Adds method to get bayes result procedure and bumps version to dev10 --- pyproject.toml | 2 +- ratapi/outputs.py | 13 +++++++++++++ tests/conftest.py | 12 ++++++++++++ tests/test_outputs.py | 13 +++++++++++++ 4 files changed, 39 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index ff9cd29b..3cfbe5c9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ build-backend = 'setuptools.build_meta' [project] name = "ratapi" -version = "0.0.0.dev9" +version = "0.0.0.dev10" description = "Python extension for the Reflectivity Analysis Toolbox (RAT)" readme = "README.md" requires-python = ">=3.10" diff --git a/ratapi/outputs.py b/ratapi/outputs.py index 2add1653..c1f5c0ed 100644 --- a/ratapi/outputs.py +++ b/ratapi/outputs.py @@ -538,6 +538,19 @@ class BayesResults(Results): nestedSamplerOutput: NestedSamplerOutput chain: np.ndarray + def from_procedure(self) -> Procedures: + """Return the procedure that created the result. + + Returns + ------- + procedure: Procedures + The procedure that created the result. + """ + samples = self.nestedSamplerOutput.nestSamples + if samples.shape == (1, 2) and not np.any(samples): + return Procedures.DREAM + return Procedures.NS + def save(self, filepath: str | Path = "./results.json"): """Save the BayesResults object to a JSON file. diff --git a/tests/conftest.py b/tests/conftest.py index 854dc9b8..b717ec4e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6546,6 +6546,18 @@ def dream_results(): ) +@pytest.fixture +def nested_sampler_results(dream_results): + results = dream_results + results.nestedSamplerOutput = ratapi.outputs.NestedSamplerOutput( + logZ=-28.99992503667041, + logZErr=0.3391187711291207, + nestSamples=np.ones((100, 9)), + postSamples=np.ones((100, 10)), + ) + return results + + @pytest.fixture def r1_default_project(): """The Project corresponding to the data in R1defaultProject.mat.""" diff --git a/tests/test_outputs.py b/tests/test_outputs.py index 9cd1ad90..7d2afff9 100644 --- a/tests/test_outputs.py +++ b/tests/test_outputs.py @@ -191,6 +191,19 @@ def test_make_results(test_procedure, test_output_results, test_bayes, test_resu check_results_equal(test_results, results) +@pytest.mark.parametrize( + ["test_procedure", "test_results"], + [ + (Procedures.NS, "nested_sampler_results"), + (Procedures.DREAM, "dream_results"), + ], +) +def test_results_procedure(test_procedure, test_results, request) -> None: + """Test that bayes results object return correct procedure.""" + test_output_results = request.getfixturevalue(test_results) + assert test_output_results.from_procedure() == test_procedure + + @pytest.mark.parametrize( ["test_output_results", "test_str"], [