Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions RATapi/utils/plotting.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Plots using the matplotlib library"""

import copy
from functools import partial, wraps
from math import ceil, floor, sqrt
from statistics import stdev
Expand Down Expand Up @@ -119,7 +120,7 @@ def plot_ref_sld_helper(
if confidence_intervals is not None:
ref_min, ref_max = confidence_intervals["reflectivity"][i]
mult = (1 if not q4 else r[:, 0] ** 4) / div
ref_plot.fill_between(r[:, 0], ref_min / div, ref_max / div, alpha=0.6, color="grey")
ref_plot.fill_between(r[:, 0], ref_min * mult, ref_max * mult, alpha=0.6, color="grey")

if data.dataPresent[i]:
sd_x = sd[:, 0]
Expand Down Expand Up @@ -241,10 +242,13 @@ def plot_ref_sld(
"""
data = PlotEventData()

# We need to take a copy of reflectivity and SLD in case we are plotting a
# shaded plot and will therefore change the plotted data to that from the
# centre of the Bayesian distribution
data.modelType = project.model
data.reflectivity = results.reflectivity
data.reflectivity = copy.deepcopy(results.reflectivity)
data.shiftedData = results.shiftedData
data.sldProfiles = results.sldProfiles
data.sldProfiles = copy.deepcopy(results.sldProfiles)
data.resampledLayers = results.resampledLayers
data.dataPresent = RATapi.inputs.make_data_present(project)
data.subRoughs = results.contrastParams.subRoughs
Expand Down Expand Up @@ -275,6 +279,12 @@ def plot_ref_sld(
for sld in results.predictionIntervals.sld
],
}
# For a shaded plot, use the mean values from predictionIntervals
for reflectivity, mean_reflectivity in zip(data.reflectivity, results.predictionIntervals.reflectivity):
reflectivity[:, 1] = mean_reflectivity[2]
for sldProfile, mean_sld_profile in zip(data.sldProfiles, results.predictionIntervals.sld):
for sld, mean_sld in zip(sldProfile, mean_sld_profile):
sld[:, 1] = mean_sld[2]
else:
raise ValueError(
"Shaded confidence intervals are only available for the results of Bayesian analysis (NS or DREAM)"
Expand Down
14 changes: 13 additions & 1 deletion cpp/rat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -995,6 +995,18 @@ py::list pyList1DFromRatCellWrap2D(const T& values)
return result;
}

template <typename T>
py::list pyList1DFromRatCellWrap1D(const T& values)
{
py::list result;

for (int32_T idx0{0}; idx0 < values.size(0); idx0++) {
result.append(pyArrayFromRatArray2d(values[idx0].f1));
}

return result;
}

