Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion phaser/__main__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/usr/bin/env python3

if __name__ == '__main__':
from phaser.main import cli
from phaser.cli import cli
cli()
109 changes: 109 additions & 0 deletions phaser/cli/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from pathlib import Path
import sys
import typing as t

import click


class MainCommand(click.MultiCommand):
def __init__(self, commands: t.Iterable[t.Union[click.Command, t.Tuple[str, str]]], **kwargs):
super().__init__(**kwargs)
# name: command or short_help
self.commands: t.Dict[str, t.Union[click.Command, str, None]]

self.commands = dict(
(t.cast(str, c.name), c) if isinstance(c, click.Command) else (c[0], c[1])
for c in commands
)

def list_commands(self, ctx: click.Context):
return list(self.commands.keys())

def format_commands(self, ctx: click.Context, formatter: click.HelpFormatter) -> None:
from gettext import gettext

if len(self.commands):
limit = formatter.width - 6 - max(map(len, self.commands.keys()))

rows = []
for name, cmd in self.commands.items():
help = cmd.get_short_help_str(limit) if isinstance(cmd, click.Command) else cmd
rows.append((name, help))

if rows:
with formatter.section(gettext("Commands")):
formatter.write_dl(rows)

def get_command(self, ctx: click.Context, cmd_name: str) -> t.Optional[click.Command]:
name = cmd_name.lower()
val = (self.commands.get(name) or
self.commands.get(name.replace('-', '_')))
if val is None:
return None
if isinstance(val, click.BaseCommand):
return val

# for now, assume `validate` command is `cli.validate:validate`
# TODO, incorporate some feature checks here
module = func = name
mod = __import__(module, globals(), fromlist=[func], level=1)
return getattr(mod, func)


@click.command()
@click.argument('path', type=click.Path(exists=True, dir_okay=False))
def run(path: t.Union[str, Path]):
"""Execute a reconstruction plan"""
from phaser.plan import ReconsPlan
from phaser.execute import execute_plan
plans = ReconsPlan.from_yaml_all(path)

for plan in plans:
execute_plan(plan)


@click.command()
@click.option('--host', type=str, default='localhost', help="Host to serve on")
@click.option('--port', type=int, help="Port to serve on")
@click.option('-v', '--verbose', count=True, help="Increase verbosity")
def serve(host: str = 'localhost', port: t.Optional[int] = None, verbose: int = 0):
"""Run phaser server"""
from phaser.web.server import server

if ':' in host:
(host, port_from_host) = host.rsplit(':', maxsplit=1)
try:
port_from_host = int(port_from_host)
except ValueError:
print(f"Invalid host '{host}:{port_from_host}'", file=sys.stderr)
sys.exit(1)

port = port or port_from_host

server.run(hostname=host, port=port, verbosity=verbose)


@click.command()
@click.argument('url', type=str, required=True)
@click.option('--quiet/--loud', default=False, help="Whether to print output to stdout")
def worker(url: str, quiet: bool = False):
"""
Run phaser worker.

URL is the server URL to connect to.
"""
from phaser.web.worker import run_worker

run_worker(url, quiet=quiet)


commands: t.List[t.Union[click.Command, t.Tuple[str, str]]] = [
run, serve, worker,
# these will be looked up in the cli folder
('validate', "Validate reconstruction plan file"),
]


@click.command(cls=MainCommand, commands=commands)
def cli():
pass
51 changes: 51 additions & 0 deletions phaser/cli/validate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from contextlib import nullcontext
from pathlib import Path
import sys
import typing as t

import click


@click.command()
@click.argument('path', type=click.Path(allow_dash=True), default='-')
@click.option('--json/--no-json', default=False,
help="Output validation result in JSON format")
def validate(path: t.Union[str, Path], json: bool = False):
"""
Validate reconstruction plan file.

PATH is the path to a YAML reconstruction plan,
or '-' (default) to read from stdin.
"""
from phaser.plan import ReconsPlan

try:
if path == '-':
file = nullcontext(sys.stdin)
else:
file = open(Path(path).expanduser(), 'r')

with file as file:
plans = ReconsPlan.from_yaml_all(file)
except Exception as e:
print(f"Validation failed:\n{e}", file=sys.stderr)

if json:
from json import dump
dump({'result': 'error', 'error': str(e)}, sys.stdout)
print()

