Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 204 additions & 1 deletion ratapi/utils/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,205 @@ def plot_ref_sld(
plt.show(block=block)


class BlittingSupport:
"""Create a SLD plot that uses blitting to get faster draws.

The blit plot stores the background from an
initial draw then updates the foreground (lines and error bars) if the background is not changed.

Parameters
----------
data : PlotEventData
The plot event data that contains all the information
to generate the ref and sld plots
fig : matplotlib.pyplot.figure, optional
The figure class that has two subplots
linear_x : bool, default: False
Controls whether the x-axis on reflectivity plot uses the linear scale
q4 : bool, default: False
Controls whether Q^4 is plotted on the reflectivity plot
show_error_bar : bool, default: True
Controls whether the error bars are shown
show_grid : bool, default: False
Controls whether the grid is shown
show_legend : bool, default: True
Controls whether the legend is shown
shift_value : float, default: 100
A value between 1 and 100 that controls the spacing between the reflectivity plots for each of the contrasts
"""

def __init__(
self,
data,
fig=None,
linear_x: bool = False,
q4: bool = False,
show_error_bar: bool = True,
show_grid: bool = False,
show_legend: bool = True,
shift_value: float = 100,
):
self.figure = fig
self.linear_x = linear_x
self.q4 = q4
self.show_error_bar = show_error_bar
self.show_grid = show_grid
self.show_legend = show_legend
self.shift_value = shift_value
self.update_plot(data)
self.event_id = self.figure.canvas.mpl_connect("resize_event", self.resizeEvent)

def __del__(self):
self.figure.canvas.mpl_disconnect(self.event_id)

def resizeEvent(self, _event):
"""Ensure the background is updated after a resize event."""
self.__background_changed = True

def update(self, data):
"""Update the foreground, if background has not changed otherwise it updates full plot.

Parameters
----------
data : PlotEventData
The plot event data that contains all the information
to generate the ref and sld plots
"""
if self.__background_changed:
self.update_plot(data)
else:
self.update_foreground(data)

def __setattr__(self, name, value):
old_value = getattr(self, name, None)
if value == old_value:
return

super().__setattr__(name, value)
if name in ["figure", "linear_x", "q4", "show_error_bar", "show_grid", "show_legend", "shift_value"]:
self.__background_changed = True

def set_animated(self, is_animated: bool):
"""Set the animated property of foreground plot elements.

Parameters
----------
is_animated : bool
Indicates if the animated property should be set.
"""
for line in self.figure.axes[0].lines:
line.set_animated(is_animated)
for line in self.figure.axes[1].lines:
line.set_animated(is_animated)
for container in self.figure.axes[0].containers:
container[2][0].set_animated(is_animated)

def adjust_error_bar(self, error_bar_container, x, y, y_error):
"""Adjust the error bar data.

Parameters
----------
error_bar_container : Tuple
Tuple containing the artist of the errorbar i.e. (data line, cap lines, bar lines)
x : np.ndarray
The shifted data x axis data
y : np.ndarray
The shifted data y axis data
y_error : np.ndarray
The shifted data y axis error data
"""
line, _, (bars_y,) = error_bar_container

line.set_data(x, y)
x_base = x
y_base = y

y_error_top = y_base + y_error
y_error_bottom = y_base - y_error

new_segments_y = [np.array([[x, yt], [x, yb]]) for x, yt, yb in zip(x_base, y_error_top, y_error_bottom)]
bars_y.set_segments(new_segments_y)

def update_plot(self, data):
"""Update the full plot.

Parameters
----------
data : PlotEventData
The plot event data that contains all the information
to generate the ref and sld plots
"""
if self.figure is not None:
self.figure.clf()
self.figure = ratapi.plotting.plot_ref_sld_helper(
data,
self.figure,
linear_x=self.linear_x,
q4=self.q4,
show_error_bar=self.show_error_bar,
show_grid=self.show_grid,
show_legend=self.show_legend,
animated=True,
)
self.figure.tight_layout(pad=1)
self.figure.canvas.draw()
self.bg = self.figure.canvas.copy_from_bbox(self.figure.bbox)
for line in self.figure.axes[0].lines:
self.figure.axes[0].draw_artist(line)
for line in self.figure.axes[1].lines:
self.figure.axes[1].draw_artist(line)
for container in self.figure.axes[0].containers:
self.figure.axes[0].draw_artist(container[2][0])
self.figure.canvas.blit(self.figure.bbox)
self.set_animated(False)
self.__background_changed = False

def update_foreground(self, data):
"""Update the plot foreground only.

Parameters
----------
data : PlotEventData
The plot event data that contains all the information
to generate the ref and sld plots
"""
self.set_animated(True)
self.figure.canvas.restore_region(self.bg)
plot_data = ratapi.plotting._extract_plot_data(data, self.q4, self.show_error_bar, self.shift_value)

offset = 2 if self.show_error_bar else 1
for i in range(
0,
len(self.figure.axes[0].lines),
):
self.figure.axes[0].lines[i].set_data(plot_data["ref"][i // offset][0], plot_data["ref"][i // offset][1])
self.figure.axes[0].draw_artist(self.figure.axes[0].lines[i])

i = 0
for j in range(len(plot_data["sld"])):
for sld in plot_data["sld"][j]:
self.figure.axes[1].lines[i].set_data(sld[0], sld[1])
self.figure.axes[1].draw_artist(self.figure.axes[1].lines[i])
i += 1

if plot_data["sld_resample"]:
for resampled in plot_data["sld_resample"][j]:
self.figure.axes[1].lines[i].set_data(resampled[0], resampled[1])
self.figure.axes[1].draw_artist(self.figure.axes[1].lines[i])
i += 1

for i, container in enumerate(self.figure.axes[0].containers):
self.adjust_error_bar(
container, plot_data["error"][i][0], plot_data["error"][i][1], plot_data["error"][i][2]
)
self.figure.axes[0].draw_artist(container[2][0])
self.figure.axes[0].draw_artist(container[0])

self.figure.canvas.blit(self.figure.bbox)
self.figure.canvas.flush_events()
self.set_animated(False)


class LivePlot:
"""Create a plot that gets updates from the plot event during a calculation.

Expand All @@ -369,6 +568,7 @@ class LivePlot:
def __init__(self, block=False):
self.block = block
self.closed = False
self.blit_plot = None

def __enter__(self):
self.figure = plt.subplots(1, 2)[0]
Expand All @@ -394,7 +594,10 @@ def plotEvent(self, event):

"""
if not self.closed and self.figure.number in plt.get_fignums():
plot_ref_sld_helper(event, self.figure)
if self.blit_plot is None:
self.blit_plot = BlittingSupport(event, self.figure)
else:
self.blit_plot.update(event)

def __exit__(self, _exc_type, _exc_val, _traceback):
ratapi.events.clear(ratapi.events.EventTypes.Plot, self.plotEvent)
Expand Down
3 changes: 3 additions & 0 deletions tests/test_orso_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def prist():
],
)
@pytest.mark.parametrize("absorption", [True, False])
@pytest.mark.skip(reason="orsopy database website (https://slddb.esss.dk/slddb/) is not available")
def test_orso_model_to_rat(model, absorption):
"""Test that orso_model_to_rat gives the expected parameters, layers and model."""

Expand Down Expand Up @@ -72,6 +73,7 @@ def test_orso_model_to_rat(model, absorption):
"prist5_10K_m_025.Rqz.ort",
],
)
@pytest.mark.skip(reason="orsopy database website (https://slddb.esss.dk/slddb/) is not available")
def test_load_ort_data(test_data):
"""Test that .ort data is loaded correctly."""
# manually get the test data for comparison
Expand Down Expand Up @@ -104,6 +106,7 @@ def test_load_ort_data(test_data):
["prist5_10K_m_025.Rqz.ort", "prist.json"],
],
)
@pytest.mark.skip(reason="orsopy database website (https://slddb.esss.dk/slddb/) is not available")
def test_load_ort_project(test_data, expected_data):
"""Test that a project with model data is loaded correctly."""
ort_data = ORSOProject(Path(TEST_DIR_PATH, test_data))
Expand Down
Loading