template <typename T>
py::list pyList2dFromRatCellWrap(const T& values)
{
Expand Down Expand Up @@ -1234,7 +1246,7 @@ BayesResults bayesResultsFromStruct9T(const RAT::struct9_T results)

bayesResults.chain = pyArrayFromRatArray2d(results.chain);

bayesResults.predictionIntervals.reflectivity = pyList1DFromRatCellWrap2D<coder::array<RAT::cell_wrap_11, 1U>>(results.predictionIntervals.reflectivity);
bayesResults.predictionIntervals.reflectivity = pyList1DFromRatCellWrap1D<coder::array<RAT::cell_wrap_11, 1U>>(results.predictionIntervals.reflectivity);
bayesResults.predictionIntervals.sld = pyList2dFromRatCellWrap<coder::array<RAT::cell_wrap_11, 2U>>(results.predictionIntervals.sld);
bayesResults.predictionIntervals.sampleChi = pyArray1dFromBoundedArray<coder::bounded_array<real_T, 1000U, 1U>>(results.predictionIntervals.sampleChi);

Expand Down
20 changes: 0 additions & 20 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1827,7 +1827,6 @@ def dream_bayes():
9.14706158e-07,
9.14597081e-07,
9.14576711e-07,
9.14569865e-07,
],
[
-2.26996619e-06,
Expand Down Expand Up @@ -1859,7 +1858,6 @@ def dream_bayes():
9.14706158e-07,
9.14597081e-07,
9.14576711e-07,
9.14569865e-07,
],
[
-1.50598689e-07,
Expand Down Expand Up @@ -1891,7 +1889,6 @@ def dream_bayes():
3.33756709e-06,
3.33751126e-06,
3.33750083e-06,
3.33749733e-06,
],
[
2.07300000e-06,
Expand Down Expand Up @@ -1923,7 +1920,6 @@ def dream_bayes():
5.87958511e-06,
5.87958515e-06,
5.87958515e-06,
5.87958515e-06,
],
[
2.07300000e-06,
Expand Down Expand Up @@ -1955,7 +1951,6 @@ def dream_bayes():
5.87958511e-06,
5.87958515e-06,
5.87958515e-06,
5.87958515e-06,
],
],
),
Expand Down Expand Up @@ -1991,7 +1986,6 @@ def dream_bayes():
4.65378471e-06,
4.65378475e-06,
4.65378475e-06,
4.65378475e-06,
],
[
-1.35107208e-06,
Expand Down Expand Up @@ -2023,7 +2017,6 @@ def dream_bayes():
4.65378471e-06,
4.65378475e-06,
4.65378475e-06,
4.65378475e-06,
],
[
3.19875093e-07,
Expand Down Expand Up @@ -2055,7 +2048,6 @@ def dream_bayes():
4.92196879e-06,
4.92248720e-06,
4.92261931e-06,
4.92266441e-06,
],
[
2.07300000e-06,
Expand Down Expand Up @@ -2087,7 +2079,6 @@ def dream_bayes():
5.17758173e-06,
5.17859423e-06,
5.17885226e-06,
5.17894035e-06,
],
[
2.07300000e-06,
Expand Down Expand Up @@ -2119,7 +2110,6 @@ def dream_bayes():
5.17758173e-06,
5.17859423e-06,
5.17885226e-06,
5.17894035e-06,
],
],
),
Expand Down Expand Up @@ -4140,7 +4130,6 @@ def dream_results():
9.14706158e-07,
9.14597081e-07,
9.14576711e-07,
9.14569865e-07,
],
[
-2.26996619e-06,
Expand Down Expand Up @@ -4172,7 +4161,6 @@ def dream_results():
9.14706158e-07,
9.14597081e-07,
9.14576711e-07,
9.14569865e-07,
],
[
-1.50598689e-07,
Expand Down Expand Up @@ -4204,7 +4192,6 @@ def dream_results():
3.33756709e-06,
3.33751126e-06,
3.33750083e-06,
3.33749733e-06,
],
[
2.07300000e-06,
Expand Down Expand Up @@ -4236,7 +4223,6 @@ def dream_results():
5.87958511e-06,
5.87958515e-06,
5.87958515e-06,
5.87958515e-06,
],
[
2.07300000e-06,
Expand Down Expand Up @@ -4268,7 +4254,6 @@ def dream_results():
5.87958511e-06,
5.87958515e-06,
5.87958515e-06,
5.87958515e-06,
],
]
),
Expand Down Expand Up @@ -4304,7 +4289,6 @@ def dream_results():
4.65378471e-06,
4.65378475e-06,
4.65378475e-06,
4.65378475e-06,
],
[
-1.35107208e-06,
Expand Down Expand Up @@ -4336,7 +4320,6 @@ def dream_results():
4.65378471e-06,
4.65378475e-06,
4.65378475e-06,
4.65378475e-06,
],
[
3.19875093e-07,
Expand Down Expand Up @@ -4368,7 +4351,6 @@ def dream_results():
4.92196879e-06,
4.92248720e-06,
4.92261931e-06,
4.92266441e-06,
],
[
2.07300000e-06,
Expand Down Expand Up @@ -4400,7 +4382,6 @@ def dream_results():
5.17758173e-06,
5.17859423e-06,
5.17885226e-06,
5.17894035e-06,
],
[
2.07300000e-06,
Expand Down Expand Up @@ -4432,7 +4413,6 @@ def dream_results():
5.17758173e-06,
5.17859423e-06,
5.17885226e-06,
5.17894035e-06,
],
]
),
Expand Down
4 changes: 2 additions & 2 deletions tests/test_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def dream_str():
"'Background parameter D2O', 'Background parameter SMW', 'D2O', 'SMW'],\n"
"predictionIntervals = PredictionIntervals(\n"
"\treflectivity = [Data array: [5 x 21], Data array: [5 x 21]],\n"
"\tsld = [[Data array: [5 x 30], Data array: [5 x 30]]],\n"
"\tsld = [[Data array: [5 x 29], Data array: [5 x 29]]],\n"
"\tsampleChi = Data array: [1000],\n"
"),\n"
"confidenceIntervals = ConfidenceIntervals(\n"
Expand Down Expand Up @@ -191,7 +191,7 @@ def test_make_results(test_procedure, test_output_results, test_bayes, test_resu
@pytest.mark.parametrize(
["test_output_results", "test_str"],
[
# ("reflectivity_calculation_results", "reflectivity_calculation_str"),
("reflectivity_calculation_results", "reflectivity_calculation_str"),
("dream_results", "dream_str"),
],
)
Expand Down
8 changes: 6 additions & 2 deletions tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,10 +207,14 @@ def test_plot_ref_sld(mock: MagicMock, input_project, reflectivity_calculation_r
assert figure.axes[0].get_subplotspec().get_gridspec().get_geometry() == (1, 2)
assert len(figure.axes) == 2

for reflectivity, reflectivity_results in zip(data.reflectivity, reflectivity_calculation_results.reflectivity):
assert (reflectivity == reflectivity_results).all()
for sldProfile, result_sld_profile in zip(data.sldProfiles, reflectivity_calculation_results.sldProfiles):
for sld, sld_results in zip(sldProfile, result_sld_profile):
assert (sld == sld_results).all()

assert data.modelType == input_project.model
assert data.reflectivity == reflectivity_calculation_results.reflectivity
assert data.shiftedData == reflectivity_calculation_results.shiftedData
assert data.sldProfiles == reflectivity_calculation_results.sldProfiles
assert data.resampledLayers == reflectivity_calculation_results.resampledLayers
assert data.dataPresent.size == 0
assert (data.subRoughs == reflectivity_calculation_results.contrastParams.subRoughs).all()
Expand Down
Loading