sys.exit(1)

if len(plans) == 1:
print("Validation of plan successful!", file=sys.stderr)
else:
print(f"Validation of {len(plans)} plans successful!", file=sys.stderr)

if json:
from json import dump
dump({
'result': 'success',
'plans': [(plan.name, plan.into_data()) for plan in plans],
}, sys.stdout)
print()
3 changes: 1 addition & 2 deletions phaser/hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
import pane.annotations as annotations
from typing_extensions import NotRequired

from ..types import Dataclass, Slices
from ..types import Dataclass, Slices, Aberration
from .hook import Hook
from ..utils.optics import Aberration

if t.TYPE_CHECKING:
from phaser.utils.num import Sampling
Expand Down
4 changes: 2 additions & 2 deletions phaser/hooks/regularization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
from numpy.typing import NDArray

from phaser.types import Dataclass
from phaser.utils.num import Float
from . import Hook

from .solver import HasState, StateT

if t.TYPE_CHECKING:
from phaser.utils.num import Float
from phaser.state import ReconsState


Expand All @@ -29,7 +29,7 @@ class CostRegularizer(HasState[StateT], t.Protocol[StateT]):
def name(self) -> str:
...

def calc_loss_group(self, group: NDArray[numpy.integer], sim: 'ReconsState', state: StateT) -> t.Tuple[Float, StateT]:
def calc_loss_group(self, group: NDArray[numpy.integer], sim: 'ReconsState', state: StateT) -> t.Tuple['Float', StateT]:
...


Expand Down
9 changes: 5 additions & 4 deletions phaser/hooks/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,19 @@
from numpy.typing import NDArray

from phaser.types import Dataclass, ReconsVar
from phaser.utils.num import Float
from . import Hook

StateT = t.TypeVar('StateT')

if t.TYPE_CHECKING:
from phaser.utils.num import Float
from phaser.engines.common.simulation import SimulationState
from phaser.execute import Observer
from phaser.plan import ConventionalEnginePlan, GradientEnginePlan # noqa: F401
from phaser.state import ReconsState


StateT = t.TypeVar('StateT')


