diff --git a/phaser/engines/common/output.py b/phaser/engines/common/output.py index 7100d18..31d641d 100644 --- a/phaser/engines/common/output.py +++ b/phaser/engines/common/output.py @@ -12,6 +12,51 @@ from phaser.state import ReconsState from phaser.plan import SaveOptions +## TODO: move these utility functions to something like utils.plotting +def interpolate_values(a1, a2, c, d, cutoff=0.5): + c = max(c, d*cutoff) + c = min(c, d) + interpolated_value1 = a1/(1-cutoff)/d*c+a1-a1/(1-cutoff) + interpolated_value2 = (a2-1)/(1-cutoff)/d * c + a2 -(a2-1)/(1-cutoff) + interpolated_value1 = max(0, min(1, interpolated_value1)) + interpolated_value2 = max(0, min(1, interpolated_value2)) + return (interpolated_value1, interpolated_value2) + +def grad_to_rgb(angle, absolute, max_abs, offset, SL = False): + import matplotlib.colors + """Get the rgb value for the given `angle` and the `absolute` value + Intended usage: grad = numpy.array(list(map(lambda p, q: grad_to_rgb(p, q, max_abs), deg, mag))) + + Parameters + ---------- + angle : float + The angle in radians + absolute : float + The absolute value of the gradient + max_abs : float + The maximum value against which all others are normalized (color-wise). + + Returns + ------- + array_like + The rgb value as a tuple with values [0..1] + """ + + # normalize angle + angle = (angle + offset) % (2 * numpy.pi) + if angle < 0: + angle += 2 * numpy.pi + + if SL: + # print('specify global max_abs!!!') + return matplotlib.colors.hsv_to_rgb((angle / 2 / numpy.pi, + numpy.where(absolute / max_abs>1, 1, absolute / max_abs), + numpy.where(absolute / max_abs>1, 1, absolute / max_abs))) + else: + return matplotlib.colors.hsv_to_rgb((angle / 2 / numpy.pi, + 1, + 1)) + def output_images(state: ReconsState, out_dir: Path, options: SaveOptions): for ty in options.images: @@ -176,6 +221,36 @@ def _save_object_mag(state: ReconsState, out_path: Path, options: SaveOptions, s write_opts['metadata']['axes'] = 'ZYX' if stack else 'YX' w.write(obj_mag, **write_opts) +## TODO: refactor using colorize_complex to match tilt implementation +def _plot_scan_update(state: ReconsState, out_path: Path, options: SaveOptions): + from matplotlib import pyplot + fig, ax = pyplot.subplots(figsize=(4, 4), dpi=options.plot_dpi, constrained_layout=True) + ax.set_aspect(1.) + [left, right, bottom, top] = state.object.sampling.mpl_extent() + ax.set_xlim(left, right) + ax.set_ylim(bottom, top) + + scan = to_numpy(state.scan.data) + + scan = to_numpy(state.scan.data) + disp = to_numpy(state.scan.data) - to_numpy(state.scan.initial_scan) + pos_y, pos_x = scan[..., 1].ravel(), scan[..., 0].ravel() + dY, dX = disp[..., 1].ravel(), disp[..., 0].ravel() + + deg = -numpy.arctan2(dY, dX) + mag = numpy.sqrt(dY ** 2 + dX ** 2) + max_abs = numpy.nanmax(mag) + + grad = numpy.array(list(map(lambda p, q: grad_to_rgb(p, q, max_abs, offset=0), deg, mag))) + + width = 0.1 + # hal = hl = 1. / width * length + headwidth = width + ax.quiver(pos_y, pos_x, dY, dX, color=grad, angles='xy', units='xy', width=width, edgecolor=None, linewidth=1, pivot='mid', scale_units='xy', scale=.1) + # ax.set_axis_off() + + fig.savefig(out_path) + pyplot.close(fig) def _plot_scan(state: ReconsState, out_path: Path, options: SaveOptions): from matplotlib import pyplot @@ -186,7 +261,7 @@ def _plot_scan(state: ReconsState, out_path: Path, options: SaveOptions): ax.set_xlim(left, right) ax.set_ylim(bottom, top) - scan = to_numpy(state.scan) + scan = to_numpy(state.scan.data) i = numpy.arange(scan[..., 0].size) ax.scatter(scan[..., 1].ravel(), scan[..., 0].ravel(), c=i, cmap='plasma', s=0.5, edgecolors='none') @@ -209,7 +284,7 @@ def _plot_tilt(state: ReconsState, out_path: Path, options: SaveOptions): ax.set_xlim(left, right) ax.set_ylim(bottom, top) - scan = to_numpy(state.scan) + scan = to_numpy(state.scan.data) tilt = to_numpy(state.tilt) tilt = tilt[..., 1] + tilt[..., 0]*1.j max_tilt = max(numpy.max(numpy.abs(tilt)), 1.0) # at least 1 mrad @@ -247,7 +322,8 @@ def _plot_tilt(state: ReconsState, out_path: Path, options: SaveOptions): 'object_mag_stack': partial(_save_object_mag, stack=True), 'object_mag_sum': partial(_save_object_mag, stack=False), 'scan': _plot_scan, + 'scan_update': _plot_scan_update, 'tilt': _plot_tilt, } # save functions with special handling of file extensions -_PLOT_FUNCS: t.Set[str] = {'scan', 'tilt'} \ No newline at end of file +_PLOT_FUNCS: t.Set[str] = {'scan', 'tilt', 'scan_update'} \ No newline at end of file diff --git a/phaser/engines/common/position_correction.py b/phaser/engines/common/position_correction.py index 2f0809f..724540a 100644 --- a/phaser/engines/common/position_correction.py +++ b/phaser/engines/common/position_correction.py @@ -41,8 +41,8 @@ def __init__(self, args: None, props: MomentumPositionSolverProps): self.momentum = props.momentum def init_state(self, sim: ReconsState) -> NDArray[numpy.floating]: - xp = get_array_module(sim.scan) - return xp.zeros_like(sim.scan) + xp = get_array_module(sim.scan.data) + return xp.zeros_like(sim.scan.data) def perform_update( self, diff --git a/phaser/engines/common/regularizers.py b/phaser/engines/common/regularizers.py index 416186e..1f03349 100644 --- a/phaser/engines/common/regularizers.py +++ b/phaser/engines/common/regularizers.py @@ -11,12 +11,157 @@ ) from phaser.utils.image import convolve1d from phaser.state import ReconsState -from phaser.hooks.regularization import ( +from phaser.hooks.regularization import (ScanConstraintProps, ClampObjectAmplitudeProps, LimitProbeSupportProps, NonNegObjectPhaseProps, RegularizeLayersProps, ObjLowPassProps, GaussianProps, CostRegularizerProps, TVRegularizerProps, UnstructuredGaussianProps ) +logger = logging.getLogger(__name__) + +class ScanUpdate(t.NamedTuple): + """ + Scan update object for holding the scan constraint row index and previous position arrays + """ + previous: numpy.typing.NDArray[numpy.floating] + row_bins: t.Optional[numpy.typing.NDArray[numpy.integer]] = None +class ScanConstraint: + """ + Constraints for the scan position updates. + This per iteration regularizer takes the unconstrained position updates and applies a + weighted average of its affine, line (row) averaged, high pass filtered, or low pass filtered components. + Currently, only the affine and line averaged constraints are implemented. + + See ref for details: + S. Ning, W. Xu, L. Loh, Z. Lu, M. Bosman, F. Zhang, Q. He, An integrated constrained gradient descent (iCGD) protocol to correct scan-positional errors for electron ptychography with high accuracy and precision. Ultramicroscopy 248, 113716 (2023). + """ + def __init__(self, args: None, props: ScanConstraintProps): + self.valid_kinds: t.Set[str] = {'affine', 'line', 'hpf', 'lpf'} + self.constraints: t.Dict[str, float] = {} #= {'default': 1.0} + + total_constraint_weight = 0 + for kind in self.valid_kinds: + if getattr(props, kind) > 0: + val = getattr(props, kind) + self.constraints[kind] = val + + total_constraint_weight += val + + if total_constraint_weight > 1.0: + raise ValueError("Sum of scan constraint weights cannot exceed 1.0") + + self.constraints['default'] = 1-total_constraint_weight + + logger.info(f"Initialized scan constraint with kinds {list(self.constraints.keys())} and weights {list(self.constraints.values())}") + + def init_state(self, sim: ReconsState) -> ScanUpdate: + if 'line' in self.constraints: + if (sim.scan.metadata.get('type') != 'raster') | (sim.scan.metadata.get('rows') is None): + raise ValueError("Line scan constraint cannot be applied to scans without row metadata") + row_vals = sim.scan.metadata.get('rows') + if isinstance(row_vals, list): + row_vals = numpy.array(row_vals, dtype=numpy.integer) + state = ScanUpdate(previous=sim.scan.data.copy(), row_bins=row_vals.ravel()) + else: + state = ScanUpdate(previous=sim.scan.data.copy(), row_bins=None) + return state + + def apply_group(self, group: NDArray[numpy.integer], sim: ReconsState, state: ScanUpdate) -> t.Tuple[ReconsState, ScanUpdate]: + return self.apply_iter(sim, state) + + def apply_iter(self, sim: ReconsState, state: ScanUpdate) -> t.Tuple[ReconsState, ScanUpdate]: + xp = get_array_module(sim.scan.data) + update = xp.zeros_like(sim.scan.data, dtype=sim.scan.data.dtype) + for kind, weight in self.constraints.items(): + match kind: + case 'affine': + update += _scan_affine(sim.scan.data, state.previous) * weight + case 'line': + if state.row_bins is not None: + update += _scan_line(sim.scan.data, state.previous, state.row_bins) * weight + case 'hpf': + pass + case 'lpf': + pass + case 'default': + update += _scan_default(sim.scan.data, state.previous) * weight + + sim.scan.data = state.previous + update + state = ScanUpdate(previous=sim.scan.data.copy(), row_bins=state.row_bins) + return (sim, state) + +# @partial(jit, donate_argnames=('pos',), cupy_fuse=True) +def _scan_default( + pos: NDArray[numpy.floating], + prev: NDArray[numpy.floating], +) -> NDArray[numpy.floating]: + """ + Pass through function for calculating the scan update from final and initial scan positions. + + :param pos: N x 2 array of unconstrained updated scan positions + :type pos: NDArray[numpy.floating] + :param prev: N x 2 array of scan positions, before update (previous iteration) + :type prev: NDArray[numpy.floating] + :return: N x 2 array of updates to the scan positions + :rtype: NDArray[floating[Any]] + """ + return pos - prev + +def _scan_affine( + pos: NDArray[numpy.floating], + prev: NDArray[numpy.floating], +) -> NDArray[numpy.floating]: + """ + Calculates and returns the affine component of the update between final and initial scan positions. + + :param pos: N x 2 array of unconstrained updated scan positions + :type pos: NDArray[numpy.floating] + :param prev: N x 2 array of scan positions, before update (previous iteration) + :type prev: NDArray[numpy.floating] + :return: N x 2 array of updates to the scan positions (affine only) + :rtype: NDArray[floating[Any]] + """ + xp = get_array_module(pos) + + disp_update = pos - prev + ones = xp.ones((pos.shape[0], 1), pos.dtype) + pos_prev = xp.concatenate([pos, ones], axis=1) + left = xp.matmul(pos_prev.T, disp_update) + right = xp.matmul(pos_prev.T, pos_prev) + A = xp.matmul(xp.linalg.inv(right), left) + constraint = xp.matmul(pos_prev, A) + center_ones = xp.ones((1, 1), pos.dtype) + center = xp.concatenate([xp.average(pos, axis = 0, keepdims=True), center_ones], axis=1, dtype=pos.dtype) + center_shift = center @ A + constraint -= center_shift + return constraint + +# @partial(jit, donate_argnames=('pos',), cupy_fuse=True) +def _scan_line( + pos: NDArray[numpy.floating], + prev: NDArray[numpy.floating], + rows: NDArray[numpy.integer], +) -> NDArray[numpy.floating]: + """ + Calculates and returns a line (row) averaged update from the unconstrained final and initial scan positions. + + :param pos: N x 2 array of unconstrained updated scan positions + :type pos: NDArray[numpy.floating] + :param prev: N x 2 array of scan positions, before update (previous iteration) + :type prev: NDArray[numpy.floating] + :param rows: an array of row indices corresponding to the N positions given by pos and prev + :type rows: NDArray[numpy.integer] + :return: N x 2 array of updates to the scan positions (row averaged) + :rtype: NDArray[floating[Any]] + """ + xp = get_array_module(pos) + + disp_val = pos - prev + + y_shifts = xp.bincount(rows, disp_val[:,0]) / xp.bincount(rows) + x_shifts = xp.bincount(rows, disp_val[:,1]) / xp.bincount(rows) + constraint = xp.stack([y_shifts[rows], x_shifts[rows]], axis=1, dtype=pos.dtype) + return constraint class ClampObjectAmplitude: def __init__(self, args: None, props: ClampObjectAmplitudeProps): @@ -250,7 +395,7 @@ def calc_loss_group( xp = get_array_module(sim.object.data) cost = xp.sum(xp.abs(sim.object.data - 1.0)) - cost_scale = xp.array(group.shape[-1] / prod(sim.scan.shape[:-1]), dtype=cost.dtype) + cost_scale = xp.array(group.shape[-1] / prod(sim.scan.data.shape[:-1]), dtype=cost.dtype) return (cost * cost_scale * self.cost, state) @@ -272,7 +417,7 @@ def calc_loss_group( cost = xp.sum(abs2(sim.object.data - 1.0)) - cost_scale = xp.array(group.shape[-1] / prod(sim.scan.shape[:-1]), dtype=cost.dtype) + cost_scale = xp.array(group.shape[-1] / prod(sim.scan.data.shape[:-1]), dtype=cost.dtype) return (cost * cost_scale * self.cost, state) # type: ignore @@ -293,7 +438,7 @@ def calc_loss_group( xp = get_array_module(sim.object.data) cost = xp.sum(xp.abs(xp.angle(sim.object.data))) - cost_scale = xp.array(group.shape[-1] / prod(sim.scan.shape[:-1]), dtype=cost.dtype) + cost_scale = xp.array(group.shape[-1] / prod(sim.scan.data.shape[:-1]), dtype=cost.dtype) return (cost * cost_scale * self.cost, state) @@ -319,7 +464,7 @@ def calc_loss_group( xp.abs(fft2(xp.prod(sim.object.data, axis=0))) ) # scale cost by fraction of the total reconstruction in the group - cost_scale = xp.array(group.shape[-1] / prod(sim.scan.shape[:-1]), dtype=cost.dtype) + cost_scale = xp.array(group.shape[-1] / prod(sim.scan.data.shape[:-1]), dtype=cost.dtype) return (cost * cost_scale * self.cost, state) @@ -351,7 +496,7 @@ def calc_loss_group( #) # scale cost by fraction of the total reconstruction in the group # TODO also scale by # of pixels or similar? - cost_scale = xp.array(group.shape[-1] / prod(sim.scan.shape[:-1]), dtype=cost.dtype) + cost_scale = xp.array(group.shape[-1] / prod(sim.scan.data.shape[:-1]), dtype=cost.dtype) return (cost * cost_scale * self.cost, state) @@ -377,7 +522,7 @@ def calc_loss_group( xp.sum(abs2(xp.diff(sim.object.data, axis=-2))) ) # scale cost by fraction of the total reconstruction in the group - cost_scale = xp.array(group.shape[-1] / prod(sim.scan.shape[:-1]), dtype=cost.dtype) + cost_scale = xp.array(group.shape[-1] / prod(sim.scan.data.shape[:-1]), dtype=cost.dtype) return (cost * cost_scale * self.cost, state) # type: ignore @@ -403,7 +548,7 @@ def calc_loss_group( cost = xp.sum(xp.abs(xp.diff(sim.object.data, axis=0))) # scale cost by fraction of the total reconstruction in the group - cost_scale = xp.array(group.shape[-1] / prod(sim.scan.shape[:-1]), dtype=cost.dtype) + cost_scale = xp.array(group.shape[-1] / prod(sim.scan.data.shape[:-1]), dtype=cost.dtype) return (cost * cost_scale * self.cost, state) @@ -429,7 +574,7 @@ def calc_loss_group( cost = xp.sum(abs2(xp.diff(sim.object.data, axis=0))) # scale cost by fraction of the total reconstruction in the group - cost_scale = xp.array(group.shape[-1] / prod(sim.scan.shape[:-1]), dtype=cost.dtype) + cost_scale = xp.array(group.shape[-1] / prod(sim.scan.data.shape[:-1]), dtype=cost.dtype) return (cost * cost_scale * self.cost, state) # type: ignore @@ -519,7 +664,7 @@ def __init__(self, args: None, props: UnstructuredGaussianProps): self.attr_path = props.attr_path def init_state(self, sim: ReconsState) -> NDArray[numpy.floating]: - xp = get_array_module(sim.scan) + xp = get_array_module(sim.scan.data) try: self.getattr_nested(sim, self.attr_path) except AttributeError as e: @@ -547,8 +692,8 @@ def setattr_nested(self, obj: t.Any, attr_path: str, value: t.Any): def apply_iter(self, sim: ReconsState, state: NDArray[numpy.floating]) -> t.Tuple[ReconsState, NDArray[numpy.floating]]: from scipy.spatial import KDTree obj_samp = sim.object.sampling - scan_flat = sim.scan.reshape(-1, 2) - scan_ndim = sim.scan.ndim - 1 + scan_flat = sim.scan.data.reshape(-1, 2) + scan_ndim = sim.scan.data.ndim - 1 attr = self.getattr_nested(sim, self.attr_path) vals = t.cast(NDArray[numpy.inexact], getattr(attr, 'data', attr)) # Extract raw array diff --git a/phaser/engines/common/simulation.py b/phaser/engines/common/simulation.py index a6e15b6..ce11bad 100644 --- a/phaser/engines/common/simulation.py +++ b/phaser/engines/common/simulation.py @@ -248,7 +248,7 @@ def cutout_group( """Returns (probe, obj) in the cutout region""" probes = state.probe.data - group_scan = state.scan[tuple(group)] + group_scan = state.scan.data[tuple(group)] group_obj = state.object.sampling.get_view_at_pos(state.object.data, group_scan, probes.shape[-2:]) # group probes in real space # shape (len(group), 1, Ny, Nx) diff --git a/phaser/engines/conventional/run.py b/phaser/engines/conventional/run.py index 2264805..ce11008 100644 --- a/phaser/engines/conventional/run.py +++ b/phaser/engines/conventional/run.py @@ -9,7 +9,6 @@ from phaser.types import process_flag, flag_any_true from ..common.simulation import SimulationState, make_propagators, GroupManager - def run_engine(args: EngineArgs, props: ConventionalEnginePlan) -> ReconsState: logger = logging.getLogger(__name__) @@ -45,7 +44,7 @@ def run_engine(args: EngineArgs, props: ConventionalEnginePlan) -> ReconsState: solver = props.solver(props) sim = solver.init(sim) - groups = GroupManager(sim.state.scan, props.grouping, props.compact, seed=seed) + groups = GroupManager(sim.state.scan.data, props.grouping, props.compact, seed=seed) calc_error_mask = mask_fraction_of_groups(len(groups), props.calc_error_fraction) @@ -73,7 +72,7 @@ def run_engine(args: EngineArgs, props: ConventionalEnginePlan) -> ReconsState: # runs rescaling sim = solver.presolve( - sim, groups.iter(sim.state.scan), + sim, groups.iter(sim.state.scan.data), patterns=patterns, pattern_mask=pattern_mask, propagators=propagators ) @@ -84,12 +83,12 @@ def run_engine(args: EngineArgs, props: ConventionalEnginePlan) -> ReconsState: for i in range(1, props.niter+1): sim.state.iter.engine_iter = i sim.state.iter.total_iter = start_i + i - + iter_update_positions = update_positions({'state': sim.state, 'niter': props.niter}) iter_shuffle_groups = shuffle_groups({'state': sim.state, 'niter': props.niter}) sim, pos_update, group_errors = solver.run_iteration( - sim, groups.iter(sim.state.scan, i, iter_shuffle_groups), + sim, groups.iter(sim.state.scan.data, i, iter_shuffle_groups), patterns=patterns, pattern_mask=pattern_mask, propagators=propagators, update_object=update_object({'state': sim.state, 'niter': props.niter}), update_probe=update_probe({'state': sim.state, 'niter': props.niter}), @@ -101,28 +100,28 @@ def run_engine(args: EngineArgs, props: ConventionalEnginePlan) -> ReconsState: assert_dtype(sim.state.object.data, cdtype) assert_dtype(sim.state.probe.data, cdtype) - sim = sim.apply_iter_constraints() - if iter_update_positions: if not position_solver: raise ValueError("Updating positions with no PositionSolver specified") # subtract mean position update pos_update -= xp.mean(pos_update, tuple(range(pos_update.ndim - 1))) - pos_update, position_solver_state = position_solver.perform_update(sim.state.scan, pos_update, position_solver_state) + pos_update, position_solver_state = position_solver.perform_update(sim.state.scan.data, pos_update, position_solver_state) # subtract mean again (this can change with momentum) pos_update -= xp.mean(pos_update, tuple(range(pos_update.ndim - 1))) pos_update_rms = float(xp.mean(xp.linalg.norm(pos_update, axis=-1, keepdims=True))) logger.info(f"Position update: mean {pos_update_rms}") - sim.state.scan += pos_update - assert_dtype(sim.state.scan, dtype) + sim.state.scan.data += pos_update + assert_dtype(sim.state.scan.data, dtype) # check positions are at least overlapping object - sim.state.object.sampling.check_scan(sim.state.scan, sim.state.probe.sampling.extent / 2.) + sim.state.object.sampling.check_scan(sim.state.scan.data, sim.state.probe.sampling.extent / 2.) progress['pos_update_rms'].iters.append(i + start_i) progress['pos_update_rms'].values.append(pos_update_rms) + sim = sim.apply_iter_constraints() + error = None if group_errors is not None and len(group_errors): error = float(to_numpy(xp.nanmean(xp.concatenate(group_errors)))) diff --git a/phaser/engines/conventional/solvers.py b/phaser/engines/conventional/solvers.py index d689d95..f3bb0b0 100644 --- a/phaser/engines/conventional/solvers.py +++ b/phaser/engines/conventional/solvers.py @@ -86,7 +86,7 @@ def run_iteration( new_obj_mag = xp.zeros_like(self.obj_mag) new_probe_mag = xp.zeros_like(self.probe_mag) - pos_update = xp.zeros_like(sim.state.scan, dtype=sim.dtype) + pos_update = xp.zeros_like(sim.state.scan.data, dtype=sim.dtype) iter_errors = [] for (group_i, (group, group_patterns)) in enumerate(stream_patterns(groups, patterns, xp=xp, @@ -345,7 +345,7 @@ def run_iteration( xp = sim.xp # TODO: ePIE position update - pos_update = xp.zeros_like(sim.state.scan) + pos_update = xp.zeros_like(sim.state.scan.data) iter_errors = [] beta_object = process_schedule(self.plan.beta_object)({'state': sim.state, 'niter': self.engine_plan.niter}) diff --git a/phaser/engines/gradient/run.py b/phaser/engines/gradient/run.py index 7482f1d..ea62d09 100644 --- a/phaser/engines/gradient/run.py +++ b/phaser/engines/gradient/run.py @@ -69,7 +69,7 @@ def process_solvers( _PATH_MAP: t.Dict[t.Tuple[str, ...], ReconsVar] = { ('object', 'data'): 'object', ('probe', 'data'): 'probe', - ('scan',): 'positions', + ('scan', 'data'): 'positions', ('tilt',): 'tilt' } @@ -114,8 +114,9 @@ def apply_update(state: ReconsState, update: t.Dict[ReconsVar, numpy.ndarray]) - # subtract mean position update xp = get_array_module(update['positions']) update['positions'] -= xp.mean(update['positions'], tuple(range(update['positions'].ndim - 1))) - - state.scan += update['positions'] + # state.scan.prev_step = state.scan.data # check that this functions correctly. positions should be only a per-iteration solver + # print(f"{state.scan.data.shape} {update['positions'].shape}") + state.scan.data += update['positions'] return state @@ -178,7 +179,7 @@ def run_engine(args: EngineArgs, props: GradientEnginePlan) -> ReconsState: } # shuffle_groups defaults to True for sparse groups, False for compact groups shuffle_groups = process_flag(props.shuffle_groups or not props.compact) - groups = GroupManager(state.scan, props.grouping, props.compact, seed) + groups = GroupManager(state.scan.data, props.grouping, props.compact, seed) observer.init_engine( state, recons_name=args['recons_name'], @@ -190,7 +191,7 @@ def run_engine(args: EngineArgs, props: GradientEnginePlan) -> ReconsState: # runs rescaling rescale_factors = [] - for (group_i, (group, group_patterns)) in enumerate(stream_patterns(groups.iter(state.scan), + for (group_i, (group, group_patterns)) in enumerate(stream_patterns(groups.iter(state.scan.data), patterns, xp=xp, buf_n=props.buffer_n_groups)): group_rescale_factors = dry_run(state, group, propagators, group_patterns, xp=xp, dtype=dtype) rescale_factors.append(group_rescale_factors) @@ -252,7 +253,7 @@ def run_engine(args: EngineArgs, props: GradientEnginePlan) -> ReconsState: for (solver, solver_state) in zip(iter_solvers, iter_solver_states) ] - for (group_i, (group, group_patterns)) in enumerate(stream_patterns(groups.iter(state.scan, i, iter_shuffle_groups), + for (group_i, (group, group_patterns)) in enumerate(stream_patterns(groups.iter(state.scan.data, i, iter_shuffle_groups), patterns, xp=xp, buf_n=props.buffer_n_groups)): (state, group_losses, iter_grads, solver_states) = run_group( state, group=group, vars=iter_vars, @@ -314,8 +315,8 @@ def run_engine(args: EngineArgs, props: GradientEnginePlan) -> ReconsState: if 'positions' in iter_vars: # check positions are at least overlapping object - state.object.sampling.check_scan(state.scan, state.probe.sampling.extent / 2.) - assert_dtype(state.scan, dtype) + state.object.sampling.check_scan(state.scan.data, state.probe.sampling.extent / 2.) + assert_dtype(state.scan.data, dtype) state.progress = progress observer.update_iteration(state, i, props.niter, losses) @@ -403,7 +404,7 @@ def run_model( ) -> t.Tuple[Float, t.Tuple[SolverStates, t.Dict[str, Float]]]: # apply vars to simulation sim = insert_vars(vars, sim, group) - group_scan = sim.scan + group_scan = sim.scan.data group_tilts = sim.tilt (ky, kx) = sim.probe.sampling.recip_grid(dtype=dtype, xp=xp) @@ -460,8 +461,8 @@ def dry_run( (ky, kx) = sim.probe.sampling.recip_grid(dtype=dtype, xp=xp) probes = sim.probe.data - group_obj = sim.object.sampling.get_view_at_pos(sim.object.data, sim.scan[tuple(group)], probes.shape[-2:]) - group_subpx_filters = fourier_shift_filter(ky, kx, sim.object.sampling.get_subpx_shifts(sim.scan[tuple(group)], probes.shape[-2:]))[:, None, ...] + group_obj = sim.object.sampling.get_view_at_pos(sim.object.data, sim.scan.data[tuple(group)], probes.shape[-2:]) + group_subpx_filters = fourier_shift_filter(ky, kx, sim.object.sampling.get_subpx_shifts(sim.scan.data[tuple(group)], probes.shape[-2:]))[:, None, ...] probes = ifft2(fft2(probes) * group_subpx_filters) def sim_slice(slice_i: int, prop: t.Optional[NDArray[numpy.complexfloating]], psi): diff --git a/phaser/execute.py b/phaser/execute.py index 593a71c..05b5cce 100644 --- a/phaser/execute.py +++ b/phaser/execute.py @@ -124,7 +124,6 @@ def _normalize_observers( return ObserverSet(obs) - def _normalize_scan_shape( patterns: Patterns, state: ReconsState ) -> t.Tuple[Patterns, ReconsState]: @@ -136,8 +135,8 @@ def _normalize_scan_shape( dimensional shape of the two. 'state.tilt' is reshaped as well. """ patterns_shape = patterns.patterns.shape[:-2] - scan_shape = state.scan.shape[:-1] - + scan_shape = state.scan.data.shape[:-1] + print('patterns shape:', patterns_shape, 'scan shape:', scan_shape) n_patterns = math.prod(patterns_shape) n_scan = math.prod(scan_shape) if n_scan != n_patterns: @@ -147,7 +146,18 @@ def _normalize_scan_shape( new_shape = scan_shape if len(scan_shape) > len(patterns_shape) else patterns_shape patterns.patterns = patterns.patterns.reshape((*new_shape, *patterns.patterns.shape[-2:])) - state.scan = state.scan.reshape((*new_shape, 2)) + state.scan.data = state.scan.data.reshape((*new_shape, 2)) + + ## TODO: check that this functions as intended, mainly when loading .h5 files and not applying dropnans + if 'raster' == state.scan.metadata.get('type'): + if 'rows' in state.scan.metadata: + if isinstance(state.scan.metadata.get('rows'), list): + state.scan.metadata['rows'] = numpy.array(state.scan.metadata.get('rows'), dtype=numpy.integer) + state.scan.metadata['rows'] = (state.scan.metadata['rows']).reshape((*new_shape, 1)) + if 'cols' in state.scan.metadata: + if isinstance(state.scan.metadata.get('cols'), list): + state.scan.metadata['cols'] = numpy.array(state.scan.metadata.get('cols'), dtype=numpy.integer) + state.scan.metadata['cols'] = (state.scan.metadata['cols']).reshape((*new_shape, 1)) if state.tilt is not None: n_tilt = math.prod(state.tilt.shape[:-1]) @@ -319,14 +329,14 @@ def initialize_reconstruction( elif tilt_hook is not None: logging.info("Initializing tilt...") tilt = pane.from_data(tilt_hook, TiltHook)( # type: ignore - {'dtype': dtype, 'xp': xp, 'shape': scan.shape[:-1]} + {'dtype': dtype, 'xp': xp, 'shape': scan.data.shape[:-1]} ) else: tilt = None obj_pad_px: float = plan.engines[0].obj_pad_px if len(plan.engines) > 0 else 5.0 # type: ignore obj_sampling = ObjectSampling.from_scan( - scan, sampling.sampling, sampling.extent / 2. + obj_pad_px * sampling.sampling + scan.data, sampling.sampling, sampling.extent / 2. + obj_pad_px * sampling.sampling ) if init_state.object is not None and plan.init.object is None: @@ -411,7 +421,7 @@ def prepare_for_engine(patterns: Patterns, state: ReconsState, xp: t.Any, engine obj_sampling = obj_sampling.with_sampling(state.probe.sampling.sampling) obj_sampling_pad = obj_sampling.expand_to_scan( - state.scan, state.probe.sampling.extent / 2. + engine.obj_pad_px * state.probe.sampling.sampling + state.scan.data, state.probe.sampling.extent / 2. + engine.obj_pad_px * state.probe.sampling.sampling ) if obj_sampling_pad != obj_sampling: @@ -447,7 +457,7 @@ def prepare_for_engine(patterns: Patterns, state: ReconsState, xp: t.Any, engine solver_vars = set(itertools.chain.from_iterable(engine.solvers.keys())) if 'tilt' in solver_vars and state.tilt is None: logging.info("Creating new, zeroed tilt map...") - state.tilt = xp.zeros_like(state.scan) + state.tilt = xp.zeros_like(state.scan.data) return patterns, state diff --git a/phaser/hooks/__init__.py b/phaser/hooks/__init__.py index f861270..cca3cb1 100644 --- a/phaser/hooks/__init__.py +++ b/phaser/hooks/__init__.py @@ -12,7 +12,7 @@ if t.TYPE_CHECKING: from phaser.utils.num import Sampling from phaser.utils.object import ObjectSampling - from ..state import ObjectState, ProbeState, ReconsState, Patterns # noqa: F401 + from ..state import ObjectState, ProbeState, ScanState, ReconsState, Patterns from ..execute import Observer @@ -136,7 +136,7 @@ class RasterScanProps(Dataclass): affine: t.Optional[t.Annotated[NDArray[numpy.floating], annotations.shape((2, 2))]] = None -class ScanHook(Hook[ScanHookArgs, NDArray[numpy.floating]]): +class ScanHook(Hook[ScanHookArgs, 'ScanState']): known = { 'raster': ('phaser.hooks.scan:raster_scan', RasterScanProps), } diff --git a/phaser/hooks/preprocessing.py b/phaser/hooks/preprocessing.py index 14ba218..41beabf 100644 --- a/phaser/hooks/preprocessing.py +++ b/phaser/hooks/preprocessing.py @@ -14,6 +14,8 @@ logger = logging.getLogger(__name__) +## TODO: actually test that cropping and dropnans work individually and together +## crop should be functional? might not be on load? probably just leave it until it becomes an issue def crop_data(raw_data: RawData, props: CropDataProps) -> RawData: if raw_data['patterns'].ndim != 4: @@ -83,9 +85,14 @@ def add_poisson_noise(raw_data: RawData, props: PoissonProps) -> RawData: def drop_nan_patterns(args: PostInitArgs, props: DropNanProps) -> t.Tuple[Patterns, ReconsState]: xp = get_array_module(args['data'].patterns) - + # flatten scan, tilt, and patterns - scan = args['state'].scan.reshape(-1, 2) + scan_pos = args['state'].scan.data.reshape(-1, 2) + prev_step = args['state'].scan.initial_scan.reshape(-1, 2) + scan_meta = args['state'].scan.metadata + scan_meta['rows'] = scan_meta['rows'].reshape(-1, 1) + scan_meta['cols'] = scan_meta['cols'].reshape(-1, 1) + tilt = None if args['state'].tilt is None else args['state'].tilt.reshape(-1, 2) patterns = args['data'].patterns.reshape(-1, *args['data'].patterns.shape[-2:]) @@ -97,11 +104,21 @@ def drop_nan_patterns(args: PostInitArgs, props: DropNanProps) -> t.Tuple[Patter logger.info(f"Dropping {n}/{patterns.shape[0]} patterns which are at least {props.threshold:.1%} NaN values") patterns = patterns[~mask] - if scan.shape[0] == xp.size(mask): + if scan_pos.shape[0] == xp.size(mask): # apply mask to scan as well - scan = scan[~mask] - elif scan.shape[0] != patterns.shape[0]: - raise ValueError(f"# of scan positions {scan.shape[0]} doesn't match # of patterns" + scan_pos = scan_pos[~mask] + prev_step = prev_step[~mask] + if scan_meta['type'] == 'raster': + if 'rows' in scan_meta: + if scan_meta['rows'].shape[:-1] != args['state'].scan.data.shape[:-1]: + raise ValueError("Scan 'rows' metadata shape doesn't match scan data shape") + scan_meta['rows'] = scan_meta['rows'][~mask] + if 'cols' in scan_meta: + if scan_meta['cols'].shape[:-1] != args['state'].scan.data.shape[:-1]: + raise ValueError("Scan 'cols' metadata shape doesn't match scan data shape") + scan_meta['cols'] = scan_meta['cols'][~mask] + elif scan_pos.shape[0] != patterns.shape[0]: + raise ValueError(f"# of scan positions {scan_pos.shape[0]} doesn't match # of patterns" f" before ({mask.size}) or after ({patterns.shape[0]}) filtering") # otherwise, we assume the mask has already been applied to the scan @@ -112,7 +129,11 @@ def drop_nan_patterns(args: PostInitArgs, props: DropNanProps) -> t.Tuple[Patter raise ValueError(f"# of tilt positions {tilt.shape[0]} doesn't match # of patterns" f" before ({mask.size}) or after ({patterns.shape[0]}) filtering") - args['state'].scan = scan + print(scan_pos.shape, scan_meta['cols'].shape) + args['state'].scan.data = scan_pos + args['state'].scan.initial_scan = prev_step ## check that copy is correct or necessary + args['state'].scan.metadata = scan_meta + args['state'].tilt = tilt args['data'].patterns = patterns diff --git a/phaser/hooks/regularization.py b/phaser/hooks/regularization.py index 6659d7d..29c787f 100644 --- a/phaser/hooks/regularization.py +++ b/phaser/hooks/regularization.py @@ -32,7 +32,12 @@ def name(self) -> str: def calc_loss_group(self, group: NDArray[numpy.integer], sim: 'ReconsState', state: StateT) -> t.Tuple['Float', StateT]: ... - +class ScanConstraintProps(Dataclass): + affine: float = 0.0 + line: float = 0.0 + hpf: float = 0.0 + lpf: float = 0.0 + class ClampObjectAmplitudeProps(Dataclass): amplitude: t.Union[float, t.List[t.Optional[float]]] = 1.1 @@ -76,6 +81,7 @@ class NonNegObjectPhaseProps(Dataclass): class IterConstraintHook(Hook[None, IterConstraint]): known = { + 'scan_constraint': ('phaser.engines.common.regularizers:ScanConstraint', ScanConstraintProps), 'clamp_object_amplitude': ('phaser.engines.common.regularizers:ClampObjectAmplitude', ClampObjectAmplitudeProps), 'limit_probe_support': ('phaser.engines.common.regularizers:LimitProbeSupport', LimitProbeSupportProps), 'layers': ('phaser.engines.common.regularizers:RegularizeLayers', RegularizeLayersProps), diff --git a/phaser/hooks/scan.py b/phaser/hooks/scan.py index a0b932f..47f08e7 100644 --- a/phaser/hooks/scan.py +++ b/phaser/hooks/scan.py @@ -6,9 +6,9 @@ from phaser.utils.num import cast_array_module from phaser.utils.scan import make_raster_scan from . import ScanHookArgs, RasterScanProps +from ..state import ScanState - -def raster_scan(args: ScanHookArgs, props: RasterScanProps) -> NDArray[numpy.floating]: +def raster_scan(args: ScanHookArgs, props: RasterScanProps) -> ScanState: xp = cast_array_module(args['xp']) logger = logging.getLogger(__name__) @@ -29,9 +29,14 @@ def raster_scan(args: ScanHookArgs, props: RasterScanProps) -> NDArray[numpy.flo f" rotation {rot:.2f} deg" f" affine transformation {affine.ravel() if affine is not None else 'None'}") - scan = make_raster_scan( + scan_pos, rows, cols = make_raster_scan( props.shape, step_size, rot, affine, dtype=args['dtype'], xp=xp, ) - - return scan + print(rows.dtype) + logger.info(f"Generated scan with {scan_pos.shape[0]}x{scan_pos.shape[1]}={scan_pos.shape[0]*scan_pos.shape[1]} positions.") + logger.info(f"Accompanying scan metadata: type 'raster', {rows.shape} rows (y), {cols.shape} cols (x).") + return ScanState(data = scan_pos, + initial_scan = scan_pos.copy(), + metadata={'type': 'raster', + 'rows': rows, 'cols': cols}) \ No newline at end of file diff --git a/phaser/plan.py b/phaser/plan.py index 8690eae..03e9a5c 100644 --- a/phaser/plan.py +++ b/phaser/plan.py @@ -11,7 +11,7 @@ SaveType: t.TypeAlias = t.Literal[ 'probe', 'probe_mag', 'probe_recip', 'probe_recip_mag', 'object_phase_stack', 'object_phase_sum', - 'object_mag_stack', 'object_mag_sum', + 'object_mag_stack', 'object_mag_sum','scan_update', 'scan', 'tilt', ] diff --git a/phaser/state.py b/phaser/state.py index 5502a74..c5ffb8d 100644 --- a/phaser/state.py +++ b/phaser/state.py @@ -1,5 +1,4 @@ import typing as t - import numpy from numpy.typing import NDArray from typing_extensions import Self @@ -13,7 +12,6 @@ from phaser.utils.image import _InterpBoundaryMode from phaser.observer import Observer, ObserverSet - @tree_dataclass class Patterns(): patterns: NDArray[numpy.floating] @@ -137,6 +135,32 @@ def copy(self) -> Self: import copy return copy.deepcopy(self) +@tree_dataclass(static_fields=('metadata',)) +class ScanState: + data: NDArray[numpy.floating] + """Scan coordinates (y, x), in length units. Shape (..., 2)""" + initial_scan: NDArray[numpy.floating] + """Initial scan coordinates (y, x), in length units. Shape (..., 2)""" + metadata: t.Union[t.Dict[str, t.Any], str] + """Metadata dictionary. Contains scan type and other info, such as rows and cols for raster scan positions""" + + def to_xp(self, xp: t.Any) -> Self: + return self.__class__( + xp.asarray(self.data), + xp.asarray(self.initial_scan), + self.metadata, + ) + + def to_numpy(self) -> Self: + return self.__class__( + to_numpy(self.data), + to_numpy(self.initial_scan), + self.metadata, + ) + + def copy(self) -> Self: + import copy + return copy.deepcopy(self) @tree_dataclass(kw_only=True, drop_fields=('progress',)) class ReconsState: @@ -145,8 +169,7 @@ class ReconsState: probe: ProbeState object: ObjectState - scan: NDArray[numpy.floating] - """Scan coordinates (y, x), in length units. Shape (..., 2)""" + scan: ScanState #NDArray[numpy.floating] tilt: t.Optional[NDArray[numpy.floating]] = None """Tilt angles (y, x) per scan position, in mrad. Shape (..., 2)""" progress: t.Dict[str, ProgressState] = field(default_factory=dict) @@ -156,7 +179,7 @@ def to_xp(self, xp: t.Any) -> Self: iter=self.iter, probe=self.probe.to_xp(xp), object=self.object.to_xp(xp), - scan=xp.asarray(self.scan), + scan=self.scan.to_xp(xp), tilt=None if self.tilt is None else xp.asarray(self.tilt), progress=self.progress, wavelength=self.wavelength, @@ -167,7 +190,7 @@ def to_numpy(self) -> Self: iter=self.iter.to_numpy(), probe=self.probe.to_numpy(), object=self.object.to_numpy(), - scan=to_numpy(self.scan), + scan=self.scan.to_numpy(), tilt=None if self.tilt is None else to_numpy(self.tilt), progress=self.progress, wavelength=float(self.wavelength), @@ -194,8 +217,7 @@ class PartialReconsState: probe: t.Optional[ProbeState] = None object: t.Optional[ObjectState] = None - scan: t.Optional[NDArray[numpy.floating]] = None - """Scan coordinates (y, x), in length units. Shape (..., 2)""" + scan: t.Optional[ScanState] = None tilt: t.Optional[NDArray[numpy.floating]] = None progress: t.Optional[t.Dict[str, ProgressState]] = None @@ -204,7 +226,7 @@ def to_numpy(self) -> Self: iter=self.iter.to_numpy() if self.iter is not None else None, probe=self.probe.to_numpy() if self.probe is not None else None, object=self.object.to_numpy() if self.object is not None else None, - scan=to_numpy(self.scan) if self.scan is not None else None, + scan=self.scan.to_numpy() if self.scan is not None else None, tilt=to_numpy(self.tilt) if self.tilt is not None else None, wavelength=float(self.wavelength) if self.wavelength is not None else None, progress=self.progress, @@ -222,7 +244,7 @@ def to_complete(self) -> ReconsState: wavelength=t.cast(Float, self.wavelength), probe=t.cast(ProbeState, self.probe), object=t.cast(ObjectState, self.object), - scan=t.cast(NDArray[numpy.floating], self.scan), + scan=t.cast(ScanState, self.scan), tilt=self.tilt, progress=progress, iter=iter, ) @@ -257,4 +279,6 @@ def with_observer(self, observer: t.Union['Observer', t.Iterable['Observer']]) - else: observers.extend(observer) - return self.__class__(self.patterns, self.state, self.name, ObserverSet(observers)) \ No newline at end of file + return self.__class__(self.patterns, self.state, self.name, ObserverSet(observers)) + + diff --git a/phaser/types.py b/phaser/types.py index c6a83b7..4f29fce 100644 --- a/phaser/types.py +++ b/phaser/types.py @@ -91,7 +91,6 @@ def __init__(self, state: 'ReconsState', continue_next_engine: bool = False): class Dataclass(pane.PaneBase, kw_only=True, allow_extra=False): ... - class SliceList(Dataclass): thicknesses: t.List[float] diff --git a/phaser/utils/io.py b/phaser/utils/io.py index 6890807..b31ae7c 100644 --- a/phaser/utils/io.py +++ b/phaser/utils/io.py @@ -1,15 +1,16 @@ import contextlib from pathlib import Path import typing as t - +import json +# import pane import numpy from numpy.typing import NDArray import h5py +from phaser.utils import tree from phaser.utils.num import Sampling, to_numpy from phaser.utils.object import ObjectSampling -from phaser.state import ReconsState, IterState, ProbeState, ObjectState, ProgressState, PartialReconsState - +from phaser.state import ReconsState, IterState, ProbeState, ScanState, ObjectState, ProgressState, PartialReconsState HdfLike: t.TypeAlias = t.Union[h5py.File, str, Path] OpenMode: t.TypeAlias = t.Literal['r', 'r+', 'w', 'w-', 'x', 'a'] @@ -106,11 +107,12 @@ def hdf5_read_state(file: HdfLike) -> PartialReconsState: probe = hdf5_read_probe_state(_assert_group(file['probe'])) if 'probe' in file else None obj = hdf5_read_object_state(_assert_group(file['object'])) if 'object' in file else None iter = hdf5_read_iter_state(_assert_group(file['iter'])) if 'iter' in file else IterState.empty() - scan = numpy.asarray(_hdf5_read_dataset(file, 'scan', numpy.float64)) if 'scan' in file else None + scan = hdf5_read_scan_state(_assert_group(file['scan'])) if 'scan' in file else None + # scan = numpy.asarray(_hdf5_read_dataset(file, 'scan', numpy.float64)) if 'scan' in file else None tilt = numpy.asarray(_hdf5_read_dataset(file, 'tilt', numpy.float64)) if 'tilt' in file else None if tilt is not None and scan is not None: - assert tilt.shape == scan.shape + assert tilt.shape == scan.data.shape progress = hdf5_read_progress_state(_assert_group(file['progress'])) if 'progress' in file else None return PartialReconsState( @@ -131,6 +133,19 @@ def hdf5_read_probe_state(group: h5py.Group) -> ProbeState: data=probes ) +# if this is called, already assumed that scan is state and contains metadata group +def hdf5_read_scan_state(group: h5py.Group) -> ScanState: + scan = _hdf5_read_dataset(group, 'data', numpy.floating) + assert scan.ndim == 2 + initial = _hdf5_read_dataset(group, 'initial', numpy.floating) + assert initial.ndim == 2 + + meta_d = json.loads(_hdf5_read_string(group, 'metadata')) + + return ScanState( + data=scan, initial_scan=initial, + metadata=meta_d + ) def hdf5_read_object_state(group: h5py.Group) -> ObjectState: obj = numpy.asarray(_hdf5_read_dataset(group, 'data', numpy.complexfloating)) @@ -186,7 +201,7 @@ def hdf5_read_progress_state(group: h5py.Group) -> t.Dict[str, ProgressState]: d[k] = ProgressState(iters.tolist(), values.tolist()) return d - + def hdf5_write_state(state: t.Union[ReconsState, PartialReconsState], file: HdfLike): file = open_hdf5(file, 'w') # overwrite if existing @@ -199,7 +214,8 @@ def hdf5_write_state(state: t.Union[ReconsState, PartialReconsState], file: HdfL if state.object is not None: hdf5_write_object_state(state.object, file.create_group("object")) if state.scan is not None: - file.create_dataset('scan', data=to_numpy(state.scan).astype(numpy.float64)) + hdf5_write_scan_state(state.scan, file.create_group("scan")) + # file.create_dataset('scan', data=to_numpy(state.scan.astype(numpy.float64))) if state.tilt is not None: file.create_dataset('tilt', data=to_numpy(state.tilt).astype(numpy.float64)) if state.iter is not None: @@ -214,10 +230,21 @@ def hdf5_write_probe_state(state: ProbeState, group: h5py.Group): dataset.dims[0].label = 'mode' dataset.dims[1].label = 'y' dataset.dims[2].label = 'x' - group.create_dataset('sampling', data=state.sampling.sampling.astype(numpy.float64)) group.create_dataset('extent', data=state.sampling.extent.astype(numpy.float64)) +def hdf5_write_scan_state(state: ScanState, group: h5py.Group): + state = state.to_numpy() + assert state.data.ndim == 2 + dataset = group.create_dataset('data', data=state.data) + dataset.dims[0].label = 'position' + dataset.dims[1].label = 'yx' + dataset = group.create_dataset('initial', data=state.initial_scan) + dataset.dims[0].label = 'position' + dataset.dims[1].label = 'yx' + meta_to_dump = tree.map(lambda v: to_numpy(v).tolist() if hasattr(v, '__array_priority__') else v, state.metadata) + + dataset = group.create_dataset('metadata', data=json.dumps(meta_to_dump)) def hdf5_write_object_state(state: ObjectState, group: h5py.Group): assert state.data.ndim == 3 diff --git a/phaser/utils/scan.py b/phaser/utils/scan.py index b559f1d..f54cffe 100644 --- a/phaser/utils/scan.py +++ b/phaser/utils/scan.py @@ -7,21 +7,24 @@ import numpy from numpy.typing import ArrayLike, DTypeLike, NDArray +from phaser.utils.tree import tree_dataclass + from .num import get_array_module, cast_array_module, NumT +## FIXME: output to Tuple? importance of array number types @t.overload def make_raster_scan(shape: t.Tuple[int, int], scan_step: ArrayLike, # pyright: ignore[reportOverlappingOverload] - rotation: float = 0., affine: t.Union[None, ArrayLike] = None, *, dtype: NumT, xp: t.Any = None) -> NDArray[NumT]: + rotation: float = 0., affine: t.Union[None, ArrayLike] = None, *, dtype: NumT, xp: t.Any = None) -> t.Tuple[NDArray[NumT], NDArray[NumT], NDArray[NumT]]: ... @t.overload def make_raster_scan(shape: t.Tuple[int, int], scan_step: ArrayLike, - rotation: float = 0., affine: t.Union[None, ArrayLike] = None, *, dtype: t.Optional[DTypeLike] = None, xp: t.Any = None) -> NDArray[numpy.floating]: + rotation: float = 0., affine: t.Union[None, ArrayLike] = None, *, dtype: t.Optional[DTypeLike] = None, xp: t.Any = None) -> t.Tuple[NDArray[numpy.floating], NDArray[numpy.integer], NDArray[numpy.integer]]: ... def make_raster_scan(shape: t.Tuple[int, int], scan_step: ArrayLike, - rotation: float = 0., affine: t.Union[None, ArrayLike] = None, *, dtype: t.Any = None, xp: t.Any = None) -> NDArray[numpy.number]: + rotation: float = 0., affine: t.Union[None, ArrayLike] = None, *, dtype: t.Any = None, xp: t.Any = None) -> t.Tuple[NDArray[numpy.number], NDArray[numpy.number], NDArray[numpy.number]]: """ Make a raster scan, centered around the origin. @@ -47,6 +50,12 @@ def make_raster_scan(shape: t.Tuple[int, int], scan_step: ArrayLike, pts = xp2.stack(xp2.meshgrid(yy, xx, indexing='ij'), axis=-1) pts *= xp2.broadcast_to(xp2.asarray(scan_step, dtype=dtype), (2,)) + yy_ind = xp2.arange(shape[0]) + xx_ind = xp2.arange(shape[1]) + grid_inds = xp2.stack(xp2.meshgrid(yy_ind, xx_ind, indexing='ij'), axis=-1) + yy_grid = grid_inds[..., 0] + xx_grid = grid_inds[..., 1] + if affine is not None: affine = xp2.asarray(affine, dtype=dtype) pts = (pts @ affine.T) @@ -56,9 +65,10 @@ def make_raster_scan(shape: t.Tuple[int, int], scan_step: ArrayLike, mat = xp2.asarray([[numpy.cos(theta), -numpy.sin(theta)], [numpy.sin(theta), numpy.cos(theta)]], dtype=dtype) pts = (pts @ mat.T) - return t.cast(NDArray[numpy.number], pts) + return t.cast(NDArray[numpy.number], pts), t.cast(NDArray[numpy.integer], yy_grid), t.cast(NDArray[numpy.integer], xx_grid) __all__ = [ 'make_raster_scan', + # 'RasterScanMetadata' ] \ No newline at end of file diff --git a/tests/test_initialization.py b/tests/test_initialization.py index b8494b4..0669d48 100644 --- a/tests/test_initialization.py +++ b/tests/test_initialization.py @@ -11,7 +11,7 @@ from phaser.hooks import RawData from phaser.plan import ReconsPlan from phaser.execute import load_raw_data, initialize_reconstruction -from phaser.state import PartialReconsState, ProbeState +from phaser.state import PartialReconsState, ProbeState, ScanState def load_empty(args, props) -> RawData: @@ -110,7 +110,10 @@ def test_load_raw_data_prev_state(caplog): } probe_state = ProbeState(Sampling((64, 64), sampling=(1.0, 1.0)), numpy.zeros((64, 64), dtype=numpy.complex64)) - scan_state = numpy.zeros((32, 32, 2)) + scan_pos = numpy.zeros((32, 32, 2)) + scan_state = ScanState(data=scan_pos, initial_scan=scan_pos, metadata={'type': 'raster', + 'rows': scan_pos.reshape(-1, 2)[:,0], + 'cols': scan_pos.reshape(-1, 2)[:,1]}) xp = numpy with caplog.at_level(logging.WARNING): @@ -133,7 +136,7 @@ def test_load_raw_data_prev_state(caplog): # probe from state overrides probe from raw data assert numpy.all(numpy.isclose(recons.state.probe.data, probe_state.data)) # but scan should be modeled - assert ~numpy.all(numpy.isclose(recons.state.scan, scan_state)) + assert ~numpy.all(numpy.isclose(recons.state.scan.data, scan_state.data)) plan['init'] = { 'scan': {}, @@ -150,7 +153,7 @@ def test_load_raw_data_prev_state(caplog): # both should be modeled assert ~numpy.all(numpy.isclose(recons.state.probe.data, probe_state.data)) - assert ~numpy.all(numpy.isclose(recons.state.scan, scan_state)) + assert ~numpy.all(numpy.isclose(recons.state.scan.data, scan_state.data)) def test_load_3d_raw_data(): @@ -180,5 +183,5 @@ def test_load_3d_raw_data(): }) recons = initialize_reconstruction(plan) - assert recons.state.scan.shape == (*scan_shape, 2) + assert recons.state.scan.data.shape == (*scan_shape, 2) assert recons.patterns.patterns.shape == (*scan_shape, *det_shape) \ No newline at end of file