From e585389b1ce15d374dcaf4b54c7bbba7ae37c81d Mon Sep 17 00:00:00 2001 From: Colin Gilgenbach Date: Sat, 17 Jan 2026 14:34:02 -0500 Subject: [PATCH 1/3] Move dependencies around, trying to speed import of ReconsPlan --- phaser/hooks/__init__.py | 3 +- phaser/hooks/regularization.py | 4 +- phaser/hooks/solver.py | 9 +-- phaser/types.py | 108 +++++++++++++++++++++++++++------ phaser/utils/optics.py | 81 +------------------------ tests/test_optics.py | 12 ++-- 6 files changed, 106 insertions(+), 111 deletions(-) diff --git a/phaser/hooks/__init__.py b/phaser/hooks/__init__.py index 5d7b27d..f861270 100644 --- a/phaser/hooks/__init__.py +++ b/phaser/hooks/__init__.py @@ -6,9 +6,8 @@ import pane.annotations as annotations from typing_extensions import NotRequired -from ..types import Dataclass, Slices +from ..types import Dataclass, Slices, Aberration from .hook import Hook -from ..utils.optics import Aberration if t.TYPE_CHECKING: from phaser.utils.num import Sampling diff --git a/phaser/hooks/regularization.py b/phaser/hooks/regularization.py index 5f32c27..6659d7d 100644 --- a/phaser/hooks/regularization.py +++ b/phaser/hooks/regularization.py @@ -3,12 +3,12 @@ from numpy.typing import NDArray from phaser.types import Dataclass -from phaser.utils.num import Float from . import Hook from .solver import HasState, StateT if t.TYPE_CHECKING: + from phaser.utils.num import Float from phaser.state import ReconsState @@ -29,7 +29,7 @@ class CostRegularizer(HasState[StateT], t.Protocol[StateT]): def name(self) -> str: ... - def calc_loss_group(self, group: NDArray[numpy.integer], sim: 'ReconsState', state: StateT) -> t.Tuple[Float, StateT]: + def calc_loss_group(self, group: NDArray[numpy.integer], sim: 'ReconsState', state: StateT) -> t.Tuple['Float', StateT]: ... diff --git a/phaser/hooks/solver.py b/phaser/hooks/solver.py index dd89152..4b402d0 100644 --- a/phaser/hooks/solver.py +++ b/phaser/hooks/solver.py @@ -5,18 +5,19 @@ from numpy.typing import NDArray from phaser.types import Dataclass, ReconsVar -from phaser.utils.num import Float from . import Hook -StateT = t.TypeVar('StateT') - if t.TYPE_CHECKING: + from phaser.utils.num import Float from phaser.engines.common.simulation import SimulationState from phaser.execute import Observer from phaser.plan import ConventionalEnginePlan, GradientEnginePlan # noqa: F401 from phaser.state import ReconsState +StateT = t.TypeVar('StateT') + + class HasState(t.Protocol[StateT]): # type: ignore def init_state(self, sim: 'ReconsState') -> StateT: ... @@ -34,7 +35,7 @@ def calc_loss( exp_patterns: NDArray[numpy.floating], mask: NDArray[numpy.floating], state: StateT, - ) -> t.Tuple[Float, StateT]: + ) -> t.Tuple['Float', StateT]: """ Return the calculated loss, summed across the detector and averaged across the scan. diff --git a/phaser/types.py b/phaser/types.py index b27a903..c6a83b7 100644 --- a/phaser/types.py +++ b/phaser/types.py @@ -5,7 +5,7 @@ import numpy import pane from pane.converters import Converter, make_converter, ConverterHandlers, ErrorNode -from pane.annotations import ConvertAnnotation +from pane.annotations import ConvertAnnotation, Condition from pane.errors import ParseInterrupt, WrongTypeError from pane.util import pluralize, list_phrase from typing_extensions import Self @@ -122,6 +122,90 @@ def thicknesses(self) -> t.List[float]: Slices: t.TypeAlias = t.Union[SliceList, SliceStep, SliceTotal] +class ComplexCartesian(pane.PaneBase, kw_only=True): + re: float + im: float = 0.0 + + def __complex__(self) -> complex: + return complex(self.re, self.im) + +class ComplexPolar(pane.PaneBase, kw_only=True): + mag: float + angle: float = 0.0 # degrees + + def __complex__(self) -> complex: + theta = numpy.deg2rad(self.angle) + return self.mag * complex(numpy.cos(theta), numpy.sin(theta)) + + +class Krivanek(pane.PaneBase): + n: int + m: int + scale_factor: float = 1.0 + + def __post_init__(self): + if ( + self.n < 0 or self.m < 0 or + self.m > self.n + 1 or + self.m % 2 + self.n % 2 != 1 + ): + raise ValueError(f"Invalid Krivanek aberration n={self.n} m={self.m}") + + @staticmethod + def from_known(s: str) -> 'Krivanek': + try: + return _KNOWN_ABERRATIONS[s.lower()] + except (KeyError, TypeError): + raise ValueError(f"Unknown aberration '{s}'") from None + +class KrivanekComplex(Krivanek, kw_only=True): + val: complex + + def __complex__(self) -> complex: + return self.val + +class KrivanekCartesian(Krivanek, ComplexCartesian, kw_only=True): + ... + +class KrivanekPolar(Krivanek, ComplexPolar, kw_only=True): + ... + + +_KNOWN_ABERRATIONS: t.Dict[str, Krivanek] = { + 'c1': Krivanek.make_unchecked(1, 0), + 'a1': Krivanek.make_unchecked(1, 2), + 'b2': Krivanek.make_unchecked(2, 1, 3.0), # C_21 = 3*B2 + 'a2': Krivanek.make_unchecked(2, 3), + 'c3': Krivanek.make_unchecked(3, 0), + 's3': Krivanek.make_unchecked(3, 2, 3.0), # C_32 = 3*S3 + 'a3': Krivanek.make_unchecked(3, 4), + 'b4': Krivanek.make_unchecked(4, 1, 4.0), # C_41 = 4*B4 + 'd4': Krivanek.make_unchecked(4, 3, 4.0), # C_43 = 4*D4 + 'a4': Krivanek.make_unchecked(4, 5), + 'c5': Krivanek.make_unchecked(5, 0), +} + +KnownAberration: t.TypeAlias = t.Annotated[str, Condition( + lambda s: s.lower() in _KNOWN_ABERRATIONS, + 'known aberration', + lambda exp, plural: pluralize('known aberration', plural) +)] + +Aberration: t.TypeAlias = t.Union[ + t.Dict[KnownAberration, t.Union[complex, ComplexCartesian, ComplexPolar]], + KrivanekComplex, KrivanekCartesian, KrivanekPolar, +] + +def process_aberrations(aberrations: t.Iterable[Aberration]) -> t.Iterator[KrivanekComplex]: + for ab in aberrations: + if isinstance(ab, dict): + for known, val in ab.items(): + ty = Krivanek.from_known(known) + yield KrivanekComplex(ty.n, ty.m, val=ty.scale_factor * complex(val)) + else: + yield KrivanekComplex(ab.n, ab.m, val=complex(ab)) + + class SimpleFlag(Dataclass): after: int = 0 every: int = 1 @@ -302,24 +386,10 @@ def collect_errors(self, val: t.Any) -> t.Optional[ErrorNode]: return self.inner.collect_errors(val) -class ComplexCartesian(pane.PaneBase, kw_only=True): - re: float - im: float = 0.0 - - def __complex__(self) -> complex: - return complex(self.re, self.im) - - -class ComplexPolar(pane.PaneBase, kw_only=True): - mag: float - angle: float = 0.0 # degrees - - def __complex__(self) -> complex: - theta = numpy.deg2rad(self.angle) - return self.mag * complex(numpy.cos(theta), numpy.sin(theta)) - - __all__ = [ 'BackendName', 'Dataclass', 'Slices', 'Flag', - 'process_flag', 'flag_any_true', + 'ComplexCartesian', 'ComplexPolar', + 'Krivanek', 'KrivanekComplex', 'KrivanekCartesian', + 'KrivanekPolar', 'KnownAberration', 'Aberration', + 'process_aberrations', 'process_flag', 'flag_any_true', ] \ No newline at end of file diff --git a/phaser/utils/optics.py b/phaser/utils/optics.py index ba6fdb5..08fd60c 100644 --- a/phaser/utils/optics.py +++ b/phaser/utils/optics.py @@ -7,96 +7,21 @@ import numpy from numpy.typing import NDArray, ArrayLike -import pane -from pane.annotations import Condition -from pane.util import pluralize -from phaser.types import ComplexCartesian, ComplexPolar +from ..types import Aberration, process_aberrations from .num import get_array_module, ifft2, abs2, NumT, ufunc_outer, is_jax, cast_array_module from .num import Float, Sampling, to_complex_dtype, to_real_dtype, split_array, to_numpy -class Krivanek(pane.PaneBase): - n: int - m: int - scale_factor: float = 1.0 - - def __post_init__(self): - if ( - self.n < 0 or self.m < 0 or - self.m > self.n + 1 or - self.m % 2 + self.n % 2 != 1 - ): - raise ValueError(f"Invalid Krivanek aberration n={self.n} m={self.m}") - - @staticmethod - def from_known(s: str) -> 'Krivanek': - try: - return _KNOWN_ABERRATIONS[s.lower()] - except (KeyError, TypeError): - raise ValueError(f"Unknown aberration '{s}'") from None - - -_KNOWN_ABERRATIONS: t.Dict[str, Krivanek] = { - 'c1': Krivanek.make_unchecked(1, 0), - 'a1': Krivanek.make_unchecked(1, 2), - 'b2': Krivanek.make_unchecked(2, 1, 3.0), # C_21 = 3*B2 - 'a2': Krivanek.make_unchecked(2, 3), - 'c3': Krivanek.make_unchecked(3, 0), - 's3': Krivanek.make_unchecked(3, 2, 3.0), # C_32 = 3*S3 - 'a3': Krivanek.make_unchecked(3, 4), - 'b4': Krivanek.make_unchecked(4, 1, 4.0), # C_41 = 4*B4 - 'd4': Krivanek.make_unchecked(4, 3, 4.0), # C_43 = 4*D4 - 'a4': Krivanek.make_unchecked(4, 5), - 'c5': Krivanek.make_unchecked(5, 0), -} - -KnownAberration: t.TypeAlias = t.Annotated[str, Condition( - lambda s: s.lower() in _KNOWN_ABERRATIONS, - 'known aberration', - lambda exp, plural: pluralize('known aberration', plural) -)] - - -class KrivanekComplex(Krivanek, kw_only=True): - val: t.Union[complex, ComplexCartesian, ComplexPolar] - - def __complex__(self) -> complex: - return complex(self.val) - -class KrivanekCartesian(Krivanek, ComplexCartesian, kw_only=True): - ... - -class KrivanekPolar(Krivanek, ComplexPolar, kw_only=True): - ... - - -Aberration: t.TypeAlias = t.Union[ - t.Dict[KnownAberration, t.Union[complex, ComplexCartesian, ComplexPolar]], - KrivanekComplex, KrivanekCartesian, KrivanekPolar, -] -AberrationList: t.TypeAlias = t.List[Aberration] - - -def _normalize_aberrations(aberrations: t.Iterable[Aberration]) -> t.Iterator[KrivanekComplex]: - for ab in aberrations: - if isinstance(ab, dict): - for known, val in ab.items(): - ty = Krivanek.from_known(known) - yield KrivanekComplex(ty.n, ty.m, val=ty.scale_factor * complex(val)) - else: - yield KrivanekComplex(ab.n, ab.m, val=complex(ab)) - - def aberration_surface( - thetay: NDArray[numpy.float64], thetax: NDArray[numpy.float64], + thetay: NDArray[numpy.floating], thetax: NDArray[numpy.floating], aberrations: t.Iterable[Aberration] ) -> NDArray[numpy.floating]: xp = get_array_module(thetay, thetax) chi = xp.zeros_like(thetay) omega = thetax + thetay*1.j - for ab in _normalize_aberrations(aberrations): + for ab in process_aberrations(aberrations): p = (ab.n + 1 + ab.m) // 2 q = ab.n + 1 - p prod = omega**p * omega.conj()**q diff --git a/tests/test_optics.py b/tests/test_optics.py index 5e7affc..231f58e 100644 --- a/tests/test_optics.py +++ b/tests/test_optics.py @@ -6,12 +6,12 @@ from .utils import with_backends, check_array_equals_file -from phaser.utils.num import get_backend_module, BackendName, Sampling, to_numpy, fft2, ifft2 -from phaser.utils.optics import ( - make_focused_probe, fresnel_propagator, - AberrationList, _normalize_aberrations, +from phaser.types import ( ComplexCartesian, ComplexPolar, KrivanekComplex, KrivanekCartesian, KrivanekPolar, + Aberration, process_aberrations ) +from phaser.utils.num import get_backend_module, BackendName, Sampling, to_numpy, fft2, ifft2 +from phaser.utils.optics import make_focused_probe, fresnel_propagator @with_backends('numpy', 'jax', 'cupy', 'torch') @@ -113,7 +113,7 @@ def test_parse_aberrations(): {'n': 4, 'm': 1, 'val': 1+1.j}, # krivanek complex {'n': 1, 'm': 0, 're': 5.0}, # krivanek cartesian {'n': 5, 'm': 0, 'mag': 5.0}, # krivanek polar - ], AberrationList) + ], list[Aberration]) assert result == [ {'c3': complex(5.0)}, @@ -124,7 +124,7 @@ def test_parse_aberrations(): KrivanekPolar(5, 0, mag=5.0, angle=0.0), ] - assert list(_normalize_aberrations(result)) == [ + assert list(process_aberrations(result)) == [ KrivanekComplex.make_unchecked(3, 0, val=complex(5.0)), KrivanekComplex.make_unchecked(2, 1, val=15.0-6.0j), KrivanekComplex.make_unchecked(1, 2, val=pytest.approx(5.0j)), From 33d58732601a768937fd5dde1ed504b9f4505514 Mon Sep 17 00:00:00 2001 From: Colin Gilgenbach Date: Sat, 17 Jan 2026 14:47:02 -0500 Subject: [PATCH 2/3] Basic structure for CLI utils --- phaser/__main__.py | 2 +- phaser/cli/__init__.py | 109 +++++++++++++++++++++++++++++++++++++++++ phaser/cli/validate.py | 51 +++++++++++++++++++ pyproject.toml | 6 +-- 4 files changed, 164 insertions(+), 4 deletions(-) create mode 100644 phaser/cli/__init__.py create mode 100644 phaser/cli/validate.py diff --git a/phaser/__main__.py b/phaser/__main__.py index 7124ae2..8d30fc7 100755 --- a/phaser/__main__.py +++ b/phaser/__main__.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 if __name__ == '__main__': - from phaser.main import cli + from phaser.cli import cli cli() diff --git a/phaser/cli/__init__.py b/phaser/cli/__init__.py new file mode 100644 index 0000000..cc31f40 --- /dev/null +++ b/phaser/cli/__init__.py @@ -0,0 +1,109 @@ +from pathlib import Path +import sys +import typing as t + +import click + + +class MainCommand(click.MultiCommand): + def __init__(self, commands: t.Iterable[t.Union[click.Command, t.Tuple[str, str]]], **kwargs): + super().__init__(**kwargs) + # name: command or short_help + self.commands: t.Dict[str, t.Union[click.Command, str, None]] + + self.commands = dict( + (t.cast(str, c.name), c) if isinstance(c, click.Command) else (c[0], c[1]) + for c in commands + ) + + def list_commands(self, ctx: click.Context): + return list(self.commands.keys()) + + def format_commands(self, ctx: click.Context, formatter: click.HelpFormatter) -> None: + from gettext import gettext + + if len(self.commands): + limit = formatter.width - 6 - max(map(len, self.commands.keys())) + + rows = [] + for name, cmd in self.commands.items(): + help = cmd.get_short_help_str(limit) if isinstance(cmd, click.Command) else cmd + rows.append((name, help)) + + if rows: + with formatter.section(gettext("Commands")): + formatter.write_dl(rows) + + def get_command(self, ctx: click.Context, cmd_name: str) -> t.Optional[click.Command]: + name = cmd_name.lower() + val = (self.commands.get(name) or + self.commands.get(name.replace('-', '_'))) + if val is None: + return None + if isinstance(val, click.BaseCommand): + return val + + # for now, assume `validate` command is `cli.validate:validate` + # TODO, incorporate some feature checks here + module = func = name + mod = __import__(module, globals(), fromlist=[func], level=1) + return getattr(mod, func) + + +@click.command() +@click.argument('path', type=click.Path(exists=True, dir_okay=False)) +def run(path: t.Union[str, Path]): + """Execute a reconstruction plan""" + from phaser.plan import ReconsPlan + from phaser.execute import execute_plan + plans = ReconsPlan.from_yaml_all(path) + + for plan in plans: + execute_plan(plan) + + +@click.command() +@click.option('--host', type=str, default='localhost', help="Host to serve on") +@click.option('--port', type=int, help="Port to serve on") +@click.option('-v', '--verbose', count=True, help="Increase verbosity") +def serve(host: str = 'localhost', port: t.Optional[int] = None, verbose: int = 0): + """Run phaser server""" + from phaser.web.server import server + + if ':' in host: + (host, port_from_host) = host.rsplit(':', maxsplit=1) + try: + port_from_host = int(port_from_host) + except ValueError: + print(f"Invalid host '{host}:{port_from_host}'", file=sys.stderr) + sys.exit(1) + + port = port or port_from_host + + server.run(hostname=host, port=port, verbosity=verbose) + + +@click.command() +@click.argument('url', type=str, required=True) +@click.option('--quiet/--loud', default=False, help="Whether to print output to stdout") +def worker(url: str, quiet: bool = False): + """ + Run phaser worker. + + URL is the server URL to connect to. + """ + from phaser.web.worker import run_worker + + run_worker(url, quiet=quiet) + + +commands: t.List[t.Union[click.Command, t.Tuple[str, str]]] = [ + run, serve, worker, + # these will be looked up in the cli folder + ('validate', "Validate reconstruction plan file"), +] + + +@click.command(cls=MainCommand, commands=commands) +def cli(): + pass \ No newline at end of file diff --git a/phaser/cli/validate.py b/phaser/cli/validate.py new file mode 100644 index 0000000..d5111cf --- /dev/null +++ b/phaser/cli/validate.py @@ -0,0 +1,51 @@ +from contextlib import nullcontext +from pathlib import Path +import sys +import typing as t + +import click + + +@click.command() +@click.argument('path', type=click.Path(allow_dash=True), default='-') +@click.option('--json/--no-json', default=False, + help="Output validation result in JSON format") +def validate(path: t.Union[str, Path], json: bool = False): + """ + Validate reconstruction plan file. + + PATH is the path to a YAML reconstruction plan, + or '-' (default) to read from stdin. + """ + from phaser.plan import ReconsPlan + + try: + if path == '-': + file = nullcontext(sys.stdin) + else: + file = open(Path(path).expanduser(), 'r') + + with file as file: + plans = ReconsPlan.from_yaml_all(file) + except Exception as e: + print(f"Validation failed:\n{e}", file=sys.stderr) + + if json: + from json import dump + dump({'result': 'error', 'error': str(e)}, sys.stdout) + print() + + sys.exit(1) + + if len(plans) == 1: + print("Validation of plan successful!", file=sys.stderr) + else: + print(f"Validation of {len(plans)} plans successful!", file=sys.stderr) + + if json: + from json import dump + dump({ + 'result': 'success', + 'plans': [(plan.name, plan.into_data()) for plan in plans], + }, sys.stdout) + print() \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index f135da3..12e5477 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,7 +62,7 @@ jax = [ "optax>=0.2.2", ] torch = [ - "torch>=2.8.0", + "torch>=2.8.0", ] web = [ "Quart>=0.20.0", @@ -83,7 +83,7 @@ Repository = "https://github.com/hexane360/phaser.git" Documentation = "https://hexane360.github.io/phaser/dev" [project.scripts] -phaser = "phaser.main:cli" +phaser = "phaser.cli:cli" [tool.setuptools.packages.find] include = ["phaser*"] @@ -109,7 +109,7 @@ typeCheckingMode = "standard" exclude = [ "tests", "notebooks/", "examples", ".github", "build", "venv", ".venv", - "**/node_modules", "**/__pycache__" + "**/node_modules", "**/__pycache__", "**/.*", ] reportIncompatibleMethodOverride = "information" From 15581a2d533f4c422a50fb058f0a9437888f190b Mon Sep 17 00:00:00 2001 From: Colin Gilgenbach Date: Sat, 17 Jan 2026 15:10:56 -0500 Subject: [PATCH 3/3] More lazy loading --- phaser/utils/num.py | 13 +++++++++++++ phaser/utils/tree.py | 23 ++++++++++++----------- 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/phaser/utils/num.py b/phaser/utils/num.py index 8ef0e7a..3a82158 100644 --- a/phaser/utils/num.py +++ b/phaser/utils/num.py @@ -75,6 +75,7 @@ def _load_torch() -> ModuleType: class _BackendLoader: def __init__(self): self.inner: t.Dict[BackendName, t.Optional[ModuleType]] = {} + self.cbs: t.Dict[BackendName, t.List[t.Callable[[], t.Any]]] = {} def _normalize(self, backend: BackendName) -> BackendName: name = t.cast(BackendName, backend.lower()) @@ -87,9 +88,21 @@ def _normalize(self, backend: BackendName) -> BackendName: def _load(self, name: BackendName): try: self.inner[name] = _LOAD_FNS[name]() + for cb in self.cbs.pop(name, ()): + cb() except ImportError: self.inner[name] = None + def _schedule_on_load(self, backend: BackendName, fn: t.Callable[[], t.Any]): + name = self._normalize(backend) + if self.inner.get(name): + # already loaded, run immediately + fn() + else: + # otherwise schedule for when (if) we load + cbs = self.cbs.setdefault(name, list()) + cbs.append(fn) + def get(self, name: BackendName): name = self._normalize(name) if name == 'numpy': diff --git a/phaser/utils/tree.py b/phaser/utils/tree.py index 044209a..402f0eb 100644 --- a/phaser/utils/tree.py +++ b/phaser/utils/tree.py @@ -11,6 +11,7 @@ Tree: t.TypeAlias = t.Any field = dataclasses.field + class TreeSpec(t.Protocol): @property def num_leaves(self) -> int: @@ -252,7 +253,7 @@ def leaves_with_path( def zeros_like( - tree: Tree, dtype: DTypeLike = None, + tree: Tree, dtype: t.Optional[DTypeLike] = None, ) -> Tree: from phaser.utils.num import get_array_module xp = get_array_module(tree) @@ -261,7 +262,7 @@ def zeros_like( def ones_like( - tree: Tree, dtype: DTypeLike = None, + tree: Tree, dtype: t.Optional[DTypeLike] = None, ) -> Tree: from phaser.utils.num import get_array_module xp = get_array_module(tree) @@ -271,7 +272,7 @@ def ones_like( def full_like( tree: Tree, fill_value: ArrayLike, - dtype: DTypeLike = None, + dtype: t.Optional[DTypeLike] = None, ) -> Tree: from phaser.utils.num import get_array_module xp = get_array_module(tree) @@ -421,25 +422,25 @@ def flatten(x: t.Any, /) -> tuple[list[t.Any], t.Hashable]: trees = list(getattr(x, name) for name in data_fields) return trees, hashed - try: + def _register_jax(): from jax.tree_util import register_pytree_with_keys, GetAttrKey - except ImportError: - pass - else: + flatten_with_keys = make_flatten_with_keys(GetAttrKey) register_pytree_with_keys(cls, flatten_with_keys, unflatten, flatten) - try: + def _register_torch(): from torch.utils._pytree import register_pytree_node, GetAttrKey - except ImportError: - pass - else: + flatten_with_keys = make_flatten_with_keys(GetAttrKey) register_pytree_node( cls, flatten, lambda trees, meta: unflatten(meta, trees), flatten_with_keys_fn=flatten_with_keys, # type: ignore ) + from phaser.utils.num import _BACKEND_LOADER + _BACKEND_LOADER._schedule_on_load('jax', _register_jax) + _BACKEND_LOADER._schedule_on_load('torch', _register_torch) + __all__ = [ 'flatten', 'flatten_with_path', 'unflatten', 'map', 'reduce', 'sum',