@@ -94,7 +94,7 @@ def _extract_plot_data(event_data: PlotEventData, q4: bool, show_error_bar: bool
9494
9595def plot_ref_sld_helper (
9696 data : PlotEventData ,
97- fig : Optional [ matplotlib .pyplot .figure ] = None ,
97+ fig : matplotlib .pyplot .figure ,
9898 delay : bool = True ,
9999 confidence_intervals : Union [dict , None ] = None ,
100100 linear_x : bool = False ,
@@ -112,8 +112,8 @@ def plot_ref_sld_helper(
112112 data : PlotEventData
113113 The plot event data that contains all the information
114114 to generate the ref and sld plots
115- fig : matplotlib.pyplot.figure, optional
116- The figure class that has two subplots
115+ fig : matplotlib.pyplot.figure
116+ The figure object that has two subplots
117117 delay : bool, default: True
118118 Controls whether to delay 0.005s after plot is created
119119 confidence_intervals : dict or None, default None
@@ -134,19 +134,13 @@ def plot_ref_sld_helper(
134134 animated : bool, default: False
135135 Controls whether the animated property of foreground plot elements should be set.
136136
137- Returns
138- -------
139- fig : matplotlib.pyplot.figure
140- The figure class that has two subplots
141-
142137 """
143138 preserve_zoom = False
144139
145- if fig is None :
146- fig = plt .subplots (1 , 2 )[0 ]
147- elif len (fig .axes ) != 2 :
140+ if len (fig .axes ) != 2 :
148141 fig .clf ()
149142 fig .subplots (1 , 2 )
143+
150144 fig .subplots_adjust (wspace = 0.3 )
151145
152146 ref_plot : plt .Axes = fig .axes [0 ]
@@ -233,13 +227,12 @@ def plot_ref_sld_helper(
233227 if delay :
234228 plt .pause (0.005 )
235229
236- return fig
237-
238230
239231def plot_ref_sld (
240232 project : ratapi .Project ,
241233 results : Union [ratapi .outputs .Results , ratapi .outputs .BayesResults ],
242234 block : bool = False ,
235+ fig : Optional [matplotlib .pyplot .figure ] = None ,
243236 return_fig : bool = False ,
244237 bayes : Literal [65 , 95 , None ] = None ,
245238 linear_x : bool = False ,
@@ -259,6 +252,8 @@ def plot_ref_sld(
259252 The result from the calculation
260253 block : bool, default: False
261254 Indicates the plot should block until it is closed
255+ fig : matplotlib.pyplot.figure, optional
256+ The figure object that has two subplots
262257 return_fig : bool, default False
263258 If True, return the figure instead of displaying it.
264259 bayes : 65, 95 or None, default None
@@ -336,11 +331,15 @@ def plot_ref_sld(
336331 else :
337332 confidence_intervals = None
338333
339- figure = plt .subplots (1 , 2 )[0 ]
334+ if fig is None :
335+ fig = plt .subplots (1 , 2 )[0 ]
336+ elif len (fig .axes ) != 2 :
337+ fig .clf ()
338+ fig .subplots (1 , 2 )
340339
341340 plot_ref_sld_helper (
342341 data ,
343- figure ,
342+ fig ,
344343 confidence_intervals = confidence_intervals ,
345344 linear_x = linear_x ,
346345 q4 = q4 ,
@@ -351,7 +350,7 @@ def plot_ref_sld(
351350 )
352351
353352 if return_fig :
354- return figure
353+ return fig
355354
356355 plt .show (block = block )
357356
@@ -486,7 +485,7 @@ def update_plot(self, data):
486485 """
487486 if self .figure is not None :
488487 self .figure .clf ()
489- self . figure = ratapi . plotting . plot_ref_sld_helper (
488+ plot_ref_sld_helper (
490489 data ,
491490 self .figure ,
492491 linear_x = self .linear_x ,
@@ -520,7 +519,7 @@ def update_foreground(self, data):
520519 """
521520 self .set_animated (True )
522521 self .figure .canvas .restore_region (self .bg )
523- plot_data = ratapi . plotting . _extract_plot_data (data , self .q4 , self .show_error_bar , self .shift_value )
522+ plot_data = _extract_plot_data (data , self .q4 , self .show_error_bar , self .shift_value )
524523
525524 offset = 2 if self .show_error_bar else 1
526525 for i in range (
@@ -649,6 +648,7 @@ def plot_corner(
649648 params : Union [list [Union [int , str ]], None ] = None ,
650649 smooth : bool = True ,
651650 block : bool = False ,
651+ fig : Optional [matplotlib .pyplot .figure ] = None ,
652652 return_fig : bool = False ,
653653 hist_kwargs : Union [dict , None ] = None ,
654654 hist2d_kwargs : Union [dict , None ] = None ,
@@ -666,6 +666,8 @@ def plot_corner(
666666 Whether to apply Gaussian smoothing to the corner plot.
667667 block : bool, default False
668668 Whether Python should block until the plot is closed.
669+ fig : matplotlib.pyplot.figure, optional
670+ The figure object to use for plot.
669671 return_fig: bool, default False
670672 If True, return the figure as an object instead of showing it.
671673 hist_kwargs : dict
@@ -696,7 +698,12 @@ def plot_corner(
696698
697699 num_params = len (params )
698700
699- fig , axes = plt .subplots (num_params , num_params , figsize = (11 , 10 ))
701+ if fig is None :
702+ fig , axes = plt .subplots (num_params , num_params , figsize = (11 , 10 ))
703+ else :
704+ fig .clf ()
705+ axes = fig .subplots (num_params , num_params )
706+
700707 # i is row, j is column
701708 for i , row_param in enumerate (params ):
702709 for j , col_param in enumerate (params ):
@@ -956,7 +963,9 @@ def plot_contour(
956963 plt .show (block = block )
957964
958965
959- def panel_plot_helper (plot_func : Callable , indices : list [int ]) -> matplotlib .figure .Figure :
966+ def panel_plot_helper (
967+ plot_func : Callable , indices : list [int ], fig : Optional [matplotlib .pyplot .figure ] = None
968+ ) -> matplotlib .figure .Figure :
960969 """Generate a panel-based plot from a single plot function.
961970
962971 Parameters
@@ -965,6 +974,8 @@ def panel_plot_helper(plot_func: Callable, indices: list[int]) -> matplotlib.fig
965974 A function which plots one parameter on an Axes object, given its index.
966975 indices : list[int]
967976 The list of indices to pass into ``plot_func``.
977+ fig : matplotlib.pyplot.figure, optional
978+ The figure object to use for plot.
968979
969980 Returns
970981 -------
@@ -974,10 +985,18 @@ def panel_plot_helper(plot_func: Callable, indices: list[int]) -> matplotlib.fig
974985 """
975986 nplots = len (indices )
976987 nrows , ncols = ceil (sqrt (nplots )), round (sqrt (nplots ))
977- fig = plt .subplots (nrows , ncols , figsize = (11 , 10 ))[0 ]
988+
989+ if fig is None :
990+ fig = plt .subplots (nrows , ncols , figsize = (11 , 10 ))[0 ]
991+ else :
992+ fig .clf ()
993+ fig .subplots (nrows , ncols )
978994 axs = fig .get_axes ()
979995
980996 for plot_num , index in enumerate (indices ):
997+ axs [plot_num ].tick_params (which = "both" , labelsize = "medium" )
998+ axs [plot_num ].xaxis .offsetText .set_fontsize ("small" )
999+ axs [plot_num ].yaxis .offsetText .set_fontsize ("small" )
9811000 plot_func (axs [plot_num ], index )
9821001
9831002 # blank unused plots
@@ -998,6 +1017,7 @@ def plot_hists(
9981017 dict [Literal ["normal" , "lognor" , "kernel" , None ]], Literal ["normal" , "lognor" , "kernel" , None ]
9991018 ] = None ,
10001019 block : bool = False ,
1020+ fig : Optional [matplotlib .pyplot .figure ] = None ,
10011021 return_fig : bool = False ,
10021022 ** hist_settings ,
10031023):
@@ -1031,6 +1051,8 @@ def plot_hists(
10311051 e.g. to apply 'normal' to all unset parameters, set `estimated_density = {'default': 'normal'}`.
10321052 block : bool, default False
10331053 Whether Python should block until the plot is closed.
1054+ fig : matplotlib.pyplot.figure, optional
1055+ The figure object to use for plot.
10341056 return_fig: bool, default False
10351057 If True, return the figure as an object instead of showing it.
10361058 hist_settings :
@@ -1090,6 +1112,7 @@ def validate_dens_type(dens_type: Union[str, None], param: str):
10901112 ** hist_settings ,
10911113 ),
10921114 params ,
1115+ fig ,
10931116 )
10941117 if return_fig :
10951118 return fig
@@ -1102,6 +1125,7 @@ def plot_chain(
11021125 params : Union [list [Union [int , str ]], None ] = None ,
11031126 maxpoints : int = 15000 ,
11041127 block : bool = False ,
1128+ fig : Optional [matplotlib .pyplot .figure ] = None ,
11051129 return_fig : bool = False ,
11061130):
11071131 """Plot the MCMC chain for each parameter of a Bayesian analysis.
@@ -1117,6 +1141,8 @@ def plot_chain(
11171141 The maximum number of points to plot for each parameter.
11181142 block : bool, default False
11191143 Whether Python should block until the plot is closed.
1144+ fig : matplotlib.pyplot.figure, optional
1145+ The figure object to use for plot.
11201146 return_fig: bool, default False
11211147 If True, return the figure as an object instead of showing it.
11221148
@@ -1142,9 +1168,9 @@ def plot_chain(
11421168
11431169 def plot_one_chain (axes : Axes , i : int ):
11441170 axes .plot (range (0 , nsimulations , skip ), chain [:, i ][0 :nsimulations :skip ])
1145- axes .set_title (results .fitNames [i ])
1171+ axes .set_title (results .fitNames [i ], fontsize = "small" )
11461172
1147- fig = panel_plot_helper (plot_one_chain , params )
1173+ fig = panel_plot_helper (plot_one_chain , params , fig = fig )
11481174 if return_fig :
11491175 return fig
11501176 plt .show (block = block )
0 commit comments