Skip to content

Commit 62df2a9

Browse files
committed
Fixes for plots popping up in RasCAL
1 parent e9208d1 commit 62df2a9

File tree

3 files changed

+66
-37
lines changed

3 files changed

+66
-37
lines changed

ratapi/utils/plotting.py

Lines changed: 49 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def _extract_plot_data(event_data: PlotEventData, q4: bool, show_error_bar: bool
9494

9595
def plot_ref_sld_helper(
9696
data: PlotEventData,
97-
fig: Optional[matplotlib.pyplot.figure] = None,
97+
fig: matplotlib.pyplot.figure,
9898
delay: bool = True,
9999
confidence_intervals: Union[dict, None] = None,
100100
linear_x: bool = False,
@@ -112,8 +112,8 @@ def plot_ref_sld_helper(
112112
data : PlotEventData
113113
The plot event data that contains all the information
114114
to generate the ref and sld plots
115-
fig : matplotlib.pyplot.figure, optional
116-
The figure class that has two subplots
115+
fig : matplotlib.pyplot.figure
116+
The figure object that has two subplots
117117
delay : bool, default: True
118118
Controls whether to delay 0.005s after plot is created
119119
confidence_intervals : dict or None, default None
@@ -134,19 +134,13 @@ def plot_ref_sld_helper(
134134
animated : bool, default: False
135135
Controls whether the animated property of foreground plot elements should be set.
136136
137-
Returns
138-
-------
139-
fig : matplotlib.pyplot.figure
140-
The figure class that has two subplots
141-
142137
"""
143138
preserve_zoom = False
144139

145-
if fig is None:
146-
fig = plt.subplots(1, 2)[0]
147-
elif len(fig.axes) != 2:
140+
if len(fig.axes) != 2:
148141
fig.clf()
149142
fig.subplots(1, 2)
143+
150144
fig.subplots_adjust(wspace=0.3)
151145

152146
ref_plot: plt.Axes = fig.axes[0]
@@ -233,13 +227,12 @@ def plot_ref_sld_helper(
233227
if delay:
234228
plt.pause(0.005)
235229

236-
return fig
237-
238230

239231
def plot_ref_sld(
240232
project: ratapi.Project,
241233
results: Union[ratapi.outputs.Results, ratapi.outputs.BayesResults],
242234
block: bool = False,
235+
fig: Optional[matplotlib.pyplot.figure] = None,
243236
return_fig: bool = False,
244237
bayes: Literal[65, 95, None] = None,
245238
linear_x: bool = False,
@@ -259,6 +252,8 @@ def plot_ref_sld(
259252
The result from the calculation
260253
block : bool, default: False
261254
Indicates the plot should block until it is closed
255+
fig : matplotlib.pyplot.figure, optional
256+
The figure object that has two subplots
262257
return_fig : bool, default False
263258
If True, return the figure instead of displaying it.
264259
bayes : 65, 95 or None, default None
@@ -336,11 +331,15 @@ def plot_ref_sld(
336331
else:
337332
confidence_intervals = None
338333

339-
figure = plt.subplots(1, 2)[0]
334+
if fig is None:
335+
fig = plt.subplots(1, 2)[0]
336+
elif len(fig.axes) != 2:
337+
fig.clf()
338+
fig.subplots(1, 2)
340339

341340
plot_ref_sld_helper(
342341
data,
343-
figure,
342+
fig,
344343
confidence_intervals=confidence_intervals,
345344
linear_x=linear_x,
346345
q4=q4,
@@ -351,7 +350,7 @@ def plot_ref_sld(
351350
)
352351

353352
if return_fig:
354-
return figure
353+
return fig
355354

356355
plt.show(block=block)
357356

@@ -486,7 +485,7 @@ def update_plot(self, data):
486485
"""
487486
if self.figure is not None:
488487
self.figure.clf()
489-
self.figure = ratapi.plotting.plot_ref_sld_helper(
488+
plot_ref_sld_helper(
490489
data,
491490
self.figure,
492491
linear_x=self.linear_x,
@@ -520,7 +519,7 @@ def update_foreground(self, data):
520519
"""
521520
self.set_animated(True)
522521
self.figure.canvas.restore_region(self.bg)
523-
plot_data = ratapi.plotting._extract_plot_data(data, self.q4, self.show_error_bar, self.shift_value)
522+
plot_data = _extract_plot_data(data, self.q4, self.show_error_bar, self.shift_value)
524523

525524
offset = 2 if self.show_error_bar else 1
526525
for i in range(
@@ -649,6 +648,7 @@ def plot_corner(
649648
params: Union[list[Union[int, str]], None] = None,
650649
smooth: bool = True,
651650
block: bool = False,
651+
fig: Optional[matplotlib.pyplot.figure] = None,
652652
return_fig: bool = False,
653653
hist_kwargs: Union[dict, None] = None,
654654
hist2d_kwargs: Union[dict, None] = None,
@@ -666,6 +666,8 @@ def plot_corner(
666666
Whether to apply Gaussian smoothing to the corner plot.
667667
block : bool, default False
668668
Whether Python should block until the plot is closed.
669+
fig : matplotlib.pyplot.figure, optional
670+
The figure object to use for plot.
669671
return_fig: bool, default False
670672
If True, return the figure as an object instead of showing it.
671673
hist_kwargs : dict
@@ -696,7 +698,12 @@ def plot_corner(
696698

697699
num_params = len(params)
698700

699-
fig, axes = plt.subplots(num_params, num_params, figsize=(11, 10))
701+
if fig is None:
702+
fig, axes = plt.subplots(num_params, num_params, figsize=(11, 10))
703+
else:
704+
fig.clf()
705+
axes = fig.subplots(num_params, num_params)
706+
700707
# i is row, j is column
701708
for i, row_param in enumerate(params):
702709
for j, col_param in enumerate(params):
@@ -956,7 +963,9 @@ def plot_contour(
956963
plt.show(block=block)
957964

958965

959-
def panel_plot_helper(plot_func: Callable, indices: list[int]) -> matplotlib.figure.Figure:
966+
def panel_plot_helper(
967+
plot_func: Callable, indices: list[int], fig: Optional[matplotlib.pyplot.figure] = None
968+
) -> matplotlib.figure.Figure:
960969
"""Generate a panel-based plot from a single plot function.
961970
962971
Parameters
@@ -965,6 +974,8 @@ def panel_plot_helper(plot_func: Callable, indices: list[int]) -> matplotlib.fig
965974
A function which plots one parameter on an Axes object, given its index.
966975
indices : list[int]
967976
The list of indices to pass into ``plot_func``.
977+
fig : matplotlib.pyplot.figure, optional
978+
The figure object to use for plot.
968979
969980
Returns
970981
-------
@@ -974,10 +985,18 @@ def panel_plot_helper(plot_func: Callable, indices: list[int]) -> matplotlib.fig
974985
"""
975986
nplots = len(indices)
976987
nrows, ncols = ceil(sqrt(nplots)), round(sqrt(nplots))
977-
fig = plt.subplots(nrows, ncols, figsize=(11, 10))[0]
988+
989+
if fig is None:
990+
fig = plt.subplots(nrows, ncols, figsize=(11, 10))[0]
991+
else:
992+
fig.clf()
993+
fig.subplots(nrows, ncols)
978994
axs = fig.get_axes()
979995

980996
for plot_num, index in enumerate(indices):
997+
axs[plot_num].tick_params(which="both", labelsize="medium")
998+
axs[plot_num].xaxis.offsetText.set_fontsize("small")
999+
axs[plot_num].yaxis.offsetText.set_fontsize("small")
9811000
plot_func(axs[plot_num], index)
9821001

9831002
# blank unused plots
@@ -998,6 +1017,7 @@ def plot_hists(
9981017
dict[Literal["normal", "lognor", "kernel", None]], Literal["normal", "lognor", "kernel", None]
9991018
] = None,
10001019
block: bool = False,
1020+
fig: Optional[matplotlib.pyplot.figure] = None,
10011021
return_fig: bool = False,
10021022
**hist_settings,
10031023
):
@@ -1031,6 +1051,8 @@ def plot_hists(
10311051
e.g. to apply 'normal' to all unset parameters, set `estimated_density = {'default': 'normal'}`.
10321052
block : bool, default False
10331053
Whether Python should block until the plot is closed.
1054+
fig : matplotlib.pyplot.figure, optional
1055+
The figure object to use for plot.
10341056
return_fig: bool, default False
10351057
If True, return the figure as an object instead of showing it.
10361058
hist_settings :
@@ -1090,6 +1112,7 @@ def validate_dens_type(dens_type: Union[str, None], param: str):
10901112
**hist_settings,
10911113
),
10921114
params,
1115+
fig,
10931116
)
10941117
if return_fig:
10951118
return fig
@@ -1102,6 +1125,7 @@ def plot_chain(
11021125
params: Union[list[Union[int, str]], None] = None,
11031126
maxpoints: int = 15000,
11041127
block: bool = False,
1128+
fig: Optional[matplotlib.pyplot.figure] = None,
11051129
return_fig: bool = False,
11061130
):
11071131
"""Plot the MCMC chain for each parameter of a Bayesian analysis.
@@ -1117,6 +1141,8 @@ def plot_chain(
11171141
The maximum number of points to plot for each parameter.
11181142
block : bool, default False
11191143
Whether Python should block until the plot is closed.
1144+
fig : matplotlib.pyplot.figure, optional
1145+
The figure object to use for plot.
11201146
return_fig: bool, default False
11211147
If True, return the figure as an object instead of showing it.
11221148
@@ -1142,9 +1168,9 @@ def plot_chain(
11421168

11431169
def plot_one_chain(axes: Axes, i: int):
11441170
axes.plot(range(0, nsimulations, skip), chain[:, i][0:nsimulations:skip])
1145-
axes.set_title(results.fitNames[i])
1171+
axes.set_title(results.fitNames[i], fontsize="small")
11461172

1147-
fig = panel_plot_helper(plot_one_chain, params)
1173+
fig = panel_plot_helper(plot_one_chain, params, fig=fig)
11481174
if return_fig:
11491175
return fig
11501176
plt.show(block=block)

tests/test_controls.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
import contextlib
44
import os
5-
from pathlib import Path
65
import tempfile
6+
from pathlib import Path
77
from typing import Any, Union
88

99
import pydantic
@@ -45,14 +45,15 @@ def test_extra_property_error() -> None:
4545
with pytest.raises(pydantic.ValidationError, match="Object has no attribute 'test'"):
4646
controls.test = 1
4747

48+
4849
@pytest.mark.parametrize(
4950
"inputs",
5051
[
51-
{"parallel": Parallel.Contrasts, "resampleMinAngle": 0.66},
52-
{"procedure": 'simplex'},
53-
{"procedure": 'dream', "nSamples": 504, "nChains": 1200},
54-
{"procedure": 'de', "crossoverProbability": 0.45, "strategy": Strategies.RandomEitherOrAlgorithm},
55-
{"procedure": 'ns', "nMCMC": 4, "propScale": 0.6},
52+
{"parallel": Parallel.Contrasts, "resampleMinAngle": 0.66},
53+
{"procedure": "simplex"},
54+
{"procedure": "dream", "nSamples": 504, "nChains": 1200},
55+
{"procedure": "de", "crossoverProbability": 0.45, "strategy": Strategies.RandomEitherOrAlgorithm},
56+
{"procedure": "ns", "nMCMC": 4, "propScale": 0.6},
5657
],
5758
)
5859
def test_save_load(inputs):
@@ -68,6 +69,7 @@ def test_save_load(inputs):
6869
for field in Controls.model_fields:
6970
assert getattr(converted_controls, field) == getattr(original_controls, field)
7071

72+
7173
class TestCalculate:
7274
"""Tests the Calculate class."""
7375

tests/test_plotting.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ def fig(request) -> plt.figure:
5050
"""Creates the fixture for the tests."""
5151
plt.close("all")
5252
figure = plt.subplots(1, 2)[0]
53-
return RATplot.plot_ref_sld_helper(fig=figure, data=domains_data() if request.param else data())
53+
RATplot.plot_ref_sld_helper(fig=figure, data=domains_data() if request.param else data())
54+
return figure
5455

5556

5657
@pytest.fixture
@@ -68,7 +69,8 @@ def bayes_fig(request) -> plt.figure:
6869
for sld in dat.sldProfiles
6970
],
7071
}
71-
return RATplot.plot_ref_sld_helper(data=dat, fig=figure, confidence_intervals=confidence_intervals)
72+
RATplot.plot_ref_sld_helper(data=dat, fig=figure, confidence_intervals=confidence_intervals)
73+
return figure
7274

7375

7476
@pytest.mark.parametrize("fig", [False, True], indirect=True)
@@ -120,8 +122,7 @@ def test_ref_sld_color_formatting(fig: plt.figure) -> None:
120122
assert sld_plot.get_lines()[i].get_color() == sld_plot.get_lines()[i + 1].get_color()
121123

122124

123-
@pytest.mark.parametrize("bayes", [65, 95])
124-
def test_ref_sld_bayes(fig, bayes_fig, bayes):
125+
def test_ref_sld_bayes(fig, bayes_fig):
125126
"""Test that shading is correctly added to the figure when confidence intervals are supplied."""
126127
# the shading is of type PolyCollection
127128
for axes in fig.axes:
@@ -137,7 +138,7 @@ def test_sld_profile_function_call(mock: MagicMock) -> None:
137138
"""Tests the makeSLDProfile function called with
138139
correct args.
139140
"""
140-
RATplot.plot_ref_sld_helper(data())
141+
RATplot.plot_ref_sld_helper(data(), plt.subplots(1, 2)[0])
141142

142143
assert mock.call_count == 3
143144
assert mock.call_args_list[0].args[0] == 2.07e-06
@@ -211,9 +212,9 @@ def test_plot_ref_sld(mock: MagicMock, input_project, reflectivity_calculation_r
211212
def test_ref_sld_subplot_correction():
212213
"""Test that if an incorrect number of subplots is corrected in the figure helper."""
213214
fig = plt.subplots(1, 3)[0]
214-
ref_sld_fig = RATplot.plot_ref_sld_helper(data=data(), fig=fig)
215-
assert ref_sld_fig.axes[0].get_subplotspec().get_gridspec().get_geometry() == (1, 2)
216-
assert len(ref_sld_fig.axes) == 2
215+
RATplot.plot_ref_sld_helper(data=data(), fig=fig)
216+
assert fig.axes[0].get_subplotspec().get_gridspec().get_geometry() == (1, 2)
217+
assert len(fig.axes) == 2
217218

218219

219220
@patch("ratapi.utils.plotting.plot_ref_sld_helper")

0 commit comments

Comments
 (0)