Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 4 additions & 12 deletions phaser/engines/common/regularizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
get_array_module, get_scipy_module, Float, unstack,
jit, fft2, ifft2, abs2, xp_is_jax, to_real_dtype, to_numpy
)
from phaser.utils.image import convolve1d
from phaser.state import ReconsState
from phaser.hooks.regularization import (
ClampObjectAmplitudeProps, LimitProbeSupportProps, NonNegObjectPhaseProps,
Expand Down Expand Up @@ -142,7 +143,6 @@ def init_state(self, sim: ReconsState) -> None:

def apply_iter(self, sim: ReconsState, state: None) -> t.Tuple[ReconsState, None]:
xp = get_array_module(sim.object.data)
scipy = get_scipy_module(sim.object.data)
dtype = to_real_dtype(sim.object.data)

if len(sim.object.thicknesses) < 2:
Expand All @@ -161,17 +161,9 @@ def apply_iter(self, sim: ReconsState, state: None) -> t.Tuple[ReconsState, None

# we convolve the log of object, because the transmission
# function is multiplicative, not additive

if xp_is_jax(xp):
new_obj = xp.exp(scipy.signal.convolve(
xp.pad(xp.log(sim.object.data), ((r, r), (0, 0), (0, 0)), mode='edge'),
kernel[:, None, None],
mode="valid"
))
else:
new_obj = xp.exp(scipy.ndimage.convolve1d(xp.log(
sim.object.data
), kernel, axis=0, mode='nearest'))
new_obj = xp.exp(convolve1d(xp.log(
sim.object.data
), kernel, axis=0, mode='nearest'))

assert new_obj.shape == sim.object.data.shape
assert new_obj.dtype == sim.object.data.dtype
Expand Down
13 changes: 11 additions & 2 deletions phaser/engines/gradient/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import numpy
from numpy.typing import NDArray

from phaser.utils.num import get_array_module
from phaser.utils.num import get_array_module, to_real_dtype, xp_is_torch
import phaser.utils.tree as tree
from phaser.hooks.solver import GradientSolver, GradientSolverArgs
from phaser.hooks.schedule import ScheduleLike, Schedule
Expand Down Expand Up @@ -229,7 +229,7 @@ def scale_by_adam(
def init_fn(params: Params) -> ScaleByAdamState:
xp = get_array_module(params)
mu = tree.zeros_like(params, dtype=mu_dtype) # First moment
nu = tree.zeros_like(params) # Second moment
nu = tree.map(lambda x: xp.zeros_like(x, dtype=to_real_dtype(x.dtype)), params) # Second moment
return ScaleByAdamState(n=xp.zeros((), dtype=xp.int32), mu=mu, nu=nu)

def update_fn(
Expand All @@ -241,6 +241,15 @@ def update_fn(
nu = tree.update_moment_per_elem_norm(updates, state.nu, b2, 2)
n_inc = safe_increment(state.n)

# HACK: on mps we need to prevent small mu values from returning nan
if xp_is_torch(xp) and any(
leaf.device.type == 'mps' for leaf in tree.leaves(updates)
):
mu = tree.map(
lambda arr: xp.nan_to_num(arr, nan=0.),
mu, is_leaf=lambda x: x is None
)

if nesterov:
mu_hat = tree.map(
lambda m, g: b1 * m + (1 - b1) * g,
Expand Down
2 changes: 1 addition & 1 deletion phaser/hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
if t.TYPE_CHECKING:
from phaser.utils.num import Sampling
from phaser.utils.object import ObjectSampling
from ..state import ObjectState, ProbeState, ReconsState, Patterns
from ..state import ObjectState, ProbeState, ReconsState, Patterns # noqa: F401
from ..execute import Observer


Expand Down
4 changes: 2 additions & 2 deletions phaser/hooks/io/empad.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def load_empad(args: None, props: LoadEmpadProps) -> RawData:
'shape': scan_shape,
'step_size': tuple(s*1e10 for s in reversed(meta.scan_step)), # m to A
'affine': meta.scan_correction[::-1, ::-1] if meta.scan_correction is not None else None,
'rotation': meta.det_rotation - meta.scan_rotation,
'rotation': (meta.det_rotation or 0.0) - meta.scan_rotation,
}

#TODO: add tilt to metafile
Expand Down Expand Up @@ -95,4 +95,4 @@ def load_empad(args: None, props: LoadEmpadProps) -> RawData:
'scan_hook': scan_hook,
'tilt_hook': tilt_hook,
'seed': None,
}
}
2 changes: 1 addition & 1 deletion phaser/hooks/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def offset_patterns(raw_data: RawData, props: OffsetProps) -> RawData:
return raw_data

def bin_patterns(raw_data: RawData, props: BinProps) -> RawData:
xp = get_array_module(raw_data['patterns'])
#xp = get_array_module(raw_data['patterns'])
bin_factor = props.bin
patterns = raw_data['patterns']
Ny, Nx = patterns.shape[-2:]
Expand Down
2 changes: 1 addition & 1 deletion phaser/hooks/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
if t.TYPE_CHECKING:
from phaser.engines.common.simulation import SimulationState
from phaser.execute import Observer
from phaser.plan import ConventionalEnginePlan, GradientEnginePlan
from phaser.plan import ConventionalEnginePlan, GradientEnginePlan # noqa: F401
from phaser.state import ReconsState


Expand Down
2 changes: 1 addition & 1 deletion phaser/io/empad.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __post_init__(self):
Flips to apply to the raw diffraction patterns, (flip_y, flip_x, transpose).
Defaults to `(True, False, False)` (appears to be the most common orientation).
"""
det_rotation: float = 0.0
det_rotation: t.Optional[float] = None
"""Detector rotation (degrees)."""

orig_path: t.Optional[Path] = None
Expand Down
38 changes: 29 additions & 9 deletions phaser/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,15 @@ def cli():

@cli.command('run')
@click.argument('path', type=click.Path(exists=True, dir_okay=False))
def run(path: t.Union[str, Path]):
@click.option('--raise-on-warn/--no-raise-on-warn')
def run(path: t.Union[str, Path], *, raise_on_warn: bool = False):
from .plan import ReconsPlan
from .execute import execute_plan

if raise_on_warn:
import warnings
warnings.simplefilter('error')

plans = ReconsPlan.from_yaml_all(path)

for plan in plans:
Expand Down Expand Up @@ -65,19 +71,33 @@ def validate(path: t.Union[str, Path], json: bool = False):

sys.exit(1)

if json:
from json import dump, dumps

def _serialize_complex(val: t.Any) -> t.Any:
if isinstance(val, complex):
return {'re': val.real, 'im': val.imag}
raise TypeError()

try:
s = dumps({
'result': 'success',
'plans': [(plan.name, plan.into_data()) for plan in plans],
}, default=_serialize_complex)
except Exception as e:
print(f"Failed to serialize validated plans: {e}", file=sys.stderr)
dump({'result': 'error', 'error': str(e)}, sys.stdout)
print()
sys.exit(2)

sys.stdout.write(s)
print()

if len(plans) == 1:
print("Validation of plan successful!", file=sys.stderr)
else:
print(f"Validation of {len(plans)} plans successful!", file=sys.stderr)

if json:
from json import dump
dump({
'result': 'success',
'plans': [(plan.name, plan.into_data()) for plan in plans],
}, sys.stdout)
print()


@cli.command('worker')
@click.argument('url', type=str, required=True)
Expand Down
17 changes: 17 additions & 0 deletions phaser/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,23 @@ def collect_errors(self, val: t.Any) -> t.Optional[ErrorNode]:
return self.inner.collect_errors(val)


class ComplexCartesian(pane.PaneBase, kw_only=True):
re: float
im: float = 0.0

def __complex__(self) -> complex:
return complex(self.re, self.im)


class ComplexPolar(pane.PaneBase, kw_only=True):
mag: float
angle: float = 0.0 # degrees

def __complex__(self) -> complex:
theta = numpy.deg2rad(self.angle)
return self.mag * complex(numpy.cos(theta), numpy.sin(theta))


__all__ = [
'BackendName', 'Dataclass', 'Slices', 'Flag',
'process_flag', 'flag_any_true',
Expand Down
Loading