From e008d9f0b151630b8b910e3f10a983a069a086c7 Mon Sep 17 00:00:00 2001 From: Colin Gilgenbach Date: Tue, 13 Jan 2026 07:07:39 -0800 Subject: [PATCH 1/9] Some fixes for torch[mps] --- phaser/plan.py | 2 +- phaser/utils/_torch_kernels.py | 34 +++++++++++++++++++++++++--------- phaser/utils/io.py | 6 +++--- phaser/utils/num.py | 12 +++++++++++- phaser/utils/object.py | 11 +++++++---- phaser/utils/scan.py | 2 +- 6 files changed, 48 insertions(+), 19 deletions(-) diff --git a/phaser/plan.py b/phaser/plan.py index 8690eae..37e4a30 100644 --- a/phaser/plan.py +++ b/phaser/plan.py @@ -153,7 +153,7 @@ class AdamSolverPlan(Dataclass, kw_only=True): b1: float = 0.9 b2: float = 0.999 eps: float = 1.0e-8 - eps_root: float = 0.0 + eps_root: float = 1.0e-8 # needed on mps w/ adam nesterov: bool = False diff --git a/phaser/utils/_torch_kernels.py b/phaser/utils/_torch_kernels.py index d8d2675..a10895c 100644 --- a/phaser/utils/_torch_kernels.py +++ b/phaser/utils/_torch_kernels.py @@ -1,7 +1,6 @@ import functools import itertools import operator -from types import ModuleType import typing as t import numpy @@ -233,19 +232,22 @@ def unwrap(arr: torch.Tensor, discont: t.Optional[float] = None, axis: int = -1, def indices( - shape: t.Tuple[int, ...], dtype: t.Union[str, None, t.Type[numpy.generic], torch.dtype] = None, sparse: bool = False + shape: t.Tuple[int, ...], + dtype: t.Union[str, None, t.Type[numpy.generic], torch.dtype] = None, + sparse: bool = False, + device: t.Optional[torch.device] = None, ) -> t.Union[torch.Tensor, t.Tuple[torch.Tensor, ...]]: - dtype = to_torch_dtype(dtype) if dtype is not None else torch.int64 + dtype = to_torch_dtype(dtype) if dtype is not None else torch.int32 n = len(shape) if sparse: return tuple( - _MockTensor(torch.arange(s, dtype=dtype).reshape((1,) * i + (s,) + (1,) * (n - i - 1))) + _MockTensor(torch.arange(s, dtype=dtype, device=device).reshape((1,) * i + (s,) + (1,) * (n - i - 1))) for (i, s) in enumerate(shape) ) - arrs = tuple(torch.arange(s, dtype=dtype) for s in shape) + arrs = tuple(torch.arange(s, dtype=dtype, device=device) for s in shape) return _MockTensor(torch.stack(torch.meshgrid(*arrs, indexing='ij'), dim=0)) @@ -275,24 +277,27 @@ def affine_transform( order: int = 1, mode: _InterpBoundaryMode = 'grid-constant', cval: ArrayLike = 0.0, ) -> torch.Tensor: + float_dtype = torch.get_default_dtype() if output_shape is None: output_shape = input.shape n_axes = len(output_shape) # num axes to transform over - idxs = t.cast(torch.Tensor, indices(output_shape, dtype=torch.float64)) + idxs = t.cast(torch.Tensor, indices(output_shape, dtype=float_dtype, device=input.device)) - matrix = asarray(matrix) + matrix = asarray(matrix, dtype=float_dtype) if matrix.size() == (n_axes + 1, n_axes + 1): # homogenous transform matrix coords = torch.tensordot( matrix, torch.stack((*idxs, torch.ones_like(idxs[0])), dim=0), dims=1 )[:-1] elif matrix.size() == (n_axes,): - coords = (idxs.T * matrix + asarray(offset)).T + coords = (idxs.T * matrix + asarray(offset, dtype=float_dtype)).T else: raise ValueError(f"Expected matrix of shape ({n_axes + 1}, {n_axes + 1}) or ({n_axes},), instead got shape {matrix.shape}") + cval = torch.asarray(cval, dtype=input.dtype) + return _MockTensor(torch.vmap( lambda a: map_coordinates(a, coords, order=order, mode=mode, cval=cval) )(input.reshape(-1, *input.shape[-n_axes:])).reshape((*input.shape[:-n_axes], *output_shape))) @@ -352,7 +357,7 @@ def _map_coordinates_constant( ) -> torch.Tensor: from phaser.utils.num import to_real_dtype weight_dtype = to_torch_dtype(to_real_dtype(to_numpy_dtype(arr.dtype))) - cval = torch.tensor(cval) + cval = torch.tensor(cval, device=arr.device) is_valid = lambda idx, size: (0 <= idx) & (idx < size) # noqa: E731 clip = lambda idx, size: torch.clip(idx, 0, size - 1) # noqa: E731 @@ -403,6 +408,17 @@ def set_default_device(device: torch.device): raise TypeError(f"Invalid device '{device}' for backend torch") torch.set_default_device(device) + default_dtype = to_torch_dtype(max_supported_float(device)) + torch.set_default_dtype(default_dtype) + + +def max_supported_float( + device: t.Optional[torch.device] = None +) -> t.Union[t.Type[numpy.float32], t.Type[numpy.float64]]: + if device is None: + device = torch.get_default_device() + return numpy.float32 if device.type in ('mps', 'xpu') else numpy.float64 + def _wrap_call(f, *args: t.Any, **kwargs: t.Any) -> t.Any: try: diff --git a/phaser/utils/io.py b/phaser/utils/io.py index 2488099..6890807 100644 --- a/phaser/utils/io.py +++ b/phaser/utils/io.py @@ -199,9 +199,9 @@ def hdf5_write_state(state: t.Union[ReconsState, PartialReconsState], file: HdfL if state.object is not None: hdf5_write_object_state(state.object, file.create_group("object")) if state.scan is not None: - file.create_dataset('scan', data=to_numpy(state.scan.astype(numpy.float64))) + file.create_dataset('scan', data=to_numpy(state.scan).astype(numpy.float64)) if state.tilt is not None: - file.create_dataset('tilt', data=to_numpy(state.tilt.astype(numpy.float64))) + file.create_dataset('tilt', data=to_numpy(state.tilt).astype(numpy.float64)) if state.iter is not None: hdf5_write_iter_state(state.iter, file.create_group("iter")) if state.progress is not None: @@ -339,7 +339,7 @@ def _hdf5_read_string(group: h5py.Group, path: str) -> str: def _hdf5_write_nullable_dataset(group: h5py.Group, name: str, data: t.Optional[numpy.ndarray], dtype: t.Any): if data is not None: - group.create_dataset(name, data=to_numpy(data.astype(dtype))) + group.create_dataset(name, data=to_numpy(data).astype(dtype)) else: group.create_dataset(name, dtype=h5py.Empty(dtype)) diff --git a/phaser/utils/num.py b/phaser/utils/num.py index 6e13374..8ef0e7a 100644 --- a/phaser/utils/num.py +++ b/phaser/utils/num.py @@ -235,6 +235,16 @@ def set_default_device(device: Device, xp: t.Any): raise ValueError(f"Invalid device '{device}' for backend 'numpy'") +def max_supported_float( + xp: t.Any, + device: t.Optional[Device] = None, +) -> t.Union[t.Type[numpy.float32], t.Type[numpy.float64]]: + if xp_is_torch(xp): + from ._torch_kernels import max_supported_float + return max_supported_float(device) + return numpy.float64 + + def get_array_module(*arrs: t.Optional[ArrayLike]): if (xp := _BACKEND_LOADER.get('jax')) is not None: import jax.tree @@ -855,7 +865,7 @@ def recip_grid( xp2 = get_array_module(self.shape, self.extent, self.sampling) if xp is None else cast_array_module(xp) if dtype is None: - dtype = numpy.common_type(self.extent, self.sampling) + dtype = max_supported_float(xp2) ky: NDArray[numpy.number] = xp2.fft.fftfreq(self.shape[0], self.sampling[0]).astype(dtype) kx: NDArray[numpy.number] = xp2.fft.fftfreq(self.shape[1], self.sampling[1]).astype(dtype) diff --git a/phaser/utils/object.py b/phaser/utils/object.py index 2dc41b4..e9b1387 100644 --- a/phaser/utils/object.py +++ b/phaser/utils/object.py @@ -13,7 +13,7 @@ from typing_extensions import Self from .num import get_array_module, cast_array_module, is_torch, to_real_dtype, as_numpy, at -from .num import as_array, is_cupy, is_jax, NumT, ComplexT, DTypeT +from .num import as_array, is_cupy, is_jax, max_supported_float, NumT, ComplexT, DTypeT from .tree import tree_dataclass from .misc import create_rng @@ -292,12 +292,14 @@ def check_scan(self, scan_positions: NDArray[numpy.floating], pad: ArrayLike = 0 def _pos_to_object_idx(self, pos: ArrayLike, cutout_shape: t.Tuple[int, ...]) -> NDArray[numpy.float64]: """Return starting index for the cutout closest to centered around `pos` (`(y, x)`)""" xp = get_array_module(pos) + dtype = max_supported_float(xp) # for a given cutout, shift to the top left pixel of that cutout # e.g. a 2x2 cutout needs shifted by s/2 - shift = -xp.maximum(0., (xp.array(cutout_shape[-2:]) - 1.)) / 2. + shift = -xp.maximum(0., (xp.array(cutout_shape[-2:], dtype=dtype) - 1.)) / 2. - return ((pos - xp.array(self.corner.copy())) / xp.array(self.sampling.copy()) + shift).astype(numpy.float64) # type: ignore + return ((xp.asarray(pos) - xp.array(self.corner.copy(), dtype=dtype)) + / xp.array(self.sampling.copy(), dtype=dtype) + shift).astype(dtype) # type: ignore def slice_at_pos(self, pos: ArrayLike, cutout_shape: t.Tuple[int, ...]) -> t.Tuple[slice, slice]: """ @@ -328,9 +330,10 @@ def get_subpx_shifts(self, pos: ArrayLike, cutout_shape: t.Tuple[int, ...]) -> N Returns the shift from the rounded position towards the actual position, in length units. """ xp = get_array_module(pos) + dtype = max_supported_float(xp) pos = self._pos_to_object_idx(as_array(pos), cutout_shape) - return (pos - xp.round(pos)).astype(numpy.float64) * xp.asarray(self.sampling, copy=True) + return (pos - xp.round(pos)) * xp.asarray(self.sampling, dtype=dtype, copy=True) # type: ignore @t.overload def cutout( # pyright: ignore[reportOverlappingOverload] diff --git a/phaser/utils/scan.py b/phaser/utils/scan.py index 1ee5bbb..b559f1d 100644 --- a/phaser/utils/scan.py +++ b/phaser/utils/scan.py @@ -45,7 +45,7 @@ def make_raster_scan(shape: t.Tuple[int, int], scan_step: ArrayLike, yy = xp2.arange(shape[0], dtype=dtype) - xp2.asarray(shape[0] / 2., dtype=dtype) xx = xp2.arange(shape[1], dtype=dtype) - xp2.asarray(shape[1] / 2., dtype=dtype) pts = xp2.stack(xp2.meshgrid(yy, xx, indexing='ij'), axis=-1) - pts *= xp2.broadcast_to(xp2.asarray(scan_step), (2,)).astype(dtype) + pts *= xp2.broadcast_to(xp2.asarray(scan_step, dtype=dtype), (2,)) if affine is not None: affine = xp2.asarray(affine, dtype=dtype) From 3fbd851a4c8de1aa335908727c11e5061d422406 Mon Sep 17 00:00:00 2001 From: Colin Gilgenbach Date: Tue, 13 Jan 2026 07:21:05 -0800 Subject: [PATCH 2/9] Change dtype of nu in adam solver --- phaser/engines/gradient/solvers.py | 4 ++-- phaser/hooks/__init__.py | 2 +- phaser/hooks/preprocessing.py | 2 +- phaser/hooks/solver.py | 2 +- phaser/plan.py | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/phaser/engines/gradient/solvers.py b/phaser/engines/gradient/solvers.py index 4ddccdd..6a58ec1 100644 --- a/phaser/engines/gradient/solvers.py +++ b/phaser/engines/gradient/solvers.py @@ -26,7 +26,7 @@ import numpy from numpy.typing import NDArray -from phaser.utils.num import get_array_module +from phaser.utils.num import get_array_module, to_real_dtype import phaser.utils.tree as tree from phaser.hooks.solver import GradientSolver, GradientSolverArgs from phaser.hooks.schedule import ScheduleLike, Schedule @@ -229,7 +229,7 @@ def scale_by_adam( def init_fn(params: Params) -> ScaleByAdamState: xp = get_array_module(params) mu = tree.zeros_like(params, dtype=mu_dtype) # First moment - nu = tree.zeros_like(params) # Second moment + nu = tree.map(lambda x: xp.zeros_like(x, dtype=to_real_dtype(x.dtype)), params) # Second moment return ScaleByAdamState(n=xp.zeros((), dtype=xp.int32), mu=mu, nu=nu) def update_fn( diff --git a/phaser/hooks/__init__.py b/phaser/hooks/__init__.py index 82841f8..5d7b27d 100644 --- a/phaser/hooks/__init__.py +++ b/phaser/hooks/__init__.py @@ -13,7 +13,7 @@ if t.TYPE_CHECKING: from phaser.utils.num import Sampling from phaser.utils.object import ObjectSampling - from ..state import ObjectState, ProbeState, ReconsState, Patterns + from ..state import ObjectState, ProbeState, ReconsState, Patterns # noqa: F401 from ..execute import Observer diff --git a/phaser/hooks/preprocessing.py b/phaser/hooks/preprocessing.py index 5fc6dd3..14ba218 100644 --- a/phaser/hooks/preprocessing.py +++ b/phaser/hooks/preprocessing.py @@ -43,7 +43,7 @@ def offset_patterns(raw_data: RawData, props: OffsetProps) -> RawData: return raw_data def bin_patterns(raw_data: RawData, props: BinProps) -> RawData: - xp = get_array_module(raw_data['patterns']) + #xp = get_array_module(raw_data['patterns']) bin_factor = props.bin patterns = raw_data['patterns'] Ny, Nx = patterns.shape[-2:] diff --git a/phaser/hooks/solver.py b/phaser/hooks/solver.py index 87d1a05..dd89152 100644 --- a/phaser/hooks/solver.py +++ b/phaser/hooks/solver.py @@ -13,7 +13,7 @@ if t.TYPE_CHECKING: from phaser.engines.common.simulation import SimulationState from phaser.execute import Observer - from phaser.plan import ConventionalEnginePlan, GradientEnginePlan + from phaser.plan import ConventionalEnginePlan, GradientEnginePlan # noqa: F401 from phaser.state import ReconsState diff --git a/phaser/plan.py b/phaser/plan.py index 37e4a30..8690eae 100644 --- a/phaser/plan.py +++ b/phaser/plan.py @@ -153,7 +153,7 @@ class AdamSolverPlan(Dataclass, kw_only=True): b1: float = 0.9 b2: float = 0.999 eps: float = 1.0e-8 - eps_root: float = 1.0e-8 # needed on mps w/ adam + eps_root: float = 0.0 nesterov: bool = False From 83a5982559e88f6aa5fc3e8ca5886fa99db09e3e Mon Sep 17 00:00:00 2001 From: Colin Gilgenbach Date: Wed, 14 Jan 2026 15:55:38 -0800 Subject: [PATCH 3/9] Fix underflow NaNs on torch[mps] --- phaser/engines/gradient/solvers.py | 11 ++++++++++- phaser/utils/_torch_kernels.py | 20 ++++++++++++++++++++ phaser/utils/tree.py | 17 +++++++++-------- 3 files changed, 39 insertions(+), 9 deletions(-) diff --git a/phaser/engines/gradient/solvers.py b/phaser/engines/gradient/solvers.py index 6a58ec1..aa0059d 100644 --- a/phaser/engines/gradient/solvers.py +++ b/phaser/engines/gradient/solvers.py @@ -26,7 +26,7 @@ import numpy from numpy.typing import NDArray -from phaser.utils.num import get_array_module, to_real_dtype +from phaser.utils.num import get_array_module, to_real_dtype, xp_is_torch import phaser.utils.tree as tree from phaser.hooks.solver import GradientSolver, GradientSolverArgs from phaser.hooks.schedule import ScheduleLike, Schedule @@ -241,6 +241,15 @@ def update_fn( nu = tree.update_moment_per_elem_norm(updates, state.nu, b2, 2) n_inc = safe_increment(state.n) + # HACK: on mps we need to prevent small mu values from returning nan + if xp_is_torch(xp) and any( + leaf.device.type == 'mps' for leaf in tree.leaves(updates) + ): + mu = tree.map( + lambda arr: xp.nan_to_num(arr, nan=0.), + mu, is_leaf=lambda x: x is None + ) + if nesterov: mu_hat = tree.map( lambda m, g: b1 * m + (1 - b1) * g, diff --git a/phaser/utils/_torch_kernels.py b/phaser/utils/_torch_kernels.py index a10895c..a353ebd 100644 --- a/phaser/utils/_torch_kernels.py +++ b/phaser/utils/_torch_kernels.py @@ -114,6 +114,16 @@ def _mirror(idx: torch.Tensor, size: int) -> torch.Tensor: 'wrap': 'circular', } + +def nan_to_num(arr: torch.Tensor, **kwargs: t.Any) -> torch.Tensor: + if torch.is_complex(arr): + return torch.view_as_complex( + torch.nan_to_num(torch.view_as_real(arr), **kwargs) + ) + + return torch.nan_to_num(arr, **kwargs) + + def min( arr: torch.Tensor, axis: t.Union[int, t.Tuple[int, ...], None] = None, *, keepdims: bool = False @@ -172,6 +182,14 @@ def maximum( return torch.maximum(x1, x2) +def cumsum( + arr: torch.Tensor, axis: t.Optional[int] = None, +) -> torch.Tensor: + if axis is None: + return torch.cumsum(arr.ravel(), 0) + return torch.cumsum(arr, axis) + + def split( arr: torch.Tensor, sections: int, *, axis: int = 0 ) -> t.Tuple[torch.Tensor, ...]: @@ -452,9 +470,11 @@ def _wrap_call(f, *args: t.Any, **kwargs: t.Any) -> t.Any: 'torch.mod': functools.update_wrapper(lambda *args, **kwargs: _MockTensor(_wrap_call(torch.remainder, *args, **kwargs)), torch.remainder), # type: ignore 'torch.split': split, 'torch.pad': pad, + 'torch.nan_to_num': nan_to_num, 'torch.min': min, 'torch.max': max, 'torch.nanmin': nanmin, 'torch.nanmax': nanmax, 'torch.minimum': minimum, 'torch.maximum': maximum, + 'torch.cumsum': cumsum, 'torch.unwrap': unwrap, 'torch.indices': indices, 'torch.size': size, diff --git a/phaser/utils/tree.py b/phaser/utils/tree.py index b29fd07..044209a 100644 --- a/phaser/utils/tree.py +++ b/phaser/utils/tree.py @@ -306,14 +306,15 @@ def conj( def update_moment(updates: Tree, moments: Tree, decay: float, order: int) -> Tree: - return map( - lambda g, t: ( - (1 - decay) * (g**order) + decay * t if g is not None else None - ), - updates, - moments, - is_leaf=lambda x: x is None, - ) + return map( + lambda g, t: ( + (1 - decay) * (g**order) + decay * t + if g is not None else None + ), + updates, + moments, + is_leaf=lambda x: x is None, + ) def update_moment_per_elem_norm(updates: Tree, moments: Tree, decay: float, order: int) -> Tree: From 3320962b7539a5229ba52a66efc5fb324ccb1ff3 Mon Sep 17 00:00:00 2001 From: Colin Gilgenbach Date: Wed, 14 Jan 2026 15:56:30 -0800 Subject: [PATCH 4/9] Allow det_rotation to be None --- phaser/hooks/io/empad.py | 4 ++-- phaser/io/empad.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/phaser/hooks/io/empad.py b/phaser/hooks/io/empad.py index 488eafc..a7b6823 100644 --- a/phaser/hooks/io/empad.py +++ b/phaser/hooks/io/empad.py @@ -46,7 +46,7 @@ def load_empad(args: None, props: LoadEmpadProps) -> RawData: 'shape': scan_shape, 'step_size': tuple(s*1e10 for s in reversed(meta.scan_step)), # m to A 'affine': meta.scan_correction[::-1, ::-1] if meta.scan_correction is not None else None, - 'rotation': meta.det_rotation - meta.scan_rotation, + 'rotation': (meta.det_rotation or 0.0) - meta.scan_rotation, } #TODO: add tilt to metafile @@ -95,4 +95,4 @@ def load_empad(args: None, props: LoadEmpadProps) -> RawData: 'scan_hook': scan_hook, 'tilt_hook': tilt_hook, 'seed': None, - } \ No newline at end of file + } diff --git a/phaser/io/empad.py b/phaser/io/empad.py index 2be86a0..d646ae8 100644 --- a/phaser/io/empad.py +++ b/phaser/io/empad.py @@ -58,7 +58,7 @@ def __post_init__(self): Flips to apply to the raw diffraction patterns, (flip_y, flip_x, transpose). Defaults to `(True, False, False)` (appears to be the most common orientation). """ - det_rotation: float = 0.0 + det_rotation: t.Optional[float] = None """Detector rotation (degrees).""" orig_path: t.Optional[Path] = None From dcd527946e157cd6ee810f1ba6af213c72fd6656 Mon Sep 17 00:00:00 2001 From: Colin Gilgenbach Date: Wed, 14 Jan 2026 16:27:42 -0800 Subject: [PATCH 5/9] Fix plan validation w/ complex numbers --- phaser/main.py | 30 ++++++++++++++++++++++-------- phaser/types.py | 17 +++++++++++++++++ phaser/utils/optics.py | 28 ++++++---------------------- tests/test_optics.py | 14 +++++++------- 4 files changed, 52 insertions(+), 37 deletions(-) diff --git a/phaser/main.py b/phaser/main.py index 03125f4..91928e2 100644 --- a/phaser/main.py +++ b/phaser/main.py @@ -65,19 +65,33 @@ def validate(path: t.Union[str, Path], json: bool = False): sys.exit(1) + if json: + from json import dump, dumps + + def _serialize_complex(val: t.Any) -> t.Any: + if isinstance(val, complex): + return {'re': val.real, 'im': val.imag} + raise TypeError() + + try: + s = dumps({ + 'result': 'success', + 'plans': [(plan.name, plan.into_data()) for plan in plans], + }, default=_serialize_complex) + except Exception as e: + print(f"Failed to serialize validated plans: {e}", file=sys.stderr) + dump({'result': 'error', 'error': str(e)}, sys.stdout) + print() + sys.exit(2) + + sys.stdout.write(s) + print() + if len(plans) == 1: print("Validation of plan successful!", file=sys.stderr) else: print(f"Validation of {len(plans)} plans successful!", file=sys.stderr) - if json: - from json import dump - dump({ - 'result': 'success', - 'plans': [(plan.name, plan.into_data()) for plan in plans], - }, sys.stdout) - print() - @cli.command('worker') @click.argument('url', type=str, required=True) diff --git a/phaser/types.py b/phaser/types.py index 3f1006f..b27a903 100644 --- a/phaser/types.py +++ b/phaser/types.py @@ -302,6 +302,23 @@ def collect_errors(self, val: t.Any) -> t.Optional[ErrorNode]: return self.inner.collect_errors(val) +class ComplexCartesian(pane.PaneBase, kw_only=True): + re: float + im: float = 0.0 + + def __complex__(self) -> complex: + return complex(self.re, self.im) + + +class ComplexPolar(pane.PaneBase, kw_only=True): + mag: float + angle: float = 0.0 # degrees + + def __complex__(self) -> complex: + theta = numpy.deg2rad(self.angle) + return self.mag * complex(numpy.cos(theta), numpy.sin(theta)) + + __all__ = [ 'BackendName', 'Dataclass', 'Slices', 'Flag', 'process_flag', 'flag_any_true', diff --git a/phaser/utils/optics.py b/phaser/utils/optics.py index 2cb8e05..ba6fdb5 100644 --- a/phaser/utils/optics.py +++ b/phaser/utils/optics.py @@ -11,6 +11,7 @@ from pane.annotations import Condition from pane.util import pluralize +from phaser.types import ComplexCartesian, ComplexPolar from .num import get_array_module, ifft2, abs2, NumT, ufunc_outer, is_jax, cast_array_module from .num import Float, Sampling, to_complex_dtype, to_real_dtype, split_array, to_numpy @@ -57,38 +58,21 @@ def from_known(s: str) -> 'Krivanek': )] -class Cartesian(pane.PaneBase, kw_only=True): - a: float - b: float = 0.0 - - def __complex__(self) -> complex: - return complex(self.a, self.b) - - -class Polar(pane.PaneBase, kw_only=True): - mag: float - angle: float = 0.0 # degrees - - def __complex__(self) -> complex: - theta = numpy.deg2rad(self.angle) - return self.mag * complex(numpy.cos(theta), numpy.sin(theta)) - - class KrivanekComplex(Krivanek, kw_only=True): - val: complex + val: t.Union[complex, ComplexCartesian, ComplexPolar] def __complex__(self) -> complex: - return self.val + return complex(self.val) -class KrivanekCartesian(Krivanek, Cartesian, kw_only=True): +class KrivanekCartesian(Krivanek, ComplexCartesian, kw_only=True): ... -class KrivanekPolar(Krivanek, Polar, kw_only=True): +class KrivanekPolar(Krivanek, ComplexPolar, kw_only=True): ... Aberration: t.TypeAlias = t.Union[ - t.Dict[KnownAberration, t.Union[complex, Cartesian, Polar]], + t.Dict[KnownAberration, t.Union[complex, ComplexCartesian, ComplexPolar]], KrivanekComplex, KrivanekCartesian, KrivanekPolar, ] AberrationList: t.TypeAlias = t.List[Aberration] diff --git a/tests/test_optics.py b/tests/test_optics.py index 8866bc3..5e7affc 100644 --- a/tests/test_optics.py +++ b/tests/test_optics.py @@ -10,7 +10,7 @@ from phaser.utils.optics import ( make_focused_probe, fresnel_propagator, AberrationList, _normalize_aberrations, - Cartesian, Polar, KrivanekComplex, KrivanekCartesian, KrivanekPolar, + ComplexCartesian, ComplexPolar, KrivanekComplex, KrivanekCartesian, KrivanekPolar, ) @@ -69,7 +69,7 @@ def test_spherical_probe(backend: BackendName) -> numpy.ndarray: aberrations=[ {'c3': 1.0e+7}, {'a1': 20.0+20.0j}, - KrivanekCartesian(3, 2, a=1.5e6, b=2.0e6), + KrivanekCartesian(3, 2, re=1.5e6, im=2.0e6), ] ) return to_numpy(probe) @@ -108,19 +108,19 @@ def test_parse_aberrations(): import pane result = pane.convert([ {'c3': 5.0}, # haider complex - {'b2': {'a': 5.0, 'b': -2.0}}, # haider cartesian + {'b2': {'re': 5.0, 'im': -2.0}}, # haider cartesian {'a1': {'mag': 5.0, 'angle': 90.0}}, # haider polar {'n': 4, 'm': 1, 'val': 1+1.j}, # krivanek complex - {'n': 1, 'm': 0, 'a': 5.0}, # krivanek cartesian + {'n': 1, 'm': 0, 're': 5.0}, # krivanek cartesian {'n': 5, 'm': 0, 'mag': 5.0}, # krivanek polar ], AberrationList) assert result == [ {'c3': complex(5.0)}, - {'b2': Cartesian(a=5.0, b=-2.0)}, - {'a1': Polar(mag=5.0, angle=90.0)}, + {'b2': ComplexCartesian(re=5.0, im=-2.0)}, + {'a1': ComplexPolar(mag=5.0, angle=90.0)}, KrivanekComplex(4, 1, val=1+1.j), - KrivanekCartesian(1, 0, a=5.0, b=0.0), + KrivanekCartesian(1, 0, re=5.0, im=0.0), KrivanekPolar(5, 0, mag=5.0, angle=0.0), ] From 83adf39bded2d38b003cb70743363a50c5e3e4cd Mon Sep 17 00:00:00 2001 From: Colin Gilgenbach Date: Fri, 16 Jan 2026 14:54:42 -0800 Subject: [PATCH 6/9] layers regularization support on pytorch --- phaser/engines/common/regularizers.py | 16 ++---- phaser/utils/_torch_kernels.py | 41 +++++++++++++++- phaser/utils/image.py | 71 +++++++++++++++++++++++++++ tests/test_image.py | 39 ++++++++++++++- 4 files changed, 152 insertions(+), 15 deletions(-) diff --git a/phaser/engines/common/regularizers.py b/phaser/engines/common/regularizers.py index aa4b689..416186e 100644 --- a/phaser/engines/common/regularizers.py +++ b/phaser/engines/common/regularizers.py @@ -9,6 +9,7 @@ get_array_module, get_scipy_module, Float, unstack, jit, fft2, ifft2, abs2, xp_is_jax, to_real_dtype, to_numpy ) +from phaser.utils.image import convolve1d from phaser.state import ReconsState from phaser.hooks.regularization import ( ClampObjectAmplitudeProps, LimitProbeSupportProps, NonNegObjectPhaseProps, @@ -142,7 +143,6 @@ def init_state(self, sim: ReconsState) -> None: def apply_iter(self, sim: ReconsState, state: None) -> t.Tuple[ReconsState, None]: xp = get_array_module(sim.object.data) - scipy = get_scipy_module(sim.object.data) dtype = to_real_dtype(sim.object.data) if len(sim.object.thicknesses) < 2: @@ -161,17 +161,9 @@ def apply_iter(self, sim: ReconsState, state: None) -> t.Tuple[ReconsState, None # we convolve the log of object, because the transmission # function is multiplicative, not additive - - if xp_is_jax(xp): - new_obj = xp.exp(scipy.signal.convolve( - xp.pad(xp.log(sim.object.data), ((r, r), (0, 0), (0, 0)), mode='edge'), - kernel[:, None, None], - mode="valid" - )) - else: - new_obj = xp.exp(scipy.ndimage.convolve1d(xp.log( - sim.object.data - ), kernel, axis=0, mode='nearest')) + new_obj = xp.exp(convolve1d(xp.log( + sim.object.data + ), kernel, axis=0, mode='nearest')) assert new_obj.shape == sim.object.data.shape assert new_obj.dtype == sim.object.data.dtype diff --git a/phaser/utils/_torch_kernels.py b/phaser/utils/_torch_kernels.py index a353ebd..0c15c9a 100644 --- a/phaser/utils/_torch_kernels.py +++ b/phaser/utils/_torch_kernels.py @@ -6,6 +6,7 @@ import numpy from numpy.typing import ArrayLike import torch +import torch.nn.functional as F from phaser.utils.num import _PadMode from phaser.utils.image import _InterpBoundaryMode @@ -218,7 +219,7 @@ def pad( pad = tuple(itertools.chain.from_iterable(t.cast(t.Sequence[t.Tuple[int, int]], reversed(pad)))) kwargs = {'value': cval} if mode == 'constant' else {} - return _MockTensor(torch.nn.functional.pad(arr, pad, mode=_PAD_MODE_MAP[mode], **kwargs)) + return _MockTensor(F.pad(arr, pad, mode=_PAD_MODE_MAP[mode], **kwargs)) def unwrap(arr: torch.Tensor, discont: t.Optional[float] = None, axis: int = -1, *, @@ -405,6 +406,44 @@ def _map_coordinates_constant( return result.type(arr.dtype) +_INTERP_TO_TORCH_PAD: t.Dict[_InterpBoundaryMode, str] = { + 'nearest': 'replicate', + 'wrap': 'circular', + 'grid-wrap': 'circular', + 'constant': 'constant', + 'grid-constant': 'constant', + 'mirror': 'reflect', +} + + +def _convolve1d( + arr: torch.Tensor, weights: torch.Tensor, axis: int, *, + mode: _InterpBoundaryMode, cval: float = 0. +) -> torch.Tensor: + pad_mode = _INTERP_TO_TORCH_PAD.get(mode) + if pad_mode is None: + raise ValueError(f"Pad mode '{mode}' not implemented for torch backend") + + # reorder to last axis + reorder = axis != arr.ndim - 1 + if reorder: + arr = torch.moveaxis(arr, axis, -1) + leading_shape = arr.shape[:-1] + arr = arr.reshape((-1, arr.shape[-1])) + r = len(weights) // 2 + + # torch's conv1d is actually a correlation + weights = weights.flip(0) + + # TODO: this will fail for some pads where weights is large, investigate further + arr = F.pad(arr, (len(weights) - r - 1, r), mode=pad_mode, value=cval) + arr = F.conv1d( + arr[:, None, :], weights[None, None, :] + )[:, 0].reshape((*leading_shape, -1)) + + return torch.moveaxis(arr, -1, axis) if reorder else arr + + def get_devices() -> t.Tuple[torch.device, ...]: devices = [] devices.extend(f'cuda:{i}' for i in range(torch.cuda.device_count())) diff --git a/phaser/utils/image.py b/phaser/utils/image.py index 522a67e..5c06fd4 100644 --- a/phaser/utils/image.py +++ b/phaser/utils/image.py @@ -278,10 +278,81 @@ def square_pixel_transfer(shape: t.Tuple[int, int], *, xp: t.Any = None) -> NDAr return xp.sinc(ky) * xp.sinc(kx) +# convert scipy boundary mode to numpy.pad mode +_INTERP_TO_PAD: t.Dict[_InterpBoundaryMode, str] = { + 'reflect': 'symmetric', + 'mirror': 'reflect', + 'nearest': 'edge', + 'grid-mirror': 'reflect', + 'grid-wrap': 'wrap', + 'grid-constant': 'constant', +} + + +def _canonicalize_axis(axis: int, num_dims: int) -> int: + """Canonicalize an axis in [-num_dims, num_dims) to [0, num_dims).""" + axis = axis.__index__() + if not -num_dims <= axis < num_dims: + raise ValueError(f"axis {axis} is out of bounds for array of dimension {num_dims}") + if axis < 0: + axis = axis + num_dims + return axis + + +def convolve1d( + arr: NDArray[NumT], weights: ArrayLike, axis: int = -1, *, + mode: _InterpBoundaryMode = 'reflect', cval: float = 0. +) -> NDArray[NumT]: + xp = get_array_module(arr, weights) + + arr = xp.asarray(arr) + weights = xp.asarray(weights) + if weights.ndim != 1: + raise ValueError("Expected 'weights' to be 1D") + axis = _canonicalize_axis(axis, arr.ndim) + + if xp_is_torch(xp): + from ._torch_kernels import _convolve1d, _MockTensor + + return t.cast(NDArray[NumT], _convolve1d( + t.cast(_MockTensor, arr), + t.cast(_MockTensor, weights), + axis=axis, mode=mode, cval=cval + )) + + scipy = get_scipy_module(arr, weights) + + if xp_is_jax(xp): + r = len(weights) // 2 + pad_mode = _INTERP_TO_PAD.get(mode, mode) + pad_kwargs = {'constant_values': cval} if pad_mode == 'constant' else {} + + pad = tuple( + (len(weights) - r - 1, r) if i == axis else (0, 0) + for i in range(arr.ndim) + ) + weights = weights[tuple( + slice(None) if i == axis else None + for i in range(arr.ndim) + )] + # TODO: use jax.lax.conv_general_dilated directly + return scipy.signal.convolve( + xp.pad(arr, pad, mode=pad_mode, **pad_kwargs), # type: ignore + weights, mode='valid', method='direct' + ).astype(arr.dtype) + + return scipy.ndimage.convolve1d( + arr, weights, axis, mode=mode, cval=cval + ) + + + + __all__ = [ 'apply_flips', 'remove_linear_ramp', 'colorize_complex', 'scale_to_integral_type', 'affine_transform', 'to_affine_matrix', + 'convolve1d', 'scale_matrix', 'rotation_matrix', 'translation_matrix', 'gaussian_transfer', 'square_pixel_transfer', ] \ No newline at end of file diff --git a/tests/test_image.py b/tests/test_image.py index a65b6be..d3e5199 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -3,13 +3,14 @@ import numpy from numpy.typing import ArrayLike, NDArray from numpy.testing import assert_array_almost_equal +import scipy.ndimage as osp import pytest from .utils import with_backends, check_array_equals_file from phaser.utils.num import get_backend_module, BackendName, to_numpy, Sampling from phaser.utils.image import ( - affine_transform, _InterpBoundaryMode + affine_transform, _InterpBoundaryMode, convolve1d ) @@ -74,4 +75,38 @@ def test_resample( (checker, old_samp) = checkerboard - return to_numpy(old_samp.resample(xp.array(checker), sampling, rotation=rotation, order=order)) \ No newline at end of file + return to_numpy(old_samp.resample(xp.array(checker), sampling, rotation=rotation, order=order)) + +@with_backends('numpy', 'jax', 'cupy', 'torch') +@pytest.mark.parametrize(('arr', 'weights', 'axis'), [ + ([1, 2, 3, 4, 5], [1, 2], 0), + ([1.0, 2.0, 3.0, 4.0, 5.0], [1.0, 2.0], 0), + ([[[1, 2], [3, 4]], [[2, 3], [4, 5]], [[3, 4], [5, 6]]], [1, 2, 3], 0), + ([[[1, 2], [3, 4]], [[2, 3], [4, 5]], [[3, 4], [5, 6]]], [1, 2, 3], -1), + ([1+1.j, 2+2.j, 3+3.j], [1-1.j, 2-1.j], 0), +]) +@pytest.mark.parametrize(('mode', 'cval'), [ + ('constant', 1.0), ('nearest', 0.0), ('mirror', 0.0), + ('reflect', 0.0), ('wrap', 0.0), +]) +def test_convolve1d( + arr, weights, axis, mode, cval, + backend: BackendName, +): + if mode == 'reflect' and backend == 'torch': + pytest.xfail("'reflect' not supported on torch") + + arr = numpy.asarray(arr) + weights = numpy.asarray(weights) + + xp = get_backend_module(backend) + + expected = osp.convolve1d( + arr, weights, axis=axis, mode=mode, cval=cval + ) + actual = to_numpy(convolve1d( + xp.array(arr), xp.array(weights), axis=axis, mode=mode, cval=cval + )) + assert actual.dtype == expected.dtype + + assert_array_almost_equal(actual, expected, decimal=6) \ No newline at end of file From fde3ec58e4ca9fd37025474b2d8035d28b00f7f0 Mon Sep 17 00:00:00 2001 From: Colin Gilgenbach Date: Fri, 16 Jan 2026 16:28:20 -0800 Subject: [PATCH 7/9] Fix warnings w/ torch --- phaser/main.py | 8 +++++++- phaser/utils/_torch_kernels.py | 2 +- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/phaser/main.py b/phaser/main.py index 91928e2..e70f087 100644 --- a/phaser/main.py +++ b/phaser/main.py @@ -11,9 +11,15 @@ def cli(): @cli.command('run') @click.argument('path', type=click.Path(exists=True, dir_okay=False)) -def run(path: t.Union[str, Path]): +@click.option('--raise-on-warn/--no-raise-on-warn') +def run(path: t.Union[str, Path], *, raise_on_warn: bool = False): from .plan import ReconsPlan from .execute import execute_plan + + if raise_on_warn: + import warnings + warnings.simplefilter('error') + plans = ReconsPlan.from_yaml_all(path) for plan in plans: diff --git a/phaser/utils/_torch_kernels.py b/phaser/utils/_torch_kernels.py index 0c15c9a..2ae09c3 100644 --- a/phaser/utils/_torch_kernels.py +++ b/phaser/utils/_torch_kernels.py @@ -376,7 +376,7 @@ def _map_coordinates_constant( ) -> torch.Tensor: from phaser.utils.num import to_real_dtype weight_dtype = to_torch_dtype(to_real_dtype(to_numpy_dtype(arr.dtype))) - cval = torch.tensor(cval, device=arr.device) + cval = torch.asarray(cval, device=arr.device) is_valid = lambda idx, size: (0 <= idx) & (idx < size) # noqa: E731 clip = lambda idx, size: torch.clip(idx, 0, size - 1) # noqa: E731 From 80a2666f01ac5c8b600c5dde7e1b53c9487e2f47 Mon Sep 17 00:00:00 2001 From: Colin Gilgenbach Date: Fri, 16 Jan 2026 16:33:39 -0800 Subject: [PATCH 8/9] Fix remove_linear_ramp on torch[mps] --- phaser/utils/image.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/phaser/utils/image.py b/phaser/utils/image.py index 5c06fd4..b63d25a 100644 --- a/phaser/utils/image.py +++ b/phaser/utils/image.py @@ -9,7 +9,7 @@ from numpy.typing import ArrayLike, NDArray from .num import ( - Sampling, cast_array_module, get_array_module, get_scipy_module, to_numpy, at, abs2, + Sampling, cast_array_module, get_array_module, get_scipy_module, max_supported_float, to_numpy, at, abs2, xp_is_jax, xp_is_torch, Float ) @@ -63,13 +63,13 @@ def remove_linear_ramp( """ Removes a linear 'ramp' from an image or stack of images. """ - xp = get_array_module(data) + float_dtype = max_supported_float(xp) output = xp.empty_like(data) data = xp.array(data) - (yy, xx) = (arr.flatten() for arr in xp.indices(data.shape[-2:], dtype=float)) + (yy, xx) = (arr.flatten() for arr in xp.indices(data.shape[-2:], dtype=float_dtype)) pts = xp.stack((xp.ones_like(xx), xx, yy), axis=-1) if mask is None: @@ -78,7 +78,7 @@ def remove_linear_ramp( mask = mask.flatten() for idx in numpy.ndindex(data.shape[:-2]): - layer = data[tuple(idx)].astype(numpy.float64) + layer = data[tuple(idx)].astype(float_dtype) p, residues, rank, singular = xp.linalg.lstsq(pts[mask], layer.flatten()[mask], rcond=None) output = at(output, idx).set((layer - (p @ pts.T).reshape(layer.shape)).astype(output.dtype)) From ce529d1a599add035b36d48c39cc4a228ca500fa Mon Sep 17 00:00:00 2001 From: Colin Gilgenbach Date: Sat, 17 Jan 2026 15:29:42 -0500 Subject: [PATCH 9/9] Fix tests on pytorch --- phaser/utils/_torch_kernels.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/phaser/utils/_torch_kernels.py b/phaser/utils/_torch_kernels.py index 2ae09c3..ba21d3f 100644 --- a/phaser/utils/_torch_kernels.py +++ b/phaser/utils/_torch_kernels.py @@ -296,7 +296,7 @@ def affine_transform( order: int = 1, mode: _InterpBoundaryMode = 'grid-constant', cval: ArrayLike = 0.0, ) -> torch.Tensor: - float_dtype = torch.get_default_dtype() + float_dtype = max_supported_float(input.device) if output_shape is None: output_shape = input.shape