Skip to content

Commit b0121f3

Browse files
authored
Adds method to get bayes result procedure and bumps version to dev10 (#192)
1 parent 0ef08e4 commit b0121f3

File tree

4 files changed

+39
-1
lines changed

4 files changed

+39
-1
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ build-backend = 'setuptools.build_meta'
88

99
[project]
1010
name = "ratapi"
11-
version = "0.0.0.dev9"
11+
version = "0.0.0.dev10"
1212
description = "Python extension for the Reflectivity Analysis Toolbox (RAT)"
1313
readme = "README.md"
1414
requires-python = ">=3.10"

ratapi/outputs.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,19 @@ class BayesResults(Results):
538538
nestedSamplerOutput: NestedSamplerOutput
539539
chain: np.ndarray
540540

541+
def from_procedure(self) -> Procedures:
542+
"""Return the procedure that created the result.
543+
544+
Returns
545+
-------
546+
procedure: Procedures
547+
The procedure that created the result.
548+
"""
549+
samples = self.nestedSamplerOutput.nestSamples
550+
if samples.shape == (1, 2) and not np.any(samples):
551+
return Procedures.DREAM
552+
return Procedures.NS
553+
541554
def save(self, filepath: str | Path = "./results.json"):
542555
"""Save the BayesResults object to a JSON file.
543556

tests/conftest.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6546,6 +6546,18 @@ def dream_results():
65466546
)
65476547

65486548

6549+
@pytest.fixture
6550+
def nested_sampler_results(dream_results):
6551+
results = dream_results
6552+
results.nestedSamplerOutput = ratapi.outputs.NestedSamplerOutput(
6553+
logZ=-28.99992503667041,
6554+
logZErr=0.3391187711291207,
6555+
nestSamples=np.ones((100, 9)),
6556+
postSamples=np.ones((100, 10)),
6557+
)
6558+
return results
6559+
6560+
65496561
@pytest.fixture
65506562
def r1_default_project():
65516563
"""The Project corresponding to the data in R1defaultProject.mat."""

tests/test_outputs.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,19 @@ def test_make_results(test_procedure, test_output_results, test_bayes, test_resu
191191
check_results_equal(test_results, results)
192192

193193

194+
@pytest.mark.parametrize(
195+
["test_procedure", "test_results"],
196+
[
197+
(Procedures.NS, "nested_sampler_results"),
198+
(Procedures.DREAM, "dream_results"),
199+
],
200+
)
201+
def test_results_procedure(test_procedure, test_results, request) -> None:
202+
"""Test that bayes results object return correct procedure."""
203+
test_output_results = request.getfixturevalue(test_results)
204+
assert test_output_results.from_procedure() == test_procedure
205+
206+
194207
@pytest.mark.parametrize(
195208
["test_output_results", "test_str"],
196209
[

0 commit comments

Comments
 (0)