diff --git a/pyproject.toml b/pyproject.toml index ff9cd29..3cfbe5c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ build-backend = 'setuptools.build_meta' [project] name = "ratapi" -version = "0.0.0.dev9" +version = "0.0.0.dev10" description = "Python extension for the Reflectivity Analysis Toolbox (RAT)" readme = "README.md" requires-python = ">=3.10" diff --git a/ratapi/outputs.py b/ratapi/outputs.py index 2add165..c1f5c0e 100644 --- a/ratapi/outputs.py +++ b/ratapi/outputs.py @@ -538,6 +538,19 @@ class BayesResults(Results): nestedSamplerOutput: NestedSamplerOutput chain: np.ndarray + def from_procedure(self) -> Procedures: + """Return the procedure that created the result. + + Returns + ------- + procedure: Procedures + The procedure that created the result. + """ + samples = self.nestedSamplerOutput.nestSamples + if samples.shape == (1, 2) and not np.any(samples): + return Procedures.DREAM + return Procedures.NS + def save(self, filepath: str | Path = "./results.json"): """Save the BayesResults object to a JSON file. diff --git a/tests/conftest.py b/tests/conftest.py index 854dc9b..b717ec4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6546,6 +6546,18 @@ def dream_results(): ) +@pytest.fixture +def nested_sampler_results(dream_results): + results = dream_results + results.nestedSamplerOutput = ratapi.outputs.NestedSamplerOutput( + logZ=-28.99992503667041, + logZErr=0.3391187711291207, + nestSamples=np.ones((100, 9)), + postSamples=np.ones((100, 10)), + ) + return results + + @pytest.fixture def r1_default_project(): """The Project corresponding to the data in R1defaultProject.mat.""" diff --git a/tests/test_outputs.py b/tests/test_outputs.py index 9cd1ad9..7d2afff 100644 --- a/tests/test_outputs.py +++ b/tests/test_outputs.py @@ -191,6 +191,19 @@ def test_make_results(test_procedure, test_output_results, test_bayes, test_resu check_results_equal(test_results, results) +@pytest.mark.parametrize( + ["test_procedure", "test_results"], + [ + (Procedures.NS, "nested_sampler_results"), + (Procedures.DREAM, "dream_results"), + ], +) +def test_results_procedure(test_procedure, test_results, request) -> None: + """Test that bayes results object return correct procedure.""" + test_output_results = request.getfixturevalue(test_results) + assert test_output_results.from_procedure() == test_procedure + + @pytest.mark.parametrize( ["test_output_results", "test_str"], [