class HasState(t.Protocol[StateT]): # type: ignore
def init_state(self, sim: 'ReconsState') -> StateT:
...
Expand All @@ -34,7 +35,7 @@ def calc_loss(
exp_patterns: NDArray[numpy.floating],
mask: NDArray[numpy.floating],
state: StateT,
) -> t.Tuple[Float, StateT]:
) -> t.Tuple['Float', StateT]:
"""
Return the calculated loss, summed across the detector and averaged across the scan.

Expand Down
108 changes: 89 additions & 19 deletions phaser/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy
import pane
from pane.converters import Converter, make_converter, ConverterHandlers, ErrorNode
from pane.annotations import ConvertAnnotation
from pane.annotations import ConvertAnnotation, Condition
from pane.errors import ParseInterrupt, WrongTypeError
from pane.util import pluralize, list_phrase
from typing_extensions import Self
Expand Down Expand Up @@ -122,6 +122,90 @@ def thicknesses(self) -> t.List[float]:
Slices: t.TypeAlias = t.Union[SliceList, SliceStep, SliceTotal]


class ComplexCartesian(pane.PaneBase, kw_only=True):
re: float
im: float = 0.0

def __complex__(self) -> complex:
return complex(self.re, self.im)

class ComplexPolar(pane.PaneBase, kw_only=True):
mag: float
angle: float = 0.0 # degrees

def __complex__(self) -> complex:
theta = numpy.deg2rad(self.angle)
return self.mag * complex(numpy.cos(theta), numpy.sin(theta))


class Krivanek(pane.PaneBase):
n: int
m: int
scale_factor: float = 1.0

def __post_init__(self):
if (
self.n < 0 or self.m < 0 or
self.m > self.n + 1 or
self.m % 2 + self.n % 2 != 1
):
raise ValueError(f"Invalid Krivanek aberration n={self.n} m={self.m}")

@staticmethod
def from_known(s: str) -> 'Krivanek':
try:
return _KNOWN_ABERRATIONS[s.lower()]
except (KeyError, TypeError):
raise ValueError(f"Unknown aberration '{s}'") from None

class KrivanekComplex(Krivanek, kw_only=True):
val: complex

def __complex__(self) -> complex:
return self.val

class KrivanekCartesian(Krivanek, ComplexCartesian, kw_only=True):
...

class KrivanekPolar(Krivanek, ComplexPolar, kw_only=True):
...


_KNOWN_ABERRATIONS: t.Dict[str, Krivanek] = {
'c1': Krivanek.make_unchecked(1, 0),
'a1': Krivanek.make_unchecked(1, 2),
'b2': Krivanek.make_unchecked(2, 1, 3.0), # C_21 = 3*B2
'a2': Krivanek.make_unchecked(2, 3),
'c3': Krivanek.make_unchecked(3, 0),
's3': Krivanek.make_unchecked(3, 2, 3.0), # C_32 = 3*S3
'a3': Krivanek.make_unchecked(3, 4),
'b4': Krivanek.make_unchecked(4, 1, 4.0), # C_41 = 4*B4
'd4': Krivanek.make_unchecked(4, 3, 4.0), # C_43 = 4*D4
'a4': Krivanek.make_unchecked(4, 5),
'c5': Krivanek.make_unchecked(5, 0),
}

KnownAberration: t.TypeAlias = t.Annotated[str, Condition(
lambda s: s.lower() in _KNOWN_ABERRATIONS,
'known aberration',
lambda exp, plural: pluralize('known aberration', plural)
)]

Aberration: t.TypeAlias = t.Union[
t.Dict[KnownAberration, t.Union[complex, ComplexCartesian, ComplexPolar]],
KrivanekComplex, KrivanekCartesian, KrivanekPolar,
]

def process_aberrations(aberrations: t.Iterable[Aberration]) -> t.Iterator[KrivanekComplex]:
for ab in aberrations:
if isinstance(ab, dict):
for known, val in ab.items():
ty = Krivanek.from_known(known)
yield KrivanekComplex(ty.n, ty.m, val=ty.scale_factor * complex(val))
else:
yield KrivanekComplex(ab.n, ab.m, val=complex(ab))


class SimpleFlag(Dataclass):
after: int = 0
every: int = 1
Expand Down Expand Up @@ -302,24 +386,10 @@ def collect_errors(self, val: t.Any) -> t.Optional[ErrorNode]:
return self.inner.collect_errors(val)


class ComplexCartesian(pane.PaneBase, kw_only=True):
re: float
im: float = 0.0

def __complex__(self) -> complex:
return complex(self.re, self.im)


class ComplexPolar(pane.PaneBase, kw_only=True):
mag: float
angle: float = 0.0 # degrees

def __complex__(self) -> complex:
theta = numpy.deg2rad(self.angle)
return self.mag * complex(numpy.cos(theta), numpy.sin(theta))


__all__ = [
'BackendName', 'Dataclass', 'Slices', 'Flag',
'process_flag', 'flag_any_true',
'ComplexCartesian', 'ComplexPolar',
'Krivanek', 'KrivanekComplex', 'KrivanekCartesian',
'KrivanekPolar', 'KnownAberration', 'Aberration',
'process_aberrations', 'process_flag', 'flag_any_true',
]
13 changes: 13 additions & 0 deletions phaser/utils/num.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def _load_torch() -> ModuleType:
class _BackendLoader:
def __init__(self):
self.inner: t.Dict[BackendName, t.Optional[ModuleType]] = {}
self.cbs: t.Dict[BackendName, t.List[t.Callable[[], t.Any]]] = {}

def _normalize(self, backend: BackendName) -> BackendName:
name = t.cast(BackendName, backend.lower())
Expand All @@ -87,9 +88,21 @@ def _normalize(self, backend: BackendName) -> BackendName:
def _load(self, name: BackendName):
try:
self.inner[name] = _LOAD_FNS[name]()
for cb in self.cbs.pop(name, ()):
cb()
except ImportError:
self.inner[name] = None

def _schedule_on_load(self, backend: BackendName, fn: t.Callable[[], t.Any]):
name = self._normalize(backend)
if self.inner.get(name):
# already loaded, run immediately
fn()
else:
# otherwise schedule for when (if) we load
cbs = self.cbs.setdefault(name, list())
cbs.append(fn)

def get(self, name: BackendName):
name = self._normalize(name)
if name == 'numpy':
Expand Down
Loading