From 6eaed789e1865398273f94ff6bab75df1e96ce0b Mon Sep 17 00:00:00 2001 From: michaelxu01 Date: Thu, 8 Jan 2026 11:40:34 -0500 Subject: [PATCH 01/19] untested refactor, conventional still needs fix, basic hook and for scan constraint --- phaser/engines/common/position_correction.py | 4 +- phaser/engines/common/regularizers.py | 69 +++++++++--- phaser/engines/common/simulation.py | 2 +- phaser/engines/conventional/run.py | 10 +- phaser/engines/gradient/run.py | 20 ++-- phaser/execute.py | 16 +-- phaser/hooks/__init__.py | 4 +- phaser/hooks/preprocessing.py | 28 +++-- phaser/hooks/regularization.py | 6 +- phaser/hooks/scan.py | 12 ++- phaser/state.py | 39 +++++-- phaser/utils/io.py | 34 +++++- phaser/utils/scan.py | 104 ++++++++++++++++++- 13 files changed, 282 insertions(+), 66 deletions(-) diff --git a/phaser/engines/common/position_correction.py b/phaser/engines/common/position_correction.py index 2f0809f..724540a 100644 --- a/phaser/engines/common/position_correction.py +++ b/phaser/engines/common/position_correction.py @@ -41,8 +41,8 @@ def __init__(self, args: None, props: MomentumPositionSolverProps): self.momentum = props.momentum def init_state(self, sim: ReconsState) -> NDArray[numpy.floating]: - xp = get_array_module(sim.scan) - return xp.zeros_like(sim.scan) + xp = get_array_module(sim.scan.data) + return xp.zeros_like(sim.scan.data) def perform_update( self, diff --git a/phaser/engines/common/regularizers.py b/phaser/engines/common/regularizers.py index 416186e..5e9c6c6 100644 --- a/phaser/engines/common/regularizers.py +++ b/phaser/engines/common/regularizers.py @@ -11,12 +11,57 @@ ) from phaser.utils.image import convolve1d from phaser.state import ReconsState -from phaser.hooks.regularization import ( +from phaser.hooks.regularization import (ScanConstraintProps, ClampObjectAmplitudeProps, LimitProbeSupportProps, NonNegObjectPhaseProps, RegularizeLayersProps, ObjLowPassProps, GaussianProps, CostRegularizerProps, TVRegularizerProps, UnstructuredGaussianProps ) +class ScanConstraint: + def __init__(self, args: None, props: ScanConstraintProps): + self.min: t.Optional[float] + self.max: t.Optional[float] + + if isinstance(props.amplitude, list): + self.min, self.max = props.amplitude + else: + self.min = None + self.max = props.amplitude + + def init_state(self, sim: ReconsState) -> None: + return None + + def apply_group(self, group: NDArray[numpy.integer], sim: ReconsState, state: None) -> t.Tuple[ReconsState, None]: + return self.apply_iter(sim, state) + + def apply_iter(self, sim: ReconsState, state: None) -> t.Tuple[ReconsState, None]: + cast = to_real_dtype(sim.object.data.dtype) + sim.object.data = clamp_amplitude(sim.object.data, None if self.min is None else cast(self.min), None if self.max is None else cast(self.max)) + return (sim, None) + + +@partial(jit, donate_argnames=('obj',), cupy_fuse=True) +def scan_affine( + obj: NDArray[numpy.complexfloating], + min: t.Union[float, numpy.floating, None], + max: t.Union[float, numpy.floating, None] +) -> NDArray[numpy.complexfloating]: + xp = get_array_module(obj) + + obj_amp = xp.abs(obj) + new_amp = obj_amp + + if min is not None and max is not None: + new_amp = xp.clip(new_amp, min, max) + elif min is not None: + new_amp = xp.maximum(new_amp, min) + elif max is not None: + new_amp = xp.minimum(new_amp, max) + else: + return obj + + scale = xp.where(obj_amp > 0, new_amp / obj_amp, 0.0) #no divide by 0 + return obj * scale class ClampObjectAmplitude: def __init__(self, args: None, props: ClampObjectAmplitudeProps): @@ -250,7 +295,7 @@ def calc_loss_group( xp = get_array_module(sim.object.data) cost = xp.sum(xp.abs(sim.object.data - 1.0)) - cost_scale = xp.array(group.shape[-1] / prod(sim.scan.shape[:-1]), dtype=cost.dtype) + cost_scale = xp.array(group.shape[-1] / prod(sim.scan.data.shape[:-1]), dtype=cost.dtype) return (cost * cost_scale * self.cost, state) @@ -272,7 +317,7 @@ def calc_loss_group( cost = xp.sum(abs2(sim.object.data - 1.0)) - cost_scale = xp.array(group.shape[-1] / prod(sim.scan.shape[:-1]), dtype=cost.dtype) + cost_scale = xp.array(group.shape[-1] / prod(sim.scan.data.shape[:-1]), dtype=cost.dtype) return (cost * cost_scale * self.cost, state) # type: ignore @@ -293,7 +338,7 @@ def calc_loss_group( xp = get_array_module(sim.object.data) cost = xp.sum(xp.abs(xp.angle(sim.object.data))) - cost_scale = xp.array(group.shape[-1] / prod(sim.scan.shape[:-1]), dtype=cost.dtype) + cost_scale = xp.array(group.shape[-1] / prod(sim.scan.data.shape[:-1]), dtype=cost.dtype) return (cost * cost_scale * self.cost, state) @@ -319,7 +364,7 @@ def calc_loss_group( xp.abs(fft2(xp.prod(sim.object.data, axis=0))) ) # scale cost by fraction of the total reconstruction in the group - cost_scale = xp.array(group.shape[-1] / prod(sim.scan.shape[:-1]), dtype=cost.dtype) + cost_scale = xp.array(group.shape[-1] / prod(sim.scan.data.shape[:-1]), dtype=cost.dtype) return (cost * cost_scale * self.cost, state) @@ -351,7 +396,7 @@ def calc_loss_group( #) # scale cost by fraction of the total reconstruction in the group # TODO also scale by # of pixels or similar? - cost_scale = xp.array(group.shape[-1] / prod(sim.scan.shape[:-1]), dtype=cost.dtype) + cost_scale = xp.array(group.shape[-1] / prod(sim.scan.data.shape[:-1]), dtype=cost.dtype) return (cost * cost_scale * self.cost, state) @@ -377,7 +422,7 @@ def calc_loss_group( xp.sum(abs2(xp.diff(sim.object.data, axis=-2))) ) # scale cost by fraction of the total reconstruction in the group - cost_scale = xp.array(group.shape[-1] / prod(sim.scan.shape[:-1]), dtype=cost.dtype) + cost_scale = xp.array(group.shape[-1] / prod(sim.scan.data.shape[:-1]), dtype=cost.dtype) return (cost * cost_scale * self.cost, state) # type: ignore @@ -403,7 +448,7 @@ def calc_loss_group( cost = xp.sum(xp.abs(xp.diff(sim.object.data, axis=0))) # scale cost by fraction of the total reconstruction in the group - cost_scale = xp.array(group.shape[-1] / prod(sim.scan.shape[:-1]), dtype=cost.dtype) + cost_scale = xp.array(group.shape[-1] / prod(sim.scan.data.shape[:-1]), dtype=cost.dtype) return (cost * cost_scale * self.cost, state) @@ -429,7 +474,7 @@ def calc_loss_group( cost = xp.sum(abs2(xp.diff(sim.object.data, axis=0))) # scale cost by fraction of the total reconstruction in the group - cost_scale = xp.array(group.shape[-1] / prod(sim.scan.shape[:-1]), dtype=cost.dtype) + cost_scale = xp.array(group.shape[-1] / prod(sim.scan.data.shape[:-1]), dtype=cost.dtype) return (cost * cost_scale * self.cost, state) # type: ignore @@ -519,7 +564,7 @@ def __init__(self, args: None, props: UnstructuredGaussianProps): self.attr_path = props.attr_path def init_state(self, sim: ReconsState) -> NDArray[numpy.floating]: - xp = get_array_module(sim.scan) + xp = get_array_module(sim.scan.data) try: self.getattr_nested(sim, self.attr_path) except AttributeError as e: @@ -547,8 +592,8 @@ def setattr_nested(self, obj: t.Any, attr_path: str, value: t.Any): def apply_iter(self, sim: ReconsState, state: NDArray[numpy.floating]) -> t.Tuple[ReconsState, NDArray[numpy.floating]]: from scipy.spatial import KDTree obj_samp = sim.object.sampling - scan_flat = sim.scan.reshape(-1, 2) - scan_ndim = sim.scan.ndim - 1 + scan_flat = sim.scan.data.reshape(-1, 2) + scan_ndim = sim.scan.data.ndim - 1 attr = self.getattr_nested(sim, self.attr_path) vals = t.cast(NDArray[numpy.inexact], getattr(attr, 'data', attr)) # Extract raw array diff --git a/phaser/engines/common/simulation.py b/phaser/engines/common/simulation.py index a6e15b6..ce11bad 100644 --- a/phaser/engines/common/simulation.py +++ b/phaser/engines/common/simulation.py @@ -248,7 +248,7 @@ def cutout_group( """Returns (probe, obj) in the cutout region""" probes = state.probe.data - group_scan = state.scan[tuple(group)] + group_scan = state.scan.data[tuple(group)] group_obj = state.object.sampling.get_view_at_pos(state.object.data, group_scan, probes.shape[-2:]) # group probes in real space # shape (len(group), 1, Ny, Nx) diff --git a/phaser/engines/conventional/run.py b/phaser/engines/conventional/run.py index 2264805..195b001 100644 --- a/phaser/engines/conventional/run.py +++ b/phaser/engines/conventional/run.py @@ -8,7 +8,7 @@ from phaser.state import ReconsState, ProgressState from phaser.types import process_flag, flag_any_true from ..common.simulation import SimulationState, make_propagators, GroupManager - +## FIXME: not yet updated for conventional engines def run_engine(args: EngineArgs, props: ConventionalEnginePlan) -> ReconsState: logger = logging.getLogger(__name__) @@ -45,7 +45,7 @@ def run_engine(args: EngineArgs, props: ConventionalEnginePlan) -> ReconsState: solver = props.solver(props) sim = solver.init(sim) - groups = GroupManager(sim.state.scan, props.grouping, props.compact, seed=seed) + groups = GroupManager(sim.state.scan.data, props.grouping, props.compact, seed=seed) calc_error_mask = mask_fraction_of_groups(len(groups), props.calc_error_fraction) @@ -73,7 +73,7 @@ def run_engine(args: EngineArgs, props: ConventionalEnginePlan) -> ReconsState: # runs rescaling sim = solver.presolve( - sim, groups.iter(sim.state.scan), + sim, groups.iter(sim.state.scan.data), patterns=patterns, pattern_mask=pattern_mask, propagators=propagators ) @@ -84,12 +84,12 @@ def run_engine(args: EngineArgs, props: ConventionalEnginePlan) -> ReconsState: for i in range(1, props.niter+1): sim.state.iter.engine_iter = i sim.state.iter.total_iter = start_i + i - + iter_update_positions = update_positions({'state': sim.state, 'niter': props.niter}) iter_shuffle_groups = shuffle_groups({'state': sim.state, 'niter': props.niter}) sim, pos_update, group_errors = solver.run_iteration( - sim, groups.iter(sim.state.scan, i, iter_shuffle_groups), + sim, groups.iter(sim.state.scan.data, i, iter_shuffle_groups), patterns=patterns, pattern_mask=pattern_mask, propagators=propagators, update_object=update_object({'state': sim.state, 'niter': props.niter}), update_probe=update_probe({'state': sim.state, 'niter': props.niter}), diff --git a/phaser/engines/gradient/run.py b/phaser/engines/gradient/run.py index 7482f1d..41a36a5 100644 --- a/phaser/engines/gradient/run.py +++ b/phaser/engines/gradient/run.py @@ -114,8 +114,8 @@ def apply_update(state: ReconsState, update: t.Dict[ReconsVar, numpy.ndarray]) - # subtract mean position update xp = get_array_module(update['positions']) update['positions'] -= xp.mean(update['positions'], tuple(range(update['positions'].ndim - 1))) - - state.scan += update['positions'] + print(f"{state.scan.data.shape} {update['positions'].shape}") + state.scan.data += update['positions'] return state @@ -178,7 +178,7 @@ def run_engine(args: EngineArgs, props: GradientEnginePlan) -> ReconsState: } # shuffle_groups defaults to True for sparse groups, False for compact groups shuffle_groups = process_flag(props.shuffle_groups or not props.compact) - groups = GroupManager(state.scan, props.grouping, props.compact, seed) + groups = GroupManager(state.scan.data, props.grouping, props.compact, seed) observer.init_engine( state, recons_name=args['recons_name'], @@ -190,7 +190,7 @@ def run_engine(args: EngineArgs, props: GradientEnginePlan) -> ReconsState: # runs rescaling rescale_factors = [] - for (group_i, (group, group_patterns)) in enumerate(stream_patterns(groups.iter(state.scan), + for (group_i, (group, group_patterns)) in enumerate(stream_patterns(groups.iter(state.scan.data), patterns, xp=xp, buf_n=props.buffer_n_groups)): group_rescale_factors = dry_run(state, group, propagators, group_patterns, xp=xp, dtype=dtype) rescale_factors.append(group_rescale_factors) @@ -252,7 +252,7 @@ def run_engine(args: EngineArgs, props: GradientEnginePlan) -> ReconsState: for (solver, solver_state) in zip(iter_solvers, iter_solver_states) ] - for (group_i, (group, group_patterns)) in enumerate(stream_patterns(groups.iter(state.scan, i, iter_shuffle_groups), + for (group_i, (group, group_patterns)) in enumerate(stream_patterns(groups.iter(state.scan.data, i, iter_shuffle_groups), patterns, xp=xp, buf_n=props.buffer_n_groups)): (state, group_losses, iter_grads, solver_states) = run_group( state, group=group, vars=iter_vars, @@ -314,8 +314,8 @@ def run_engine(args: EngineArgs, props: GradientEnginePlan) -> ReconsState: if 'positions' in iter_vars: # check positions are at least overlapping object - state.object.sampling.check_scan(state.scan, state.probe.sampling.extent / 2.) - assert_dtype(state.scan, dtype) + state.object.sampling.check_scan(state.scan.data, state.probe.sampling.extent / 2.) + assert_dtype(state.scan.data, dtype) state.progress = progress observer.update_iteration(state, i, props.niter, losses) @@ -403,7 +403,7 @@ def run_model( ) -> t.Tuple[Float, t.Tuple[SolverStates, t.Dict[str, Float]]]: # apply vars to simulation sim = insert_vars(vars, sim, group) - group_scan = sim.scan + group_scan = sim.scan.data group_tilts = sim.tilt (ky, kx) = sim.probe.sampling.recip_grid(dtype=dtype, xp=xp) @@ -460,8 +460,8 @@ def dry_run( (ky, kx) = sim.probe.sampling.recip_grid(dtype=dtype, xp=xp) probes = sim.probe.data - group_obj = sim.object.sampling.get_view_at_pos(sim.object.data, sim.scan[tuple(group)], probes.shape[-2:]) - group_subpx_filters = fourier_shift_filter(ky, kx, sim.object.sampling.get_subpx_shifts(sim.scan[tuple(group)], probes.shape[-2:]))[:, None, ...] + group_obj = sim.object.sampling.get_view_at_pos(sim.object.data, sim.scan.data[tuple(group)], probes.shape[-2:]) + group_subpx_filters = fourier_shift_filter(ky, kx, sim.object.sampling.get_subpx_shifts(sim.scan.data[tuple(group)], probes.shape[-2:]))[:, None, ...] probes = ifft2(fft2(probes) * group_subpx_filters) def sim_slice(slice_i: int, prop: t.Optional[NDArray[numpy.complexfloating]], psi): diff --git a/phaser/execute.py b/phaser/execute.py index 593a71c..4de760a 100644 --- a/phaser/execute.py +++ b/phaser/execute.py @@ -124,7 +124,7 @@ def _normalize_observers( return ObserverSet(obs) - +## FIXME: the scan normalization here - happens before dropnans and scan data flattening, but may alter shape and therefore rows/cols? why is this needed def _normalize_scan_shape( patterns: Patterns, state: ReconsState ) -> t.Tuple[Patterns, ReconsState]: @@ -136,8 +136,8 @@ def _normalize_scan_shape( dimensional shape of the two. 'state.tilt' is reshaped as well. """ patterns_shape = patterns.patterns.shape[:-2] - scan_shape = state.scan.shape[:-1] - + scan_shape = state.scan.data.shape[:-1] + print('patterns shape:', patterns_shape, 'scan shape:', scan_shape) n_patterns = math.prod(patterns_shape) n_scan = math.prod(scan_shape) if n_scan != n_patterns: @@ -147,7 +147,7 @@ def _normalize_scan_shape( new_shape = scan_shape if len(scan_shape) > len(patterns_shape) else patterns_shape patterns.patterns = patterns.patterns.reshape((*new_shape, *patterns.patterns.shape[-2:])) - state.scan = state.scan.reshape((*new_shape, 2)) + state.scan.data = state.scan.data.reshape((*new_shape, 2)) if state.tilt is not None: n_tilt = math.prod(state.tilt.shape[:-1]) @@ -319,14 +319,14 @@ def initialize_reconstruction( elif tilt_hook is not None: logging.info("Initializing tilt...") tilt = pane.from_data(tilt_hook, TiltHook)( # type: ignore - {'dtype': dtype, 'xp': xp, 'shape': scan.shape[:-1]} + {'dtype': dtype, 'xp': xp, 'shape': scan.data.shape[:-1]} ) else: tilt = None obj_pad_px: float = plan.engines[0].obj_pad_px if len(plan.engines) > 0 else 5.0 # type: ignore obj_sampling = ObjectSampling.from_scan( - scan, sampling.sampling, sampling.extent / 2. + obj_pad_px * sampling.sampling + scan.data, sampling.sampling, sampling.extent / 2. + obj_pad_px * sampling.sampling ) if init_state.object is not None and plan.init.object is None: @@ -411,7 +411,7 @@ def prepare_for_engine(patterns: Patterns, state: ReconsState, xp: t.Any, engine obj_sampling = obj_sampling.with_sampling(state.probe.sampling.sampling) obj_sampling_pad = obj_sampling.expand_to_scan( - state.scan, state.probe.sampling.extent / 2. + engine.obj_pad_px * state.probe.sampling.sampling + state.scan.data, state.probe.sampling.extent / 2. + engine.obj_pad_px * state.probe.sampling.sampling ) if obj_sampling_pad != obj_sampling: @@ -447,7 +447,7 @@ def prepare_for_engine(patterns: Patterns, state: ReconsState, xp: t.Any, engine solver_vars = set(itertools.chain.from_iterable(engine.solvers.keys())) if 'tilt' in solver_vars and state.tilt is None: logging.info("Creating new, zeroed tilt map...") - state.tilt = xp.zeros_like(state.scan) + state.tilt = xp.zeros_like(state.scan.data) return patterns, state diff --git a/phaser/hooks/__init__.py b/phaser/hooks/__init__.py index f861270..f8ebb00 100644 --- a/phaser/hooks/__init__.py +++ b/phaser/hooks/__init__.py @@ -12,7 +12,7 @@ if t.TYPE_CHECKING: from phaser.utils.num import Sampling from phaser.utils.object import ObjectSampling - from ..state import ObjectState, ProbeState, ReconsState, Patterns # noqa: F401 + from ..state import ObjectState, ProbeState, ScanState, ReconsState, Patterns from ..execute import Observer @@ -136,7 +136,7 @@ class RasterScanProps(Dataclass): affine: t.Optional[t.Annotated[NDArray[numpy.floating], annotations.shape((2, 2))]] = None -class ScanHook(Hook[ScanHookArgs, NDArray[numpy.floating]]): +class ScanHook(Hook[ScanHookArgs, ScanState]): known = { 'raster': ('phaser.hooks.scan:raster_scan', RasterScanProps), } diff --git a/phaser/hooks/preprocessing.py b/phaser/hooks/preprocessing.py index 14ba218..2a16338 100644 --- a/phaser/hooks/preprocessing.py +++ b/phaser/hooks/preprocessing.py @@ -14,6 +14,8 @@ logger = logging.getLogger(__name__) +## FIXME: the scan flattening is done here, but doesn't safely handle any row or column metadata. +## should the scan be flattened when initialized and metadata generated in raster scan hook flow? def crop_data(raw_data: RawData, props: CropDataProps) -> RawData: if raw_data['patterns'].ndim != 4: @@ -83,9 +85,11 @@ def add_poisson_noise(raw_data: RawData, props: PoissonProps) -> RawData: def drop_nan_patterns(args: PostInitArgs, props: DropNanProps) -> t.Tuple[Patterns, ReconsState]: xp = get_array_module(args['data'].patterns) - + ## FIXME: should the scan be flattened to begin with? is there any situation where we don't want that? + # flatten scan, tilt, and patterns - scan = args['state'].scan.reshape(-1, 2) + scan_pos = args['state'].scan.data.reshape(-1, 2) + scan_meta = args['state'].scan.metadata tilt = None if args['state'].tilt is None else args['state'].tilt.reshape(-1, 2) patterns = args['data'].patterns.reshape(-1, *args['data'].patterns.shape[-2:]) @@ -97,11 +101,19 @@ def drop_nan_patterns(args: PostInitArgs, props: DropNanProps) -> t.Tuple[Patter logger.info(f"Dropping {n}/{patterns.shape[0]} patterns which are at least {props.threshold:.1%} NaN values") patterns = patterns[~mask] - if scan.shape[0] == xp.size(mask): + if scan_pos.shape[0] == xp.size(mask): # apply mask to scan as well - scan = scan[~mask] - elif scan.shape[0] != patterns.shape[0]: - raise ValueError(f"# of scan positions {scan.shape[0]} doesn't match # of patterns" + scan_pos = scan_pos[~mask] + if scan_meta['type'] == 'raster': + rows = scan_meta['rows'].reshape(-1, 1) + cols = scan_meta['cols'].reshape(-1, 1) + scan_meta['rows'] = rows[~mask] + scan_meta['cols'] = cols[~mask] + assert scan_meta['rows'].shape[0] == scan_pos.shape[0], f"After filtering, # of scan positions {scan_pos.shape[0]} doesn't match # of rows/cols {scan_meta['rows'].shape[0]}" + else: + logger.info("Not raster scan, not updating scan metadata rows/cols") + elif scan_pos.shape[0] != patterns.shape[0]: + raise ValueError(f"# of scan positions {scan_pos.shape[0]} doesn't match # of patterns" f" before ({mask.size}) or after ({patterns.shape[0]}) filtering") # otherwise, we assume the mask has already been applied to the scan @@ -112,7 +124,9 @@ def drop_nan_patterns(args: PostInitArgs, props: DropNanProps) -> t.Tuple[Patter raise ValueError(f"# of tilt positions {tilt.shape[0]} doesn't match # of patterns" f" before ({mask.size}) or after ({patterns.shape[0]}) filtering") - args['state'].scan = scan + args['state'].scan.data = scan_pos + args['state'].scan.metadata = scan_meta + args['state'].tilt = tilt args['data'].patterns = patterns diff --git a/phaser/hooks/regularization.py b/phaser/hooks/regularization.py index 6659d7d..d125fff 100644 --- a/phaser/hooks/regularization.py +++ b/phaser/hooks/regularization.py @@ -32,7 +32,10 @@ def name(self) -> str: def calc_loss_group(self, group: NDArray[numpy.integer], sim: 'ReconsState', state: StateT) -> t.Tuple['Float', StateT]: ... - +class ScanConstraintProps(Dataclass): + type: t.Literal['affine', 'line'] = 'affine' + weight: float = 1.0 + class ClampObjectAmplitudeProps(Dataclass): amplitude: t.Union[float, t.List[t.Optional[float]]] = 1.1 @@ -76,6 +79,7 @@ class NonNegObjectPhaseProps(Dataclass): class IterConstraintHook(Hook[None, IterConstraint]): known = { + 'scan_constraint': ('phaser.engines.common.regularizers:ScanConstraint', ScanConstraintProps), 'clamp_object_amplitude': ('phaser.engines.common.regularizers:ClampObjectAmplitude', ClampObjectAmplitudeProps), 'limit_probe_support': ('phaser.engines.common.regularizers:LimitProbeSupport', LimitProbeSupportProps), 'layers': ('phaser.engines.common.regularizers:RegularizeLayers', RegularizeLayersProps), diff --git a/phaser/hooks/scan.py b/phaser/hooks/scan.py index a0b932f..960fc61 100644 --- a/phaser/hooks/scan.py +++ b/phaser/hooks/scan.py @@ -6,9 +6,9 @@ from phaser.utils.num import cast_array_module from phaser.utils.scan import make_raster_scan from . import ScanHookArgs, RasterScanProps +from ..state import ScanState - -def raster_scan(args: ScanHookArgs, props: RasterScanProps) -> NDArray[numpy.floating]: +def raster_scan(args: ScanHookArgs, props: RasterScanProps) -> ScanState: xp = cast_array_module(args['xp']) logger = logging.getLogger(__name__) @@ -29,9 +29,13 @@ def raster_scan(args: ScanHookArgs, props: RasterScanProps) -> NDArray[numpy.flo f" rotation {rot:.2f} deg" f" affine transformation {affine.ravel() if affine is not None else 'None'}") - scan = make_raster_scan( + scan_pos, rows, cols = make_raster_scan( props.shape, step_size, rot, affine, dtype=args['dtype'], xp=xp, ) - return scan + logger.info(f"Generated scan with {scan_pos.shape[0]}x{scan_pos.shape[1]}={scan_pos.shape[0]*scan_pos.shape[1]} positions.") + logger.info(f"Accompanying scan metadata: type 'raster', {rows.shape} rows (y), {cols.shape} cols (x).") + return ScanState(data = scan_pos, + metadata={'type': 'raster', + 'rows': rows, 'cols': cols}) \ No newline at end of file diff --git a/phaser/state.py b/phaser/state.py index 5502a74..ead495f 100644 --- a/phaser/state.py +++ b/phaser/state.py @@ -13,6 +13,7 @@ from phaser.utils.image import _InterpBoundaryMode from phaser.observer import Observer, ObserverSet +## FIXME: scan metadata format not yet finalized, but scanstate object and references here are updated @tree_dataclass class Patterns(): @@ -137,6 +138,32 @@ def copy(self) -> Self: import copy return copy.deepcopy(self) +@tree_dataclass +class ScanState(): + # sampling: ObjectSampling + # """Object coordinate system. See `ObjectSampling` for more details.""" + data: NDArray[numpy.floating] + """Scan coordinates (y, x), in length units. Shape (..., 2)""" + metadata: t.Dict[str, t.Any] + """Scan row positions (y), in length units. Shape (...)""" + # cols: NDArray[numpy.floating] + # """Scan column positions (x), in length units. Shape (...)""" + + def to_xp(self, xp: t.Any) -> Self: + return self.__class__( + xp.asarray(self.data), + self.metadata, + ) + + def to_numpy(self) -> Self: + return self.__class__( + to_numpy(self.data), + self.metadata, + ) + + def copy(self) -> Self: + import copy + return copy.deepcopy(self) @tree_dataclass(kw_only=True, drop_fields=('progress',)) class ReconsState: @@ -145,7 +172,7 @@ class ReconsState: probe: ProbeState object: ObjectState - scan: NDArray[numpy.floating] + scan: ScanState #NDArray[numpy.floating] """Scan coordinates (y, x), in length units. Shape (..., 2)""" tilt: t.Optional[NDArray[numpy.floating]] = None """Tilt angles (y, x) per scan position, in mrad. Shape (..., 2)""" @@ -156,7 +183,7 @@ def to_xp(self, xp: t.Any) -> Self: iter=self.iter, probe=self.probe.to_xp(xp), object=self.object.to_xp(xp), - scan=xp.asarray(self.scan), + scan=self.scan.to_xp(xp), tilt=None if self.tilt is None else xp.asarray(self.tilt), progress=self.progress, wavelength=self.wavelength, @@ -167,7 +194,7 @@ def to_numpy(self) -> Self: iter=self.iter.to_numpy(), probe=self.probe.to_numpy(), object=self.object.to_numpy(), - scan=to_numpy(self.scan), + scan=self.scan.to_numpy(), tilt=None if self.tilt is None else to_numpy(self.tilt), progress=self.progress, wavelength=float(self.wavelength), @@ -194,7 +221,7 @@ class PartialReconsState: probe: t.Optional[ProbeState] = None object: t.Optional[ObjectState] = None - scan: t.Optional[NDArray[numpy.floating]] = None + scan: t.Optional[ScanState] = None """Scan coordinates (y, x), in length units. Shape (..., 2)""" tilt: t.Optional[NDArray[numpy.floating]] = None progress: t.Optional[t.Dict[str, ProgressState]] = None @@ -204,7 +231,7 @@ def to_numpy(self) -> Self: iter=self.iter.to_numpy() if self.iter is not None else None, probe=self.probe.to_numpy() if self.probe is not None else None, object=self.object.to_numpy() if self.object is not None else None, - scan=to_numpy(self.scan) if self.scan is not None else None, + scan=self.scan.to_numpy() if self.scan is not None else None, tilt=to_numpy(self.tilt) if self.tilt is not None else None, wavelength=float(self.wavelength) if self.wavelength is not None else None, progress=self.progress, @@ -222,7 +249,7 @@ def to_complete(self) -> ReconsState: wavelength=t.cast(Float, self.wavelength), probe=t.cast(ProbeState, self.probe), object=t.cast(ObjectState, self.object), - scan=t.cast(NDArray[numpy.floating], self.scan), + scan=t.cast(ScanState, self.scan), tilt=self.tilt, progress=progress, iter=iter, ) diff --git a/phaser/utils/io.py b/phaser/utils/io.py index 6890807..871d0fb 100644 --- a/phaser/utils/io.py +++ b/phaser/utils/io.py @@ -8,8 +8,9 @@ from phaser.utils.num import Sampling, to_numpy from phaser.utils.object import ObjectSampling -from phaser.state import ReconsState, IterState, ProbeState, ObjectState, ProgressState, PartialReconsState +from phaser.state import ReconsState, IterState, ProbeState, ScanState, ObjectState, ProgressState, PartialReconsState +## FIXME: scan metadata format not yet finalized HdfLike: t.TypeAlias = t.Union[h5py.File, str, Path] OpenMode: t.TypeAlias = t.Literal['r', 'r+', 'w', 'w-', 'x', 'a'] @@ -106,7 +107,8 @@ def hdf5_read_state(file: HdfLike) -> PartialReconsState: probe = hdf5_read_probe_state(_assert_group(file['probe'])) if 'probe' in file else None obj = hdf5_read_object_state(_assert_group(file['object'])) if 'object' in file else None iter = hdf5_read_iter_state(_assert_group(file['iter'])) if 'iter' in file else IterState.empty() - scan = numpy.asarray(_hdf5_read_dataset(file, 'scan', numpy.float64)) if 'scan' in file else None + scan = hdf5_read_scan_state(_assert_group(file['scan'])) if 'scan' in file else None + # scan = numpy.asarray(_hdf5_read_dataset(file, 'scan', numpy.float64)) if 'scan' in file else None tilt = numpy.asarray(_hdf5_read_dataset(file, 'tilt', numpy.float64)) if 'tilt' in file else None if tilt is not None and scan is not None: @@ -131,6 +133,15 @@ def hdf5_read_probe_state(group: h5py.Group) -> ProbeState: data=probes ) +def hdf5_read_scan_state(group: h5py.Group) -> ScanState: + scan = _hdf5_read_dataset(group, 'data', numpy.floating) + assert scan.ndim == 2 + + ## TODO: read metadata properly from whatever write_scan_state does + return ScanState( + data=scan, + metadata={} + ) def hdf5_read_object_state(group: h5py.Group) -> ObjectState: obj = numpy.asarray(_hdf5_read_dataset(group, 'data', numpy.complexfloating)) @@ -199,7 +210,8 @@ def hdf5_write_state(state: t.Union[ReconsState, PartialReconsState], file: HdfL if state.object is not None: hdf5_write_object_state(state.object, file.create_group("object")) if state.scan is not None: - file.create_dataset('scan', data=to_numpy(state.scan).astype(numpy.float64)) + hdf5_write_scan_state(state.scan, file.create_group("scan")) + # file.create_dataset('scan', data=to_numpy(state.scan.astype(numpy.float64))) if state.tilt is not None: file.create_dataset('tilt', data=to_numpy(state.tilt).astype(numpy.float64)) if state.iter is not None: @@ -214,10 +226,24 @@ def hdf5_write_probe_state(state: ProbeState, group: h5py.Group): dataset.dims[0].label = 'mode' dataset.dims[1].label = 'y' dataset.dims[2].label = 'x' - group.create_dataset('sampling', data=state.sampling.sampling.astype(numpy.float64)) group.create_dataset('extent', data=state.sampling.extent.astype(numpy.float64)) +def hdf5_write_scan_state(state: ScanState, group: h5py.Group): + assert state.data.ndim == 2 + dataset = group.create_dataset('data', data=to_numpy(state.data)) + dataset.dims[0].label = 'position' + dataset.dims[1].label = 'yx' + + metagroup = group.require_group('metadata') + for (k, v) in state.metadata.items(): + metasubgroup = metagroup.require_group(k) + ##TODO: directly dump all metadata keys as json? or subclass as rasterscanstate + metakey = metasubgroup.create_dataset("data", data=numpy.array(v.iters, dtype=numpy.int64)) + + # group.create_dataset('sampling', data=state.sampling.sampling.astype(numpy.float64)) + # group.create_dataset('extent', data=state.sampling.extent.astype(numpy.float64)) + def hdf5_write_object_state(state: ObjectState, group: h5py.Group): assert state.data.ndim == 3 diff --git a/phaser/utils/scan.py b/phaser/utils/scan.py index b559f1d..41e1038 100644 --- a/phaser/utils/scan.py +++ b/phaser/utils/scan.py @@ -7,21 +7,24 @@ import numpy from numpy.typing import ArrayLike, DTypeLike, NDArray +from phaser.utils.tree import tree_dataclass + from .num import get_array_module, cast_array_module, NumT +## FIXME: output to Tuple? importance of array number types @t.overload def make_raster_scan(shape: t.Tuple[int, int], scan_step: ArrayLike, # pyright: ignore[reportOverlappingOverload] - rotation: float = 0., affine: t.Union[None, ArrayLike] = None, *, dtype: NumT, xp: t.Any = None) -> NDArray[NumT]: + rotation: float = 0., affine: t.Union[None, ArrayLike] = None, *, dtype: NumT, xp: t.Any = None) -> t.Tuple[NDArray[NumT], NDArray[NumT], NDArray[NumT]]: ... @t.overload def make_raster_scan(shape: t.Tuple[int, int], scan_step: ArrayLike, - rotation: float = 0., affine: t.Union[None, ArrayLike] = None, *, dtype: t.Optional[DTypeLike] = None, xp: t.Any = None) -> NDArray[numpy.floating]: + rotation: float = 0., affine: t.Union[None, ArrayLike] = None, *, dtype: t.Optional[DTypeLike] = None, xp: t.Any = None) -> t.Tuple[NDArray[numpy.floating], NDArray[numpy.floating], NDArray[numpy.floating]]: ... def make_raster_scan(shape: t.Tuple[int, int], scan_step: ArrayLike, - rotation: float = 0., affine: t.Union[None, ArrayLike] = None, *, dtype: t.Any = None, xp: t.Any = None) -> NDArray[numpy.number]: + rotation: float = 0., affine: t.Union[None, ArrayLike] = None, *, dtype: t.Any = None, xp: t.Any = None) -> t.Tuple[NDArray[numpy.number], NDArray[numpy.number], NDArray[numpy.number]]: """ Make a raster scan, centered around the origin. @@ -47,6 +50,10 @@ def make_raster_scan(shape: t.Tuple[int, int], scan_step: ArrayLike, pts = xp2.stack(xp2.meshgrid(yy, xx, indexing='ij'), axis=-1) pts *= xp2.broadcast_to(xp2.asarray(scan_step, dtype=dtype), (2,)) + yy_ind = xp2.arange(shape[0], dtype=dtype) + xx_ind = xp2.arange(shape[1], dtype=dtype) + yy_grid, xx_grid = xp2.meshgrid(yy_ind, xx_ind, indexing='ij') + if affine is not None: affine = xp2.asarray(affine, dtype=dtype) pts = (pts @ affine.T) @@ -56,9 +63,98 @@ def make_raster_scan(shape: t.Tuple[int, int], scan_step: ArrayLike, mat = xp2.asarray([[numpy.cos(theta), -numpy.sin(theta)], [numpy.sin(theta), numpy.cos(theta)]], dtype=dtype) pts = (pts @ mat.T) - return t.cast(NDArray[numpy.number], pts) + return t.cast(NDArray[numpy.number], pts), t.cast(NDArray[numpy.number], yy_grid), t.cast(NDArray[numpy.number], xx_grid) + +# @tree_dataclass(frozen=True, init=False) +# class RasterScanMetadata: +# shape: NDArray[numpy.int_] +# """Sampling shape `(n_y, n_x)`""" +# sampling: NDArray[numpy.float64] +# """Sample spacing `(s_y, s_x)`""" +# corner: NDArray[numpy.float64] +# """Corner of sampling `(y_min, x_min)`""" + +# region_min: t.Optional[NDArray[numpy.float64]] +# region_max: t.Optional[NDArray[numpy.float64]] + +# @property +# def min(self) -> NDArray[numpy.float64]: +# """Minimum object pixel position (y, x). Alias for `corner`.""" +# return self.corner + +# @property +# def max(self) -> NDArray[numpy.float64]: +# """Maximum pixel position (y, x).""" +# return (self.corner + (self.shape - 1) * self.sampling).astype(numpy.float64) + +# @property +# def extent(self) -> NDArray[numpy.float64]: +# return (self.shape * self.sampling).astype(numpy.float64) + +# def __init__(self, shape: t.Tuple[int, int], sampling: ArrayLike, corner: t.Optional[ArrayLike] = None, +# region_min: t.Optional[ArrayLike] = None, region_max: t.Optional[ArrayLike] = None): +# object.__setattr__(self, 'shape', numpy.broadcast_to(as_numpy(shape).astype(numpy.int_), (2,))) +# object.__setattr__(self, 'sampling', numpy.broadcast_to(as_numpy(sampling).astype(numpy.float64), (2,))) +# object.__setattr__(self, 'region_min', numpy.broadcast_to(as_numpy(region_min).astype(numpy.float64), (2,)) if region_min is not None else None) +# object.__setattr__(self, 'region_max', numpy.broadcast_to(as_numpy(region_max).astype(numpy.float64), (2,)) if region_max is not None else None) + +# if corner is None: +# corner = -self.extent / 2. + self.sampling/2. #* (self.shape % 2) +# else: +# corner = numpy.broadcast_to(as_numpy(corner).astype(numpy.float64), (2,)) + +# object.__setattr__(self, 'corner', corner) + +# def __eq__(self, other: t.Any) -> bool: +# if type(self) is not type(other): +# return False +# xp = get_array_module(self.sampling, other.sampling) +# return ( +# xp.array_equal(self.shape, other.shape) and +# xp.array_equal(self.sampling, other.sampling) and +# xp.array_equal(self.corner, other.corner) +# ) + +# @staticmethod +# def _scan_extent(scan_positions: NDArray[numpy.floating]) -> t.Tuple[NDArray[numpy.float64], NDArray[numpy.float64]]: +# xp = get_array_module(scan_positions) +# scan_min = numpy.array(tuple(float(xp.nanmin(scan_positions[..., i])) for i in range(2))) +# scan_max = numpy.array(tuple(float(xp.nanmax(scan_positions[..., i])) for i in range(2))) +# return (scan_min, scan_max) + +# @classmethod +# def from_scan(cls: t.Type[Self], scan_positions: NDArray[numpy.floating], sampling: ArrayLike, pad: ArrayLike = 0) -> Self: +# """Create an ObjectSampling around the given scan positions, padded by at least a radius `pad` in real-space.""" +# sampling = as_numpy(sampling).astype(numpy.float64) +# pad = numpy.broadcast_to(pad, (2,)).astype(numpy.float64) + +# (scan_min, scan_max) = cls._scan_extent(scan_positions) +# n = numpy.ceil((2.*pad + scan_max - scan_min) / sampling).astype(numpy.int_) + 1 + +# return cls((n[0], n[1]), sampling, scan_min - pad, scan_min, scan_max) + +# def expand_to_scan(self, scan_positions: NDArray[numpy.floating], pad: ArrayLike = 0.) -> Self: +# pad = numpy.broadcast_to(pad, (2,)).astype(numpy.float64) + +# (scan_min, scan_max) = self._scan_extent(scan_positions) +# pad_min = numpy.ceil(numpy.maximum(0, self.min - scan_min + pad) / self.sampling).astype(numpy.int_) +# pad_max = numpy.ceil(numpy.maximum(0, scan_max - self.max + pad) / self.sampling).astype(numpy.int_) + +# if numpy.all(pad_min == 0) and numpy.all(pad_max == 0): +# return self + +# region_min = numpy.minimum(self.region_min, scan_min) if self.region_min is not None else None +# region_max = numpy.maximum(self.region_max, scan_max) if self.region_max is not None else None + +# return self.__class__( +# t.cast(t.Tuple[int, int], tuple(self.shape + pad_min + pad_max)), +# self.sampling, +# self.corner - pad_min * self.sampling, +# region_min, region_max +# ) __all__ = [ 'make_raster_scan', + # 'RasterScanMetadata' ] \ No newline at end of file From 0bcdd5b032d40466d0b89fb391eaa166d69577d8 Mon Sep 17 00:00:00 2001 From: michaelxu01 Date: Thu, 8 Jan 2026 16:44:31 -0500 Subject: [PATCH 02/19] bug fixes, functional h5 output --- phaser/engines/common/output.py | 4 ++-- phaser/engines/gradient/run.py | 2 +- phaser/execute.py | 7 +++++++ phaser/hooks/__init__.py | 2 +- phaser/hooks/preprocessing.py | 15 ++++++++------- phaser/state.py | 4 ++-- phaser/utils/io.py | 9 ++++++--- phaser/utils/scan.py | 4 +++- 8 files changed, 30 insertions(+), 17 deletions(-) diff --git a/phaser/engines/common/output.py b/phaser/engines/common/output.py index 7100d18..c8d7321 100644 --- a/phaser/engines/common/output.py +++ b/phaser/engines/common/output.py @@ -186,7 +186,7 @@ def _plot_scan(state: ReconsState, out_path: Path, options: SaveOptions): ax.set_xlim(left, right) ax.set_ylim(bottom, top) - scan = to_numpy(state.scan) + scan = to_numpy(state.scan.data) i = numpy.arange(scan[..., 0].size) ax.scatter(scan[..., 1].ravel(), scan[..., 0].ravel(), c=i, cmap='plasma', s=0.5, edgecolors='none') @@ -209,7 +209,7 @@ def _plot_tilt(state: ReconsState, out_path: Path, options: SaveOptions): ax.set_xlim(left, right) ax.set_ylim(bottom, top) - scan = to_numpy(state.scan) + scan = to_numpy(state.scan.data) tilt = to_numpy(state.tilt) tilt = tilt[..., 1] + tilt[..., 0]*1.j max_tilt = max(numpy.max(numpy.abs(tilt)), 1.0) # at least 1 mrad diff --git a/phaser/engines/gradient/run.py b/phaser/engines/gradient/run.py index 41a36a5..77030da 100644 --- a/phaser/engines/gradient/run.py +++ b/phaser/engines/gradient/run.py @@ -69,7 +69,7 @@ def process_solvers( _PATH_MAP: t.Dict[t.Tuple[str, ...], ReconsVar] = { ('object', 'data'): 'object', ('probe', 'data'): 'probe', - ('scan',): 'positions', + ('scan', 'data'): 'positions', ('tilt',): 'tilt' } diff --git a/phaser/execute.py b/phaser/execute.py index 4de760a..a5867d6 100644 --- a/phaser/execute.py +++ b/phaser/execute.py @@ -149,6 +149,13 @@ def _normalize_scan_shape( patterns.patterns = patterns.patterns.reshape((*new_shape, *patterns.patterns.shape[-2:])) state.scan.data = state.scan.data.reshape((*new_shape, 2)) + ## FIXME: check that this functions as intended, mainly when loading .h5 files and not applying dropnans + if 'raster' == state.scan.metadata.get('type'): + if 'rows' in state.scan.metadata: + state.scan.metadata['rows'] = state.scan.metadata['rows'].reshape((*new_shape, 1)) + if 'cols' in state.scan.metadata: + state.scan.metadata['cols'] = state.scan.metadata['cols'].reshape((*new_shape, 1)) + if state.tilt is not None: n_tilt = math.prod(state.tilt.shape[:-1]) if n_tilt != n_patterns: diff --git a/phaser/hooks/__init__.py b/phaser/hooks/__init__.py index f8ebb00..cca3cb1 100644 --- a/phaser/hooks/__init__.py +++ b/phaser/hooks/__init__.py @@ -136,7 +136,7 @@ class RasterScanProps(Dataclass): affine: t.Optional[t.Annotated[NDArray[numpy.floating], annotations.shape((2, 2))]] = None -class ScanHook(Hook[ScanHookArgs, ScanState]): +class ScanHook(Hook[ScanHookArgs, 'ScanState']): known = { 'raster': ('phaser.hooks.scan:raster_scan', RasterScanProps), } diff --git a/phaser/hooks/preprocessing.py b/phaser/hooks/preprocessing.py index 2a16338..8215c3d 100644 --- a/phaser/hooks/preprocessing.py +++ b/phaser/hooks/preprocessing.py @@ -105,13 +105,14 @@ def drop_nan_patterns(args: PostInitArgs, props: DropNanProps) -> t.Tuple[Patter # apply mask to scan as well scan_pos = scan_pos[~mask] if scan_meta['type'] == 'raster': - rows = scan_meta['rows'].reshape(-1, 1) - cols = scan_meta['cols'].reshape(-1, 1) - scan_meta['rows'] = rows[~mask] - scan_meta['cols'] = cols[~mask] - assert scan_meta['rows'].shape[0] == scan_pos.shape[0], f"After filtering, # of scan positions {scan_pos.shape[0]} doesn't match # of rows/cols {scan_meta['rows'].shape[0]}" - else: - logger.info("Not raster scan, not updating scan metadata rows/cols") + if 'rows' in scan_meta: + if scan_meta['rows'].shape[:-1] != args['state'].scan.data.shape[:-1]: + raise ValueError("Scan 'rows' metadata shape doesn't match scan data shape") + scan_meta['rows'] = scan_meta['rows'].reshape(-1, 1)[~mask] + if 'cols' in scan_meta: + if scan_meta['cols'].shape[:-1] != args['state'].scan.data.shape[:-1]: + raise ValueError("Scan 'cols' metadata shape doesn't match scan data shape") + scan_meta['cols'] = scan_meta['cols'].reshape(-1, 1)[~mask] elif scan_pos.shape[0] != patterns.shape[0]: raise ValueError(f"# of scan positions {scan_pos.shape[0]} doesn't match # of patterns" f" before ({mask.size}) or after ({patterns.shape[0]}) filtering") diff --git a/phaser/state.py b/phaser/state.py index ead495f..e289387 100644 --- a/phaser/state.py +++ b/phaser/state.py @@ -138,8 +138,8 @@ def copy(self) -> Self: import copy return copy.deepcopy(self) -@tree_dataclass -class ScanState(): +@tree_dataclass(static_fields=('metadata',)) +class ScanState: # sampling: ObjectSampling # """Object coordinate system. See `ObjectSampling` for more details.""" data: NDArray[numpy.floating] diff --git a/phaser/utils/io.py b/phaser/utils/io.py index 871d0fb..1503cd9 100644 --- a/phaser/utils/io.py +++ b/phaser/utils/io.py @@ -235,11 +235,14 @@ def hdf5_write_scan_state(state: ScanState, group: h5py.Group): dataset.dims[0].label = 'position' dataset.dims[1].label = 'yx' - metagroup = group.require_group('metadata') + metagroup = group.require_group('metadata') for (k, v) in state.metadata.items(): - metasubgroup = metagroup.require_group(k) + # metasubgroup = metagroup.require_group(k) ##TODO: directly dump all metadata keys as json? or subclass as rasterscanstate - metakey = metasubgroup.create_dataset("data", data=numpy.array(v.iters, dtype=numpy.int64)) + # if isinstance(v, str): + metakey = metagroup.create_dataset(k, data=v) + # else: + # metakey = metagroup.create_dataset(k, data=v)) # group.create_dataset('sampling', data=state.sampling.sampling.astype(numpy.float64)) # group.create_dataset('extent', data=state.sampling.extent.astype(numpy.float64)) diff --git a/phaser/utils/scan.py b/phaser/utils/scan.py index 41e1038..9cca845 100644 --- a/phaser/utils/scan.py +++ b/phaser/utils/scan.py @@ -52,7 +52,9 @@ def make_raster_scan(shape: t.Tuple[int, int], scan_step: ArrayLike, yy_ind = xp2.arange(shape[0], dtype=dtype) xx_ind = xp2.arange(shape[1], dtype=dtype) - yy_grid, xx_grid = xp2.meshgrid(yy_ind, xx_ind, indexing='ij') + grid_inds = xp2.stack(xp2.meshgrid(yy_ind, xx_ind, indexing='ij'), axis=-1) + yy_grid = grid_inds[..., 0] + xx_grid = grid_inds[..., 1] if affine is not None: affine = xp2.asarray(affine, dtype=dtype) From 7d348c2237a073e82fd1fe09344e1c19fdf2699c Mon Sep 17 00:00:00 2001 From: michaelxu01 Date: Fri, 9 Jan 2026 15:11:31 -0500 Subject: [PATCH 03/19] functioning reconstruction with affine constraint... not yet verified --- phaser/engines/common/output.py | 80 ++++++++++++++++++++++- phaser/engines/common/regularizers.py | 93 +++++++++++++++++---------- phaser/engines/gradient/run.py | 3 +- phaser/hooks/preprocessing.py | 6 ++ phaser/hooks/regularization.py | 4 +- phaser/hooks/scan.py | 1 + phaser/plan.py | 2 +- phaser/state.py | 6 +- phaser/utils/io.py | 3 +- 9 files changed, 158 insertions(+), 40 deletions(-) diff --git a/phaser/engines/common/output.py b/phaser/engines/common/output.py index c8d7321..1309942 100644 --- a/phaser/engines/common/output.py +++ b/phaser/engines/common/output.py @@ -12,6 +12,51 @@ from phaser.state import ReconsState from phaser.plan import SaveOptions +## TODO: move these utility functions to something like utils.plotting +def interpolate_values(a1, a2, c, d, cutoff=0.5): + c = max(c, d*cutoff) + c = min(c, d) + interpolated_value1 = a1/(1-cutoff)/d*c+a1-a1/(1-cutoff) + interpolated_value2 = (a2-1)/(1-cutoff)/d * c + a2 -(a2-1)/(1-cutoff) + interpolated_value1 = max(0, min(1, interpolated_value1)) + interpolated_value2 = max(0, min(1, interpolated_value2)) + return (interpolated_value1, interpolated_value2) + +def grad_to_rgb(angle, absolute, max_abs, offset, SL = False): + import matplotlib.colors + """Get the rgb value for the given `angle` and the `absolute` value + Intended usage: grad = numpy.array(list(map(lambda p, q: grad_to_rgb(p, q, max_abs), deg, mag))) + + Parameters + ---------- + angle : float + The angle in radians + absolute : float + The absolute value of the gradient + max_abs : float + The maximum value against which all others are normalized (color-wise). + + Returns + ------- + array_like + The rgb value as a tuple with values [0..1] + """ + + # normalize angle + angle = (angle + offset) % (2 * numpy.pi) + if angle < 0: + angle += 2 * numpy.pi + + if SL: + # print('specify global max_abs!!!') + return matplotlib.colors.hsv_to_rgb((angle / 2 / numpy.pi, + numpy.where(absolute / max_abs>1, 1, absolute / max_abs), + numpy.where(absolute / max_abs>1, 1, absolute / max_abs))) + else: + return matplotlib.colors.hsv_to_rgb((angle / 2 / numpy.pi, + 1, + 1)) + def output_images(state: ReconsState, out_dir: Path, options: SaveOptions): for ty in options.images: @@ -176,6 +221,38 @@ def _save_object_mag(state: ReconsState, out_path: Path, options: SaveOptions, s write_opts['metadata']['axes'] = 'ZYX' if stack else 'YX' w.write(obj_mag, **write_opts) +## TODO: refactor using colorize_complex to match tilt implementation +def _plot_scan_update(state: ReconsState, out_path: Path, options: SaveOptions): + from matplotlib import pyplot + fig, ax = pyplot.subplots(figsize=(4, 4), dpi=options.plot_dpi, constrained_layout=True) + ax.set_aspect(1.) + [left, right, bottom, top] = state.object.sampling.mpl_extent() + ax.set_xlim(left, right) + ax.set_ylim(bottom, top) + + scan = to_numpy(state.scan.data) + # i = numpy.arange(scan[..., 0].size) + # ax.scatter(scan[..., 1].ravel(), scan[..., 0].ravel(), c=i, cmap='plasma', s=0.5, edgecolors='none') + + scan = to_numpy(state.scan.data) + disp = to_numpy(state.scan.data) - to_numpy(state.scan.prev_step) + pos_y, pos_x = scan[..., 1].ravel(), scan[..., 0].ravel() + dY, dX = disp[..., 1].ravel(), disp[..., 0].ravel() + + deg = -numpy.arctan2(dY, dX) + mag = numpy.sqrt(dY ** 2 + dX ** 2) + max_abs = numpy.nanmax(mag) + + grad = numpy.array(list(map(lambda p, q: grad_to_rgb(p, q, max_abs, offset=0), deg, mag))) + + width = 0.1 + # hal = hl = 1. / width * length + headwidth = width + ax.quiver(pos_y, pos_x, dY, dX, color=grad, angles='xy', units='xy', width=width, edgecolor=None, linewidth=1, pivot='mid', scale_units='xy', scale=.1) + # ax.set_axis_off() + + fig.savefig(out_path) + pyplot.close(fig) def _plot_scan(state: ReconsState, out_path: Path, options: SaveOptions): from matplotlib import pyplot @@ -247,7 +324,8 @@ def _plot_tilt(state: ReconsState, out_path: Path, options: SaveOptions): 'object_mag_stack': partial(_save_object_mag, stack=True), 'object_mag_sum': partial(_save_object_mag, stack=False), 'scan': _plot_scan, + 'scan_update': _plot_scan_update, 'tilt': _plot_tilt, } # save functions with special handling of file extensions -_PLOT_FUNCS: t.Set[str] = {'scan', 'tilt'} \ No newline at end of file +_PLOT_FUNCS: t.Set[str] = {'scan', 'tilt', 'scan_update'} \ No newline at end of file diff --git a/phaser/engines/common/regularizers.py b/phaser/engines/common/regularizers.py index 5e9c6c6..ed23400 100644 --- a/phaser/engines/common/regularizers.py +++ b/phaser/engines/common/regularizers.py @@ -16,17 +16,14 @@ RegularizeLayersProps, ObjLowPassProps, GaussianProps, CostRegularizerProps, TVRegularizerProps, UnstructuredGaussianProps ) - +#type: t.Literal['affine', 'line'] = 'affine' + # weight: float = 1.0 class ScanConstraint: def __init__(self, args: None, props: ScanConstraintProps): - self.min: t.Optional[float] - self.max: t.Optional[float] - - if isinstance(props.amplitude, list): - self.min, self.max = props.amplitude - else: - self.min = None - self.max = props.amplitude + self.weight: float = props.weight + self.kind: str = props.kind + # self.weight: t.Optional[float] + # self.type: t.Optional[str] def init_state(self, sim: ReconsState) -> None: return None @@ -35,33 +32,63 @@ def apply_group(self, group: NDArray[numpy.integer], sim: ReconsState, state: No return self.apply_iter(sim, state) def apply_iter(self, sim: ReconsState, state: None) -> t.Tuple[ReconsState, None]: - cast = to_real_dtype(sim.object.data.dtype) - sim.object.data = clamp_amplitude(sim.object.data, None if self.min is None else cast(self.min), None if self.max is None else cast(self.max)) + # cast = to_real_dtype(sim.object.data.dtype) + if self.kind == 'affine': + sim.scan.data = scan_affine(sim.scan.data, sim.scan.prev_step, self.weight) + elif self.kind == 'line': + if (sim.scan.metadata.get('type') != 'raster') | (sim.scan.metadata.get('rows') is None): + raise ValueError("Line scan constraint cannot be applied to scans without row metadata") + # assert type(sim.scan.metadata.get('rows')) is Numeric + sim.scan.data = scan_line(sim.scan.data, sim.scan.prev_step, self.weight, sim.scan.metadata.get('rows')) return (sim, None) - -@partial(jit, donate_argnames=('obj',), cupy_fuse=True) +## double check that if position update is off (scan == prev_step), this doesn't break anything +@partial(jit, donate_argnames=('scan',), cupy_fuse=True) def scan_affine( - obj: NDArray[numpy.complexfloating], - min: t.Union[float, numpy.floating, None], - max: t.Union[float, numpy.floating, None] -) -> NDArray[numpy.complexfloating]: - xp = get_array_module(obj) - - obj_amp = xp.abs(obj) - new_amp = obj_amp - - if min is not None and max is not None: - new_amp = xp.clip(new_amp, min, max) - elif min is not None: - new_amp = xp.maximum(new_amp, min) - elif max is not None: - new_amp = xp.minimum(new_amp, max) - else: - return obj - - scale = xp.where(obj_amp > 0, new_amp / obj_amp, 0.0) #no divide by 0 - return obj * scale + scan: NDArray[numpy.floating], + prev: NDArray[numpy.floating], + weight: t.Union[float, numpy.floating] +) -> NDArray[numpy.floating]: + xp = get_array_module(scan) + + disp_update = scan - prev + ones = xp.ones((scan.shape[0], 1), scan.dtype) + pos_prev = xp.concatenate([scan, ones], axis=1) + left = xp.matmul(pos_prev.T, disp_update) + right = xp.matmul(pos_prev.T, pos_prev) + A = xp.matmul(xp.linalg.inv(right), left) + constraint = xp.matmul(pos_prev, A) + #remove the middle shift, keep the middle unchanged + center_ones = xp.ones((1, 1), scan.dtype) + # center[0, 0:2] = xp.average(scan, axis = 0) + center = xp.concatenate([xp.average(scan, axis = 0, keepdims=True), center_ones], axis=1) + center_shift = xp.matmul(center, A) + constraint -= center_shift + return prev + (constraint * weight + disp_update * (1 - weight)) + +@partial(jit, donate_argnames=('scan',), cupy_fuse=True) +def scan_line( + scan: NDArray[numpy.floating], + prev: NDArray[numpy.floating], + weight: t.Union[float, numpy.floating], + rows: t.Union[t.Any, NDArray[numpy.integer]] +) -> NDArray[numpy.floating]: + xp = get_array_module(scan) + + disp_update = scan - prev + ones = xp.ones((scan.shape[0], 1), scan.dtype) + pos_prev = xp.concatenate([scan, ones], axis=1) + left = xp.matmul(pos_prev.T, disp_update) + right = xp.matmul(pos_prev.T, pos_prev) + A = xp.matmul(xp.linalg.inv(right), left) + constraint = xp.matmul(pos_prev, A) + #remove the middle shift, keep the middle unchanged + center_ones = xp.ones((1, 1), scan.dtype) + # center[0, 0:2] = xp.average(scan, axis = 0) + center = xp.concatenate([xp.average(scan, axis = 0), center_ones], axis=1) + center_shift = xp.matmul(center, A) + constraint -= center_shift + return prev + (constraint * weight + disp_update * (1 - weight)) class ClampObjectAmplitude: def __init__(self, args: None, props: ClampObjectAmplitudeProps): diff --git a/phaser/engines/gradient/run.py b/phaser/engines/gradient/run.py index 77030da..d6eb7ab 100644 --- a/phaser/engines/gradient/run.py +++ b/phaser/engines/gradient/run.py @@ -114,7 +114,8 @@ def apply_update(state: ReconsState, update: t.Dict[ReconsVar, numpy.ndarray]) - # subtract mean position update xp = get_array_module(update['positions']) update['positions'] -= xp.mean(update['positions'], tuple(range(update['positions'].ndim - 1))) - print(f"{state.scan.data.shape} {update['positions'].shape}") + state.scan.prev_step = state.scan.data # check that this functions correctly. positions should be only a per-iteration solver + # print(f"{state.scan.data.shape} {update['positions'].shape}") state.scan.data += update['positions'] return state diff --git a/phaser/hooks/preprocessing.py b/phaser/hooks/preprocessing.py index 8215c3d..1772b49 100644 --- a/phaser/hooks/preprocessing.py +++ b/phaser/hooks/preprocessing.py @@ -17,6 +17,9 @@ ## FIXME: the scan flattening is done here, but doesn't safely handle any row or column metadata. ## should the scan be flattened when initialized and metadata generated in raster scan hook flow? +## TODO: actually test that cropping and dropnans work individually and together +## crop should be functional? might not be on load? probably just leave it until it becomes an issue + def crop_data(raw_data: RawData, props: CropDataProps) -> RawData: if raw_data['patterns'].ndim != 4: raise ValueError(f"'crop_data' expects a 4D array of patterns, got shape {raw_data['patterns'].shape} instead") @@ -89,6 +92,7 @@ def drop_nan_patterns(args: PostInitArgs, props: DropNanProps) -> t.Tuple[Patter # flatten scan, tilt, and patterns scan_pos = args['state'].scan.data.reshape(-1, 2) + prev_step = args['state'].scan.prev_step.reshape(-1, 2) scan_meta = args['state'].scan.metadata tilt = None if args['state'].tilt is None else args['state'].tilt.reshape(-1, 2) patterns = args['data'].patterns.reshape(-1, *args['data'].patterns.shape[-2:]) @@ -104,6 +108,7 @@ def drop_nan_patterns(args: PostInitArgs, props: DropNanProps) -> t.Tuple[Patter if scan_pos.shape[0] == xp.size(mask): # apply mask to scan as well scan_pos = scan_pos[~mask] + prev_step = prev_step[~mask] if scan_meta['type'] == 'raster': if 'rows' in scan_meta: if scan_meta['rows'].shape[:-1] != args['state'].scan.data.shape[:-1]: @@ -126,6 +131,7 @@ def drop_nan_patterns(args: PostInitArgs, props: DropNanProps) -> t.Tuple[Patter f" before ({mask.size}) or after ({patterns.shape[0]}) filtering") args['state'].scan.data = scan_pos + args['state'].scan.prev_step = prev_step ## check that copy is correct or necessary args['state'].scan.metadata = scan_meta args['state'].tilt = tilt diff --git a/phaser/hooks/regularization.py b/phaser/hooks/regularization.py index d125fff..aea9c4b 100644 --- a/phaser/hooks/regularization.py +++ b/phaser/hooks/regularization.py @@ -33,9 +33,9 @@ def calc_loss_group(self, group: NDArray[numpy.integer], sim: 'ReconsState', sta ... class ScanConstraintProps(Dataclass): - type: t.Literal['affine', 'line'] = 'affine' + kind: t.Literal['affine', 'line'] = 'affine' weight: float = 1.0 - + class ClampObjectAmplitudeProps(Dataclass): amplitude: t.Union[float, t.List[t.Optional[float]]] = 1.1 diff --git a/phaser/hooks/scan.py b/phaser/hooks/scan.py index 960fc61..94af1b6 100644 --- a/phaser/hooks/scan.py +++ b/phaser/hooks/scan.py @@ -37,5 +37,6 @@ def raster_scan(args: ScanHookArgs, props: RasterScanProps) -> ScanState: logger.info(f"Generated scan with {scan_pos.shape[0]}x{scan_pos.shape[1]}={scan_pos.shape[0]*scan_pos.shape[1]} positions.") logger.info(f"Accompanying scan metadata: type 'raster', {rows.shape} rows (y), {cols.shape} cols (x).") return ScanState(data = scan_pos, + prev_step = scan_pos.copy(), metadata={'type': 'raster', 'rows': rows, 'cols': cols}) \ No newline at end of file diff --git a/phaser/plan.py b/phaser/plan.py index 8690eae..03e9a5c 100644 --- a/phaser/plan.py +++ b/phaser/plan.py @@ -11,7 +11,7 @@ SaveType: t.TypeAlias = t.Literal[ 'probe', 'probe_mag', 'probe_recip', 'probe_recip_mag', 'object_phase_stack', 'object_phase_sum', - 'object_mag_stack', 'object_mag_sum', + 'object_mag_stack', 'object_mag_sum','scan_update', 'scan', 'tilt', ] diff --git a/phaser/state.py b/phaser/state.py index e289387..a02fa4b 100644 --- a/phaser/state.py +++ b/phaser/state.py @@ -144,6 +144,8 @@ class ScanState: # """Object coordinate system. See `ObjectSampling` for more details.""" data: NDArray[numpy.floating] """Scan coordinates (y, x), in length units. Shape (..., 2)""" + prev_step: NDArray[numpy.floating] + """Previous step Scan coordinates (y, x), in length units. Shape (..., 2)""" metadata: t.Dict[str, t.Any] """Scan row positions (y), in length units. Shape (...)""" # cols: NDArray[numpy.floating] @@ -151,13 +153,15 @@ class ScanState: def to_xp(self, xp: t.Any) -> Self: return self.__class__( - xp.asarray(self.data), + xp.asarray(self.data), + xp.asarray(self.prev_step), self.metadata, ) def to_numpy(self) -> Self: return self.__class__( to_numpy(self.data), + to_numpy(self.prev_step), self.metadata, ) diff --git a/phaser/utils/io.py b/phaser/utils/io.py index 1503cd9..7f8969f 100644 --- a/phaser/utils/io.py +++ b/phaser/utils/io.py @@ -10,7 +10,8 @@ from phaser.utils.object import ObjectSampling from phaser.state import ReconsState, IterState, ProbeState, ScanState, ObjectState, ProgressState, PartialReconsState -## FIXME: scan metadata format not yet finalized +## FIXME: scan metadata format somewhat finalized, though clunky. +## TODO: h5 read function for scan state. should be a conditional and check if scan is a group with metadata, or dataset only HdfLike: t.TypeAlias = t.Union[h5py.File, str, Path] OpenMode: t.TypeAlias = t.Literal['r', 'r+', 'w', 'w-', 'x', 'a'] From 13880c0bca72ba74838edb596569cb979cb486d0 Mon Sep 17 00:00:00 2001 From: michaelxu01 Date: Mon, 12 Jan 2026 10:55:10 -0500 Subject: [PATCH 04/19] before split --- phaser/engines/common/regularizers.py | 71 +++++++++++++++++++-------- phaser/engines/gradient/run.py | 2 +- 2 files changed, 51 insertions(+), 22 deletions(-) diff --git a/phaser/engines/common/regularizers.py b/phaser/engines/common/regularizers.py index ed23400..0698fce 100644 --- a/phaser/engines/common/regularizers.py +++ b/phaser/engines/common/regularizers.py @@ -25,22 +25,29 @@ def __init__(self, args: None, props: ScanConstraintProps): # self.weight: t.Optional[float] # self.type: t.Optional[str] - def init_state(self, sim: ReconsState) -> None: - return None + def init_state(self, sim: ReconsState) -> NDArray[numpy.floating]: + ## init the original scan positions + return sim.scan.data.copy() ## copy not needed for jax, but needed for others? - def apply_group(self, group: NDArray[numpy.integer], sim: ReconsState, state: None) -> t.Tuple[ReconsState, None]: + def apply_group(self, group: NDArray[numpy.integer], sim: ReconsState, state: NDArray[numpy.floating]) -> t.Tuple[ReconsState, NDArray[numpy.floating]]: return self.apply_iter(sim, state) - def apply_iter(self, sim: ReconsState, state: None) -> t.Tuple[ReconsState, None]: + ## TODO: more formally handle applying both affine and line simultaneously + ## I think probably for 'kind' can give a list of constraints to apply, even not in order. and here it will handle it in the best sequence + ## this preserves state better, though the dictionary of the state might be more complicated + def apply_iter(self, sim: ReconsState, state: NDArray[numpy.floating]) -> t.Tuple[ReconsState, NDArray[numpy.floating]]: # cast = to_real_dtype(sim.object.data.dtype) if self.kind == 'affine': - sim.scan.data = scan_affine(sim.scan.data, sim.scan.prev_step, self.weight) - elif self.kind == 'line': + sim.scan.data = scan_affine(sim.scan.data, state, self.weight) + # sims.object.data = ## affine deform object + if self.kind == 'line': if (sim.scan.metadata.get('type') != 'raster') | (sim.scan.metadata.get('rows') is None): raise ValueError("Line scan constraint cannot be applied to scans without row metadata") # assert type(sim.scan.metadata.get('rows')) is Numeric - sim.scan.data = scan_line(sim.scan.data, sim.scan.prev_step, self.weight, sim.scan.metadata.get('rows')) - return (sim, None) + sim.scan.data = scan_line(sim.scan.data, state, self.weight, sim.scan.metadata.get('rows')) + + state = sim.scan.data.copy() + return (sim, state) ## double check that if position update is off (scan == prev_step), this doesn't break anything @partial(jit, donate_argnames=('scan',), cupy_fuse=True) @@ -75,19 +82,41 @@ def scan_line( ) -> NDArray[numpy.floating]: xp = get_array_module(scan) - disp_update = scan - prev - ones = xp.ones((scan.shape[0], 1), scan.dtype) - pos_prev = xp.concatenate([scan, ones], axis=1) - left = xp.matmul(pos_prev.T, disp_update) - right = xp.matmul(pos_prev.T, pos_prev) - A = xp.matmul(xp.linalg.inv(right), left) - constraint = xp.matmul(pos_prev, A) - #remove the middle shift, keep the middle unchanged - center_ones = xp.ones((1, 1), scan.dtype) - # center[0, 0:2] = xp.average(scan, axis = 0) - center = xp.concatenate([xp.average(scan, axis = 0), center_ones], axis=1) - center_shift = xp.matmul(center, A) - constraint -= center_shift + # >>> import numpy +# >>> rows = [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2,] +# >>> rows = [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3] +# >>> numpy.bincount(rows, [0., 0., 0., 1., 1., 1., 1., 2., 2., 2., 2., 3., 4.]) +# array([1., 5., 9., 4.]) +# >>> numpy.bincount(rows, [0., 0., 0., 1., 1., 1., 1., 2., 2., 2., 2., 3., 4.]) / numpy.bincount(rows) +# array([0.25, 1.25, 2.25, 4. ]) +# >>> x_shifts = [0., 0., 0., 1., 1., 1., 1., 2., 2., 2., 2., 3., 4.] +# >>> x_shifts = numpy.array([0., 0., 0., 1., 1., 1., 1., 2., 2., 2., 2., 3., 4.]) +# >>> x_shifts[rows] +# array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]) +# >>> x_shifts[rows] +# KeyboardInterrupt +# >>> (numpy.bincount(rows, [0., 0., 0., 1., 1., 1., 1., 2., 2., 2., 2., 3., 4.]) / numpy.bincount(rows))[rows] +# array([0.25, 0.25, 0.25, 0.25, 1.25, 1.25, 1.25, 1.25, 2.25, 2.25, 2.25, +# 2.25, 4. ]) + + # disp_update = scan - prev + # ones = xp.ones((scan.shape[0], 1), scan.dtype) + # pos_prev = xp.concatenate([scan, ones], axis=1) + # left = xp.matmul(pos_prev.T, disp_update) + # right = xp.matmul(pos_prev.T, pos_prev) + # A = xp.matmul(xp.linalg.inv(right), left) + # constraint = xp.matmul(pos_prev, A) + # #remove the middle shift, keep the middle unchanged + # center_ones = xp.ones((1, 1), scan.dtype) + # # center[0, 0:2] = xp.average(scan, axis = 0) + # center = xp.concatenate([xp.average(scan, axis = 0), center_ones], axis=1) + # center_shift = xp.matmul(center, A) + # constraint -= center_shift + row = int(xp.sqrt(scan_pos.shape[0])) + col = int(scan_pos.shape[0]/row) + constraint = np.zeros_like(disp_val) + for i in range(row): + constraint[i*col:(i+1)*col]= np.average(disp_val[i*col:(i+1)*col], axis=0) return prev + (constraint * weight + disp_update * (1 - weight)) class ClampObjectAmplitude: diff --git a/phaser/engines/gradient/run.py b/phaser/engines/gradient/run.py index d6eb7ab..ea62d09 100644 --- a/phaser/engines/gradient/run.py +++ b/phaser/engines/gradient/run.py @@ -114,7 +114,7 @@ def apply_update(state: ReconsState, update: t.Dict[ReconsVar, numpy.ndarray]) - # subtract mean position update xp = get_array_module(update['positions']) update['positions'] -= xp.mean(update['positions'], tuple(range(update['positions'].ndim - 1))) - state.scan.prev_step = state.scan.data # check that this functions correctly. positions should be only a per-iteration solver + # state.scan.prev_step = state.scan.data # check that this functions correctly. positions should be only a per-iteration solver # print(f"{state.scan.data.shape} {update['positions'].shape}") state.scan.data += update['positions'] From 9ce4dd343860f1720aaf04d812115c9f771c6e6a Mon Sep 17 00:00:00 2001 From: michaelxu01 Date: Mon, 12 Jan 2026 16:31:46 -0500 Subject: [PATCH 05/19] working new scanconstraintprops and bug fixes --- phaser/engines/common/regularizers.py | 135 ++++++++++++++++---------- phaser/hooks/preprocessing.py | 8 +- phaser/hooks/regularization.py | 9 +- phaser/hooks/scan.py | 2 +- phaser/state.py | 4 +- phaser/types.py | 1 - phaser/utils/scan.py | 8 +- 7 files changed, 102 insertions(+), 65 deletions(-) diff --git a/phaser/engines/common/regularizers.py b/phaser/engines/common/regularizers.py index 0698fce..092a471 100644 --- a/phaser/engines/common/regularizers.py +++ b/phaser/engines/common/regularizers.py @@ -16,71 +16,105 @@ RegularizeLayersProps, ObjLowPassProps, GaussianProps, CostRegularizerProps, TVRegularizerProps, UnstructuredGaussianProps ) +# from phaser.types import ScanUpdate +logger = logging.getLogger(__name__) + +class ScanUpdate(t.NamedTuple): + previous: numpy.typing.NDArray[numpy.floating] + row_bins: t.Optional[numpy.typing.NDArray[numpy.integer]] = None + + # def __init__(self, prev: numpy.typing.NDArray[numpy.floating], row_bins: t.Optional[numpy.typing.NDArray[numpy.integer]] = None): + # self.prev = prev + # self.row_bins = row_bins + #type: t.Literal['affine', 'line'] = 'affine' # weight: float = 1.0 class ScanConstraint: def __init__(self, args: None, props: ScanConstraintProps): - self.weight: float = props.weight - self.kind: str = props.kind + self.valid_kinds: t.Set[str] = {'affine', 'line', 'hpf', 'lpf', 'default'} + self.constraints: t.Dict[str, float] = {} #= {'default': 1.0} + + for kind in self.valid_kinds: + if getattr(props, kind) > 0: + self.constraints[kind] = getattr(props, kind) + self.total_weight = sum(self.constraints.values()) # self.weight: t.Optional[float] # self.type: t.Optional[str] - - def init_state(self, sim: ReconsState) -> NDArray[numpy.floating]: + logger.info(f"Initialized scan constraint with kinds {list(self.constraints.keys())} and weights {list(self.constraints.values())} with total weight {self.total_weight:.4f}") + def init_state(self, sim: ReconsState) -> ScanUpdate: ## init the original scan positions - return sim.scan.data.copy() ## copy not needed for jax, but needed for others? + if 'line' in self.constraints: + if (sim.scan.metadata.get('type') != 'raster') | (sim.scan.metadata.get('rows') is None): + raise ValueError("Line scan constraint cannot be applied to scans without row metadata") + state = ScanUpdate(previous=sim.scan.data.copy(), row_bins=sim.scan.metadata.get('rows').ravel()) + else: + state = ScanUpdate(previous=sim.scan.data.copy(), row_bins=None) + return state - def apply_group(self, group: NDArray[numpy.integer], sim: ReconsState, state: NDArray[numpy.floating]) -> t.Tuple[ReconsState, NDArray[numpy.floating]]: + def apply_group(self, group: NDArray[numpy.integer], sim: ReconsState, state: ScanUpdate) -> t.Tuple[ReconsState, ScanUpdate]: return self.apply_iter(sim, state) ## TODO: more formally handle applying both affine and line simultaneously ## I think probably for 'kind' can give a list of constraints to apply, even not in order. and here it will handle it in the best sequence ## this preserves state better, though the dictionary of the state might be more complicated - def apply_iter(self, sim: ReconsState, state: NDArray[numpy.floating]) -> t.Tuple[ReconsState, NDArray[numpy.floating]]: + def apply_iter(self, sim: ReconsState, state: ScanUpdate) -> t.Tuple[ReconsState, ScanUpdate]: # cast = to_real_dtype(sim.object.data.dtype) - if self.kind == 'affine': - sim.scan.data = scan_affine(sim.scan.data, state, self.weight) - # sims.object.data = ## affine deform object - if self.kind == 'line': - if (sim.scan.metadata.get('type') != 'raster') | (sim.scan.metadata.get('rows') is None): - raise ValueError("Line scan constraint cannot be applied to scans without row metadata") - # assert type(sim.scan.metadata.get('rows')) is Numeric - sim.scan.data = scan_line(sim.scan.data, state, self.weight, sim.scan.metadata.get('rows')) - - state = sim.scan.data.copy() + xp = get_array_module(sim.scan.data) + update = xp.zeros_like(sim.scan.data, dtype=sim.scan.data.dtype) + for kind, weight in self.constraints.items(): + weight = weight / self.total_weight + if kind == 'affine': + update += scan_affine(sim.scan.data, state.previous) * weight + # sims.object.data = ## affine deform object + if kind == 'line' and state.row_bins is not None: + update += scan_line(sim.scan.data, state.previous, state.row_bins) * weight + if kind == 'hpf': + pass + if kind == 'lpf': + pass + if kind == 'default': + update += scan_default(sim.scan.data, state.previous) * weight + # assert update.dtype == sim.scan.data.dtype + sim.scan.data = state.previous + update + state = ScanUpdate(previous=sim.scan.data.copy(), row_bins=state.row_bins) return (sim, state) ## double check that if position update is off (scan == prev_step), this doesn't break anything -@partial(jit, donate_argnames=('scan',), cupy_fuse=True) +# @partial(jit, donate_argnames=('pos',), cupy_fuse=True) +def scan_default( + pos: NDArray[numpy.floating], + prev: NDArray[numpy.floating], +) -> NDArray[numpy.floating]: + return pos - prev + def scan_affine( - scan: NDArray[numpy.floating], + pos: NDArray[numpy.floating], prev: NDArray[numpy.floating], - weight: t.Union[float, numpy.floating] ) -> NDArray[numpy.floating]: - xp = get_array_module(scan) + xp = get_array_module(pos) - disp_update = scan - prev - ones = xp.ones((scan.shape[0], 1), scan.dtype) - pos_prev = xp.concatenate([scan, ones], axis=1) + disp_update = pos - prev + ones = xp.ones((pos.shape[0], 1), pos.dtype) + pos_prev = xp.concatenate([pos, ones], axis=1) left = xp.matmul(pos_prev.T, disp_update) right = xp.matmul(pos_prev.T, pos_prev) A = xp.matmul(xp.linalg.inv(right), left) constraint = xp.matmul(pos_prev, A) #remove the middle shift, keep the middle unchanged - center_ones = xp.ones((1, 1), scan.dtype) - # center[0, 0:2] = xp.average(scan, axis = 0) - center = xp.concatenate([xp.average(scan, axis = 0, keepdims=True), center_ones], axis=1) + center_ones = xp.ones((1, 1), pos.dtype) + # center[0, 0:2] = xp.average(pos, axis = 0) + center = xp.concatenate([xp.average(pos, axis = 0, keepdims=True), center_ones], axis=1, dtype=pos.dtype) center_shift = xp.matmul(center, A) constraint -= center_shift - return prev + (constraint * weight + disp_update * (1 - weight)) + return constraint -@partial(jit, donate_argnames=('scan',), cupy_fuse=True) +# @partial(jit, donate_argnames=('pos',), cupy_fuse=True) def scan_line( - scan: NDArray[numpy.floating], + pos: NDArray[numpy.floating], prev: NDArray[numpy.floating], - weight: t.Union[float, numpy.floating], - rows: t.Union[t.Any, NDArray[numpy.integer]] + rows: NDArray[numpy.integer], ) -> NDArray[numpy.floating]: - xp = get_array_module(scan) + xp = get_array_module(pos) # >>> import numpy # >>> rows = [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2,] @@ -98,26 +132,21 @@ def scan_line( # >>> (numpy.bincount(rows, [0., 0., 0., 1., 1., 1., 1., 2., 2., 2., 2., 3., 4.]) / numpy.bincount(rows))[rows] # array([0.25, 0.25, 0.25, 0.25, 1.25, 1.25, 1.25, 1.25, 2.25, 2.25, 2.25, # 2.25, 4. ]) - - # disp_update = scan - prev - # ones = xp.ones((scan.shape[0], 1), scan.dtype) - # pos_prev = xp.concatenate([scan, ones], axis=1) - # left = xp.matmul(pos_prev.T, disp_update) - # right = xp.matmul(pos_prev.T, pos_prev) - # A = xp.matmul(xp.linalg.inv(right), left) - # constraint = xp.matmul(pos_prev, A) - # #remove the middle shift, keep the middle unchanged - # center_ones = xp.ones((1, 1), scan.dtype) - # # center[0, 0:2] = xp.average(scan, axis = 0) - # center = xp.concatenate([xp.average(scan, axis = 0), center_ones], axis=1) - # center_shift = xp.matmul(center, A) - # constraint -= center_shift - row = int(xp.sqrt(scan_pos.shape[0])) - col = int(scan_pos.shape[0]/row) - constraint = np.zeros_like(disp_val) - for i in range(row): - constraint[i*col:(i+1)*col]= np.average(disp_val[i*col:(i+1)*col], axis=0) - return prev + (constraint * weight + disp_update * (1 - weight)) + # rows is shape (N x 2) + # constraint is shape (N x 2) + disp_val = pos - prev + + y_shifts = xp.bincount(rows, disp_val[:,0]) / xp.bincount(rows) + x_shifts = xp.bincount(rows, disp_val[:,1]) / xp.bincount(rows) + + # scan_pos = scan + # row = int(xp.sqrt(scan_pos.shape[0])) + # col = int(scan_pos.shape[0]/row) + # constraint = np.zeros_like(disp_val) + # for i in range(row): + # constraint[i*col:(i+1)*col]= np.average(disp_val[i*col:(i+1)*col], axis=0) + constraint = xp.stack([y_shifts[rows], x_shifts[rows]], axis=1, dtype=pos.dtype) + return constraint class ClampObjectAmplitude: def __init__(self, args: None, props: ClampObjectAmplitudeProps): diff --git a/phaser/hooks/preprocessing.py b/phaser/hooks/preprocessing.py index 1772b49..371330e 100644 --- a/phaser/hooks/preprocessing.py +++ b/phaser/hooks/preprocessing.py @@ -94,6 +94,9 @@ def drop_nan_patterns(args: PostInitArgs, props: DropNanProps) -> t.Tuple[Patter scan_pos = args['state'].scan.data.reshape(-1, 2) prev_step = args['state'].scan.prev_step.reshape(-1, 2) scan_meta = args['state'].scan.metadata + scan_meta['rows'] = scan_meta['rows'].reshape(-1, 1) + scan_meta['cols'] = scan_meta['cols'].reshape(-1, 1) + tilt = None if args['state'].tilt is None else args['state'].tilt.reshape(-1, 2) patterns = args['data'].patterns.reshape(-1, *args['data'].patterns.shape[-2:]) @@ -113,11 +116,11 @@ def drop_nan_patterns(args: PostInitArgs, props: DropNanProps) -> t.Tuple[Patter if 'rows' in scan_meta: if scan_meta['rows'].shape[:-1] != args['state'].scan.data.shape[:-1]: raise ValueError("Scan 'rows' metadata shape doesn't match scan data shape") - scan_meta['rows'] = scan_meta['rows'].reshape(-1, 1)[~mask] + scan_meta['rows'] = scan_meta['rows'][~mask] if 'cols' in scan_meta: if scan_meta['cols'].shape[:-1] != args['state'].scan.data.shape[:-1]: raise ValueError("Scan 'cols' metadata shape doesn't match scan data shape") - scan_meta['cols'] = scan_meta['cols'].reshape(-1, 1)[~mask] + scan_meta['cols'] = scan_meta['cols'][~mask] elif scan_pos.shape[0] != patterns.shape[0]: raise ValueError(f"# of scan positions {scan_pos.shape[0]} doesn't match # of patterns" f" before ({mask.size}) or after ({patterns.shape[0]}) filtering") @@ -130,6 +133,7 @@ def drop_nan_patterns(args: PostInitArgs, props: DropNanProps) -> t.Tuple[Patter raise ValueError(f"# of tilt positions {tilt.shape[0]} doesn't match # of patterns" f" before ({mask.size}) or after ({patterns.shape[0]}) filtering") + print(scan_pos.shape, scan_meta['cols'].shape) args['state'].scan.data = scan_pos args['state'].scan.prev_step = prev_step ## check that copy is correct or necessary args['state'].scan.metadata = scan_meta diff --git a/phaser/hooks/regularization.py b/phaser/hooks/regularization.py index aea9c4b..cd75b98 100644 --- a/phaser/hooks/regularization.py +++ b/phaser/hooks/regularization.py @@ -33,9 +33,12 @@ def calc_loss_group(self, group: NDArray[numpy.integer], sim: 'ReconsState', sta ... class ScanConstraintProps(Dataclass): - kind: t.Literal['affine', 'line'] = 'affine' - weight: float = 1.0 - + affine: float = 0.0 + line: float = 0.0 + hpf: float = 0.0 + lpf: float = 0.0 + default: float = 1.0 + class ClampObjectAmplitudeProps(Dataclass): amplitude: t.Union[float, t.List[t.Optional[float]]] = 1.1 diff --git a/phaser/hooks/scan.py b/phaser/hooks/scan.py index 94af1b6..ba899e2 100644 --- a/phaser/hooks/scan.py +++ b/phaser/hooks/scan.py @@ -33,7 +33,7 @@ def raster_scan(args: ScanHookArgs, props: RasterScanProps) -> ScanState: props.shape, step_size, rot, affine, dtype=args['dtype'], xp=xp, ) - + print(rows.dtype) logger.info(f"Generated scan with {scan_pos.shape[0]}x{scan_pos.shape[1]}={scan_pos.shape[0]*scan_pos.shape[1]} positions.") logger.info(f"Accompanying scan metadata: type 'raster', {rows.shape} rows (y), {cols.shape} cols (x).") return ScanState(data = scan_pos, diff --git a/phaser/state.py b/phaser/state.py index a02fa4b..46fbeb1 100644 --- a/phaser/state.py +++ b/phaser/state.py @@ -288,4 +288,6 @@ def with_observer(self, observer: t.Union['Observer', t.Iterable['Observer']]) - else: observers.extend(observer) - return self.__class__(self.patterns, self.state, self.name, ObserverSet(observers)) \ No newline at end of file + return self.__class__(self.patterns, self.state, self.name, ObserverSet(observers)) + + diff --git a/phaser/types.py b/phaser/types.py index c6a83b7..4f29fce 100644 --- a/phaser/types.py +++ b/phaser/types.py @@ -91,7 +91,6 @@ def __init__(self, state: 'ReconsState', continue_next_engine: bool = False): class Dataclass(pane.PaneBase, kw_only=True, allow_extra=False): ... - class SliceList(Dataclass): thicknesses: t.List[float] diff --git a/phaser/utils/scan.py b/phaser/utils/scan.py index 9cca845..f1c6d13 100644 --- a/phaser/utils/scan.py +++ b/phaser/utils/scan.py @@ -20,7 +20,7 @@ def make_raster_scan(shape: t.Tuple[int, int], scan_step: ArrayLike, # pyright: @t.overload def make_raster_scan(shape: t.Tuple[int, int], scan_step: ArrayLike, - rotation: float = 0., affine: t.Union[None, ArrayLike] = None, *, dtype: t.Optional[DTypeLike] = None, xp: t.Any = None) -> t.Tuple[NDArray[numpy.floating], NDArray[numpy.floating], NDArray[numpy.floating]]: + rotation: float = 0., affine: t.Union[None, ArrayLike] = None, *, dtype: t.Optional[DTypeLike] = None, xp: t.Any = None) -> t.Tuple[NDArray[numpy.floating], NDArray[numpy.integer], NDArray[numpy.integer]]: ... def make_raster_scan(shape: t.Tuple[int, int], scan_step: ArrayLike, @@ -50,8 +50,8 @@ def make_raster_scan(shape: t.Tuple[int, int], scan_step: ArrayLike, pts = xp2.stack(xp2.meshgrid(yy, xx, indexing='ij'), axis=-1) pts *= xp2.broadcast_to(xp2.asarray(scan_step, dtype=dtype), (2,)) - yy_ind = xp2.arange(shape[0], dtype=dtype) - xx_ind = xp2.arange(shape[1], dtype=dtype) + yy_ind = xp2.arange(shape[0]) + xx_ind = xp2.arange(shape[1]) grid_inds = xp2.stack(xp2.meshgrid(yy_ind, xx_ind, indexing='ij'), axis=-1) yy_grid = grid_inds[..., 0] xx_grid = grid_inds[..., 1] @@ -65,7 +65,7 @@ def make_raster_scan(shape: t.Tuple[int, int], scan_step: ArrayLike, mat = xp2.asarray([[numpy.cos(theta), -numpy.sin(theta)], [numpy.sin(theta), numpy.cos(theta)]], dtype=dtype) pts = (pts @ mat.T) - return t.cast(NDArray[numpy.number], pts), t.cast(NDArray[numpy.number], yy_grid), t.cast(NDArray[numpy.number], xx_grid) + return t.cast(NDArray[numpy.number], pts), t.cast(NDArray[numpy.integer], yy_grid), t.cast(NDArray[numpy.integer], xx_grid) # @tree_dataclass(frozen=True, init=False) # class RasterScanMetadata: From d9558b13ecee5596f9cc4322722d19d4219fff39 Mon Sep 17 00:00:00 2001 From: michaelxu01 Date: Thu, 15 Jan 2026 16:10:12 -0500 Subject: [PATCH 06/19] add reference --- phaser/engines/common/regularizers.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/phaser/engines/common/regularizers.py b/phaser/engines/common/regularizers.py index 092a471..1d3b758 100644 --- a/phaser/engines/common/regularizers.py +++ b/phaser/engines/common/regularizers.py @@ -30,6 +30,9 @@ class ScanUpdate(t.NamedTuple): #type: t.Literal['affine', 'line'] = 'affine' # weight: float = 1.0 class ScanConstraint: + """See ref for details: + S. Ning, W. Xu, L. Loh, Z. Lu, M. Bosman, F. Zhang, Q. He, An integrated constrained gradient descent (iCGD) protocol to correct scan-positional errors for electron ptychography with high accuracy and precision. Ultramicroscopy 248, 113716 (2023). + """ def __init__(self, args: None, props: ScanConstraintProps): self.valid_kinds: t.Set[str] = {'affine', 'line', 'hpf', 'lpf', 'default'} self.constraints: t.Dict[str, float] = {} #= {'default': 1.0} From 29cfe0f9b2a3862eb0dc6bcab06718a0293e7bc1 Mon Sep 17 00:00:00 2001 From: michaelxu01 Date: Thu, 15 Jan 2026 18:17:09 -0500 Subject: [PATCH 07/19] write and read h5 functional breaking with all previous h5 files, since they do not contain scan object --- phaser/engines/common/output.py | 2 +- phaser/hooks/preprocessing.py | 4 ++-- phaser/hooks/scan.py | 2 +- phaser/state.py | 6 +++--- phaser/utils/io.py | 35 ++++++++++++++++++++++++++++----- 5 files changed, 37 insertions(+), 12 deletions(-) diff --git a/phaser/engines/common/output.py b/phaser/engines/common/output.py index 1309942..44df43d 100644 --- a/phaser/engines/common/output.py +++ b/phaser/engines/common/output.py @@ -235,7 +235,7 @@ def _plot_scan_update(state: ReconsState, out_path: Path, options: SaveOptions): # ax.scatter(scan[..., 1].ravel(), scan[..., 0].ravel(), c=i, cmap='plasma', s=0.5, edgecolors='none') scan = to_numpy(state.scan.data) - disp = to_numpy(state.scan.data) - to_numpy(state.scan.prev_step) + disp = to_numpy(state.scan.data) - to_numpy(state.scan.initial_scan) pos_y, pos_x = scan[..., 1].ravel(), scan[..., 0].ravel() dY, dX = disp[..., 1].ravel(), disp[..., 0].ravel() diff --git a/phaser/hooks/preprocessing.py b/phaser/hooks/preprocessing.py index 371330e..3f31252 100644 --- a/phaser/hooks/preprocessing.py +++ b/phaser/hooks/preprocessing.py @@ -92,7 +92,7 @@ def drop_nan_patterns(args: PostInitArgs, props: DropNanProps) -> t.Tuple[Patter # flatten scan, tilt, and patterns scan_pos = args['state'].scan.data.reshape(-1, 2) - prev_step = args['state'].scan.prev_step.reshape(-1, 2) + prev_step = args['state'].scan.initial_scan.reshape(-1, 2) scan_meta = args['state'].scan.metadata scan_meta['rows'] = scan_meta['rows'].reshape(-1, 1) scan_meta['cols'] = scan_meta['cols'].reshape(-1, 1) @@ -135,7 +135,7 @@ def drop_nan_patterns(args: PostInitArgs, props: DropNanProps) -> t.Tuple[Patter print(scan_pos.shape, scan_meta['cols'].shape) args['state'].scan.data = scan_pos - args['state'].scan.prev_step = prev_step ## check that copy is correct or necessary + args['state'].scan.initial_scan = prev_step ## check that copy is correct or necessary args['state'].scan.metadata = scan_meta args['state'].tilt = tilt diff --git a/phaser/hooks/scan.py b/phaser/hooks/scan.py index ba899e2..47f08e7 100644 --- a/phaser/hooks/scan.py +++ b/phaser/hooks/scan.py @@ -37,6 +37,6 @@ def raster_scan(args: ScanHookArgs, props: RasterScanProps) -> ScanState: logger.info(f"Generated scan with {scan_pos.shape[0]}x{scan_pos.shape[1]}={scan_pos.shape[0]*scan_pos.shape[1]} positions.") logger.info(f"Accompanying scan metadata: type 'raster', {rows.shape} rows (y), {cols.shape} cols (x).") return ScanState(data = scan_pos, - prev_step = scan_pos.copy(), + initial_scan = scan_pos.copy(), metadata={'type': 'raster', 'rows': rows, 'cols': cols}) \ No newline at end of file diff --git a/phaser/state.py b/phaser/state.py index 46fbeb1..f2c624d 100644 --- a/phaser/state.py +++ b/phaser/state.py @@ -144,7 +144,7 @@ class ScanState: # """Object coordinate system. See `ObjectSampling` for more details.""" data: NDArray[numpy.floating] """Scan coordinates (y, x), in length units. Shape (..., 2)""" - prev_step: NDArray[numpy.floating] + initial_scan: NDArray[numpy.floating] """Previous step Scan coordinates (y, x), in length units. Shape (..., 2)""" metadata: t.Dict[str, t.Any] """Scan row positions (y), in length units. Shape (...)""" @@ -154,14 +154,14 @@ class ScanState: def to_xp(self, xp: t.Any) -> Self: return self.__class__( xp.asarray(self.data), - xp.asarray(self.prev_step), + xp.asarray(self.initial_scan), self.metadata, ) def to_numpy(self) -> Self: return self.__class__( to_numpy(self.data), - to_numpy(self.prev_step), + to_numpy(self.initial_scan), self.metadata, ) diff --git a/phaser/utils/io.py b/phaser/utils/io.py index 7f8969f..0c62f56 100644 --- a/phaser/utils/io.py +++ b/phaser/utils/io.py @@ -113,7 +113,7 @@ def hdf5_read_state(file: HdfLike) -> PartialReconsState: tilt = numpy.asarray(_hdf5_read_dataset(file, 'tilt', numpy.float64)) if 'tilt' in file else None if tilt is not None and scan is not None: - assert tilt.shape == scan.shape + assert tilt.shape == scan.data.shape progress = hdf5_read_progress_state(_assert_group(file['progress'])) if 'progress' in file else None return PartialReconsState( @@ -134,14 +134,36 @@ def hdf5_read_probe_state(group: h5py.Group) -> ProbeState: data=probes ) +# if this is called, already assumed that scan is state and contains metadata group def hdf5_read_scan_state(group: h5py.Group) -> ScanState: scan = _hdf5_read_dataset(group, 'data', numpy.floating) assert scan.ndim == 2 + initial = _hdf5_read_dataset(group, 'initial_scan', numpy.floating) + assert initial.ndim == 2 + + meta_d: t.Dict[str, t.Any] = {} + # iterate to find + for (k, subgroup) in group.items(): + if not isinstance(subgroup, h5py.Group): + continue + elif k == 'metadata': + scan_type = _hdf5_read_string(subgroup, 'type') + meta_d['type'] = scan_type + + if scan_type == 'raster': + rows = _hdf5_read_dataset(subgroup, 'rows', numpy.integer) + assert rows.ndim == 2 + cols = _hdf5_read_dataset(subgroup, 'cols', numpy.integer) + assert rows.ndim == 2 + + meta_d['rows'] = rows + meta_d['cols'] = cols + else: + continue - ## TODO: read metadata properly from whatever write_scan_state does return ScanState( - data=scan, - metadata={} + data=scan, initial_scan=initial, + metadata=meta_d ) def hdf5_read_object_state(group: h5py.Group) -> ObjectState: @@ -198,7 +220,7 @@ def hdf5_read_progress_state(group: h5py.Group) -> t.Dict[str, ProgressState]: d[k] = ProgressState(iters.tolist(), values.tolist()) return d - + def hdf5_write_state(state: t.Union[ReconsState, PartialReconsState], file: HdfLike): file = open_hdf5(file, 'w') # overwrite if existing @@ -235,6 +257,9 @@ def hdf5_write_scan_state(state: ScanState, group: h5py.Group): dataset = group.create_dataset('data', data=to_numpy(state.data)) dataset.dims[0].label = 'position' dataset.dims[1].label = 'yx' + dataset = group.create_dataset('initial', data=to_numpy(state.initial_scan)) + dataset.dims[0].label = 'position' + dataset.dims[1].label = 'yx' metagroup = group.require_group('metadata') for (k, v) in state.metadata.items(): From f5039fd7469d4a88cdf3c58592fb55f47b64853f Mon Sep 17 00:00:00 2001 From: michaelxu01 Date: Fri, 16 Jan 2026 15:06:16 -0500 Subject: [PATCH 08/19] bug fix in read h5 --- phaser/utils/io.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/phaser/utils/io.py b/phaser/utils/io.py index 0c62f56..8a1bc03 100644 --- a/phaser/utils/io.py +++ b/phaser/utils/io.py @@ -138,7 +138,7 @@ def hdf5_read_probe_state(group: h5py.Group) -> ProbeState: def hdf5_read_scan_state(group: h5py.Group) -> ScanState: scan = _hdf5_read_dataset(group, 'data', numpy.floating) assert scan.ndim == 2 - initial = _hdf5_read_dataset(group, 'initial_scan', numpy.floating) + initial = _hdf5_read_dataset(group, 'initial', numpy.floating) assert initial.ndim == 2 meta_d: t.Dict[str, t.Any] = {} From 3bca912400a5f196b00a12368dfa061c581fce41 Mon Sep 17 00:00:00 2001 From: michaelxu01 Date: Tue, 20 Jan 2026 10:43:49 -0500 Subject: [PATCH 09/19] initial changes to conventional --- phaser/engines/conventional/run.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/phaser/engines/conventional/run.py b/phaser/engines/conventional/run.py index 195b001..8f295e6 100644 --- a/phaser/engines/conventional/run.py +++ b/phaser/engines/conventional/run.py @@ -100,7 +100,7 @@ def run_engine(args: EngineArgs, props: ConventionalEnginePlan) -> ReconsState: ) assert_dtype(sim.state.object.data, cdtype) assert_dtype(sim.state.probe.data, cdtype) - + ## FIXME: can this apply iter constraints be moved after the position update? sim = sim.apply_iter_constraints() if iter_update_positions: @@ -109,12 +109,12 @@ def run_engine(args: EngineArgs, props: ConventionalEnginePlan) -> ReconsState: # subtract mean position update pos_update -= xp.mean(pos_update, tuple(range(pos_update.ndim - 1))) - pos_update, position_solver_state = position_solver.perform_update(sim.state.scan, pos_update, position_solver_state) + pos_update, position_solver_state = position_solver.perform_update(sim.state.scan.data, pos_update, position_solver_state) # subtract mean again (this can change with momentum) pos_update -= xp.mean(pos_update, tuple(range(pos_update.ndim - 1))) pos_update_rms = float(xp.mean(xp.linalg.norm(pos_update, axis=-1, keepdims=True))) logger.info(f"Position update: mean {pos_update_rms}") - sim.state.scan += pos_update + sim.state.scan.data += pos_update assert_dtype(sim.state.scan, dtype) # check positions are at least overlapping object From c80d3822297860bf1f8d424692f375e8bfac5953 Mon Sep 17 00:00:00 2001 From: michaelxu01 Date: Tue, 20 Jan 2026 12:01:27 -0500 Subject: [PATCH 10/19] Update test_initialization.py --- tests/test_initialization.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tests/test_initialization.py b/tests/test_initialization.py index b8494b4..0669d48 100644 --- a/tests/test_initialization.py +++ b/tests/test_initialization.py @@ -11,7 +11,7 @@ from phaser.hooks import RawData from phaser.plan import ReconsPlan from phaser.execute import load_raw_data, initialize_reconstruction -from phaser.state import PartialReconsState, ProbeState +from phaser.state import PartialReconsState, ProbeState, ScanState def load_empty(args, props) -> RawData: @@ -110,7 +110,10 @@ def test_load_raw_data_prev_state(caplog): } probe_state = ProbeState(Sampling((64, 64), sampling=(1.0, 1.0)), numpy.zeros((64, 64), dtype=numpy.complex64)) - scan_state = numpy.zeros((32, 32, 2)) + scan_pos = numpy.zeros((32, 32, 2)) + scan_state = ScanState(data=scan_pos, initial_scan=scan_pos, metadata={'type': 'raster', + 'rows': scan_pos.reshape(-1, 2)[:,0], + 'cols': scan_pos.reshape(-1, 2)[:,1]}) xp = numpy with caplog.at_level(logging.WARNING): @@ -133,7 +136,7 @@ def test_load_raw_data_prev_state(caplog): # probe from state overrides probe from raw data assert numpy.all(numpy.isclose(recons.state.probe.data, probe_state.data)) # but scan should be modeled - assert ~numpy.all(numpy.isclose(recons.state.scan, scan_state)) + assert ~numpy.all(numpy.isclose(recons.state.scan.data, scan_state.data)) plan['init'] = { 'scan': {}, @@ -150,7 +153,7 @@ def test_load_raw_data_prev_state(caplog): # both should be modeled assert ~numpy.all(numpy.isclose(recons.state.probe.data, probe_state.data)) - assert ~numpy.all(numpy.isclose(recons.state.scan, scan_state)) + assert ~numpy.all(numpy.isclose(recons.state.scan.data, scan_state.data)) def test_load_3d_raw_data(): @@ -180,5 +183,5 @@ def test_load_3d_raw_data(): }) recons = initialize_reconstruction(plan) - assert recons.state.scan.shape == (*scan_shape, 2) + assert recons.state.scan.data.shape == (*scan_shape, 2) assert recons.patterns.patterns.shape == (*scan_shape, *det_shape) \ No newline at end of file From 56d5a7430733233f9da4ced7eb1d6786be6a10cf Mon Sep 17 00:00:00 2001 From: michaelxu01 Date: Tue, 20 Jan 2026 12:19:57 -0500 Subject: [PATCH 11/19] match statement, matmul, private, and cleanup of comments --- phaser/engines/common/regularizers.py | 80 +++++++-------------------- phaser/execute.py | 3 +- phaser/state.py | 2 - 3 files changed, 22 insertions(+), 63 deletions(-) diff --git a/phaser/engines/common/regularizers.py b/phaser/engines/common/regularizers.py index 1d3b758..2fa92e9 100644 --- a/phaser/engines/common/regularizers.py +++ b/phaser/engines/common/regularizers.py @@ -16,19 +16,12 @@ RegularizeLayersProps, ObjLowPassProps, GaussianProps, CostRegularizerProps, TVRegularizerProps, UnstructuredGaussianProps ) -# from phaser.types import ScanUpdate + logger = logging.getLogger(__name__) class ScanUpdate(t.NamedTuple): previous: numpy.typing.NDArray[numpy.floating] row_bins: t.Optional[numpy.typing.NDArray[numpy.integer]] = None - - # def __init__(self, prev: numpy.typing.NDArray[numpy.floating], row_bins: t.Optional[numpy.typing.NDArray[numpy.integer]] = None): - # self.prev = prev - # self.row_bins = row_bins - -#type: t.Literal['affine', 'line'] = 'affine' - # weight: float = 1.0 class ScanConstraint: """See ref for details: S. Ning, W. Xu, L. Loh, Z. Lu, M. Bosman, F. Zhang, Q. He, An integrated constrained gradient descent (iCGD) protocol to correct scan-positional errors for electron ptychography with high accuracy and precision. Ultramicroscopy 248, 113716 (2023). @@ -41,11 +34,10 @@ def __init__(self, args: None, props: ScanConstraintProps): if getattr(props, kind) > 0: self.constraints[kind] = getattr(props, kind) self.total_weight = sum(self.constraints.values()) - # self.weight: t.Optional[float] - # self.type: t.Optional[str] + logger.info(f"Initialized scan constraint with kinds {list(self.constraints.keys())} and weights {list(self.constraints.values())} with total weight {self.total_weight:.4f}") + def init_state(self, sim: ReconsState) -> ScanUpdate: - ## init the original scan positions if 'line' in self.constraints: if (sim.scan.metadata.get('type') != 'raster') | (sim.scan.metadata.get('rows') is None): raise ValueError("Line scan constraint cannot be applied to scans without row metadata") @@ -57,40 +49,37 @@ def init_state(self, sim: ReconsState) -> ScanUpdate: def apply_group(self, group: NDArray[numpy.integer], sim: ReconsState, state: ScanUpdate) -> t.Tuple[ReconsState, ScanUpdate]: return self.apply_iter(sim, state) - ## TODO: more formally handle applying both affine and line simultaneously - ## I think probably for 'kind' can give a list of constraints to apply, even not in order. and here it will handle it in the best sequence - ## this preserves state better, though the dictionary of the state might be more complicated def apply_iter(self, sim: ReconsState, state: ScanUpdate) -> t.Tuple[ReconsState, ScanUpdate]: # cast = to_real_dtype(sim.object.data.dtype) xp = get_array_module(sim.scan.data) update = xp.zeros_like(sim.scan.data, dtype=sim.scan.data.dtype) for kind, weight in self.constraints.items(): weight = weight / self.total_weight - if kind == 'affine': - update += scan_affine(sim.scan.data, state.previous) * weight - # sims.object.data = ## affine deform object - if kind == 'line' and state.row_bins is not None: - update += scan_line(sim.scan.data, state.previous, state.row_bins) * weight - if kind == 'hpf': - pass - if kind == 'lpf': - pass - if kind == 'default': - update += scan_default(sim.scan.data, state.previous) * weight - # assert update.dtype == sim.scan.data.dtype + match kind: + case 'affine': + update += _scan_affine(sim.scan.data, state.previous) * weight + case 'line': + if state.row_bins is not None: + update += _scan_line(sim.scan.data, state.previous, state.row_bins) * weight + case 'hpf': + pass + case 'lpf': + pass + case 'default': + update += _scan_default(sim.scan.data, state.previous) * weight + sim.scan.data = state.previous + update state = ScanUpdate(previous=sim.scan.data.copy(), row_bins=state.row_bins) return (sim, state) -## double check that if position update is off (scan == prev_step), this doesn't break anything # @partial(jit, donate_argnames=('pos',), cupy_fuse=True) -def scan_default( +def _scan_default( pos: NDArray[numpy.floating], prev: NDArray[numpy.floating], ) -> NDArray[numpy.floating]: return pos - prev -def scan_affine( +def _scan_affine( pos: NDArray[numpy.floating], prev: NDArray[numpy.floating], ) -> NDArray[numpy.floating]: @@ -103,51 +92,24 @@ def scan_affine( right = xp.matmul(pos_prev.T, pos_prev) A = xp.matmul(xp.linalg.inv(right), left) constraint = xp.matmul(pos_prev, A) - #remove the middle shift, keep the middle unchanged center_ones = xp.ones((1, 1), pos.dtype) - # center[0, 0:2] = xp.average(pos, axis = 0) center = xp.concatenate([xp.average(pos, axis = 0, keepdims=True), center_ones], axis=1, dtype=pos.dtype) - center_shift = xp.matmul(center, A) + center_shift = center @ A constraint -= center_shift return constraint # @partial(jit, donate_argnames=('pos',), cupy_fuse=True) -def scan_line( +def _scan_line( pos: NDArray[numpy.floating], prev: NDArray[numpy.floating], rows: NDArray[numpy.integer], ) -> NDArray[numpy.floating]: xp = get_array_module(pos) - - # >>> import numpy -# >>> rows = [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2,] -# >>> rows = [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3] -# >>> numpy.bincount(rows, [0., 0., 0., 1., 1., 1., 1., 2., 2., 2., 2., 3., 4.]) -# array([1., 5., 9., 4.]) -# >>> numpy.bincount(rows, [0., 0., 0., 1., 1., 1., 1., 2., 2., 2., 2., 3., 4.]) / numpy.bincount(rows) -# array([0.25, 1.25, 2.25, 4. ]) -# >>> x_shifts = [0., 0., 0., 1., 1., 1., 1., 2., 2., 2., 2., 3., 4.] -# >>> x_shifts = numpy.array([0., 0., 0., 1., 1., 1., 1., 2., 2., 2., 2., 3., 4.]) -# >>> x_shifts[rows] -# array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]) -# >>> x_shifts[rows] -# KeyboardInterrupt -# >>> (numpy.bincount(rows, [0., 0., 0., 1., 1., 1., 1., 2., 2., 2., 2., 3., 4.]) / numpy.bincount(rows))[rows] -# array([0.25, 0.25, 0.25, 0.25, 1.25, 1.25, 1.25, 1.25, 2.25, 2.25, 2.25, -# 2.25, 4. ]) - # rows is shape (N x 2) - # constraint is shape (N x 2) + disp_val = pos - prev y_shifts = xp.bincount(rows, disp_val[:,0]) / xp.bincount(rows) x_shifts = xp.bincount(rows, disp_val[:,1]) / xp.bincount(rows) - - # scan_pos = scan - # row = int(xp.sqrt(scan_pos.shape[0])) - # col = int(scan_pos.shape[0]/row) - # constraint = np.zeros_like(disp_val) - # for i in range(row): - # constraint[i*col:(i+1)*col]= np.average(disp_val[i*col:(i+1)*col], axis=0) constraint = xp.stack([y_shifts[rows], x_shifts[rows]], axis=1, dtype=pos.dtype) return constraint diff --git a/phaser/execute.py b/phaser/execute.py index a5867d6..00f24ec 100644 --- a/phaser/execute.py +++ b/phaser/execute.py @@ -124,7 +124,6 @@ def _normalize_observers( return ObserverSet(obs) -## FIXME: the scan normalization here - happens before dropnans and scan data flattening, but may alter shape and therefore rows/cols? why is this needed def _normalize_scan_shape( patterns: Patterns, state: ReconsState ) -> t.Tuple[Patterns, ReconsState]: @@ -149,7 +148,7 @@ def _normalize_scan_shape( patterns.patterns = patterns.patterns.reshape((*new_shape, *patterns.patterns.shape[-2:])) state.scan.data = state.scan.data.reshape((*new_shape, 2)) - ## FIXME: check that this functions as intended, mainly when loading .h5 files and not applying dropnans + ## TODO: check that this functions as intended, mainly when loading .h5 files and not applying dropnans if 'raster' == state.scan.metadata.get('type'): if 'rows' in state.scan.metadata: state.scan.metadata['rows'] = state.scan.metadata['rows'].reshape((*new_shape, 1)) diff --git a/phaser/state.py b/phaser/state.py index f2c624d..4986a3d 100644 --- a/phaser/state.py +++ b/phaser/state.py @@ -13,8 +13,6 @@ from phaser.utils.image import _InterpBoundaryMode from phaser.observer import Observer, ObserverSet -## FIXME: scan metadata format not yet finalized, but scanstate object and references here are updated - @tree_dataclass class Patterns(): patterns: NDArray[numpy.floating] From a2566cba465f383d3ec6d21e2600ad63e5e7f0c3 Mon Sep 17 00:00:00 2001 From: michaelxu01 Date: Tue, 20 Jan 2026 12:39:13 -0500 Subject: [PATCH 12/19] additional cleanup --- phaser/engines/common/output.py | 2 -- phaser/hooks/preprocessing.py | 4 ---- phaser/utils/io.py | 3 --- 3 files changed, 9 deletions(-) diff --git a/phaser/engines/common/output.py b/phaser/engines/common/output.py index 44df43d..31d641d 100644 --- a/phaser/engines/common/output.py +++ b/phaser/engines/common/output.py @@ -231,8 +231,6 @@ def _plot_scan_update(state: ReconsState, out_path: Path, options: SaveOptions): ax.set_ylim(bottom, top) scan = to_numpy(state.scan.data) - # i = numpy.arange(scan[..., 0].size) - # ax.scatter(scan[..., 1].ravel(), scan[..., 0].ravel(), c=i, cmap='plasma', s=0.5, edgecolors='none') scan = to_numpy(state.scan.data) disp = to_numpy(state.scan.data) - to_numpy(state.scan.initial_scan) diff --git a/phaser/hooks/preprocessing.py b/phaser/hooks/preprocessing.py index 3f31252..41beabf 100644 --- a/phaser/hooks/preprocessing.py +++ b/phaser/hooks/preprocessing.py @@ -14,9 +14,6 @@ logger = logging.getLogger(__name__) -## FIXME: the scan flattening is done here, but doesn't safely handle any row or column metadata. -## should the scan be flattened when initialized and metadata generated in raster scan hook flow? - ## TODO: actually test that cropping and dropnans work individually and together ## crop should be functional? might not be on load? probably just leave it until it becomes an issue @@ -88,7 +85,6 @@ def add_poisson_noise(raw_data: RawData, props: PoissonProps) -> RawData: def drop_nan_patterns(args: PostInitArgs, props: DropNanProps) -> t.Tuple[Patterns, ReconsState]: xp = get_array_module(args['data'].patterns) - ## FIXME: should the scan be flattened to begin with? is there any situation where we don't want that? # flatten scan, tilt, and patterns scan_pos = args['state'].scan.data.reshape(-1, 2) diff --git a/phaser/utils/io.py b/phaser/utils/io.py index 8a1bc03..789d68e 100644 --- a/phaser/utils/io.py +++ b/phaser/utils/io.py @@ -10,9 +10,6 @@ from phaser.utils.object import ObjectSampling from phaser.state import ReconsState, IterState, ProbeState, ScanState, ObjectState, ProgressState, PartialReconsState -## FIXME: scan metadata format somewhat finalized, though clunky. -## TODO: h5 read function for scan state. should be a conditional and check if scan is a group with metadata, or dataset only - HdfLike: t.TypeAlias = t.Union[h5py.File, str, Path] OpenMode: t.TypeAlias = t.Literal['r', 'r+', 'w', 'w-', 'x', 'a'] DTypeT = t.TypeVar('DTypeT', bound=numpy.generic) From 6d94fd20146e26814e374824213b1002e943d2bb Mon Sep 17 00:00:00 2001 From: michaelxu01 Date: Tue, 20 Jan 2026 17:06:12 -0500 Subject: [PATCH 13/19] lsqml working --- phaser/engines/conventional/run.py | 8 ++++---- phaser/engines/conventional/solvers.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/phaser/engines/conventional/run.py b/phaser/engines/conventional/run.py index 8f295e6..37945d7 100644 --- a/phaser/engines/conventional/run.py +++ b/phaser/engines/conventional/run.py @@ -100,8 +100,6 @@ def run_engine(args: EngineArgs, props: ConventionalEnginePlan) -> ReconsState: ) assert_dtype(sim.state.object.data, cdtype) assert_dtype(sim.state.probe.data, cdtype) - ## FIXME: can this apply iter constraints be moved after the position update? - sim = sim.apply_iter_constraints() if iter_update_positions: if not position_solver: @@ -115,14 +113,16 @@ def run_engine(args: EngineArgs, props: ConventionalEnginePlan) -> ReconsState: pos_update_rms = float(xp.mean(xp.linalg.norm(pos_update, axis=-1, keepdims=True))) logger.info(f"Position update: mean {pos_update_rms}") sim.state.scan.data += pos_update - assert_dtype(sim.state.scan, dtype) + assert_dtype(sim.state.scan.data, dtype) # check positions are at least overlapping object - sim.state.object.sampling.check_scan(sim.state.scan, sim.state.probe.sampling.extent / 2.) + sim.state.object.sampling.check_scan(sim.state.scan.data, sim.state.probe.sampling.extent / 2.) progress['pos_update_rms'].iters.append(i + start_i) progress['pos_update_rms'].values.append(pos_update_rms) + sim = sim.apply_iter_constraints() + error = None if group_errors is not None and len(group_errors): error = float(to_numpy(xp.nanmean(xp.concatenate(group_errors)))) diff --git a/phaser/engines/conventional/solvers.py b/phaser/engines/conventional/solvers.py index d689d95..f3bb0b0 100644 --- a/phaser/engines/conventional/solvers.py +++ b/phaser/engines/conventional/solvers.py @@ -86,7 +86,7 @@ def run_iteration( new_obj_mag = xp.zeros_like(self.obj_mag) new_probe_mag = xp.zeros_like(self.probe_mag) - pos_update = xp.zeros_like(sim.state.scan, dtype=sim.dtype) + pos_update = xp.zeros_like(sim.state.scan.data, dtype=sim.dtype) iter_errors = [] for (group_i, (group, group_patterns)) in enumerate(stream_patterns(groups, patterns, xp=xp, @@ -345,7 +345,7 @@ def run_iteration( xp = sim.xp # TODO: ePIE position update - pos_update = xp.zeros_like(sim.state.scan) + pos_update = xp.zeros_like(sim.state.scan.data) iter_errors = [] beta_object = process_schedule(self.plan.beta_object)({'state': sim.state, 'niter': self.engine_plan.niter}) From 3e596d02952774c17e4f6bb9c2c11b83373799d9 Mon Sep 17 00:00:00 2001 From: michaelxu01 Date: Tue, 20 Jan 2026 17:29:30 -0500 Subject: [PATCH 14/19] fixed constraint weights, no renormalization --- phaser/engines/common/regularizers.py | 19 ++++++++++++++----- phaser/hooks/regularization.py | 1 - 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/phaser/engines/common/regularizers.py b/phaser/engines/common/regularizers.py index 2fa92e9..38d302b 100644 --- a/phaser/engines/common/regularizers.py +++ b/phaser/engines/common/regularizers.py @@ -27,15 +27,25 @@ class ScanConstraint: S. Ning, W. Xu, L. Loh, Z. Lu, M. Bosman, F. Zhang, Q. He, An integrated constrained gradient descent (iCGD) protocol to correct scan-positional errors for electron ptychography with high accuracy and precision. Ultramicroscopy 248, 113716 (2023). """ def __init__(self, args: None, props: ScanConstraintProps): - self.valid_kinds: t.Set[str] = {'affine', 'line', 'hpf', 'lpf', 'default'} + self.valid_kinds: t.Set[str] = {'affine', 'line', 'hpf', 'lpf'} self.constraints: t.Dict[str, float] = {} #= {'default': 1.0} + total_constraint_weight = 0 for kind in self.valid_kinds: if getattr(props, kind) > 0: - self.constraints[kind] = getattr(props, kind) - self.total_weight = sum(self.constraints.values()) + val = getattr(props, kind) + self.constraints[kind] = val - logger.info(f"Initialized scan constraint with kinds {list(self.constraints.keys())} and weights {list(self.constraints.values())} with total weight {self.total_weight:.4f}") + total_constraint_weight += val + + if total_constraint_weight > 1.0: + raise ValueError("Sum of scan constraint weights cannot exceed 1.0") + + self.constraints['default'] = 1-total_constraint_weight + + # self.total_weight = sum(self.constraints.values()) + + logger.info(f"Initialized scan constraint with kinds {list(self.constraints.keys())} and weights {list(self.constraints.values()):.4f}") def init_state(self, sim: ReconsState) -> ScanUpdate: if 'line' in self.constraints: @@ -54,7 +64,6 @@ def apply_iter(self, sim: ReconsState, state: ScanUpdate) -> t.Tuple[ReconsState xp = get_array_module(sim.scan.data) update = xp.zeros_like(sim.scan.data, dtype=sim.scan.data.dtype) for kind, weight in self.constraints.items(): - weight = weight / self.total_weight match kind: case 'affine': update += _scan_affine(sim.scan.data, state.previous) * weight diff --git a/phaser/hooks/regularization.py b/phaser/hooks/regularization.py index cd75b98..29c787f 100644 --- a/phaser/hooks/regularization.py +++ b/phaser/hooks/regularization.py @@ -37,7 +37,6 @@ class ScanConstraintProps(Dataclass): line: float = 0.0 hpf: float = 0.0 lpf: float = 0.0 - default: float = 1.0 class ClampObjectAmplitudeProps(Dataclass): amplitude: t.Union[float, t.List[t.Optional[float]]] = 1.1 From a4de8a89805448adb8a608aaef9378330b12b2d1 Mon Sep 17 00:00:00 2001 From: michaelxu01 Date: Fri, 23 Jan 2026 13:04:53 -0500 Subject: [PATCH 15/19] added docs for scan constraint reg, additional cleanup --- phaser/engines/common/regularizers.py | 43 ++++++++++++++++++++++++++- phaser/engines/conventional/run.py | 1 - 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/phaser/engines/common/regularizers.py b/phaser/engines/common/regularizers.py index 38d302b..77cc465 100644 --- a/phaser/engines/common/regularizers.py +++ b/phaser/engines/common/regularizers.py @@ -20,10 +20,19 @@ logger = logging.getLogger(__name__) class ScanUpdate(t.NamedTuple): + """ + Scan update object for holding the scan constraint row index and previous position arrays + """ previous: numpy.typing.NDArray[numpy.floating] row_bins: t.Optional[numpy.typing.NDArray[numpy.integer]] = None class ScanConstraint: - """See ref for details: + """ + Constraints for the scan position updates. + This per iteration regularizer takes the unconstrained position updates and applies a + weighted average of its affine, line (row) averaged, high pass filtered, or low pass filtered components. + Currently, only the affine and line averaged constraints are implemented. + + See ref for details: S. Ning, W. Xu, L. Loh, Z. Lu, M. Bosman, F. Zhang, Q. He, An integrated constrained gradient descent (iCGD) protocol to correct scan-positional errors for electron ptychography with high accuracy and precision. Ultramicroscopy 248, 113716 (2023). """ def __init__(self, args: None, props: ScanConstraintProps): @@ -86,12 +95,32 @@ def _scan_default( pos: NDArray[numpy.floating], prev: NDArray[numpy.floating], ) -> NDArray[numpy.floating]: + """ + Pass through function for calculating the scan update from final and initial scan positions. + + :param pos: N x 2 array of unconstrained updated scan positions + :type pos: NDArray[numpy.floating] + :param prev: N x 2 array of scan positions, before update (previous iteration) + :type prev: NDArray[numpy.floating] + :return: N x 2 array of updates to the scan positions + :rtype: NDArray[floating[Any]] + """ return pos - prev def _scan_affine( pos: NDArray[numpy.floating], prev: NDArray[numpy.floating], ) -> NDArray[numpy.floating]: + """ + Calculates and returns the affine component of the update between final and initial scan positions. + + :param pos: N x 2 array of unconstrained updated scan positions + :type pos: NDArray[numpy.floating] + :param prev: N x 2 array of scan positions, before update (previous iteration) + :type prev: NDArray[numpy.floating] + :return: N x 2 array of updates to the scan positions (affine only) + :rtype: NDArray[floating[Any]] + """ xp = get_array_module(pos) disp_update = pos - prev @@ -113,6 +142,18 @@ def _scan_line( prev: NDArray[numpy.floating], rows: NDArray[numpy.integer], ) -> NDArray[numpy.floating]: + """ + Calculates and returns a line (row) averaged update from the unconstrained final and initial scan positions. + + :param pos: N x 2 array of unconstrained updated scan positions + :type pos: NDArray[numpy.floating] + :param prev: N x 2 array of scan positions, before update (previous iteration) + :type prev: NDArray[numpy.floating] + :param rows: an array of row indices corresponding to the N positions given by pos and prev + :type rows: NDArray[numpy.integer] + :return: N x 2 array of updates to the scan positions (row averaged) + :rtype: NDArray[floating[Any]] + """ xp = get_array_module(pos) disp_val = pos - prev diff --git a/phaser/engines/conventional/run.py b/phaser/engines/conventional/run.py index 37945d7..ce11008 100644 --- a/phaser/engines/conventional/run.py +++ b/phaser/engines/conventional/run.py @@ -8,7 +8,6 @@ from phaser.state import ReconsState, ProgressState from phaser.types import process_flag, flag_any_true from ..common.simulation import SimulationState, make_propagators, GroupManager -## FIXME: not yet updated for conventional engines def run_engine(args: EngineArgs, props: ConventionalEnginePlan) -> ReconsState: logger = logging.getLogger(__name__) From b50f062fbddf2e25219656d03af908e0ad4f768f Mon Sep 17 00:00:00 2001 From: michaelxu01 Date: Fri, 23 Jan 2026 13:16:13 -0500 Subject: [PATCH 16/19] cleanup --- phaser/utils/scan.py | 88 -------------------------------------------- 1 file changed, 88 deletions(-) diff --git a/phaser/utils/scan.py b/phaser/utils/scan.py index f1c6d13..f54cffe 100644 --- a/phaser/utils/scan.py +++ b/phaser/utils/scan.py @@ -67,94 +67,6 @@ def make_raster_scan(shape: t.Tuple[int, int], scan_step: ArrayLike, return t.cast(NDArray[numpy.number], pts), t.cast(NDArray[numpy.integer], yy_grid), t.cast(NDArray[numpy.integer], xx_grid) -# @tree_dataclass(frozen=True, init=False) -# class RasterScanMetadata: -# shape: NDArray[numpy.int_] -# """Sampling shape `(n_y, n_x)`""" -# sampling: NDArray[numpy.float64] -# """Sample spacing `(s_y, s_x)`""" -# corner: NDArray[numpy.float64] -# """Corner of sampling `(y_min, x_min)`""" - -# region_min: t.Optional[NDArray[numpy.float64]] -# region_max: t.Optional[NDArray[numpy.float64]] - -# @property -# def min(self) -> NDArray[numpy.float64]: -# """Minimum object pixel position (y, x). Alias for `corner`.""" -# return self.corner - -# @property -# def max(self) -> NDArray[numpy.float64]: -# """Maximum pixel position (y, x).""" -# return (self.corner + (self.shape - 1) * self.sampling).astype(numpy.float64) - -# @property -# def extent(self) -> NDArray[numpy.float64]: -# return (self.shape * self.sampling).astype(numpy.float64) - -# def __init__(self, shape: t.Tuple[int, int], sampling: ArrayLike, corner: t.Optional[ArrayLike] = None, -# region_min: t.Optional[ArrayLike] = None, region_max: t.Optional[ArrayLike] = None): -# object.__setattr__(self, 'shape', numpy.broadcast_to(as_numpy(shape).astype(numpy.int_), (2,))) -# object.__setattr__(self, 'sampling', numpy.broadcast_to(as_numpy(sampling).astype(numpy.float64), (2,))) -# object.__setattr__(self, 'region_min', numpy.broadcast_to(as_numpy(region_min).astype(numpy.float64), (2,)) if region_min is not None else None) -# object.__setattr__(self, 'region_max', numpy.broadcast_to(as_numpy(region_max).astype(numpy.float64), (2,)) if region_max is not None else None) - -# if corner is None: -# corner = -self.extent / 2. + self.sampling/2. #* (self.shape % 2) -# else: -# corner = numpy.broadcast_to(as_numpy(corner).astype(numpy.float64), (2,)) - -# object.__setattr__(self, 'corner', corner) - -# def __eq__(self, other: t.Any) -> bool: -# if type(self) is not type(other): -# return False -# xp = get_array_module(self.sampling, other.sampling) -# return ( -# xp.array_equal(self.shape, other.shape) and -# xp.array_equal(self.sampling, other.sampling) and -# xp.array_equal(self.corner, other.corner) -# ) - -# @staticmethod -# def _scan_extent(scan_positions: NDArray[numpy.floating]) -> t.Tuple[NDArray[numpy.float64], NDArray[numpy.float64]]: -# xp = get_array_module(scan_positions) -# scan_min = numpy.array(tuple(float(xp.nanmin(scan_positions[..., i])) for i in range(2))) -# scan_max = numpy.array(tuple(float(xp.nanmax(scan_positions[..., i])) for i in range(2))) -# return (scan_min, scan_max) - -# @classmethod -# def from_scan(cls: t.Type[Self], scan_positions: NDArray[numpy.floating], sampling: ArrayLike, pad: ArrayLike = 0) -> Self: -# """Create an ObjectSampling around the given scan positions, padded by at least a radius `pad` in real-space.""" -# sampling = as_numpy(sampling).astype(numpy.float64) -# pad = numpy.broadcast_to(pad, (2,)).astype(numpy.float64) - -# (scan_min, scan_max) = cls._scan_extent(scan_positions) -# n = numpy.ceil((2.*pad + scan_max - scan_min) / sampling).astype(numpy.int_) + 1 - -# return cls((n[0], n[1]), sampling, scan_min - pad, scan_min, scan_max) - -# def expand_to_scan(self, scan_positions: NDArray[numpy.floating], pad: ArrayLike = 0.) -> Self: -# pad = numpy.broadcast_to(pad, (2,)).astype(numpy.float64) - -# (scan_min, scan_max) = self._scan_extent(scan_positions) -# pad_min = numpy.ceil(numpy.maximum(0, self.min - scan_min + pad) / self.sampling).astype(numpy.int_) -# pad_max = numpy.ceil(numpy.maximum(0, scan_max - self.max + pad) / self.sampling).astype(numpy.int_) - -# if numpy.all(pad_min == 0) and numpy.all(pad_max == 0): -# return self - -# region_min = numpy.minimum(self.region_min, scan_min) if self.region_min is not None else None -# region_max = numpy.maximum(self.region_max, scan_max) if self.region_max is not None else None - -# return self.__class__( -# t.cast(t.Tuple[int, int], tuple(self.shape + pad_min + pad_max)), -# self.sampling, -# self.corner - pad_min * self.sampling, -# region_min, region_max -# ) - __all__ = [ 'make_raster_scan', From 1bc614e7da6d90eb3f9d6164e522ed52d7ec7ca5 Mon Sep 17 00:00:00 2001 From: michaelxu01 Date: Mon, 26 Jan 2026 13:49:19 -0500 Subject: [PATCH 17/19] bug fix --- phaser/engines/common/regularizers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/phaser/engines/common/regularizers.py b/phaser/engines/common/regularizers.py index 77cc465..17f3722 100644 --- a/phaser/engines/common/regularizers.py +++ b/phaser/engines/common/regularizers.py @@ -54,7 +54,7 @@ def __init__(self, args: None, props: ScanConstraintProps): # self.total_weight = sum(self.constraints.values()) - logger.info(f"Initialized scan constraint with kinds {list(self.constraints.keys())} and weights {list(self.constraints.values()):.4f}") + logger.info(f"Initialized scan constraint with kinds {list(self.constraints.keys())} and weights {list(self.constraints.values())}") def init_state(self, sim: ReconsState) -> ScanUpdate: if 'line' in self.constraints: From 04c45f312548fca05d4468d1ce796c793bc3eb06 Mon Sep 17 00:00:00 2001 From: michaelxu01 Date: Mon, 26 Jan 2026 16:53:55 -0500 Subject: [PATCH 18/19] change metadata read/write to json - map all array-likes to numpy and then to list before json dump - coerce to numpy array when needed after loading --- phaser/engines/common/regularizers.py | 10 +++---- phaser/execute.py | 8 +++-- phaser/state.py | 13 ++------ phaser/utils/io.py | 43 ++++++--------------------- 4 files changed, 23 insertions(+), 51 deletions(-) diff --git a/phaser/engines/common/regularizers.py b/phaser/engines/common/regularizers.py index 17f3722..1f03349 100644 --- a/phaser/engines/common/regularizers.py +++ b/phaser/engines/common/regularizers.py @@ -51,16 +51,17 @@ def __init__(self, args: None, props: ScanConstraintProps): raise ValueError("Sum of scan constraint weights cannot exceed 1.0") self.constraints['default'] = 1-total_constraint_weight - - # self.total_weight = sum(self.constraints.values()) - + logger.info(f"Initialized scan constraint with kinds {list(self.constraints.keys())} and weights {list(self.constraints.values())}") def init_state(self, sim: ReconsState) -> ScanUpdate: if 'line' in self.constraints: if (sim.scan.metadata.get('type') != 'raster') | (sim.scan.metadata.get('rows') is None): raise ValueError("Line scan constraint cannot be applied to scans without row metadata") - state = ScanUpdate(previous=sim.scan.data.copy(), row_bins=sim.scan.metadata.get('rows').ravel()) + row_vals = sim.scan.metadata.get('rows') + if isinstance(row_vals, list): + row_vals = numpy.array(row_vals, dtype=numpy.integer) + state = ScanUpdate(previous=sim.scan.data.copy(), row_bins=row_vals.ravel()) else: state = ScanUpdate(previous=sim.scan.data.copy(), row_bins=None) return state @@ -69,7 +70,6 @@ def apply_group(self, group: NDArray[numpy.integer], sim: ReconsState, state: Sc return self.apply_iter(sim, state) def apply_iter(self, sim: ReconsState, state: ScanUpdate) -> t.Tuple[ReconsState, ScanUpdate]: - # cast = to_real_dtype(sim.object.data.dtype) xp = get_array_module(sim.scan.data) update = xp.zeros_like(sim.scan.data, dtype=sim.scan.data.dtype) for kind, weight in self.constraints.items(): diff --git a/phaser/execute.py b/phaser/execute.py index 00f24ec..05b5cce 100644 --- a/phaser/execute.py +++ b/phaser/execute.py @@ -151,9 +151,13 @@ def _normalize_scan_shape( ## TODO: check that this functions as intended, mainly when loading .h5 files and not applying dropnans if 'raster' == state.scan.metadata.get('type'): if 'rows' in state.scan.metadata: - state.scan.metadata['rows'] = state.scan.metadata['rows'].reshape((*new_shape, 1)) + if isinstance(state.scan.metadata.get('rows'), list): + state.scan.metadata['rows'] = numpy.array(state.scan.metadata.get('rows'), dtype=numpy.integer) + state.scan.metadata['rows'] = (state.scan.metadata['rows']).reshape((*new_shape, 1)) if 'cols' in state.scan.metadata: - state.scan.metadata['cols'] = state.scan.metadata['cols'].reshape((*new_shape, 1)) + if isinstance(state.scan.metadata.get('cols'), list): + state.scan.metadata['cols'] = numpy.array(state.scan.metadata.get('cols'), dtype=numpy.integer) + state.scan.metadata['cols'] = (state.scan.metadata['cols']).reshape((*new_shape, 1)) if state.tilt is not None: n_tilt = math.prod(state.tilt.shape[:-1]) diff --git a/phaser/state.py b/phaser/state.py index 4986a3d..c5ffb8d 100644 --- a/phaser/state.py +++ b/phaser/state.py @@ -1,5 +1,4 @@ import typing as t - import numpy from numpy.typing import NDArray from typing_extensions import Self @@ -138,16 +137,12 @@ def copy(self) -> Self: @tree_dataclass(static_fields=('metadata',)) class ScanState: - # sampling: ObjectSampling - # """Object coordinate system. See `ObjectSampling` for more details.""" data: NDArray[numpy.floating] """Scan coordinates (y, x), in length units. Shape (..., 2)""" initial_scan: NDArray[numpy.floating] - """Previous step Scan coordinates (y, x), in length units. Shape (..., 2)""" - metadata: t.Dict[str, t.Any] - """Scan row positions (y), in length units. Shape (...)""" - # cols: NDArray[numpy.floating] - # """Scan column positions (x), in length units. Shape (...)""" + """Initial scan coordinates (y, x), in length units. Shape (..., 2)""" + metadata: t.Union[t.Dict[str, t.Any], str] + """Metadata dictionary. Contains scan type and other info, such as rows and cols for raster scan positions""" def to_xp(self, xp: t.Any) -> Self: return self.__class__( @@ -175,7 +170,6 @@ class ReconsState: probe: ProbeState object: ObjectState scan: ScanState #NDArray[numpy.floating] - """Scan coordinates (y, x), in length units. Shape (..., 2)""" tilt: t.Optional[NDArray[numpy.floating]] = None """Tilt angles (y, x) per scan position, in mrad. Shape (..., 2)""" progress: t.Dict[str, ProgressState] = field(default_factory=dict) @@ -224,7 +218,6 @@ class PartialReconsState: probe: t.Optional[ProbeState] = None object: t.Optional[ObjectState] = None scan: t.Optional[ScanState] = None - """Scan coordinates (y, x), in length units. Shape (..., 2)""" tilt: t.Optional[NDArray[numpy.floating]] = None progress: t.Optional[t.Dict[str, ProgressState]] = None diff --git a/phaser/utils/io.py b/phaser/utils/io.py index 789d68e..9371dfb 100644 --- a/phaser/utils/io.py +++ b/phaser/utils/io.py @@ -1,10 +1,12 @@ import contextlib from pathlib import Path import typing as t - +import json +# import pane import numpy from numpy.typing import NDArray import h5py +from phaser.utils.tree import map from phaser.utils.num import Sampling, to_numpy from phaser.utils.object import ObjectSampling @@ -138,25 +140,7 @@ def hdf5_read_scan_state(group: h5py.Group) -> ScanState: initial = _hdf5_read_dataset(group, 'initial', numpy.floating) assert initial.ndim == 2 - meta_d: t.Dict[str, t.Any] = {} - # iterate to find - for (k, subgroup) in group.items(): - if not isinstance(subgroup, h5py.Group): - continue - elif k == 'metadata': - scan_type = _hdf5_read_string(subgroup, 'type') - meta_d['type'] = scan_type - - if scan_type == 'raster': - rows = _hdf5_read_dataset(subgroup, 'rows', numpy.integer) - assert rows.ndim == 2 - cols = _hdf5_read_dataset(subgroup, 'cols', numpy.integer) - assert rows.ndim == 2 - - meta_d['rows'] = rows - meta_d['cols'] = cols - else: - continue + meta_d = json.loads(_hdf5_read_string(group, 'metadata')) return ScanState( data=scan, initial_scan=initial, @@ -250,26 +234,17 @@ def hdf5_write_probe_state(state: ProbeState, group: h5py.Group): group.create_dataset('extent', data=state.sampling.extent.astype(numpy.float64)) def hdf5_write_scan_state(state: ScanState, group: h5py.Group): + state = state.to_numpy() assert state.data.ndim == 2 - dataset = group.create_dataset('data', data=to_numpy(state.data)) + dataset = group.create_dataset('data', data=state.data) dataset.dims[0].label = 'position' dataset.dims[1].label = 'yx' - dataset = group.create_dataset('initial', data=to_numpy(state.initial_scan)) + dataset = group.create_dataset('initial', data=state.initial_scan) dataset.dims[0].label = 'position' dataset.dims[1].label = 'yx' + meta_to_dump = map(lambda v: to_numpy(v).tolist() if hasattr(v, '__array_priority__') else v, state.metadata) - metagroup = group.require_group('metadata') - for (k, v) in state.metadata.items(): - # metasubgroup = metagroup.require_group(k) - ##TODO: directly dump all metadata keys as json? or subclass as rasterscanstate - # if isinstance(v, str): - metakey = metagroup.create_dataset(k, data=v) - # else: - # metakey = metagroup.create_dataset(k, data=v)) - - # group.create_dataset('sampling', data=state.sampling.sampling.astype(numpy.float64)) - # group.create_dataset('extent', data=state.sampling.extent.astype(numpy.float64)) - + dataset = group.create_dataset('metadata', data=json.dumps(meta_to_dump)) def hdf5_write_object_state(state: ObjectState, group: h5py.Group): assert state.data.ndim == 3 From 48ad8f73bb91087c559235f457a74a5db47596dc Mon Sep 17 00:00:00 2001 From: michaelxu01 Date: Tue, 27 Jan 2026 12:54:17 -0500 Subject: [PATCH 19/19] fixed import of same function name --- phaser/utils/io.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/phaser/utils/io.py b/phaser/utils/io.py index 9371dfb..b31ae7c 100644 --- a/phaser/utils/io.py +++ b/phaser/utils/io.py @@ -6,7 +6,7 @@ import numpy from numpy.typing import NDArray import h5py -from phaser.utils.tree import map +from phaser.utils import tree from phaser.utils.num import Sampling, to_numpy from phaser.utils.object import ObjectSampling @@ -242,7 +242,7 @@ def hdf5_write_scan_state(state: ScanState, group: h5py.Group): dataset = group.create_dataset('initial', data=state.initial_scan) dataset.dims[0].label = 'position' dataset.dims[1].label = 'yx' - meta_to_dump = map(lambda v: to_numpy(v).tolist() if hasattr(v, '__array_priority__') else v, state.metadata) + meta_to_dump = tree.map(lambda v: to_numpy(v).tolist() if hasattr(v, '__array_priority__') else v, state.metadata) dataset = group.create_dataset('metadata', data=json.dumps(meta_to_dump))